Merge main -> google
* 6c1869bc Add license header to linker scripts (#2724)
* 6d32fe81 Updates all e2e tests to use TracedModuleTestCase (#2736)
* 7fea2dd1 Revert "[android] Ignore adb push status for the executable (#2627)" (#2735)
* da48d03c Adds a class for tracing module calls and example tests. (#2660)
* d66150d8 Enables `conv_test` on vulkan. (#2723)
* a3b73bf6 Add .python-version to .gitignore (#2727)
* 6d2bfd35 Enables passing MobileNet and MobileNetV2 vision tests (#2725)
* 763dc098 Fix ModelBuilder build (#2726)
* 528a86a7 Update the codegen pipeline example snippets to better demonstrate (#2709)
* 91a05ff5 Merge google -> main (#2719)
* f0307a87 Revert "Merge google -> main" (#2718)
* 77d2f8b4 Merge google -> main (#2711)
* aafabbcb Add linker scripts to cmake pyiree builds to hide symbols. (#2707)
* 1918d5c7 Updated gather to run on LLVM and SPIRV (#2701)
* ec38aa00 Add a VM to EmitC conversion and CModule Target (#2536)
* 32fe2c9b Revert "Add a VM to EmitC conversion and CModule Target (#2536)" (#2703)
* e7f90b61 Revert addition of DYLIB-LLVM-AOT as a default backend (#2702)
* 48efada8 Dump input values to a file. (#2683)
* 0fd8a550 Mark shell scripts as executable (#2699)
* 966fe78f Add dep "//iree/base:localfile" to iree-run-mlir (#2687)
* 64e29879 Merge google -> main (#2698)
* ced7477f Update scripts for packaging wheels from bazel.
* 1c05e0ae Disable dylib driver for android build (#2692)
* e00a85f8 Disable a few gcc warnings that are chatty.
* 74af1f01 Add a config for gcc. (#2696)
* eec62a1e Remove dylib-llvm-aot tests (#2694)
* 11bd7a00 Fix stdalign include in Android (#2688)
* b9026bd6 Enable dylib-llvm-aot backend e2e/xla_op tests (#2639)
* 8ed9f73f Added Complex operation lowerings to pre-target conversion (#2662)
* 8224247a Working around duplicate func->llvm.func rewrites. (#2685)
* f2d6649c Bump pybind11 to head. (#2592)
* 9ac5241d [NFC] Remove empty ArrayRef<NamedAttribute> arg in op creation (#2686)
* 36abdfde [ModelBuilder] Revert spurious path change (#2684)
* c96bbb1d Merge pull request #2568 from google/benvanik-flatcc-vm
* f80864f8 [ModelBuilder] Add MemRefUtil helpers for padding. (#2577)
* 1425737d Fix flatcc target name and binary path (#2599)
* 8c18b8c3 Refactoring bytecode_module to use flatcc instead of flatbuffers C++. This imp..
* 3762c132 Linking in the flatcc bytecode_module_def files. This ensures they are correct..
* eb53de3e Building bytecode_module_def with flatcc.
PiperOrigin-RevId: 324296127
diff --git a/.gitignore b/.gitignore
index 7db7ef9..96a5b66 100644
--- a/.gitignore
+++ b/.gitignore
@@ -39,3 +39,6 @@
compile_commands.json
.cache/clangd
.clangd/
+
+# Pyenv files
+.python-version
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 48c4948..b761338 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -95,7 +95,8 @@
set( IREE_HAL_DRIVERS_TO_BUILD ${IREE_ALL_HAL_DRIVERS} )
# For cross compilation towords Android, we don't want LLVM JIT HAL driver.
if(ANDROID)
- list(REMOVE_ITEM IREE_HAL_DRIVERS_TO_BUILD LLVM)
+ # TODO(ataei): Enable dylib/dylib-llvm-aot for android.
+ list(REMOVE_ITEM IREE_HAL_DRIVERS_TO_BUILD LLVM DyLib)
endif()
endif()
message(STATUS "Building HAL drivers ${IREE_HAL_DRIVERS_TO_BUILD}")
@@ -115,7 +116,7 @@
# List of all target backends to be built by default:
set(IREE_ALL_TARGET_BACKENDS
- # TODO(scotttodd): LLVMAOT
+ # TODO(#2645): Add DYLIB-LLVM-AOT when it doesn't require an env var
LLVM-IR
Vulkan-SPIRV
VMLA
@@ -356,7 +357,7 @@
# We need flatc to generate some source code. When cross-compiling, we need
# to make sure the flatc binary is configured under host environment.
iree_declare_host_excutable(flatc BUILDONLY)
- iree_declare_host_excutable(flatcc BUILDONLY)
+ iree_declare_host_excutable(flatcc_cli BUILDONLY)
# Set the FLATBUFFERS_FLATC_EXECUTABLE. It controls where to find the flatc
# binary in BuildFlatBuffers().
@@ -371,14 +372,17 @@
DEPENDS iree_host_build_flatc
COMMENT "Installing host flatc..."
)
- add_custom_target(iree_host_flatcc
+ add_custom_target(iree_host_flatcc_cli
COMMAND
"${CMAKE_COMMAND}" -E copy_if_different
- "${IREE_HOST_BINARY_ROOT}/third_party/flatcc/flatcc${IREE_HOST_EXECUTABLE_SUFFIX}"
- "${IREE_HOST_BINARY_ROOT}/bin"
- DEPENDS iree_host_build_flatcc
+ "${PROJECT_SOURCE_DIR}/third_party/flatcc/bin/flatcc${IREE_HOST_EXECUTABLE_SUFFIX}"
+ "${IREE_HOST_BINARY_ROOT}/bin/flatcc_cli"
+ DEPENDS iree_host_build_flatcc_cli
COMMENT "Installing host flatcc..."
)
+else()
+ # TODO: unify flatc and flatcc handling to the same mechanism.
+ add_executable(iree_host_flatcc_cli ALIAS flatcc_cli)
endif()
if(${IREE_BUILD_COMPILER})
diff --git a/SUBMODULE_VERSIONS b/SUBMODULE_VERSIONS
index 58303cb..e719d5a 100644
--- a/SUBMODULE_VERSIONS
+++ b/SUBMODULE_VERSIONS
@@ -7,7 +7,7 @@
cd4e8d7f6f5ef108919f9f53db35ac73d1edea3d third_party/llvm-project
17b12a4481daa150e2d1ea3ada086b551b856707 third_party/marl
80885f899e12d55a45561ef758eea47bb340dbf1 third_party/mlir-emitc
-80d452484c5409444b0ec19383faa84bb7a4d351 third_party/pybind11
+d8c7ee00a687ac369e62e2032514a93a9b413502 third_party/pybind11
9f53ba413e6fc879236dcaa3e008915973d67a4f third_party/ruy
a1390ed39ec77ecfb574bc6fcd5bfc5e3adbdea9 third_party/sdl2
f8bf11a0253a32375c32cad92c841237b96696c0 third_party/spirv_headers
diff --git a/bindings/python/pyiree/compiler/CMakeLists.txt b/bindings/python/pyiree/compiler/CMakeLists.txt
index b7bc2cc..929e621 100644
--- a/bindings/python/pyiree/compiler/CMakeLists.txt
+++ b/bindings/python/pyiree/compiler/CMakeLists.txt
@@ -26,6 +26,8 @@
PyExtCompiler
MODULE_NAME
binding
+ UNIX_LINKER_SCRIPT
+ "unix_version.lds"
SRCS
"initialize_module.cc"
PYEXT_DEPS
diff --git a/bindings/python/pyiree/compiler/unix_version.lds b/bindings/python/pyiree/compiler/unix_version.lds
new file mode 100644
index 0000000..68ef766
--- /dev/null
+++ b/bindings/python/pyiree/compiler/unix_version.lds
@@ -0,0 +1,19 @@
+/* Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+{
+ global: PyInit_binding;
+ local: *;
+};
diff --git a/bindings/python/pyiree/rt/CMakeLists.txt b/bindings/python/pyiree/rt/CMakeLists.txt
index 1ec5eaf..225dccd 100644
--- a/bindings/python/pyiree/rt/CMakeLists.txt
+++ b/bindings/python/pyiree/rt/CMakeLists.txt
@@ -50,6 +50,8 @@
MODULE_NAME binding
SRCS
"initialize_module.cc"
+ UNIX_LINKER_SCRIPT
+ "unix_version.lds"
PYEXT_DEPS
::PyExtRtLib
bindings::python::pyiree::common::PyextCommonLib
diff --git a/bindings/python/pyiree/rt/unix_version.lds b/bindings/python/pyiree/rt/unix_version.lds
new file mode 100644
index 0000000..68ef766
--- /dev/null
+++ b/bindings/python/pyiree/rt/unix_version.lds
@@ -0,0 +1,19 @@
+/* Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+{
+ global: PyInit_binding;
+ local: *;
+};
diff --git a/build_tools/bazel/build_bindings.sh b/build_tools/bazel/build_bindings.sh
index b4a417e..4ca7ddb 100755
--- a/build_tools/bazel/build_bindings.sh
+++ b/build_tools/bazel/build_bindings.sh
@@ -43,7 +43,7 @@
declare -a test_env_args=(
--test_env=IREE_LLVMJIT_DISABLE=$IREE_LLVMJIT_DISABLE
--test_env=IREE_VULKAN_DISABLE=$IREE_VULKAN_DISABLE
- --test_env=IREE_LLVMAOT_LINKER_PATH
+ --action_env=IREE_LLVMAOT_LINKER_PATH=$IREE_LLVMAOT_LINKER_PATH
)
declare -a default_build_tag_filters=("-nokokoro")
diff --git a/build_tools/bazel/build_core.sh b/build_tools/bazel/build_core.sh
index 89d94e2..bd04942 100755
--- a/build_tools/bazel/build_core.sh
+++ b/build_tools/bazel/build_core.sh
@@ -42,6 +42,7 @@
declare -a test_env_args=(
--test_env=IREE_LLVMJIT_DISABLE=$IREE_LLVMJIT_DISABLE
--test_env=IREE_VULKAN_DISABLE=$IREE_VULKAN_DISABLE
+ --action_env=IREE_LLVMAOT_LINKER_PATH=$IREE_LLVMAOT_LINKER_PATH
)
declare -a default_build_tag_filters=("-nokokoro")
diff --git a/build_tools/bazel/build_tensorflow.sh b/build_tools/bazel/build_tensorflow.sh
index be82227..3634c4d 100755
--- a/build_tools/bazel/build_tensorflow.sh
+++ b/build_tools/bazel/build_tensorflow.sh
@@ -42,6 +42,7 @@
declare -a test_env_args=(
--test_env=IREE_LLVMJIT_DISABLE=$IREE_LLVMJIT_DISABLE
--test_env=IREE_VULKAN_DISABLE=$IREE_VULKAN_DISABLE
+ --action_env=IREE_LLVMAOT_LINKER_PATH=$IREE_LLVMAOT_LINKER_PATH
)
# Pass in VK_ICD_FILENAMES if exists so that the Vulkan loader can find the
# Vulkan implementation.
diff --git a/build_tools/bazel/iree.bazelrc b/build_tools/bazel/iree.bazelrc
index dba02ed..220b868 100644
--- a/build_tools/bazel/iree.bazelrc
+++ b/build_tools/bazel/iree.bazelrc
@@ -36,6 +36,23 @@
build --define open_source_build=true
###############################################################################
+# Options for "generic_gcc" builds
+###############################################################################
+
+# C++14 standard version is required.
+build:generic_gcc --cxxopt=-std=c++14 --host_cxxopt=-std=c++14
+
+# Default to adding back asserts in optimized builds.
+# This is a good compromise between runtime and debugability.
+build:generic_gcc --copt=-UNDEBUG
+
+# Disable warnings we don't care about or that generally have a low signal/noise
+# ratio.
+build:generic_gcc --copt=-Wno-unused-but-set-parameter
+build:generic_gcc --copt=-Wno-comment
+build:generic_gcc --copt=-Wno-attributes
+
+###############################################################################
# Options for "generic_clang" builds: these options should generally apply to
# either clang or gcc and are curated based on need.
###############################################################################
diff --git a/build_tools/cmake/flatbuffer_c_library.cmake b/build_tools/cmake/flatbuffer_c_library.cmake
index 916cfd9..eb63567 100644
--- a/build_tools/cmake/flatbuffer_c_library.cmake
+++ b/build_tools/cmake/flatbuffer_c_library.cmake
@@ -100,7 +100,7 @@
endforeach()
list(TRANSFORM _OUTS PREPEND "${CMAKE_CURRENT_BINARY_DIR}/")
- iree_get_executable_path(_FLATCC_BIN flatcc)
+ iree_get_executable_path(_FLATCC_BIN flatcc_cli)
add_custom_command(
OUTPUT
${_OUTS}
@@ -115,7 +115,7 @@
MAIN_DEPENDENCY
${_RULE_SRCS}
DEPENDS
- ${_FLATCC_BIN}
+ iree_host_flatcc_cli
${_RULE_SRCS}
COMMAND_EXPAND_LISTS
)
diff --git a/build_tools/cmake/iree_multipy.cmake b/build_tools/cmake/iree_multipy.cmake
index a7fd1f1..706f6a6 100644
--- a/build_tools/cmake/iree_multipy.cmake
+++ b/build_tools/cmake/iree_multipy.cmake
@@ -98,7 +98,7 @@
function(iree_pyext_module)
cmake_parse_arguments(ARG
""
- "NAME;MODULE_NAME"
+ "NAME;MODULE_NAME;UNIX_LINKER_SCRIPT"
"SRCS;COPTS;DEPS;PYEXT_DEPS"
${ARGN})
_setup_iree_pyext_names()
@@ -127,6 +127,14 @@
SUFFIX "${IREE_MULTIPY_${V}_SUFFIX}${IREE_MULTIPY_${V}_EXTENSION}"
)
+ # Link flags.
+ if(UNIX AND NOT APPLE) # Apple does not support linker scripts.
+ if(ARG_UNIX_LINKER_SCRIPT)
+ set_target_properties(${VER_NAME} PROPERTIES LINK_FLAGS
+ "-Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/${ARG_UNIX_LINKER_SCRIPT}")
+ endif()
+ endif()
+
iree_pyext_pybind11_options(${VER_NAME})
target_include_directories(${VER_NAME}
PUBLIC
diff --git a/build_tools/cmake/run_android_test.sh b/build_tools/cmake/run_android_test.sh
index 0c7510d..292890c 100755
--- a/build_tools/cmake/run_android_test.sh
+++ b/build_tools/cmake/run_android_test.sh
@@ -35,7 +35,7 @@
set -x
set -e
-adb push $TEST_EXECUTABLE $TEST_ANDROID_ABS_DIR/$(basename $TEST_EXECUTABLE) 1>/dev/null
+adb push $TEST_EXECUTABLE $TEST_ANDROID_ABS_DIR/$(basename $TEST_EXECUTABLE)
if [ -n "$TEST_DATA" ]; then
adb push $TEST_DATA $TEST_ANDROID_ABS_DIR/$(basename $TEST_DATA)
diff --git a/build_tools/embed_data/generate_cc_embed_data.cc b/build_tools/embed_data/generate_cc_embed_data.cc
index 21275cc..fed4f13 100644
--- a/build_tools/embed_data/generate_cc_embed_data.cc
+++ b/build_tools/embed_data/generate_cc_embed_data.cc
@@ -113,11 +113,8 @@
f << "#include <cstddef>\n";
GenerateTocStruct(f);
GenerateNamespaceOpen(f);
- f << "static const struct ::iree::FileToc toc[] = {\n";
- assert(input_files.size() == toc_files.size());
for (size_t i = 0, e = input_files.size(); i < e; ++i) {
- f << " {";
- f << "\"" << absl::CEscape(toc_files[i]) << "\",\n";
+ f << "alignas(alignof(void*)) static char const file_" << i << "[] = {\n";
std::string contents;
if (!SlurpFile(input_files[i], &contents)) {
std::cerr << "Error reading file " << input_files[i] << "\n";
@@ -130,7 +127,16 @@
f << "\"" << absl::CHexEscape(line) << "\"\n";
remaining_contents = remaining_contents.substr(line.size());
}
- f << "\"\\0\", " << contents.size() << "},\n";
+ f << "};\n";
+ }
+ f << "static const struct ::iree::FileToc toc[] = {\n";
+ assert(input_files.size() == toc_files.size());
+ for (size_t i = 0, e = input_files.size(); i < e; ++i) {
+ f << " {\n";
+ f << " \"" << absl::CEscape(toc_files[i]) << "\",\n";
+ f << " file_" << i << ",\n";
+ f << " sizeof(file_" << i << ") - 1\n";
+ f << " },\n";
}
f << " {nullptr, nullptr, 0},\n";
f << "};\n";
diff --git a/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/integrations/build.sh b/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/integrations/build.sh
old mode 100644
new mode 100755
diff --git a/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/integrations/build_kokoro.sh b/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/integrations/build_kokoro.sh
old mode 100644
new mode 100755
diff --git a/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-turing/integrations/build.sh b/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-turing/integrations/build.sh
old mode 100644
new mode 100755
diff --git a/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-turing/integrations/build_kokoro.sh b/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-turing/integrations/build_kokoro.sh
old mode 100644
new mode 100755
diff --git a/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-swiftshader/build.sh b/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-swiftshader/build.sh
old mode 100644
new mode 100755
diff --git a/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-swiftshader/build_kokoro.sh b/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-swiftshader/build_kokoro.sh
old mode 100644
new mode 100755
diff --git a/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/hlo/CMakeLists.txt b/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/hlo/CMakeLists.txt
index 0900667..d8ddd33 100644
--- a/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/hlo/CMakeLists.txt
+++ b/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/hlo/CMakeLists.txt
@@ -40,7 +40,9 @@
"lib/Dialect/mhlo/IR/lhlo_ops.cc"
"lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc"
"lib/Dialect/mhlo/transforms/legalize_control_flow.cc"
+ "lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc"
"lib/Dialect/mhlo/transforms/legalize_to_standard.cc"
+ "lib/Dialect/mhlo/transforms/lower_complex.cc"
"lib/Dialect/mhlo/transforms/lower_general_dot.cc"
"lib/Dialect/mhlo/transforms/materialize_broadcasts.cc"
"lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc"
@@ -89,6 +91,7 @@
tensorflow_mlir_hlo_hlo_ops_pattern_gen
tensorflow_mlir_hlo_infer_fusibility_op_interface_gen
tensorflow_mlir_hlo_legalize_to_standard_patterns_gen
+ tensorflow_mlir_hlo_lower_to_complex_inc_gen
tensorflow_mlir_hlo_lhlo_ops_gen
tensorflow_mlir_hlo_xla_canonicalize_gen
PUBLIC
@@ -222,3 +225,18 @@
OUTS
-gen-rewriters lib/Dialect/mhlo/transforms/generated_legalize_to_standard.inc
)
+
+external_tablegen_library(
+ PACKAGE
+ tensorflow
+ NAME
+ mlir_hlo_lower_to_complex_inc_gen
+ TBLGEN
+ MLIR
+ ROOT
+ ${TF_MLIR_HLO_SRC_ROOT}
+ SRCS
+ "lib/Dialect/mhlo/transforms/lower_complex_patterns.td"
+ OUTS
+ -gen-rewriters lib/Dialect/mhlo/transforms/generated_lower_complex.inc
+)
diff --git a/docs/design_docs/codegen_passes.md b/docs/design_docs/codegen_passes.md
index 83e37fc..4a1335c 100644
--- a/docs/design_docs/codegen_passes.md
+++ b/docs/design_docs/codegen_passes.md
@@ -27,13 +27,13 @@
func @main_ex_dispatch() {
%c0 = constant 0 : index
%0 = hal.interface.load.tensor @legacy_io::@arg0,
- offset = %c0 : tensor<4x5xf32>
+ offset = %c0 : tensor<32x24xf32>
%1 = hal.interface.load.tensor @legacy_io::@arg1,
- offset = %c0 : tensor<5x10xf32>
+ offset = %c0 : tensor<24x16xf32>
%2 = "mhlo.dot"(%0, %1) {precision_config = ["DEFAULT", "DEFAULT"]} :
- (tensor<4x5xf32>, tensor<5x10xf32>) -> tensor<4x10xf32>
+ (tensor<32x24xf32>, tensor<24x16xf32>) -> tensor<32x16xf32>
hal.interface.store.tensor %2, @legacy_io::@ret0,
- offset = %c0 : tensor<4x10xf32>
+ offset = %c0 : tensor<32x16xf32>
return
}
hal.interface @legacy_io attributes {sym_visibility = "private"} {
@@ -57,17 +57,18 @@
func @main_ex_dispatch() {
%c0 = constant 0 : index
%0 = hal.interface.load.tensor @legacy_io::@arg0,
- offset = %c0 : tensor<10x5xf32>
+ offset = %c0 : tensor<10x15xf32>
%1 = hal.interface.load.tensor @legacy_io::@arg1,
- offset = %c0 : tensor<10x5xf32>
+ offset = %c0 : tensor<10x15xf32>
%2 = hal.interface.load.tensor @legacy_io::@arg2,
- offset = %c0 : tensor<10x5xf32>
+ offset = %c0 : tensor<15xf32>
%3 = "mhlo.add"(%0, %1) :
- (tensor<10x5xf32>, tensor<10x5xf32>) -> tensor<10x5xf32>
- %4 = "mhlo.multiply"(%3, %2) :
- (tensor<10x5xf32>, tensor<10x5xf32>) -> tensor<10x5xf32>
- hal.interface.store.tensor %4, @legacy_io::@ret0,
- offset = %c0 : tensor<10x5xf32>
+ (tensor<10x15xf32>, tensor<10x15xf32>) -> tensor<10x15xf32>
+ %4 = "mhlo.broadcast"(%2) : (tensor<15xf32>) -> tensor<10x15xf32>
+ %5 = "mhlo.multiply"(%3, %4) :
+ (tensor<10x15xf32>, tensor<10x15xf32>) -> tensor<10x15xf32>
+ hal.interface.store.tensor %5, @legacy_io::@ret0,
+ offset = %c0 : tensor<10x15xf32>
return
}
hal.interface @legacy_io attributes {sym_visibility = "private"} {
@@ -130,11 +131,13 @@
The first step is to convert MHLO operations to Linalg on tensors. This is done
using the [HLOToLinalgPass][HLOToLinalgPass] from Tensorflow. An example of the
-conversion is shown below, where each of the `mhlo.add` and `mhlo.multiply`
-operations are converted to `linalg.generic` operations on tensors.
+conversion is shown below, where the `mhlo.add`, `mhlo.broadcast` and
+`mhlo.multiply` operations are converted to `linalg.generic` operations on
+tensors.
```mlir
#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1)>
%3 = linalg.generic
{args_in = 2 : i64, args_out = 1 : i64,
indexing_maps = [#map0, #map0, #map0],
@@ -142,15 +145,22 @@
^bb0(%arg0: f32, %arg1: f32): // no predecessors
%5 = addf %arg0, %arg1 : f32
linalg.yield %5 : f32
- } : tensor<10x5xf32>, tensor<10x5xf32> -> tensor<10x5xf32>
+ } : tensor<10x15xf32>, tensor<10x15xf32> -> tensor<10x15xf32>
%4 = linalg.generic
+ {args_in = 1 : i64, args_out = 1 : i64,
+ indexing_maps = [#map1, #map0],
+ iterator_types = ["parallel", "parallel"]} %2 {
+ ^bb0(%arg0: f32): // no predecessors
+ linalg.yield %arg0 : f32
+ }: tensor<15xf32> -> tensor<10x15xf32>
+%5 = linalg.generic
{args_in = 2 : i64, args_out = 1 : i64,
indexing_maps = [#map0, #map0, #map0],
- iterator_types = ["parallel", "parallel"]} %3, %2 {
+ iterator_types = ["parallel", "parallel"]} %3, %4 {
^bb0(%arg0: f32, %arg1: f32): // no predecessors
%5 = mulf %arg0, %arg1 : f32
linalg.yield %5 : f32
- }: tensor<10x5xf32>, tensor<10x5xf32> -> tensor<10x5xf32>
+ }: tensor<10x15xf32>, tensor<10x15xf32> -> tensor<10x15xf32>
```
<a name="snippet3"></a> Snippet 3 : MHLO to Linalg conversion for
@@ -190,15 +200,16 @@
```mlir
#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1)>
%3 = linalg.generic
{args_in = 3 : i64, args_out = 1 : i64,
- indexing_maps = [#map0, #map0, #map0, #map0],
+ indexing_maps = [#map0, #map0, #map1, #map0],
iterator_types = ["parallel", "parallel"]} %0, %1, %2 {
^bb0(%arg0: f32, %arg1: f32, %arg2: f32): // no predecessors
%4 = addf %arg0, %arg1 : f32
%5 = mulf %4, %arg2 : f32
linalg.yield %5 : f32
- }: tensor<?x5xf32>, tensor<?x5xf32>, tensor<?x5xf32> -> tensor<?x5xf32>
+ }: tensor<10x15xf32>, tensor<10x15xf32>, tensor<15xf32> -> tensor<10x15xf32>
```
<a name="snippet4"></a> Snippet 4: Fusion of Linalg operation on tensors for
@@ -265,15 +276,15 @@
```mlir
func @main_ex_dispatch() {
%0 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@ret0} : memref<4x10xf32>
+ {binding = @legacy_io::@ret0} : memref<32x16xf32>
%c0 = constant 0 : index
%1 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@arg0} : memref<4x5xf32>
+ {binding = @legacy_io::@arg0} : memref<32x24xf32>
%2 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@arg1} : memref<5x10xf32>
+ {binding = @legacy_io::@arg1} : memref<24x16xf32>
%cst = constant 0.000000e+00 : f32
linalg.matmul(%1, %2, %0) :
- memref<4x5xf32>, memref<5x10xf32>, memref<4x10xf32>
+ memref<32x24xf32>, memref<24x16xf32>, memref<32x16xf32>
return
}
```
@@ -283,25 +294,26 @@
```mlir
#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1)>
func @main_ex_dispatch() {
%0 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@ret0} : memref<10x5xf32>
+ {binding = @legacy_io::@ret0} : memref<10x15xf32>
%c0 = constant 0 : index
%1 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@arg0} : memref<10x5xf32>
+ {binding = @legacy_io::@arg0} : memref<10x15xf32>
%2 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@arg1} : memref<10x5xf32>
+ {binding = @legacy_io::@arg1} : memref<10x15xf32>
%3 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@arg2} : memref<10x5xf32>
+ {binding = @legacy_io::@arg2} : memref<15xf32>
linalg.generic
{args_in = 3 : i64, args_out = 1 : i64,
- indexing_maps = [#map0, #map0, #map0],
+ indexing_maps = [#map0, #map0, #map1, #map0],
iterator_types = ["parallel", "parallel"]} %1, %2, %3, %0 {
^bb0(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32): // no predecessors
%4 = addf %arg0, %arg1 : f32
%5 = mulf %4, %arg2 : f32
linalg.yield %5 : f32
- }: memref<10x5xf32>, memref<10x5xf32>, memref<10x5xf32>, memref<10x5xf32>
+ }: memref<10x15xf32>, memref<10x15xf32>, memref<15xf32>, memref<10x15xf32>
return
}
```
@@ -350,18 +362,20 @@
attributes {
spv.entry_point_abi = {local_size = dense<[8, 8, 1]> : vector<3xi32>}} {
%cst = constant 0.000000e+00 : f32
+ %c32 = constant 32 : index
+ %c24 = constant 24 : index
+ %c16 = constant 16 : index
%c0 = constant 0 : index
%c4 = constant 4 : index
- %c10 = constant 10 : index
%0 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@ret0} : memref<4x10xf32>
+ {binding = @legacy_io::@ret0} : memref<32x16xf32>
%1 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@arg0} : memref<4x5xf32>
+ {binding = @legacy_io::@arg0} : memref<32x24xf32>
%2 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@arg1} : memref<5x10xf32>
- linalg.fill(%0, %cst) : memref<4x10xf32>, f32
- scf.parallel (%arg0, %arg1) = (%c0, %c0) to (%c4, %c10) step (%c8, %c8) {
- scf.for %arg2 = %c0 to %c5 step %c4 {
+ {binding = @legacy_io::@arg1} : memref<24x16xf32>
+ linalg.fill(%0, %cst) : memref<32x16xf32>, f32
+ scf.parallel (%arg0, %arg1) = (%c0, %c0) to (%c32, %c16) step (%c8, %c8) {
+ scf.for %arg2 = %c0 to %24 step %c4 {
...
%5 = subview %1[%arg0, %arg2]...
...
@@ -440,19 +454,21 @@
func @matmul_tile()
attributes {
spv.entry_point_abi = {local_size = dense<[8, 8, 1]> : vector<3xi32>}} {
- %c96 = constant 96 : index
+ %c32 = constant 32 : index
+ %c24 = constant 24 : index
+ %c16 = constant 16 : index
%c4 = constant 4 : index
%c8 = constant 8 : index
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@arg0} : memref<96x96xf32>
+ {binding = @legacy_io::@arg0} : memref<32x24xf32>
%1 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@arg1} : memref<96x96xf32>
+ {binding = @legacy_io::@arg1} : memref<24x16xf32>
%2 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@ret0} : memref<96x96xf32>
- scf.parallel (%arg0, %arg1) = (%c0, %c0) to (%c96, %c96) step (%c8, %c8) {
- scf.for %arg2 = %c0 to %c96 step %c4 {
+ {binding = @legacy_io::@ret0} : memref<32x16xf32>
+ scf.parallel (%arg0, %arg1) = (%c0, %c0) to (%c32, %c16) step (%c8, %c8) {
+ scf.for %arg2 = %c0 to %c24 step %c4 {
...
%5 = subview %0[%arg0, %arg2]...
...
@@ -518,38 +534,40 @@
func @main_ex_dispatch_0_dispatch_1()
attributes {
spv.entry_point_abi = {local_size = dense<[8, 8, 1]> : vector<3xi32>}} {
- %c5 = constant 5 : index
+ %c24 = constant 24 : index
%c8 = constant 8 : index
%c4 = constant 4 : index
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@ret0} : memref<4x10xf32>
+ {binding = @legacy_io::@ret0} : memref<32x16xf32>
%1 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@arg0} : memref<4x5xf32>
+ {binding = @legacy_io::@arg0} : memref<32x24xf32>
%2 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@arg1} : memref<5x10xf32>
+ {binding = @legacy_io::@arg1} : memref<24x16xf32>
%3 = "gpu.block_id"() {dimension = "x"} : () -> index
- %4 = muli %3, %c8 : index
- scf.for %arg0 = %c0 to %c5 step %c4 {
+ %4 = "gpu.block_id"() {dimension = "y"} : () -> index
+ %5 = muli %4, %c8 : index
+ %6 = muli %3, %c8 : index
+ scf.for %arg0 = %c0 to %c24 step %c4 {
...
- %9 = subview %1[0, %arg0]
+ %15 = subview %1[%5, %arg0]
...
- %14 = subview %2[%arg0, %4]
- %15 = subview %0[0, %4]
- %16 = "gpu.thread_id"() {dimension = "x"} : () -> index
- %17 = "gpu.thread_id"() {dimension = "y"} : () -> index
- %18 = cmpi "slt", %17, %c4 : index
- %19 = cmpi "slt", %16, %13 : index
- %20 = and %18, %19 : i1
- scf.if %20 {
- scf.for %arg1 = %c0 to %8 step %c1 {
- %21 = load %9[%17, %arg1] : memref<4x?xf32, #map0>
- %22 = load %14[%arg1, %16] : memref<?x?xf32, #map1>
- %23 = load %15[%17, %16] : memref<4x?xf32, #map1>
- %24 = mulf %21, %22 : f32
- %25 = addf %23, %24 : f32
- store %25, %15[%17, %16] : memref<4x?xf32, #map1>
+ %20 = subview %2[%arg0, %6]
+ %21 = subview %0[%5, %6]
+ %22 = "gpu.thread_id"() {dimension = "x"} : () -> index
+ %23 = "gpu.thread_id"() {dimension = "y"} : () -> index
+ %24 = cmpi "slt", %23, %10 : index
+ %25 = cmpi "slt", %22, %19 : index
+ %26 = and %24, %25 : i1
+ scf.if %26 {
+ scf.for %arg1 = %c0 to %14 step %c1 {
+ %27 = load %15[%23, %arg1] : memref<?x?xf32, #map0>
+ %28 = load %20[%arg1, %22] : memref<?x?xf32, #map1>
+ %29 = load %21[%23, %22] : memref<?x?xf32, #map1>
+ %30 = mulf %21, %22 : f32
+ %31 = addf %23, %24 : f32
+ store %25, %15[%23, %22] : memref<4x?xf32, #map1>
}
}
}
@@ -574,13 +592,13 @@
%c50 = constant 50 : index
%c5 = constant 5 : index
%0 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@ret0} : memref<10x5xf32>
+ {binding = @legacy_io::@ret0} : memref<10x15xf32>
%1 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@arg0} : memref<10x5xf32>
+ {binding = @legacy_io::@arg0} : memref<10x15xf32>
%2 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@arg1} : memref<10x5xf32>
+ {binding = @legacy_io::@arg1} : memref<10x15xf32>
%3 = iree.placeholder for "interface buffer"
- {binding = @legacy_io::@arg2} : memref<10x5xf32>
+ {binding = @legacy_io::@arg2} : memref<15xf32>
%4 = "gpu.block_id"() {dimension = "x"} : () -> index
%5 = "gpu.block_dim"() {dimension = "x"} : () -> index
%6 = "gpu.thread_id"() {dimension = "x"} : () -> index
@@ -590,12 +608,12 @@
scf.if %9 {
%10 = divi_signed %8, %c5 : index
%11 = remi_signed %8, %c5 : index
- %12 = load %1[%10, %11] : memref<10x5xf32>
- %13 = load %2[%10, %11] : memref<10x5xf32>
- %14 = load %3[%10, %11] : memref<10x5xf32>
+ %12 = load %1[%10, %11] : memref<10x15xf32>
+ %13 = load %2[%10, %11] : memref<10x15xf32>
+ %14 = load %3[%11] : memref<15xf32>
%15 = addf %12, %13 : f32
%16 = mulf %15, %14 : f32
- store %16, %0[%10, %11] : memref<10x5xf32>
+ store %16, %0[%10, %11] : memref<10x15xf32>
}
return
}
diff --git a/experimental/ModelBuilder/MemRefUtils.h b/experimental/ModelBuilder/MemRefUtils.h
index ebd6d1b..4d1e67b 100644
--- a/experimental/ModelBuilder/MemRefUtils.h
+++ b/experimental/ModelBuilder/MemRefUtils.h
@@ -74,21 +74,55 @@
// Mallocs a StridedMemRefDescriptor<T, N>* that matches the MLIR ABI.
// This is an implementation detail that is kept in sync with MLIR codegen
+// conventions. Additionally takes a `shapeAlloc` array which
+// is used instead of `shape` to allocate "more aligned" data and compute the
+// corresponding strides.
+template <typename T, int N>
+typename std::enable_if<(N >= 1), StridedMemRefType<T, N> *>::type
+makeStridedMemRefDescriptor(void *ptr, void *alignedPtr,
+ const std::array<int64_t, N> &shape,
+ const std::array<int64_t, N> &shapeAlloc,
+ AllocFunType allocFun = &::malloc) {
+ StridedMemRefType<T, N> *descriptor = static_cast<StridedMemRefType<T, N> *>(
+ allocFun(sizeof(StridedMemRefType<T, N>)));
+ descriptor->basePtr = static_cast<T *>(ptr);
+ descriptor->data = static_cast<T *>(alignedPtr);
+ descriptor->offset = 0;
+ std::copy(shape.begin(), shape.end(), descriptor->sizes);
+ auto strides = makeStrides<N>(shapeAlloc);
+ std::copy(strides.begin(), strides.end(), descriptor->strides);
+ return descriptor;
+}
+
+// Mallocs a StridedMemRefDescriptor<T, 0>* that matches the MLIR ABI.
+// This is an implementation detail that is kept in sync with MLIR codegen
+// conventions. Additionally takes a `shapeAlloc` array which
+// is used instead of `shape` to allocate "more aligned" data and compute the
+// corresponding strides.
+template <typename T, int N>
+typename std::enable_if<(N == 0), StridedMemRefType<T, 0> *>::type
+makeStridedMemRefDescriptor(void *ptr, void *alignedPtr,
+ const std::array<int64_t, N> &shape = {},
+ const std::array<int64_t, N> &shapeAlloc = {},
+ AllocFunType allocFun = &::malloc) {
+ StridedMemRefType<T, 0> *descriptor = static_cast<StridedMemRefType<T, 0> *>(
+ allocFun(sizeof(StridedMemRefType<T, 0>)));
+ descriptor->basePtr = static_cast<T *>(ptr);
+ descriptor->data = static_cast<T *>(alignedPtr);
+ descriptor->offset = 0;
+ return descriptor;
+}
+
+// Mallocs a StridedMemRefDescriptor<T, N>* that matches the MLIR ABI.
+// This is an implementation detail that is kept in sync with MLIR codegen
// conventions.
template <typename T, int N>
typename std::enable_if<(N >= 1), StridedMemRefType<T, N> *>::type
makeStridedMemRefDescriptor(void *ptr, void *alignedPtr,
const std::array<int64_t, N> &shape,
- AllocFunType alloc = &::malloc) {
- StridedMemRefType<T, N> *descriptor = static_cast<StridedMemRefType<T, N> *>(
- alloc(sizeof(StridedMemRefType<T, N>)));
- descriptor->basePtr = static_cast<T *>(ptr);
- descriptor->data = static_cast<T *>(alignedPtr);
- descriptor->offset = 0;
- std::copy(shape.begin(), shape.end(), descriptor->sizes);
- auto strides = makeStrides<N>(shape);
- std::copy(strides.begin(), strides.end(), descriptor->strides);
- return descriptor;
+ AllocFunType allocFun = &::malloc) {
+ return makeStridedMemRefDescriptor<T, N>(ptr, alignedPtr, shape, shape,
+ allocFun);
}
// Mallocs a StridedMemRefDescriptor<T, 0>* (i.e. a pointer to scalar) that
@@ -98,13 +132,9 @@
typename std::enable_if<(N == 0), StridedMemRefType<T, 0> *>::type
makeStridedMemRefDescriptor(void *ptr, void *alignedPtr,
const std::array<int64_t, N> &shape = {},
- AllocFunType alloc = &::malloc) {
- StridedMemRefType<T, 0> *descriptor = static_cast<StridedMemRefType<T, 0> *>(
- alloc(sizeof(StridedMemRefType<T, 0>)));
- descriptor->basePtr = static_cast<T *>(ptr);
- descriptor->data = static_cast<T *>(alignedPtr);
- descriptor->offset = 0;
- return descriptor;
+ AllocFunType allocFun = &::malloc) {
+ return makeStridedMemRefDescriptor<T, N>(ptr, alignedPtr, shape, shape,
+ allocFun);
}
// Mallocs an UnrankedMemRefType<T>* that contains a ranked
@@ -113,9 +143,9 @@
template <typename T, int N>
::UnrankedMemRefType<T> *allocUnrankedDescriptor(
void *data, void *alignedData, const std::array<int64_t, N> &shape,
- AllocFunType alloc = &::malloc) {
+ AllocFunType allocFun = &::malloc) {
::UnrankedMemRefType<T> *res = static_cast<::UnrankedMemRefType<T> *>(
- alloc(sizeof(::UnrankedMemRefType<T>)));
+ allocFun(sizeof(::UnrankedMemRefType<T>)));
res->rank = N;
res->descriptor = makeStridedMemRefDescriptor<T, N>(data, alignedData, shape);
return res;
@@ -157,14 +187,14 @@
// and greater than the size of T. By default the alignment is sizeof(T).
template <typename T>
std::pair<void *, void *> allocAligned(
- size_t nElements, AllocFunType alloc = &::malloc,
+ size_t nElements, AllocFunType allocFun = &::malloc,
llvm::Optional<uint64_t> alignment = llvm::Optional<uint64_t>()) {
assert(sizeof(T) < (1ul << 32) && "Elemental type overflows");
auto size = nElements * sizeof(T);
auto desiredAlignment = alignment.getValueOr(pow2msb(sizeof(T)));
assert((desiredAlignment & (desiredAlignment - 1)) == 0);
assert(desiredAlignment >= sizeof(T));
- void *data = alloc(size + desiredAlignment);
+ void *data = allocFun(size + desiredAlignment);
uintptr_t addr = reinterpret_cast<uintptr_t>(data);
uintptr_t rem = addr % desiredAlignment;
void *alignedData =
@@ -194,24 +224,48 @@
}
// Entry point to allocate a dense buffer with a given `shape` and initializer
-// of type PointwiseInitializer. Can optionally take specific `alloc` and `free`
+// of type PointwiseInitializer. Additionally takes a `shapeAlloc` array which
+// is used instead of `shape` to allocate "more aligned" data and compute the
+// corresponding strides.
+// Can optionally take specific alloc and free functions.
+//
+// Example:
+// When called with `shape = [128, 127]` and `shapeAlloc = [128, 128]`, this
+// allocates a memref with `128*128*sizeof(T)` bytes, `sizes = [128, 127]` and
+// `strides = [128, 1]`.
+template <typename T, int N, typename FreeFunType = decltype(&::free)>
+std::unique_ptr<StridedMemRefType<T, N>, FreeFunType>
+makeInitializedStridedMemRefDescriptor(
+ const std::array<int64_t, N> &shape,
+ const std::array<int64_t, N> &shapeAlloc, LinearInitializer<T> init,
+ llvm::Optional<uint64_t> alignment = llvm::Optional<uint64_t>(),
+ AllocFunType allocFun = &::malloc, FreeFunType freeFun = &::free) {
+ for (unsigned i = 0; i < N; ++i)
+ assert(shape[i] <= shapeAlloc[i] &&
+ "shapeAlloc must be greater than or equal to shape");
+ int64_t nElements = 1;
+ for (int64_t s : shapeAlloc) nElements *= s;
+ auto allocated = allocAligned<T>(nElements, allocFun, alignment);
+ auto *data = static_cast<T *>(allocated.first);
+ auto *alignedData = static_cast<T *>(allocated.second);
+ for (unsigned i = 0; i < nElements; ++i) init(i, alignedData);
+ return std::unique_ptr<StridedMemRefType<T, N>, FreeFunType>(
+ detail::makeStridedMemRefDescriptor<T, N>(data, alignedData, shape,
+ shapeAlloc, allocFun),
+ freeFun);
+}
+
+// Entry point to allocate a dense buffer with a given `shape` and initializer
+// of type PointwiseInitializer. Can optionally take specific alloc and free
// functions.
template <typename T, int N, typename FreeFunType = decltype(&::free)>
std::unique_ptr<StridedMemRefType<T, N>, FreeFunType>
makeInitializedStridedMemRefDescriptor(
const std::array<int64_t, N> &shape, LinearInitializer<T> init,
llvm::Optional<uint64_t> alignment = llvm::Optional<uint64_t>(),
- AllocFunType alloc = &::malloc, FreeFunType freeFun = &::free) {
- int64_t nElements = 1;
- for (int64_t s : shape) nElements *= s;
- auto allocated = allocAligned<T>(nElements, alloc, alignment);
- auto *data = static_cast<T *>(allocated.first);
- auto *alignedData = static_cast<T *>(allocated.second);
- for (unsigned i = 0; i < nElements; ++i) init(i, alignedData);
- return std::unique_ptr<StridedMemRefType<T, N>, FreeFunType>(
- detail::makeStridedMemRefDescriptor<T, N>(data, alignedData, shape,
- alloc),
- freeFun);
+ AllocFunType allocFun = &::malloc, FreeFunType freeFun = &::free) {
+ return makeInitializedStridedMemRefDescriptor<T, N>(
+ shape, shape, init, alignment, allocFun, freeFun);
}
} // namespace mlir
diff --git a/experimental/ModelBuilder/ModelBuilder.cpp b/experimental/ModelBuilder/ModelBuilder.cpp
index a4a9062..706f419 100644
--- a/experimental/ModelBuilder/ModelBuilder.cpp
+++ b/experimental/ModelBuilder/ModelBuilder.cpp
@@ -235,8 +235,7 @@
func = builder.create<FuncOp>(
module.getLoc(), functionName,
FunctionType::get(SmallVector<Type, 4>(values.getTypes()), returnTypes,
- builder.getContext()),
- ArrayRef<NamedAttribute>{});
+ builder.getContext()));
}
return std_call(builder.getSymbolRefAttr(func), returnTypes, values);
}
diff --git a/experimental/ModelBuilder/VulkanWrapperPass.cpp b/experimental/ModelBuilder/VulkanWrapperPass.cpp
index b20d2fc..c9ba950 100644
--- a/experimental/ModelBuilder/VulkanWrapperPass.cpp
+++ b/experimental/ModelBuilder/VulkanWrapperPass.cpp
@@ -95,10 +95,9 @@
vulkanLaunchTypes.insert(vulkanLaunchTypes.end(), args.begin(), args.end());
// Declare vulkan launch function.
- builder.create<FuncOp>(
- loc, kVulkanLaunch,
- FunctionType::get(vulkanLaunchTypes, ArrayRef<Type>{}, loc->getContext()),
- ArrayRef<NamedAttribute>{});
+ builder.create<FuncOp>(loc, kVulkanLaunch,
+ FunctionType::get(vulkanLaunchTypes, ArrayRef<Type>{},
+ loc->getContext()));
return success();
}
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index f2c6151..750aaba 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -18,9 +18,15 @@
# pylint: disable=protected-access
# pylint: disable=unsupported-assignment-operation
-import collections
+# This file uses the following abbreviations:
+# ref: reference – for the reference CompiledModule
+# tar: target - for one of the target CompiledModules
+
+import copy
+import inspect
import os
-import re
+import sys
+import tempfile
from absl import flags
from absl import logging
@@ -29,244 +35,329 @@
from pyiree.tf.support import tf_utils
import tensorflow.compat.v2 as tf
+flags.DEFINE_string("reference_backend", "tf",
+ "The backend to treat as a source of truth.")
flags.DEFINE_string("target_backends", None,
"Explicit comma-delimited list of target backends.")
flags.DEFINE_string(
- "debug_dir", None,
- "Specifies a directory to dump debug artifacts to. Defaults to "
- "--test_tmpdir")
+ "artifacts_dir", None,
+ "Specifies a directory to dump compilation artifacts and traces to. "
+ "Defaults to the OS's tempdir.")
+flags.DEFINE_bool(
+ "summarize", True,
+ "Summarize the inputs and outputs of each module trace logged to disk.")
FLAGS = flags.FLAGS
+NUMPY_LINEWIDTH = 120
-def _setup_test_debug_dir(test_name):
- global global_debug_dir
+def _setup_artifacts_dir(module_name):
+ parent_dir = FLAGS.artifacts_dir
+ if parent_dir is None:
+ parent_dir = os.path.join(tempfile.gettempdir(), "iree", "modules")
+ artifacts_dir = os.path.join(parent_dir, module_name)
+ logging.info("Saving compilation artifacts and traces to '%s'", artifacts_dir)
- # Use test_tempdir (which defaults to '/tmp/absl_testing/') if FLAGS.debug_dir
- # is not provided.
- parent = FLAGS.debug_dir if FLAGS.debug_dir is not None else FLAGS.test_tmpdir
- global_debug_dir = os.path.join(parent, test_name)
-
- # Create the directory.
+ # If the artifacts already exist then we overwrite/update them.
try:
- os.makedirs(global_debug_dir)
+ # Use try/except instead of os.path.exists to address a race condition
+ # between multiple tests targets.
+ os.makedirs(artifacts_dir)
except IOError:
- logging.exception("Error creating debug dir for: %s", global_debug_dir)
+ pass
+ return artifacts_dir
-class _VirtualModuleInstance(object):
- """Wraps a namedtuple of modules and represents a union of them."""
+def _parse_target_backends(target_backends):
+ """Decodes a comma-delimited string of backends into BackendInfo objects."""
+ backends = []
+ for backend_name in target_backends.split(","):
+ if backend_name not in tf_utils.BackendInfo.ALL.keys():
+ raise ValueError(
+ "Invalid backend specification string '{}', unexpected name '{}';"
+ " valid names are '{}'".format(target_backends, backend_name,
+ tf_utils.BackendInfo.ALL.keys()))
+ backends.append(tf_utils.BackendInfo.ALL[backend_name])
+ return backends
- def __init__(self, named_modules, match_spec):
- self._named_modules = named_modules
- self._match_spec = match_spec
+
+def get_target_backends():
+ """Gets the BackendInfo instances to compare with the reference backend.
+
+ By default all backends in BackendInfo will be used. Specific backends to
+ run on can be specified using the `--target_backends` flag.
+
+ Returns:
+ Sequence of BackendInfo that should be used.
+ """
+ if FLAGS.target_backends is not None:
+ logging.info("Using backends from command line: %s", FLAGS.target_backends)
+ backends = _parse_target_backends(FLAGS.target_backends)
+ else:
+ # If no backends are specified, use them all.
+ backends = list(tf_utils.BackendInfo.ALL.values())
+ return backends
+
+
+def _indent(input_str, indentation=2):
+ """Indents a string by the specified number of spaces, defaulting to 2."""
+ spaces = " " * indentation
+ lines = input_str.split("\n")
+ # Prepend spaces to each non-empty line.
+ lines = [f"{spaces}{line}" if len(line) else line for line in lines]
+ return "\n".join(lines)
+
+
+class ModuleCall:
+
+ def __init__(self, method_name, inputs, outputs, rtol=1e-6, atol=1e-6):
+ """Records the details of a call to a CompiledModule."""
+ self.method = method_name
+
+ # Deepcopy to safegard against mutation.
+ self.inputs = copy.deepcopy(inputs)
+ if outputs is not None:
+ outputs = copy.deepcopy(outputs)
+ else:
+ outputs = tuple()
+ self.outputs = outputs if isinstance(outputs, tuple) else (outputs,)
+
+ self.rtol = rtol
+ self.atol = atol
+
+ def get_tolerances(self):
+ """Gets the floating point tolerances associated with this call."""
+ return self.rtol, self.atol
+
+ def __str__(self):
+ prior_printoptions = np.get_printoptions()
+ np.set_printoptions(linewidth=NUMPY_LINEWIDTH)
+
+ header = f"Method: {self.method}"
+ inputs = "\n".join(_indent(str(value)) for value in self.inputs)
+ outputs = "\n".join(_indent(str(value)) for value in self.outputs)
+ tolerances = _indent(f"rtol={self.rtol}, atol={self.atol}")
+ body = f"Inputs:\n{inputs}\nOutputs:\n{outputs}\nTolerances:\n{tolerances}"
+ result = f"{header}\n{_indent(body)}"
+
+ np.set_printoptions(**prior_printoptions)
+ return result
+
+
+class Trace:
+ """Stores the inputs and outputs of a series of calls to a module."""
+
+ def __init__(self, module, function):
+ """Extracts metadata from module and function and initializes.
+
+ Example usage:
+ def forward_pass(...):
+ ...
+ module = IreeCompiledModule(...)
+ trace = Trace(module, forward_pass)
+ forward_pass(TracedModule(module, trace))
+
+ Args:
+ module: the module who's outputs this trace will record.
+ function: the function that module will be traced on.
+ """
+ # Extract metadata from module and function.
+ self.module_name = module.module_name
+ self.backend = module.backend
+ self.function_name = function.__name__
+ self.function_sourcefile = inspect.getsourcefile(function)
+ source, start_line = inspect.getsourcelines(function)
+ self.function_line_numbers = (start_line, start_line + len(source))
+ self.function_source = "".join(source)
+
+ self.calls = []
+
+ def __str__(self):
+ header = (f"Trace of {self.module_name} compiled to '{self.backend}' "
+ f"on function '{self.function_name}':")
+ # Give each call a number so it's easier to compare between multiple traces.
+ calls = [f"{i + 1}. {str(call)}" for i, call in enumerate(self.calls)]
+ calls = _indent("\n".join(calls))
+ return f"{header}\n{calls}"
+
+ def __iter__(self):
+ for call in self.calls:
+ yield call
+
+ @staticmethod
+ def compare_traces(ref_trace, tar_trace):
+ traces_match = True
+
+ # Check that all method invocations match.
+ ref_methods = [(call.method, call.rtol, call.atol) for call in ref_trace]
+ tar_methods = [(call.method, call.rtol, call.atol) for call in tar_trace]
+ if ref_methods != tar_methods:
+ # Raise a ValueError instead of returning False since this is an
+ # unexpected error.
+ raise ValueError(
+ "The reference and target traces have different call structures:\n"
+ f"Reference: {ref_methods}\nTarget: {tar_methods}")
+
+ for ref_call, tar_call in zip(ref_trace, tar_trace):
+ logging.info("Comparing calls to '%s'", ref_call.method)
+ rtol, atol = ref_call.get_tolerances()
+
+ inputs_match = Trace._check_same(ref_call.inputs, tar_call.inputs, rtol,
+ atol)
+ if not inputs_match:
+ logging.error("Inputs did not match.")
+ outputs_match = Trace._check_same(ref_call.outputs, tar_call.outputs,
+ rtol, atol)
+ if not outputs_match:
+ logging.error("Outputs did not match.")
+ calls_match = inputs_match and outputs_match
+
+ if not calls_match:
+ logging.error("Comparision between '%s' and '%s' failed on method '%s'",
+ ref_trace.backend, tar_trace.backend, ref_call.method)
+ logging.error("Reference call '%s':\n%s", ref_trace.backend, ref_call)
+ logging.error("Target call '%s':\n%s", tar_trace.backend, tar_call)
+
+ traces_match = traces_match and calls_match
+ return traces_match
+
+ @staticmethod
+ def _check_same(ref, tar, rtol, atol):
+ """Checks that ref and tar have identical datastructures and values."""
+ # Check for matching types.
+ if not isinstance(tar, type(ref)):
+ logging.error(
+ "Expected ref and tar to have the same type but got '%s' and '%s'",
+ type(ref), type(tar))
+ return False
+
+ if ref is None:
+ # Nothing to compare (e.g. the called method had no outputs).
+ return True
+
+ # Recursive check for dicts.
+ if isinstance(ref, dict):
+ if ref.keys() != tar.keys():
+ logging.error(
+ "Expected ref and tar to have the same keys, but got '%s' and '%s'",
+ ref.keys(), tar.keys())
+ return False
+ # Check that all of the dictionaries' values are the same.
+ for key in ref:
+ if not Trace._check_same(ref[key], tar[key], rtol, atol):
+ return False
+
+ # Recursive check for iterables.
+ elif isinstance(ref, list) or isinstance(ref, tuple):
+ if len(ref) != len(tar):
+ logging.error(
+ "Expected ref and tar to have the same length, but got %s and %s",
+ len(ref), len(tar))
+ return False
+ # Check that all of the iterables' values are the same.
+ for i in range(len(ref)):
+ if not Trace._check_same(ref[i], tar[i], rtol, atol):
+ return False
+
+ # Base check for numpy arrays.
+ elif isinstance(ref, np.ndarray):
+ if ref.dtype != tar.dtype:
+ logging.error(
+ "Expected ref and tar to have the same dtype, but got %s and %s",
+ ref.dtype, tar.dtype)
+ return False
+ if np.issubdtype(ref.dtype, np.floating):
+ return np.allclose(ref, tar, rtol=rtol, atol=atol)
+ else:
+ return np.array_equal(ref, tar)
+
+ # Base check for native number types.
+ elif isinstance(ref, (int, float)):
+ return ref == tar
+
+ # If outputs end up here then an extra branch for that type should be added.
+ else:
+ raise TypeError(f"Encountered results with unexpected type {type(ref)}")
+ return True
+
+ def _get_trace_dir(self, artifacts_dir):
+ trace_dir = os.path.join(artifacts_dir, "traces")
+ if not os.path.exists(trace_dir):
+ os.makedirs(trace_dir)
+ return trace_dir
+
+ def save_plaintext(self, artifacts_dir, summarize=True):
+ """Saves a human-readable string representation of this trace to disk.
+
+ Args:
+ artifacts_dir: the base directory to save the trace in.
+ summarize: a bool controlling whether numpy should summarize the inputs
+ and outputs if they're large. Setting this to False is very slow for
+ large outputs.
+ """
+ prior_printoptions = np.get_printoptions()
+ np.set_printoptions(
+ linewidth=NUMPY_LINEWIDTH,
+ threshold=None if summarize else sys.maxsize,
+ edgeitems=10) # Can show more items since they won't clutter the logs.
+
+ trace_dir = self._get_trace_dir(artifacts_dir)
+ path = os.path.join(trace_dir, f"{self.function_name}__{self.backend}.txt")
+ with open(path, "w") as f:
+ f.write(str(self))
+ f.write("\n")
+
+ np.set_printoptions(**prior_printoptions)
+
+
+class TracedModule:
+
+ def __init__(self, module, trace):
+ """Wraps a CompiledModule so that all inputs and outputs are traced.
+
+ The TracedModule returned will have an API almost identical to that of the
+ passed CompiledModule. The only changes is that if the keywords `rtol` or
+ `atol` are passed to one of the CompiledModule's methods, then they will be
+ used to set the tolerance for comparing that call to the same call in
+ another trace. So for example, calling `traced_module.add(a, b rtol=1e-8)`
+ would be the same as calling `module.add(a, b)`.
+
+ Args:
+ module: the CompiledModule to trace.
+ trace: the Trace to record calls to this module with.
+ """
+ self._module = module
+ self._trace = trace
+
+ def _trace_call(self, method, method_name):
+ """Decorates a CompiledModule method to capture its inputs and outputs."""
+
+ def call(*args, **kwargs):
+ # Pop manually specified tolerances from the kwargs (if any).
+ tolerances = {}
+ tolerances["rtol"] = kwargs.pop("rtol", None)
+ tolerances["atol"] = kwargs.pop("atol", None)
+ # Only pass these to ModuleCall if they were specified by the user.
+ tolerances = {k: v for k, v in tolerances.items() if v is not None}
+
+ # Run the method and record the details of the call.
+ outputs = method(*args, **kwargs)
+ self._trace.calls.append(
+ ModuleCall(method_name, args, outputs, **tolerances))
+ return outputs
+
+ return call
def __getattr__(self, attr):
- match_modules = {
- k: v
- for k, v in self._named_modules.items()
- if re.search(self._match_spec, k)
- }
- if not match_modules:
- raise AttributeError(
- "Module match spec '%s' did not match anything. (Have %r)" %
- (self._match_spec, self._named_modules.keys()))
- # Resolve functions on each.
- match_functions = {}
- for backend, module in match_modules.items():
- try:
- match_functions[backend] = getattr(module, attr)
- except:
- raise AttributeError(
- "Could not resolve function '%s' on backend module '%s'" %
- (attr, backend))
- return _VirtualFunctionWrapper(match_functions)
-
-
-class _VirtualFunctionWrapper(object):
- """Wrapper around a virtual dict of functions."""
-
- def __init__(self, backend_function_dict):
- self._backend_function_dict = backend_function_dict
-
- def __call__(self, *args, **kwargs):
- all_results = {
- backend: f(*args, **kwargs)
- for backend, f in self._backend_function_dict.items()
- }
- # Turn it into a named tuple so we get nice class-like access to it.
- results_tuple_class = collections.namedtuple("Results", all_results.keys())
- return _make_multi_result_class(results_tuple_class)(*all_results.values())
-
-
-def _recursive_check_same(result_ref, result_tgt, rtol=1e-6, atol=1e-6):
- same = True
- if not isinstance(result_tgt, type(result_ref)):
- raise ValueError("Types of the outputs must be the same, but have '{}' and "
- "'{}'".format(type(result_ref), type(result_tgt)))
- if isinstance(result_ref, dict):
- if result_ref.keys() != result_tgt.keys():
- raise ValueError("Outputs must have the same structure, but have '{}' and"
- " '{}'".format(result_ref.keys(), result_tgt.keys()))
- for key in result_ref.keys():
- same = same and _recursive_check_same(result_ref[key], result_tgt[key],
- rtol, atol)
- if not same:
- return False # no need to go further they are different
- elif isinstance(result_ref, list):
- if len(result_ref) != len(result_tgt):
- raise ValueError("Outputs must have the same structure, but have '{}' and"
- " '{}'".format(result_ref, result_tgt))
- for i in range(len(result_ref)):
- same = same and _recursive_check_same(result_ref[i], result_tgt[i], rtol,
- atol)
- if not same:
- return False # no need to go further they are different
- elif isinstance(result_ref, np.ndarray):
- if isinstance(result_ref.flat[0], np.floating):
- return np.allclose(result_ref, result_tgt, rtol=rtol, atol=atol)
+ # Try to resolve it as an attr on self._module.
+ if not hasattr(self._module, attr):
+ raise AttributeError(f"The compiled module does not have attr '{attr}'")
+ module_attr = getattr(self._module, attr)
+ if not hasattr(module_attr, "__call__"):
+ # e.g. traced_module.backend
+ return module_attr
else:
- return np.array_equal(result_ref, result_tgt)
- else:
- # this one need more checks
- return result_ref == result_tgt
- return same
-
-
-def _collect_disagreements_recursively(mr, rtol=1e-6, atol=1e-6):
- """Compare result structs recursively and search for disagreements.
-
- Args:
- mr: A MultiResults namedtuple where each entry corresponds to a backend set
- of results.
- rtol: The relative tolerance parameter.
- atol: The absolute tolerance parameter.
-
- Returns:
- An equivalent MultiResults where each entry is an array of result names
- that disagree.
- """
- has_disagreement = False
- disagreement_list = [list() for _ in mr]
- for i in range(len(mr)):
- result_ref = mr[i]
- for j in range(len(mr)):
- if i < j:
- continue # Don't check self and reverse comparisons
- result_tgt = mr[j]
- if not _recursive_check_same(result_ref, result_tgt, rtol, atol):
- has_disagreement = True
- disagreement_list[i].append(mr._fields[j])
- disagreements_tuple = collections.namedtuple("Disagreements", mr._fields)
- return has_disagreement, disagreements_tuple(*disagreement_list)
-
-
-def _collect_disagreements(mr, predicate):
- """Verifies that result structs.
-
- Args:
- mr: A MultiResults namedtuple where each entry corresponds to a backend set
- of results.
- predicate: A predicate function which takes (a, b) and returns whether they
- should be considered equivalent.
-
- Returns:
- An equivalent MultiResults where each entry is an array of result names
- that disagree.
- """
- has_disagreement = False
- disagreement_list = [list() for _ in mr]
- for i in range(len(mr)):
- result_ref = mr[i]
- for j in range(len(mr)):
- if i == j:
- continue # Don't check self.
- result_tgt = mr[j]
- if not predicate(result_ref, result_tgt):
- has_disagreement = True
- disagreement_list[i].append(mr._fields[j])
- disagreements_tuple = collections.namedtuple("Disagreements", mr._fields)
- return has_disagreement, disagreements_tuple(*disagreement_list)
-
-
-def _make_multi_result_class(named_tuple_class):
- """Makes a class that wraps a mapping of backend results."""
-
- class MultiResults(named_tuple_class):
- """Wraps a mapping of results."""
-
- def assert_all_close(self, rtol=1e-6, atol=1e-6):
- predicate = (lambda a, b: np.allclose(a, b, rtol=rtol, atol=atol))
- has_disagreement, disagreements = _collect_disagreements(self, predicate)
- assert not has_disagreement, ("Multiple backends disagree (%r):\n%r" %
- (disagreements, self))
- return self
-
- def assert_all_equal(self):
- predicate = np.array_equal
- has_disagreement, disagreements = _collect_disagreements(self, predicate)
- assert not has_disagreement, ("Multiple backends disagree (%r):\n%r" %
- (disagreements, self))
- return self
-
- def assert_all_close_and_equal(self, rtol=1e-6, atol=1e-6):
- # it is a special case when output can be a nestet map of dict(), list()
- # with different types: int, float, string
- # in this case int and string must be equal and for float we use rtol,atol
- has_disagreement, disagreements = _collect_disagreements_recursively(
- self, rtol, atol)
- assert not has_disagreement, ("Multiple backends disagree (%r):\n%r" %
- (disagreements, self))
- return self
-
- def print(self):
- print(self)
- return self
-
- def save(self):
- for i in range(len(self)):
- result = self[i] # output generated by a model
- field = self._fields[i] # backend name
- fname = os.path.join(global_debug_dir, "output_{}".format(field))
- with open(fname, "w") as file:
- # content of txt file can be converted to py objects by eval(txt)
- file.write(str(result))
- return self
-
- return MultiResults
-
-
-def _instantiate_backends(compiled_backends):
- """Creates a VirtualBackend namedtuple class for a dict.
-
- Args:
- compiled_backends: Dictionary of backend_name:ModuleInstance.
-
- Returns:
- a VirtualBackendsClass instance. The VirtualBackendsClass is a dynamically
- generated namedtuple mapping backend_name:ModuleInstance, where the
- ModuleInstance allows attribute resolution of public functions on the
- module. The VirtualBackendsClass also contributes some convenience methods
- for selecting all or a subset of matching backend modules.
- """
- tuple_class = collections.namedtuple("VirtualBackendsTuple",
- compiled_backends.keys())
-
- class VirtualBackendsClass(tuple_class):
- """Adds a __call__ method that creates a virtual module."""
-
- def multi(self, match_spec="."):
- """Selects multiple backends that match a regular expression."""
- return _VirtualModuleInstance(self._asdict(), match_spec)
-
- @property
- def all(self):
- """Shorthand for multi() which selects all backends."""
- return self.multi()
-
- reinitialized_modules = [
- module.create_reinitialized() for module in compiled_backends.values()
- ]
- return VirtualBackendsClass(*reinitialized_modules)
+ # e.g. traced_module.simple_mul(a, b)
+ return self._trace_call(module_attr, method_name=attr)
def compile_module(module_class, exported_names=()):
@@ -288,10 +379,10 @@
def decorator(cls):
"""Decorator Function."""
- if not issubclass(cls, CompiledModuleTestCase):
+ if not issubclass(cls, TracedModuleTestCase):
logging.exception(
"The 'compile_module' decorator must be applied to a "
- "CompiledModuleTestCase derived class, which %s is not.", cls)
+ "TracedModuleTestCase derived class, which %s is not.", cls)
cls._module_class = module_class
cls._exported_names = exported_names
return cls
@@ -299,90 +390,108 @@
return decorator
-def _parse_target_backends(target_backends):
- """Decodes a comma-delimited string of backends into BackendInfo objects."""
- backends = []
- for backend_name in target_backends.split(","):
- if backend_name not in tf_utils.BackendInfo.ALL.keys():
- raise ValueError(
- "Invalid backend specification string '{}', unexpected name '{}';"
- " valid names are '{}'".format(target_backends, backend_name,
- tf_utils.BackendInfo.ALL.keys()))
- backends.append(tf_utils.BackendInfo.ALL[backend_name])
- return backends
-
-
-def get_backends():
- """Gets the BackendInfo instances to test.
-
- By default all backends in BackendInfo will be used. Specific backends to
- run on can be specified using the `--target_backends` flag. If only "tf" is
- provided then it will be compared against itself.
-
- Returns:
- Sequence of BackendInfo that should be used.
- """
- if FLAGS.target_backends is not None:
- logging.info("Using backends from command line: %s", FLAGS.target_backends)
- backends = _parse_target_backends(FLAGS.target_backends)
- # If tf is the only backend then we will test it itself by adding tf_also.
- if len(backends) == 1 and "tf" == backends[0].name:
- backends.append(tf_utils.BackendInfo.ALL["tf_also"])
- else:
- # If no backends are specified, use them all.
- backends = list(tf_utils.BackendInfo.ALL.values())
- return backends
-
-
-class CompiledModuleTestCase(tf.test.TestCase):
+class TracedModuleTestCase(tf.test.TestCase):
"""Compiles a tf.Module to multiple backends to test their correctness."""
-
# Will be initialized by the @compile_module decorator.
_module_class = None
_exported_names = ()
- # Will be initialized in setUpClass to a dict of
- # {backend_name: CompiledModule}.
- _compiled_backends_dict = None
+ # Will be initialized in setUpClass.
+ _ref_module = None
+ _tar_modules = None
+
+ @classmethod
+ def _compile(cls, backend_info):
+ return backend_info.CompiledModule(cls._module_class, backend_info,
+ cls._exported_names, cls._artifacts_dir)
@classmethod
def setUpClass(cls):
+ # Ran before any of the unit tests.
super().setUpClass()
if cls._module_class is None:
raise AttributeError(
"setUpClass was called but no module was specified. Specify a module "
"to compile via the @tf_test_utils.compile_module decorator.")
- # Setup the debug directory for this test. Creates a global variable
- # `global_debug_dir`.
- _setup_test_debug_dir(test_name=cls.__name__)
+ # Setup the directory for saving compilation artifacts and traces.
+ cls._artifacts_dir = _setup_artifacts_dir(cls._module_class.__name__)
# Setup crash reproducer for the test.
- crash_reproducer_path = os.path.join(global_debug_dir, "reproducer.mlir")
+ crash_reproducer_path = os.path.join(cls._artifacts_dir, "reproducer.mlir")
compiler.Context.default_crash_reproducer_path = crash_reproducer_path
- # Create a CompiledModule for each backend.
+ # Create a CompiledModule for the reference backend and each target backend.
try:
- backends = get_backends()
- cls._compiled_backends_dict = {}
- for backend_info in backends:
- compiled_backend = backend_info.CompiledModule(cls._module_class,
- backend_info,
- cls._exported_names,
- global_debug_dir)
- cls._compiled_backends_dict[backend_info.name] = compiled_backend
+ ref_backend_info = tf_utils.BackendInfo.ALL[FLAGS.reference_backend]
+ cls._ref_module = cls._compile(ref_backend_info)
+
+ tar_backend_infos = get_target_backends()
+ cls._tar_modules = [
+ cls._compile(backend_info) for backend_info in tar_backend_infos
+ ]
finally:
+ # TODO(meadowlark): Move this into tf_util.compile_tf_module to prevent
+ # overwritting `reproducer.mlir`.
# Disable crash reproducer (to avoid inadvertently overwriting this
- # path on a subsequent interaction).
+ # path if there are multiple TestCases in the same file).
compiler.Context.default_crash_reproducer_path = None
+ def setUp(self):
+ # Ran before each unit test.
+ super().setUp()
+ self._ref_module.create_reinitialized()
+ self._tar_modules = [
+ module.create_reinitialized() for module in self._tar_modules
+ ]
+
+ def compare_backends(self, trace_function):
+ """Run the reference and target backends on trace_function and compare them.
+
+ Random seeds for tensorflow, numpy and python are set before each invocation
+ of trace_function.
+
+ Args:
+ trace_function: a function accepting a TracedModule as its argument.
+ """
+ # Create Traces for each backend.
+ ref_trace = Trace(self._ref_module, trace_function)
+ tar_traces = [Trace(module, trace_function) for module in self._tar_modules]
+
+ # Run the traces through trace_function with their associated modules.
+ tf_utils.set_random_seed()
+ trace_function(TracedModule(self._ref_module, ref_trace))
+ for module, trace in zip(self._tar_modules, tar_traces):
+ tf_utils.set_random_seed()
+ trace_function(TracedModule(module, trace))
+
+ # Compare each target trace of trace_function with the reference trace.
+ failed_backend_indices = []
+ for i, tar_trace in enumerate(tar_traces):
+ logging.info("Comparing the reference backend '%s' with '%s'",
+ ref_trace.backend, tar_trace.backend)
+ traces_match = Trace.compare_traces(ref_trace, tar_trace)
+ if not traces_match:
+ failed_backend_indices.append(i)
+
+ # Save the results to disk before validating.
+ ref_trace.save_plaintext(self._artifacts_dir, FLAGS.summarize)
+ for tar_trace in tar_traces:
+ tar_trace.save_plaintext(self._artifacts_dir, FLAGS.summarize)
+
+ # Validate results.
+ if failed_backend_indices:
+ # Extract info for logging.
+ failed_backends = [tar_traces[i].backend for i in failed_backend_indices]
+ failure_info = (
+ "Comparision between the reference backend and the following targets "
+ f"failed: {failed_backends}. The errors above show the inputs and "
+ "outputs the non-matching calls.")
+
+ # This condition is always True, but is useful for context in the logs.
+ self.assertEmpty(failed_backends, failure_info)
+
@classmethod
def tearDownClass(cls):
+ # Ran after all unit tests are completed.
super().tearDownClass()
-
- def setUp(self):
- super().setUp()
- self.compiled_modules = _instantiate_backends(self._compiled_backends_dict)
-
- def get_module(self):
- return self.compiled_modules.all
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
index 20ba522..f21521a 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
@@ -17,9 +17,32 @@
from absl.testing import parameterized
import numpy as np
from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
import tensorflow as tf
+class StatefulCountingModule(tf.Module):
+
+ def __init__(self):
+ self.count = tf.Variable([0.])
+
+ @tf.function(input_signature=[])
+ def increment(self):
+ self.count.assign_add(tf.constant([1.]))
+
+ @tf.function(input_signature=[])
+ def get_count(self):
+ return self.count
+
+ @tf.function(input_signature=[tf.TensorSpec([1])])
+ def increment_by(self, value):
+ self.count.assign_add(value)
+
+ @tf.function(input_signature=[])
+ def decrement(self):
+ self.count.assign_sub(tf.constant([1.]))
+
+
class UtilsTests(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters([
@@ -28,31 +51,31 @@
'array_c': np.array([0, 1, 2]),
'array_d': np.array(['0', '1', '2']),
'array_e': np.array([0.0, 0.1, 0.2]),
- 'tgt_same': True,
+ 'tar_same': True,
},
{
'testcase_name': 'wrong int',
'array_c': np.array([1, 1, 2]),
'array_d': np.array(['0', '1', '2']),
'array_e': np.array([0.0, 0.1, 0.2]),
- 'tgt_same': False,
+ 'tar_same': False,
},
{
'testcase_name': 'wrong string',
'array_c': np.array([0, 1, 2]),
'array_d': np.array(['a', '1', '2']),
'array_e': np.array([0.0, 0.1, 0.2]),
- 'tgt_same': False,
+ 'tar_same': False,
},
{
'testcase_name': 'wrong float',
'array_c': np.array([0, 1, 2]),
'array_d': np.array(['0', '1', '2']),
'array_e': np.array([1.0, 0.1, 0.2]),
- 'tgt_same': False,
+ 'tar_same': False,
},
])
- def test_recursive_check_same(self, array_c, array_d, array_e, tgt_same):
+ def test_recursive_check_same(self, array_c, array_d, array_e, tar_same):
ref = {
'a':
@@ -65,7 +88,7 @@
'e': np.array([0.0, 0.1, 0.2])
}],
}
- tgt = {
+ tar = {
'a': 1,
'b': [{
'c': array_c
@@ -75,8 +98,74 @@
'e': array_e
}],
}
- same = tf_test_utils._recursive_check_same(ref, tgt)
- self.assertEqual(tgt_same, same)
+ same = tf_test_utils.Trace._check_same(ref, tar, rtol=1e-6, atol=1e-6)
+ self.assertEqual(tar_same, same)
+
+ def test_trace_inputs_and_outputs(self):
+
+ def trace_function(module):
+ # No inputs or outpus
+ module.increment()
+ # Only inputs
+ module.increment_by(np.array([81.], dtype=np.float32))
+ # Only outputs
+ module.get_count()
+
+ module = tf_utils.TfCompiledModule(StatefulCountingModule,
+ tf_utils.BackendInfo.ALL['tf'])
+ trace = tf_test_utils.Trace(module, trace_function)
+ trace_function(tf_test_utils.TracedModule(module, trace))
+
+ self.assertIsInstance(trace.calls[0].inputs, tuple)
+ self.assertEmpty(trace.calls[0].inputs)
+ self.assertIsInstance(trace.calls[0].outputs, tuple)
+ self.assertEmpty(trace.calls[0].outputs)
+
+ self.assertAllClose(trace.calls[1].inputs[0], [81.])
+ self.assertAllClose(trace.calls[2].outputs[0], [82.])
+
+ def test_nonmatching_methods(self):
+
+ def tf_function(module):
+ module.increment()
+ module.increment()
+
+ def vmla_function(module):
+ module.increment()
+ module.decrement()
+
+ tf_module = tf_utils.TfCompiledModule(StatefulCountingModule,
+ tf_utils.BackendInfo.ALL['tf'])
+ tf_trace = tf_test_utils.Trace(tf_module, tf_function)
+ tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))
+
+ vmla_module = tf_utils.IreeCompiledModule(
+ StatefulCountingModule, tf_utils.BackendInfo.ALL['iree_vmla'])
+ vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
+ vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))
+
+ with self.assertRaises(ValueError):
+ tf_test_utils.Trace.compare_traces(tf_trace, vmla_trace)
+
+ def test_nonmatching_inputs(self):
+
+ def tf_function(module):
+ module.increment_by(np.array([42.], dtype=np.float32))
+
+ def vmla_function(module):
+ module.increment_by(np.array([22.], dtype=np.float32))
+
+ tf_module = tf_utils.TfCompiledModule(StatefulCountingModule,
+ tf_utils.BackendInfo.ALL['tf'])
+ tf_trace = tf_test_utils.Trace(tf_module, tf_function)
+ tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))
+
+ vmla_module = tf_utils.IreeCompiledModule(
+ StatefulCountingModule, tf_utils.BackendInfo.ALL['iree_vmla'])
+ vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
+ vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))
+
+ self.assertFalse(tf_test_utils.Trace.compare_traces(tf_trace, vmla_trace))
if __name__ == '__main__':
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index 46a3785..f67119e 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -29,6 +29,8 @@
from pyiree.tf import compiler
import tensorflow.compat.v2 as tf
+flags.DEFINE_bool("keep_saved_model", False,
+ "Keep the SavedModel used by compile_tf_module on disk.")
FLAGS = flags.FLAGS
@@ -39,6 +41,14 @@
np.random.seed(seed)
+def uniform(shape, dtype=np.float32):
+ return np.random.uniform(size=shape).astype(dtype)
+
+
+def ndarange(shape, dtype=np.float32):
+ return np.arange(np.prod(shape), dtype=dtype).reshape(shape)
+
+
def backends_to_str(target_backends):
"""Creates a flattened and normalized string representing target_backends."""
normalized_backends = []
@@ -49,6 +59,36 @@
return "__".join(normalized_backends)
+def to_mlir_type(dtype):
+ """Returns a string that denotes the type `dtype` in MLIR style."""
+ bits = dtype.itemsize * 8
+ if np.issubdtype(dtype, np.integer):
+ return f"i{bits}"
+ elif np.issubdtype(dtype, np.floating):
+ return f"f{bits}"
+ else:
+ raise TypeError(f"Expected integer or floating type, but got {dtype}")
+
+
+def save_input_values(inputs, artifacts_dir=None):
+ """Saves input values with IREE tools format if `artifacts_dir` is set."""
+ result = []
+ for array in inputs:
+ shape = [str(dim) for dim in list(array.shape)]
+ shape.append(to_mlir_type(array.dtype))
+ shape = "x".join(shape)
+ values = " ".join([str(x) for x in array.flatten()])
+ result.append(f"{shape}={values}")
+ result = "\n".join(result)
+ if artifacts_dir is not None:
+ inputs_path = os.path.join(artifacts_dir, "inputs.txt")
+ logging.info("Saving IREE input values to: %s", inputs_path)
+ with open(inputs_path, "w") as f:
+ f.write(result)
+ f.write("\n")
+ return result
+
+
def compile_tf_module(tf_module,
target_backends=(),
exported_names=(),
@@ -120,9 +160,14 @@
return compiled_module
options = tf.saved_model.SaveOptions(save_debug_info=True)
- if artifacts_dir is not None:
+ if artifacts_dir is not None and FLAGS.keep_saved_model:
# Save the saved model alongside the other compilation artifacts.
- sm_path = os.path.join(artifacts_dir, "saved_model")
+
+ # Create a saved model for these target backends to avoid a race condition
+ # when running a test suite.
+ # TODO(meadowlark): Remove this once we have a TfLiteCompiledModule.
+ sm_path = os.path.join(artifacts_dir,
+ f"saved_model__{backends_to_str(target_backends)}")
tf.saved_model.save(tf_module, sm_path, options=options)
return _compile_from_path(sm_path)
else:
@@ -142,6 +187,10 @@
self._exported_names = exported_names
self._artifacts_dir = artifacts_dir
+ # Public attributes:
+ self.backend = self._backend_info.name
+ self.module_name = self._module_class.__name__
+
def create_reinitialized(self):
"""Duplicates this module with its initial state without recompiling."""
raise NotImplementedError()
@@ -153,7 +202,7 @@
def __init__(self,
module_class,
backend_info,
- exported_names=[],
+ exported_names=(),
artifacts_dir=None,
_create_reinitialized_args=None):
"""Compile a tf.Module to the target backend in backend_info.
@@ -166,10 +215,12 @@
module_class's functions to compile. If exported_names is empty all
functions will be compiled.
artifacts_dir: an optional path to save compilation artifacts to.
+ _create_reinitialized_args: used internally.
"""
super().__init__(module_class, backend_info, exported_names, artifacts_dir)
if _create_reinitialized_args is None:
+ set_random_seed()
self._module_blob = compile_tf_module(
tf_module=module_class(),
target_backends=backend_info.iree_compiler_targets,
@@ -222,7 +273,7 @@
def __init__(self,
module_class,
backend_info,
- exported_names=[],
+ exported_names=(),
artifacts_dir=None):
"""Wrap a tf.Module in a TFCompiledModule facade.
@@ -236,6 +287,7 @@
effect for this subclass as nothing is compiled.
"""
super().__init__(module_class, backend_info, exported_names, artifacts_dir)
+ set_random_seed()
self._tf_module = module_class()
def create_reinitialized(self):
@@ -245,7 +297,7 @@
def __getattr__(self, attr):
# Try to resolve it as a function.
- exported = len(self._exported_names) == 0 or attr in self._exported_names
+ exported = not self._exported_names or attr in self._exported_names
if not hasattr(self._tf_module, attr) or not exported:
raise AttributeError(f"The TensorFlow module does not have attr '{attr}'")
f = getattr(self._tf_module, attr)
@@ -261,6 +313,13 @@
def __init__(self, f):
self._f = f
+ def _convert_to_numpy(self, tensor):
+ result = tensor.numpy()
+ if np.isscalar(result):
+ # convert_to_tensor isn't reversible via .numpy()
+ result = np.array(result)
+ return result
+
def __call__(self, *args, **kwargs):
# TensorFlow will auto-convert all inbound args.
results = self._f(*args, **kwargs)
@@ -270,7 +329,7 @@
if not isinstance(results, tuple):
results = (results,)
return tf.nest.map_structure(
- lambda t: t.numpy() if isinstance(t, tf.Tensor) else t,
+ lambda t: self._convert_to_numpy(t) if isinstance(t, tf.Tensor) else t,
*results,
check_types=False)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
index b1d9adb..cde0a08 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
@@ -21,6 +21,7 @@
from absl.testing import parameterized
from pyiree.tf.support import tf_utils
import tensorflow as tf
+import numpy as np
class ConstantModule(tf.Module):
@@ -65,7 +66,6 @@
artifacts_dir=artifacts_dir)
artifacts_to_check = [
- 'saved_model',
'tf_input.mlir',
'iree_input.mlir',
f'compiled__{tf_utils.backends_to_str(target_backends)}.vmfb',
@@ -100,6 +100,18 @@
# Test independent state.
self.assertEqual([1.], module.get_count())
+ def test_to_mlir_type(self):
+ self.assertEqual('i8', tf_utils.to_mlir_type(np.dtype('int8')))
+ self.assertEqual('i32', tf_utils.to_mlir_type(np.dtype('int32')))
+ self.assertEqual('f32', tf_utils.to_mlir_type(np.dtype('float32')))
+ self.assertEqual('f64', tf_utils.to_mlir_type(np.dtype('float64')))
+
+ def test_save_input_values(self):
+ inputs = [np.array([1, 2], dtype=np.int32)]
+ self.assertEqual('2xi32=1 2', tf_utils.save_input_values(inputs))
+ inputs = [np.array([1, 2], dtype=np.float32)]
+ self.assertEqual('2xf32=1.0 2.0', tf_utils.save_input_values(inputs))
+
if __name__ == '__main__':
tf.test.main()
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index d63b839..a3346e4 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -23,7 +23,6 @@
"INTREE_TENSORFLOW_PY_DEPS",
"NUMPY_DEPS",
"iree_py_binary",
- "iree_py_test",
)
load(
"//integrations/tensorflow/e2e:iree_e2e_test_suite.bzl",
@@ -54,7 +53,6 @@
# backends.
# keep sorted
SPECIAL_CASES = [
- "explicit_backend_test.py",
"linspace_test.py",
]
@@ -84,7 +82,6 @@
VULKAN_FAILING = [
"broadcasting_test.py",
"control_flow_test.py",
- "conv_test.py",
"depth_conv_test.py",
"dynamic_mlp_relu_test.py",
"dynamic_mlp_test.py",
@@ -93,7 +90,7 @@
"matrix_ops_test.py",
"ring_buffer_test.py", # TODO(b/148747011)
"scatter_update_test.py",
- "sliding_window_test.py", # TODO(#2659)
+ "sliding_window_test.py", # TODO(#2659) Failing on nvidia, passing on swiftshader.
"strings_test.py",
]
@@ -187,20 +184,3 @@
"//integrations/tensorflow/bindings/python/pyiree/tf/support",
],
)
-
-# This tests explicitly writing which backends to use in Python,
-# so overriding the backends can cause it to break.
-iree_py_test(
- name = "explicit_backend_test",
- srcs = ["explicit_backend_test.py"],
- main = "explicit_backend_test.py",
- python_version = "PY3",
- tags = [
- "driver=llvmjit",
- "driver=vmla",
- "driver=vulkan",
- ],
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index 04604d8..a213bda 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -21,7 +21,8 @@
`iree_vulkan` from the list of backends to run the tests on).
The test suites can be run excluding Vulkan by specifying
-`--test_tag_filters="-driver=vulkan"` in the `bazel test` invocation.
+`--test_tag_filters="-driver=vulkan"` in the `bazel test` invocation, or by
+adding `test --test_tag_filters="-driver=vulkan"` to your `user.bazelrc`.
## Compiling `tf.Module`s
@@ -44,6 +45,9 @@
vmla_module.predict(...)
```
+By default the TensorFlow SavedModels will not be kept. This can be overridden
+via the `--keep_saved_model` flag.
+
## Running tests
For locally running tests and iterating on backend development, `bazel run` is
@@ -53,11 +57,15 @@
# Run math_test on all backends.
bazel run :math_test_manual
-# Run math_test on the VMLA backend only.
+# Run math_test comparing TensorFlow to itself (e.g. to debug randomization).
+bazel run :math_test_manual -- target_backends=tf
+
+# Run math_test comparing the VMLA backend and TensorFlow.
bazel run :math_test_manual -- --target_backends=iree_vmla
-# Same as above, but add `tf` backend to cross-check numerical correctness.
-bazel run :math_test_manual -- --target_backends=tf,iree_vmla
+# Run math_test comparing the VMLA backend to itself multiple times.
+bazel run :math_test_manual -- \
+ --reference_backend=iree_vmla --target_backends=iree_vmla,iree_vmla
# Run math_test and output on failure.
bazel test :math_test_manual --test_output=errors
@@ -66,14 +74,43 @@
bazel run :math_test_manual -- --test_output=streamed
```
-If you specify the same backend multiple times, for example
-`--target_backends=iree_vmla,iree_vmla`. The same backends are grouped and in
-this example `iree_vmla` will run once. If you specify `tf,iree_vmla` as
-backends, then we will test both backends and compare them with each other. If
-you specify `tf` backend only, then we will also test `tf` vs `tf` to capture
-any model initialization/randomization issues (it is a special case for debug
-purpose). For reproducibility of the unit tests we set random seed of `tf` and
-`numpy` by calling `tf_utils.set_random_seed()` before model creation.
+For reproducibility of the unit tests `CompiledModule()` sets the random seeds
+of `tf`, `numpy` and `python` by calling `tf_utils.set_random_seed()` before
+model creation.
+
+## Writing Tests
+
+Our tests use a class `TracedModule` to capture and store all of the inputs and
+outputs of a `CompiledModule` in a `Trace`. Each unittest on a `TestCase` uses
+the `compare_backends` method. This method runs the function it is passed with a
+`TracedModule` once for each reference and target backend. The inputs and
+outputs to these modules are then checked for correctness, using the reference
+backend as a source of truth. For example:
+
+```python
+# Compile a `tf.Module` named `SimpleArithmeticModule` into a `CompiledModule`.
+@tf_test_utils.compile_module(SimpleArithmeticModule)
+# Inherit from `TracedModuleTestCase`.
+class SimpleArithmeticTest(tf_test_utils.TracedModuleTestCase):
+
+ # Unit test.
+ def test_simple_mul(self):
+
+ # Trace function.
+ def simple_mul(module):
+ # A random seed is automatically set before each call to `simple_mul`.
+ a = tf_utils.uniform([4])
+ b = np.array([400., 5., 6., 7.], dtype=np.float32)
+ # The inputs `a` and `b` are recorded along with the output `c`
+ c = module.simple_mul(a, b)
+ # The inputs `a` and `b` are recorded along with the (unnamed) output
+ # module.simple_mul returns.
+ module.simple_mul(a, c)
+
+ # Calls `simple_mul` once for each backend, recording the inputs and outputs
+ # to `module` and then comparing them.
+ self.compare_backends(simple_mul)
+```
## Test Suites
diff --git a/integrations/tensorflow/e2e/batch_norm_test.py b/integrations/tensorflow/e2e/batch_norm_test.py
index 75de16d..e16436d 100644
--- a/integrations/tensorflow/e2e/batch_norm_test.py
+++ b/integrations/tensorflow/e2e/batch_norm_test.py
@@ -16,6 +16,7 @@
import numpy as np
from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
import tensorflow.compat.v2 as tf
@@ -39,18 +40,20 @@
@tf_test_utils.compile_module(BatchNormModule)
-class BatchNormTest(tf_test_utils.CompiledModuleTestCase):
+class BatchNormTest(tf_test_utils.TracedModuleTestCase):
def test_batch_norm_inference(self):
- np.random.seed(12345)
- # Note: scaling by a small value to increase numerical stability.
- x = np.random.random((4, 16)).astype(np.float32) * 1e-3
- mean = np.random.random((16,)).astype(np.float32) * 1e-3
- variance = np.random.random((16,)).astype(np.float32) * 1e-3
- offset = np.random.random((16,)).astype(np.float32) * 1e-3
- scale = np.random.random((16,)).astype(np.float32) * 1e-3
- r = self.get_module().batch_norm_inference(x, mean, variance, offset, scale)
- r.print().assert_all_close()
+
+ def batch_norm_inference(module):
+ # Note: scaling by a small value to increase numerical stability.
+ x = tf_utils.uniform((4, 16)) * 1e-3
+ mean = tf_utils.uniform((16,)) * 1e-3
+ variance = tf_utils.uniform((16,)) * 1e-3
+ offset = tf_utils.uniform((16,)) * 1e-3
+ scale = tf_utils.uniform((16,)) * 1e-3
+ module.batch_norm_inference(x, mean, variance, offset, scale)
+
+ self.compare_backends(batch_norm_inference)
if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/broadcasting_test.py b/integrations/tensorflow/e2e/broadcasting_test.py
index 74880bd..c4f8e38 100644
--- a/integrations/tensorflow/e2e/broadcasting_test.py
+++ b/integrations/tensorflow/e2e/broadcasting_test.py
@@ -14,7 +14,9 @@
# limitations under the License.
"""Test broadcasting support."""
+import numpy as np
from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
import tensorflow.compat.v2 as tf
@@ -29,24 +31,35 @@
@tf_test_utils.compile_module(BroadcastingModule)
-class BroadcastingTest(tf_test_utils.CompiledModuleTestCase):
+class BroadcastingTest(tf_test_utils.TracedModuleTestCase):
def test_add_same_shape(self):
- m = self.get_module()
- dst = m.add(tf.random.uniform([4]), tf.random.uniform([4]))
- dst.print().assert_all_close()
+ def add_same_shape(module):
+ lhs = tf_utils.uniform([4])
+ rhs = tf_utils.uniform([4])
+ module.add(lhs, rhs)
-# TODO(silvasean): Make these work.
-# def test_add_broadcast_lhs(self):
-# m = self.get_module()
-# dst = m.add(tf.random.uniform([1]), tf.random.uniform([4]))
-# dst.print().assert_all_close()
-#
-# def test_add_broadcast_rhs(self):
-# m = self.get_module()
-# dst = m.add(tf.random.uniform([4]), tf.random.uniform([1]))
-# dst.print().assert_all_close()
+ self.compare_backends(add_same_shape)
+
+ def test_add_broadcast_lhs(self):
+
+ def add_broadcast_lhs(module):
+ lhs = tf_utils.uniform([1])
+ rhs = tf_utils.uniform([4])
+ module.add(lhs, rhs)
+
+ self.compare_backends(add_broadcast_lhs)
+
+ def test_add_broadcast_rhs(self):
+
+ def add_broadcast_rhs(module):
+ lhs = tf_utils.uniform([4])
+ rhs = tf_utils.uniform([1])
+ module.add(lhs, rhs)
+
+ self.compare_backends(add_broadcast_rhs)
+
if __name__ == "__main__":
if hasattr(tf, "enable_v2_behavior"):
diff --git a/integrations/tensorflow/e2e/complex_test.py b/integrations/tensorflow/e2e/complex_test.py
new file mode 100644
index 0000000..102396b
--- /dev/null
+++ b/integrations/tensorflow/e2e/complex_test.py
@@ -0,0 +1,51 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+import tensorflow.compat.v2 as tf
+
+
+class ComplexModule(tf.Module):
+
+ def __init__(self):
+ pass
+
+ @tf.function(input_signature=[
+ tf.TensorSpec([2], tf.float32),
+ tf.TensorSpec([2], tf.float32)
+ ])
+ def complex_exp(self, real, imag):
+ tensor = tf.complex(real, imag)
+ exp = tf.exp(tensor)
+ return tf.math.real(exp)
+
+
+@tf_test_utils.compile_module(ComplexModule)
+class ComplexTest(tf_test_utils.TracedModuleTestCase):
+
+ def test_complex(self):
+
+ def complex_exp(module):
+ real = np.array([2., 3.], dtype=np.float32)
+ imag = np.array([-1., 0.4], dtype=np.float32)
+ module.complex_exp(real, imag)
+
+ self.compare_backends(complex_exp)
+
+
+if __name__ == "__main__":
+ if hasattr(tf, "enable_v2_behavior"):
+ tf.enable_v2_behavior()
+ tf.test.main()
diff --git a/integrations/tensorflow/e2e/concat_test.py b/integrations/tensorflow/e2e/concat_test.py
index b7a348c..187c6cb 100644
--- a/integrations/tensorflow/e2e/concat_test.py
+++ b/integrations/tensorflow/e2e/concat_test.py
@@ -14,6 +14,7 @@
# limitations under the License.
"""Test concat op."""
+import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
import tensorflow.compat.v2 as tf
@@ -51,39 +52,43 @@
@tf_test_utils.compile_module(ConcatOpsModule)
-class ConcatOpsTest(tf_test_utils.CompiledModuleTestCase):
+class ConcatOpsTest(tf_test_utils.TracedModuleTestCase):
def test_concat_zero_dim(self):
- tf_utils.set_random_seed()
- m = self.get_module()
- a = tf.random.uniform([1, 5, 0], dtype=tf.float32)
- b = tf.random.uniform([1, 5, 1], dtype=tf.float32)
- dst = m.concat_zero_dim(a, b)
- dst.assert_all_close()
- def concat0axis(self):
- tf_utils.set_random_seed()
- m = self.get_module()
- a = tf.random.uniform([1, 5, 1], dtype=tf.float32)
- b = tf.random.uniform([1, 5, 1], dtype=tf.float32)
- dst = m.concat_zero_dim(a, b)
- dst.assert_all_close()
+ def concat_zero_dim(module):
+ a = tf_utils.uniform([1, 5, 0])
+ b = tf_utils.uniform([1, 5, 1])
+ module.concat_zero_dim(a, b)
- def concat1axis(self):
- tf_utils.set_random_seed()
- m = self.get_module()
- a = tf.random.uniform([1, 5, 1], dtype=tf.float32)
- b = tf.random.uniform([1, 5, 1], dtype=tf.float32)
- dst = m.concat_zero_dim(a, b)
- dst.assert_all_close()
+ self.compare_backends(concat_zero_dim)
- def concat2axis(self):
- tf_utils.set_random_seed()
- m = self.get_module()
- a = tf.random.uniform([1, 5, 1], dtype=tf.float32)
- b = tf.random.uniform([1, 5, 1], dtype=tf.float32)
- dst = m.concat_zero_dim(a, b)
- dst.assert_all_close()
+ def test_concat0axis(self):
+
+ def concat0axis(module):
+ a = tf_utils.uniform([1, 5, 1])
+ b = tf_utils.uniform([1, 5, 1])
+ module.concat0axis(a, b)
+
+ self.compare_backends(concat0axis)
+
+ def test_concat1axis(self):
+
+ def concat1axis(module):
+ a = tf_utils.uniform([1, 5, 1])
+ b = tf_utils.uniform([1, 5, 1])
+ module.concat1axis(a, b)
+
+ self.compare_backends(concat1axis)
+
+ def test_concat2axis(self):
+
+ def concat2axis(module):
+ a = tf_utils.uniform([1, 5, 1])
+ b = tf_utils.uniform([1, 5, 1])
+ module.concat2axis(a, b)
+
+ self.compare_backends(concat2axis)
if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/control_flow_test.py b/integrations/tensorflow/e2e/control_flow_test.py
index 0c25fd6..bc0a328 100644
--- a/integrations/tensorflow/e2e/control_flow_test.py
+++ b/integrations/tensorflow/e2e/control_flow_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import numpy
+import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -35,17 +35,23 @@
@tf_test_utils.compile_module(ControlFlowModule)
-class ControlFlowTest(tf_test_utils.CompiledModuleTestCase):
+class ControlFlowTest(tf_test_utils.TracedModuleTestCase):
def test_short_sequence(self):
- input_array = numpy.array(9., dtype=numpy.float32)
- result = self.get_module().collatz(input_array)
- result.print().assert_all_close()
+
+ def short_sequence(module):
+ input_array = np.array(9., dtype=np.float32)
+ module.collatz(input_array)
+
+ self.compare_backends(short_sequence)
def test_long_sequence(self):
- input_array = numpy.array(178., dtype=numpy.float32)
- result = self.get_module().collatz(input_array)
- result.print().assert_all_close()
+
+ def long_sequence(module):
+ input_array = np.array(178., dtype=np.float32)
+ module.collatz(input_array)
+
+ self.compare_backends(long_sequence)
if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/conv_test.py b/integrations/tensorflow/e2e/conv_test.py
index 61c46cf..9346a52 100644
--- a/integrations/tensorflow/e2e/conv_test.py
+++ b/integrations/tensorflow/e2e/conv_test.py
@@ -15,6 +15,7 @@
import numpy as np
from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
import tensorflow.compat.v2 as tf
@@ -99,73 +100,109 @@
@tf_test_utils.compile_module(Conv2dModule)
-class ConvTest(tf_test_utils.CompiledModuleTestCase):
+class ConvTest(tf_test_utils.TracedModuleTestCase):
def test_id_batch_size_1(self):
- i = np.arange(20, dtype=np.float32).reshape([1, 4, 5, 1])
- k = np.ones([1, 1, 1, 1], dtype=np.float32)
- r = self.get_module().conv2d_1451x1111_valid(i, k)
- r.print().assert_all_close()
+
+ def id_batch_size_1(module):
+ i = tf_utils.ndarange([1, 4, 5, 1])
+ k = np.ones([1, 1, 1, 1], dtype=np.float32)
+ module.conv2d_1451x1111_valid(i, k)
+
+ self.compare_backends(id_batch_size_1)
def test_id_batch_size_2(self):
- i = np.arange(40, dtype=np.float32).reshape([2, 4, 5, 1])
- k = np.ones([1, 1, 1, 1], dtype=np.float32)
- r = self.get_module().conv2d_2451x1111_valid(i, k)
- r.print().assert_all_close()
- def test_asym_kernel(self):
- i = np.arange(20, dtype=np.float32).reshape([1, 4, 5, 1])
- k = np.array([[1, 4, 2], [-2, 0, 1]], dtype=np.float32).reshape(2, 3, 1, 1)
- r = self.get_module().conv2d_1451x2311_valid(i, k)
- r.print().assert_all_close()
+ def id_batch_size_2(module):
+ i = tf_utils.ndarange([2, 4, 5, 1])
+ k = np.ones([1, 1, 1, 1], dtype=np.float32)
+ module.conv2d_2451x1111_valid(i, k)
+
+ self.compare_backends(id_batch_size_2)
+
+ def test_asymmetric_kernel(self):
+
+ def asymmetric_kernel(module):
+ i = tf_utils.ndarange([1, 4, 5, 1])
+ k = np.array([[1, 4, 2], [-2, 0, 1]],
+ dtype=np.float32).reshape(2, 3, 1, 1)
+ module.conv2d_1451x2311_valid(i, k)
+
+ self.compare_backends(asymmetric_kernel)
def test_padding(self):
- i = np.arange(20, dtype=np.float32).reshape([1, 4, 5, 1])
- k = np.array([[1, 4, 2], [-2, 0, 1]], dtype=np.float32).reshape(2, 3, 1, 1)
- r = self.get_module().conv2d_1451x2311_same(i, k)
- r.print().assert_all_close()
+
+ def padding(module):
+ i = tf_utils.ndarange([1, 4, 5, 1])
+ k = np.array([[1, 4, 2], [-2, 0, 1]],
+ dtype=np.float32).reshape(2, 3, 1, 1)
+ module.conv2d_1451x2311_same(i, k)
+
+ self.compare_backends(padding)
def test_batched_padding(self):
- i = np.arange(40, dtype=np.float32).reshape([2, 4, 5, 1])
- k = np.array([[1, 4, 2], [-2, 0, 1]], dtype=np.float32).reshape(2, 3, 1, 1)
- r = self.get_module().conv2d_2451x2311_same(i, k)
- r.print().assert_all_close()
+
+ def batched_padding(module):
+ i = tf_utils.ndarange([2, 4, 5, 1])
+ k = np.array([[1, 4, 2], [-2, 0, 1]],
+ dtype=np.float32).reshape(2, 3, 1, 1)
+ module.conv2d_2451x2311_same(i, k)
+
+ self.compare_backends(batched_padding)
def test_feature_reduce(self):
- i = np.arange(40, dtype=np.float32).reshape([1, 4, 5, 2])
- k = np.ones([3, 2, 2, 1], dtype=np.float32)
- r = self.get_module().conv2d_1452x3221_same(i, k)
- r.print().assert_all_close()
+
+ def feature_reduce(module):
+ i = tf_utils.ndarange([1, 4, 5, 2])
+ k = np.ones([3, 2, 2, 1], dtype=np.float32)
+ module.conv2d_1452x3221_same(i, k)
+
+ self.compare_backends(feature_reduce)
def test_feature_inflate(self):
- i = np.arange(20, dtype=np.float32).reshape([1, 4, 5, 1])
- k = np.arange(2, dtype=np.float32).reshape([1, 1, 1, 2])
- r = self.get_module().conv2d_1451x1112_same(i, k)
- r.print().assert_all_close()
+
+ def feature_inflate(module):
+ i = tf_utils.ndarange([1, 4, 5, 1])
+ k = tf_utils.ndarange([1, 1, 1, 2])
+ module.conv2d_1451x1112_same(i, k)
+
+ self.compare_backends(feature_inflate)
def test_feature_mix(self):
- i = np.arange(40, dtype=np.float32).reshape([1, 4, 5, 2])
- k = np.arange(4, dtype=np.float32).reshape([1, 1, 2, 2])
- r = self.get_module().conv2d_1452x1122_same(i, k)
- r.print().assert_all_close()
+
+ def feature_mix(module):
+ i = tf_utils.ndarange([1, 4, 5, 2])
+ k = tf_utils.ndarange([1, 1, 2, 2])
+ module.conv2d_1452x1122_same(i, k)
+
+ self.compare_backends(feature_mix)
def test_feature_padded(self):
- i = np.arange(40, dtype=np.float32).reshape([1, 4, 5, 2])
- k = np.arange(24, dtype=np.float32).reshape([2, 2, 2, 3])
- r = self.get_module().conv2d_1452x2223_same(i, k)
- r.print().assert_all_close()
+
+ def feature_padded(module):
+ i = tf_utils.ndarange([1, 4, 5, 2])
+ k = tf_utils.ndarange([2, 2, 2, 3])
+ module.conv2d_1452x2223_same(i, k)
+
+ self.compare_backends(feature_padded)
def test_feature_unpadded(self):
- i = np.arange(40, dtype=np.float32).reshape([1, 4, 5, 2])
- k = np.arange(24, dtype=np.float32).reshape([2, 2, 2, 3])
- r = self.get_module().conv2d_1452x2223_valid(i, k)
- r.print().assert_all_close()
+
+ def feature_unpadded(module):
+ i = tf_utils.ndarange([1, 4, 5, 2])
+ k = tf_utils.ndarange([2, 2, 2, 3])
+ module.conv2d_1452x2223_valid(i, k)
+
+ self.compare_backends(feature_unpadded)
def test_batched_feature_unpadded(self):
- i = np.arange(80, dtype=np.float32).reshape([2, 4, 5, 2])
- k = np.arange(24, dtype=np.float32).reshape([2, 2, 2, 3])
- r = self.get_module().conv2d_2452x2223_valid(i, k)
- r.print().assert_all_close()
+
+ def batched_feature_unpadded(module):
+ i = tf_utils.ndarange([2, 4, 5, 2])
+ k = tf_utils.ndarange([2, 2, 2, 3])
+ module.conv2d_2452x2223_valid(i, k)
+
+ self.compare_backends(batched_feature_unpadded)
if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/depth_conv_test.py b/integrations/tensorflow/e2e/depth_conv_test.py
index 1e8a002..9d88bee 100644
--- a/integrations/tensorflow/e2e/depth_conv_test.py
+++ b/integrations/tensorflow/e2e/depth_conv_test.py
@@ -15,6 +15,7 @@
import numpy as np
from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
import tensorflow.compat.v2 as tf
@@ -39,19 +40,25 @@
@tf_test_utils.compile_module(Conv2dModule)
-class ConvTest(tf_test_utils.CompiledModuleTestCase):
+class ConvTest(tf_test_utils.TracedModuleTestCase):
def test_batched_feature_unpadded(self):
- i = np.arange(80, dtype=np.float32).reshape([2, 4, 5, 2])
- k = np.arange(24, dtype=np.float32).reshape([2, 2, 2, 3])
- r = self.get_module().conv2d_2452x2223_valid(i, k)
- r.print().assert_all_close()
- def test_batched_feature_unpadded_smae(self):
- i = np.arange(80, dtype=np.float32).reshape([2, 4, 5, 2])
- k = np.arange(48, dtype=np.float32).reshape([2, 4, 2, 3])
- r = self.get_module().conv2d_2452x2223_same(i, k)
- r.print().assert_all_close()
+ def batched_feature_unpadded(module):
+ i = tf_utils.ndarange([2, 4, 5, 2])
+ k = tf_utils.ndarange([2, 2, 2, 3])
+ module.conv2d_2452x2223_valid(i, k)
+
+ self.compare_backends(batched_feature_unpadded)
+
+ def test_batched_feature_unpadded_same(self):
+
+ def batched_feature_unpadded_same(module):
+ i = tf_utils.ndarange([2, 4, 5, 2])
+ k = tf_utils.ndarange([2, 4, 2, 3])
+ module.conv2d_2452x2223_same(i, k)
+
+ self.compare_backends(batched_feature_unpadded_same)
if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py b/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
index 64c51e9..eec6815 100644
--- a/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
+++ b/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
@@ -28,7 +28,7 @@
CLASSES = 10
-class Mlp(tf.Module):
+class MlpRelu(tf.Module):
def __init__(self,
hidden_1_dim=256,
@@ -65,14 +65,16 @@
return tf.nn.softmax(self.mlp(x))
-@tf_test_utils.compile_module(Mlp, exported_names=["predict"])
-class DynamicMlpTest(tf_test_utils.CompiledModuleTestCase):
+@tf_test_utils.compile_module(MlpRelu, exported_names=["predict"])
+class DynamicMlpReluTest(tf_test_utils.TracedModuleTestCase):
def test_dynamic_batch(self):
- m = self.get_module()
- np.random.seed(12345)
- x = np.random.random([3, 28 * 28]).astype(np.float32) * 1e-3
- m.predict(x).print().assert_all_close()
+
+ def dynamic_batch(module):
+ x = tf_utils.uniform([3, 28 * 28]) * 1e-3
+ module.predict(x)
+
+ self.compare_backends(dynamic_batch)
if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/dynamic_mlp_test.py b/integrations/tensorflow/e2e/dynamic_mlp_test.py
index 72d7f1f..0b70e84 100644
--- a/integrations/tensorflow/e2e/dynamic_mlp_test.py
+++ b/integrations/tensorflow/e2e/dynamic_mlp_test.py
@@ -62,13 +62,15 @@
@tf_test_utils.compile_module(Mlp, exported_names=["predict"])
-class DynamicMlpTest(tf_test_utils.CompiledModuleTestCase):
+class DynamicMlpTest(tf_test_utils.TracedModuleTestCase):
def test_dynamic_batch(self):
- m = self.get_module()
- np.random.seed(12345)
- x = np.random.random([3, 28 * 28]).astype(np.float32) * 1e-3
- m.predict(x).print().assert_all_close()
+
+ def dynamic_batch(module):
+ x = tf_utils.uniform([3, 28 * 28]) * 1e-3
+ module.predict(x)
+
+ self.compare_backends(dynamic_batch)
if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/explicit_backend_test.py b/integrations/tensorflow/e2e/explicit_backend_test.py
deleted file mode 100644
index bdcdd79..0000000
--- a/integrations/tensorflow/e2e/explicit_backend_test.py
+++ /dev/null
@@ -1,74 +0,0 @@
-# Lint as: python3
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Tests explicitly specifying a backend in Python."""
-
-import numpy as np
-from pyiree.tf.support import tf_test_utils
-import tensorflow.compat.v2 as tf
-
-
-class SimpleArithmeticModule(tf.Module):
-
- @tf.function(input_signature=[
- tf.TensorSpec([4], tf.float32),
- tf.TensorSpec([4], tf.float32)
- ])
- def simple_mul(self, a, b):
- return a * b
-
-
-@tf_test_utils.compile_module(SimpleArithmeticModule)
-class ExplicitBackendTest(tf_test_utils.CompiledModuleTestCase):
-
- def test_explicit(self):
- a = np.array([1., 2., 3., 4.], dtype=np.float32)
- b = np.array([400., 5., 6., 7.], dtype=np.float32)
-
- # Demonstrates simple, one by one invocation of functions against
- # different explicit backends. Individual backends can be accessed off of
- # the module by name ('tf', 'iree_vmla' below).
- tf_c = self.compiled_modules.tf.simple_mul(a, b)
- print("TF Result:", tf_c)
- iree_c = self.compiled_modules.iree_vmla.simple_mul(a, b)
- print("IREE Result:", iree_c)
- self.assertAllClose(tf_c, iree_c)
-
- def test_multi(self):
- a = np.array([1., 2., 3., 4.], dtype=np.float32)
- b = np.array([400., 5., 6., 7.], dtype=np.float32)
-
- # Evaluating against multiple backends can be done with the multi() method,
- # which takes a regex string matching backend names. This also returns a
- # MultiResults tuple with actual results keyed by backend name. These also
- # have convenience methods like print() and assert_all_close().
- vmod = self.compiled_modules.multi("tf|iree")
- r = vmod.simple_mul(a, b)
- r.print().assert_all_close()
-
- def test_get_module(self):
- a = np.array([1., 2., 3., 4.], dtype=np.float32)
- b = np.array([400., 5., 6., 7.], dtype=np.float32)
-
- # Evaluating against all backends can be done with self.get_module(). This
- # also returns a MultiResults tuple with actual results keyed by backend
- # name.
- r = self.get_module().simple_mul(a, b)
- r.print().assert_all_close()
-
-
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
- tf.enable_v2_behavior()
- tf.test.main()
diff --git a/integrations/tensorflow/e2e/fill_test.py b/integrations/tensorflow/e2e/fill_test.py
index 8d912a4..050eb47 100644
--- a/integrations/tensorflow/e2e/fill_test.py
+++ b/integrations/tensorflow/e2e/fill_test.py
@@ -31,14 +31,16 @@
@tf_test_utils.compile_module(FillModule)
-class FillTest(tf_test_utils.CompiledModuleTestCase):
+class FillTest(tf_test_utils.TracedModuleTestCase):
def test_fill(self):
- dims = np.array([2, 3], dtype=np.int32)
- value = np.array(9., dtype=np.float32)
- result = self.get_module().fill(dims, value)
- result.assert_all_close()
+ def fill(module):
+ dims = np.array([2, 3], dtype=np.int32)
+ value = np.array(9., dtype=np.float32)
+ module.fill(dims, value)
+
+ self.compare_backends(fill)
if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/gather_test.py b/integrations/tensorflow/e2e/gather_test.py
index 67f5acf..7ae7af6 100644
--- a/integrations/tensorflow/e2e/gather_test.py
+++ b/integrations/tensorflow/e2e/gather_test.py
@@ -14,6 +14,7 @@
import numpy as np
from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
import tensorflow.compat.v2 as tf
@@ -49,31 +50,43 @@
@tf_test_utils.compile_module(GatherModule)
-class GatherTest(tf_test_utils.CompiledModuleTestCase):
+class GatherTest(tf_test_utils.TracedModuleTestCase):
def test_gather_axis0_scalar(self):
- indices = np.array(2, dtype=np.int32)
- params = np.arange(32, dtype=np.float32).reshape(4, 8)
- result = self.get_module().gather_axis0_scalar(params, indices)
- result.print().assert_all_close()
+
+ def gather_axis0_scalar(module):
+ indices = np.array(2, dtype=np.int32)
+ params = tf_utils.ndarange([4, 8])
+ module.gather_axis0_scalar(params, indices)
+
+ self.compare_backends(gather_axis0_scalar)
def test_gather_axis0_batch0(self):
- indices = np.array([2, 3], dtype=np.int32)
- params = np.arange(32, dtype=np.float32).reshape(4, 8)
- result = self.get_module().gather_axis0_batch0(params, indices)
- result.print().assert_all_close()
+
+ def gather_axis0_batch0(module):
+ indices = np.array([2, 3], dtype=np.int32)
+ params = tf_utils.ndarange([4, 8])
+ module.gather_axis0_batch0(params, indices)
+
+ self.compare_backends(gather_axis0_batch0)
def test_gather_axis1_batch0(self):
- indices = np.array([2, 3], dtype=np.int32)
- params = np.arange(4 * 7 * 8, dtype=np.float32).reshape(4, 7, 8)
- result = self.get_module().gather_axis1_batch0(params, indices)
- result.print().assert_all_close()
+
+ def gather_axis1_batch0(module):
+ indices = np.array([2, 3], dtype=np.int32)
+ params = tf_utils.ndarange([4, 7, 8])
+ module.gather_axis1_batch0(params, indices)
+
+ self.compare_backends(gather_axis1_batch0)
def test_gather_axis2_batch1(self):
- indices = np.array([[2], [3], [0], [1]], dtype=np.int32)
- params = np.arange(4 * 7 * 8 * 2, dtype=np.float32).reshape(4, 7, 8, 2)
- result = self.get_module().gather_axis2_batch1(params, indices)
- result.print().assert_all_close()
+
+ def gather_axis2_batch1(module):
+ indices = np.array([[2], [3], [0], [1]], dtype=np.int32)
+ params = tf_utils.ndarange([4, 7, 8, 2])
+ module.gather_axis2_batch1(params, indices)
+
+ self.compare_backends(gather_axis2_batch1)
if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl b/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
index bd80fbc..60bbff4 100644
--- a/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
+++ b/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
@@ -55,7 +55,8 @@
backend,
)
args = [
- "--target_backends={},{}".format(reference_backend, backend),
+ "--reference_backend={}".format(reference_backend),
+ "--target_backends={}".format(backend),
]
# TODO(GH-2175): Simplify this after backend names are standardized.
diff --git a/integrations/tensorflow/e2e/keras/BUILD b/integrations/tensorflow/e2e/keras/BUILD
index 16ddde9..7d1f1d3 100644
--- a/integrations/tensorflow/e2e/keras/BUILD
+++ b/integrations/tensorflow/e2e/keras/BUILD
@@ -185,11 +185,21 @@
# interpreter and vulkan backends for these tests.
{
"models": [
- "MobileNet",
"MobileNetV2",
],
"datasets": [
"cifar10",
+ ],
+ "backends": [
+ "iree_vulkan",
+ ],
+ },
+ {
+ "models": [
+ "MobileNet",
+ "MobileNetV2",
+ ],
+ "datasets": [
"imagenet",
],
"backends": [
diff --git a/integrations/tensorflow/e2e/keras/iree_vision_test_suite.bzl b/integrations/tensorflow/e2e/keras/iree_vision_test_suite.bzl
index a8d4f29..d286948 100644
--- a/integrations/tensorflow/e2e/keras/iree_vision_test_suite.bzl
+++ b/integrations/tensorflow/e2e/keras/iree_vision_test_suite.bzl
@@ -107,7 +107,8 @@
"--model={}".format(model),
"--data={}".format(dataset),
"--include_top=1",
- "--target_backends={},{}".format(reference_backend, backend),
+ "--reference_backend={}".format(reference_backend),
+ "--target_backends={}".format(backend),
]
if external_weights:
args.append("--url={}".format(external_weights))
diff --git a/integrations/tensorflow/e2e/keras/lstm_static_test.py b/integrations/tensorflow/e2e/keras/lstm_static_test.py
index fb7a58c..64db229 100644
--- a/integrations/tensorflow/e2e/keras/lstm_static_test.py
+++ b/integrations/tensorflow/e2e/keras/lstm_static_test.py
@@ -21,9 +21,9 @@
from pyiree.tf.support import tf_utils
import tensorflow.compat.v2 as tf
-NUM_UNITS = 10
-NUM_TIMESTEPS = 24
NUM_BATCH = 7
+NUM_TIMESTEPS = 24
+NUM_UNITS = 10
INPUT_SHAPE = [NUM_BATCH, NUM_TIMESTEPS, NUM_UNITS]
@@ -43,17 +43,15 @@
@tf_test_utils.compile_module(LstmStatic, exported_names=["predict"])
-class LstmTest(tf_test_utils.CompiledModuleTestCase):
+class LstmStaticTest(tf_test_utils.TracedModuleTestCase):
def test_lstm(self):
- m = self.get_module()
- m.predict(
- tf.constant(
- np.arange(NUM_BATCH * NUM_TIMESTEPS * NUM_UNITS,
- dtype=np.float32).reshape(
- [NUM_BATCH, NUM_TIMESTEPS, NUM_UNITS]),
- shape=[NUM_BATCH, NUM_TIMESTEPS,
- NUM_UNITS])).print().assert_all_close(1e-5, 1e-5)
+
+ def predict(module):
+ inputs = tf_utils.ndarange(INPUT_SHAPE)
+ module.predict(inputs, rtol=1e-5, atol=1e-5)
+
+ self.compare_backends(predict)
if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/keras/lstm_test.py b/integrations/tensorflow/e2e/keras/lstm_test.py
index 9409d04..8232d19 100644
--- a/integrations/tensorflow/e2e/keras/lstm_test.py
+++ b/integrations/tensorflow/e2e/keras/lstm_test.py
@@ -18,10 +18,11 @@
from pyiree.tf.support import tf_utils
import tensorflow.compat.v2 as tf
-NUM_UNITS = 10
-NUM_TIMESTEPS = 24
NUM_BATCH = 7
-INPUT_SHAPE = [None, None, NUM_UNITS]
+NUM_TIMESTEPS = 24
+NUM_UNITS = 10
+DYNAMIC_SHAPE = [None, None, NUM_UNITS]
+INPUT_SHAPE = [NUM_BATCH, NUM_TIMESTEPS, NUM_UNITS]
class Lstm(tf.Module):
@@ -29,28 +30,26 @@
def __init__(self):
super(Lstm, self).__init__()
tf_utils.set_random_seed()
- inputs = tf.keras.layers.Input(batch_size=None, shape=INPUT_SHAPE[1:])
+ inputs = tf.keras.layers.Input(batch_size=None, shape=DYNAMIC_SHAPE[1:])
outputs = tf.keras.layers.LSTM(
units=NUM_UNITS, return_sequences=True)(
inputs)
self.m = tf.keras.Model(inputs, outputs)
self.predict = tf.function(
- input_signature=[tf.TensorSpec(INPUT_SHAPE, tf.float32)])(
+ input_signature=[tf.TensorSpec(DYNAMIC_SHAPE, tf.float32)])(
self.m.call)
@tf_test_utils.compile_module(Lstm, exported_names=["predict"])
-class LstmTest(tf_test_utils.CompiledModuleTestCase):
+class LstmTest(tf_test_utils.TracedModuleTestCase):
def test_lstm(self):
- m = self.get_module()
- m.predict(
- tf.constant(
- np.arange(NUM_BATCH * NUM_TIMESTEPS * NUM_UNITS,
- dtype=np.float32).reshape(
- [NUM_BATCH, NUM_TIMESTEPS, NUM_UNITS]),
- shape=[NUM_BATCH, NUM_TIMESTEPS,
- NUM_UNITS])).print().assert_all_close()
+
+ def predict(module):
+ inputs = tf_utils.ndarange(INPUT_SHAPE)
+ module.predict(inputs)
+
+ self.compare_backends(predict)
if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/keras/train/model_train_test.py b/integrations/tensorflow/e2e/keras/train/model_train_test.py
index e30bd57..75286ad 100644
--- a/integrations/tensorflow/e2e/keras/train/model_train_test.py
+++ b/integrations/tensorflow/e2e/keras/train/model_train_test.py
@@ -65,7 +65,7 @@
tf.TensorSpec(_INPUT_DATA_SHAPE, tf.float32),
tf.TensorSpec(_OUTPUT_DATA_SHAPE, tf.float32)
])
- def TrainStep(self, inputs, targets):
+ def train_step(self, inputs, targets):
with tf.GradientTape() as tape:
predictions = self.model(inputs, training=True)
loss_value = self.loss(predictions, targets)
@@ -76,8 +76,8 @@
@tf_test_utils.compile_module(
- ModelTrain.CreateModule, exported_names=["TrainStep"])
-class ModelTrainTest(tf_test_utils.CompiledModuleTestCase):
+ ModelTrain.CreateModule, exported_names=["train_step"])
+class ModelTrainTest(tf_test_utils.TracedModuleTestCase):
def generate_regression_data(self, size=8):
x = np.arange(size) - size // 2
@@ -86,22 +86,25 @@
def test_model_train(self):
- # generate input and output data for regression problem
+ # Generate input and output data for regression problem.
inputs, targets = self.generate_regression_data()
- # normalize data
+ # Normalize data.
inputs = inputs / max(inputs)
targets = targets / max(targets)
- # generate polynomial features
+ # Generate polynomial features.
inputs = np.expand_dims(inputs, axis=1)
polynomial = PolynomialFeatures(_DEGREE) # returns: [1, a, b, a^2, ab, b^2]
inputs = polynomial.fit_transform(inputs)
targets = np.expand_dims(targets, axis=1)
- # run one iteration of training step
- result = self.get_module().TrainStep(inputs, targets)
- result.print().assert_all_close()
+
+ def train_step(module):
+ # Run one iteration of training step.
+ module.train_step(inputs, targets)
+
+ self.compare_backends(train_step)
if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/keras/vision_model_test.py b/integrations/tensorflow/e2e/keras/vision_model_test.py
index 54374ff..7c4b90e 100644
--- a/integrations/tensorflow/e2e/keras/vision_model_test.py
+++ b/integrations/tensorflow/e2e/keras/vision_model_test.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test all applications models in Keras."""
+
import os
from absl import app
@@ -133,11 +134,14 @@
@tf_test_utils.compile_module(VisionModule, exported_names=['predict'])
-class AppTest(tf_test_utils.CompiledModuleTestCase):
+class AppTest(tf_test_utils.TracedModuleTestCase):
def test_application(self):
- input_data = np.random.rand(*get_input_shape()).astype(np.float32)
- self.get_module().predict(input_data).print().assert_all_close(atol=1e-6)
+
+ def predict(module):
+ module.predict(tf_utils.uniform(get_input_shape()))
+
+ self.compare_backends(predict)
def main(argv):
@@ -148,7 +152,7 @@
if FLAGS.model not in APP_MODELS:
raise ValueError(f'Unsupported model: {FLAGS.model}')
# Override VisionModule's __name__ to be more specific.
- VisionModule.__name__ = FLAGS.model
+ VisionModule.__name__ = os.path.join(FLAGS.model, FLAGS.data)
tf.test.main()
diff --git a/integrations/tensorflow/e2e/linspace_test.py b/integrations/tensorflow/e2e/linspace_test.py
index aa49f5b..b535021 100644
--- a/integrations/tensorflow/e2e/linspace_test.py
+++ b/integrations/tensorflow/e2e/linspace_test.py
@@ -34,14 +34,16 @@
@tf_test_utils.compile_module(LinSpaceModule)
-class LinspaceTest(tf_test_utils.CompiledModuleTestCase):
+class LinspaceTest(tf_test_utils.TracedModuleTestCase):
def test_linspace(self):
- start = np.array(10., dtype=np.float32)
- stop = np.array(12., dtype=np.float32)
- result = self.get_module().linspace(start, stop)
- result.assert_all_close()
+ def linspace(module):
+ start = np.array(10., dtype=np.float32)
+ stop = np.array(12., dtype=np.float32)
+ module.linspace(start, stop)
+
+ self.compare_backends(linspace)
if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/mandelbrot_test.py b/integrations/tensorflow/e2e/mandelbrot_test.py
index 2b3a8d9..b5a3929 100644
--- a/integrations/tensorflow/e2e/mandelbrot_test.py
+++ b/integrations/tensorflow/e2e/mandelbrot_test.py
@@ -91,19 +91,17 @@
@tf_test_utils.compile_module(MandelbrotModule)
-class MandelbrotTest(tf_test_utils.CompiledModuleTestCase):
+class MandelbrotTest(tf_test_utils.TracedModuleTestCase):
def test_mandelbrot(self):
- mandelbrot = self.get_module()
- # Basic view of the entire set.
- pixels = mandelbrot.calculate(-0.7, 0.0, 3.0, 400, 100)
- pixels.assert_all_close()
+ def mandelbrot(module):
+ # Basic view of the entire set.
+ module.calculate(-0.7, 0.0, 3.0, 400, 100)
+ # This is a much more detailed view, so more iterations are needed.
+ module.calculate(-0.7436447860, 0.1318252536, 0.0000029336, 400, 3000)
- # This is a much more detailed view, so more iterations are needed.
- pixels = mandelbrot.calculate(-0.7436447860, 0.1318252536, 0.0000029336,
- 400, 3000)
- pixels.assert_all_close()
+ self.compare_backends(mandelbrot)
if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/math_test.py b/integrations/tensorflow/e2e/math_test.py
index a33ac7c..3f1c8be 100644
--- a/integrations/tensorflow/e2e/math_test.py
+++ b/integrations/tensorflow/e2e/math_test.py
@@ -39,27 +39,35 @@
@tf_test_utils.compile_module(MathModule)
-class MathTest(tf_test_utils.CompiledModuleTestCase):
+class MathTest(tf_test_utils.TracedModuleTestCase):
def test_abs(self):
- a = np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32)
- r = self.get_module().abs(a)
- r.print().assert_all_close()
+
+ def abs(module):
+ module.abs(np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32))
+
+ self.compare_backends(abs)
def test_cos(self):
- a = np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32)
- r = self.get_module().cos(a)
- r.print().assert_all_close()
+
+ def cos(module):
+ module.cos(np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32))
+
+ self.compare_backends(cos)
def test_log(self):
- a = np.array([0.1, 0.2, 0.5, 1.0], dtype=np.float32)
- r = self.get_module().log(a)
- r.print().assert_all_close()
+
+ def log(module):
+ module.log(np.array([0.1, 0.2, 0.5, 1.0], dtype=np.float32))
+
+ self.compare_backends(log)
def test_mod(self):
- a = np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32)
- r = self.get_module().mod(a)
- r.print().assert_all_close()
+
+ def mod(module):
+ module.mod(np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32))
+
+ self.compare_backends(mod)
if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/matrix_ops_test.py b/integrations/tensorflow/e2e/matrix_ops_test.py
index d04ce3a..c814397 100644
--- a/integrations/tensorflow/e2e/matrix_ops_test.py
+++ b/integrations/tensorflow/e2e/matrix_ops_test.py
@@ -15,6 +15,7 @@
"""Test matrix ops."""
from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
import tensorflow.compat.v2 as tf
@@ -71,60 +72,78 @@
@tf_test_utils.compile_module(MatrixOpsModule)
-class MatrixOpsTest(tf_test_utils.CompiledModuleTestCase):
+class MatrixOpsTest(tf_test_utils.TracedModuleTestCase):
def test_basic_matmul(self):
- m = self.get_module()
- dst = m.basic_matmul(tf.random.uniform([4, 2]), tf.random.uniform([2, 4]))
- dst.assert_all_close()
+
+ def basic_matmul(module):
+ module.basic_matmul(tf_utils.uniform([4, 2]), tf_utils.uniform([2, 4]))
+
+ self.compare_backends(basic_matmul)
def test_matmul_lhs_batch(self):
- m = self.get_module()
- dst = m.matmul_lhs_batch(
- tf.random.uniform([3, 4, 2]), tf.random.uniform([2, 4]))
- dst.assert_all_close()
+
+ def matmul_lhs_batch(module):
+ module.matmul_lhs_batch(
+ tf_utils.uniform([3, 4, 2]), tf_utils.uniform([2, 4]))
+
+ self.compare_backends(matmul_lhs_batch)
def test_matmul_rhs_batch(self):
- m = self.get_module()
- dst = m.matmul_rhs_batch(
- tf.random.uniform([4, 2]), tf.random.uniform([3, 2, 4]))
- dst.assert_all_close()
+
+ def matmul_rhs_batch(module):
+ module.matmul_rhs_batch(
+ tf_utils.uniform([4, 2]), tf_utils.uniform([3, 2, 4]))
+
+ self.compare_backends(matmul_rhs_batch)
def test_matmul_broadcast_singleton_dimension(self):
- m = self.get_module()
- dst = m.matmul_broadcast_singleton_dimension(
- tf.random.uniform([1, 4, 2]), tf.random.uniform([3, 2, 4]))
- dst.assert_all_close()
+
+ def matmul_broadcast_singleton_dimension(module):
+ module.matmul_broadcast_singleton_dimension(
+ tf_utils.uniform([1, 4, 2]), tf_utils.uniform([3, 2, 4]))
+
+ self.compare_backends(matmul_broadcast_singleton_dimension)
def test_matmul_high_rank_batch(self):
- m = self.get_module()
- dst = m.matmul_high_rank_batch(
- tf.random.uniform([1, 7, 4, 2]), tf.random.uniform([7, 1, 2, 4]))
- dst.assert_all_close()
+
+ def matmul_high_rank_batch(module):
+ module.matmul_high_rank_batch(
+ tf_utils.uniform([1, 7, 4, 2]), tf_utils.uniform([7, 1, 2, 4]))
+
+ self.compare_backends(matmul_high_rank_batch)
def test_matmul_dynamic_matching_batch(self):
- m = self.get_module()
- dst = m.matmul_dynamic(
- tf.random.uniform([2, 2, 3]), tf.random.uniform([2, 3, 4]))
- dst.assert_all_close()
+
+ def matmul_dynamic_matching_batch(module):
+ module.matmul_dynamic(
+ tf_utils.uniform([2, 2, 3]), tf_utils.uniform([2, 3, 4]))
+
+ self.compare_backends(matmul_dynamic_matching_batch)
def test_matmul_dynamic_broadcast_lhs(self):
- m = self.get_module()
- dst = m.matmul_dynamic(
- tf.random.uniform([1, 2, 3]), tf.random.uniform([2, 3, 4]))
- dst.assert_all_close()
+
+ def matmul_dynamic_broadcast_lhs(module):
+ module.matmul_dynamic(
+ tf_utils.uniform([1, 2, 3]), tf_utils.uniform([2, 3, 4]))
+
+ self.compare_backends(matmul_dynamic_broadcast_lhs)
def test_matmul_dynamic_broadcast_rhs(self):
- m = self.get_module()
- dst = m.matmul_dynamic(
- tf.random.uniform([2, 2, 3]), tf.random.uniform([1, 3, 4]))
- dst.assert_all_close()
+
+ def matmul_dynamic_broadcast_rhs(module):
+ module.matmul_dynamic(
+ tf_utils.uniform([2, 2, 3]), tf_utils.uniform([1, 3, 4]))
+
+ self.compare_backends(matmul_dynamic_broadcast_rhs)
def test_matmul_dynamic_rank_broadcasting(self):
- m = self.get_module()
- dst = m.matmul_dynamic_lhs_batch(
- tf.random.uniform([7, 2, 3]), tf.random.uniform([3, 4]))
- dst.assert_all_close()
+
+ def matmul_dynamic_rank_broadcasting(module):
+ module.matmul_dynamic_lhs_batch(
+ tf_utils.uniform([7, 2, 3]), tf_utils.uniform([3, 4]))
+
+ self.compare_backends(matmul_dynamic_rank_broadcasting)
if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/resource_ops_test.py b/integrations/tensorflow/e2e/resource_ops_test.py
index 8daa6cf..dd5ad6d 100644
--- a/integrations/tensorflow/e2e/resource_ops_test.py
+++ b/integrations/tensorflow/e2e/resource_ops_test.py
@@ -29,11 +29,14 @@
@tf_test_utils.compile_module(ResourcesOpsModule)
-class ResourcesOpsTest(tf_test_utils.CompiledModuleTestCase):
+class ResourcesOpsTest(tf_test_utils.TracedModuleTestCase):
def test_add_assign(self):
- result = self.get_module().add_assign(np.array(9., dtype=np.float32))
- result.assert_all_close()
+
+ def add_assign(module):
+ module.add_assign(np.array(9., dtype=np.float32))
+
+ self.compare_backends(add_assign)
if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/ring_buffer_test.py b/integrations/tensorflow/e2e/ring_buffer_test.py
index 3af1502..8e437ea 100644
--- a/integrations/tensorflow/e2e/ring_buffer_test.py
+++ b/integrations/tensorflow/e2e/ring_buffer_test.py
@@ -179,27 +179,26 @@
@tf_test_utils.compile_module(
StatefulRingBufferModule, exported_names=["predict"])
-class StatefulRingBufferTest(tf_test_utils.CompiledModuleTestCase):
+class StatefulRingBufferTest(tf_test_utils.TracedModuleTestCase):
def test_stateful_ringbuffer(self):
- input1 = np.array([[1.0, 2.0]], dtype=np.float32)
- result1 = self.get_module().predict(input1)
- output1 = np.array([[1.0, 2.0]], dtype=np.float32)
- assert np.allclose(result1, output1)
- # ring buffer is not filled yet,
- # so data from first cycle will be returned
- input2 = np.array([[3.0, 4.0]], dtype=np.float32)
- result2 = self.get_module().predict(input2)
- output2 = np.array([[1.0, 2.0]], dtype=np.float32)
- assert np.allclose(result2, output2)
+ def stateful_ringbuffer(module):
+ input1 = np.array([[1.0, 2.0]], dtype=np.float32)
+ module.predict(input1)
+ # output = np.array([[1.0, 2.0]], dtype=np.float32)
- # on 3rd cycle we overwrite oldest data
- # and return data from 2nd cycle
- input3 = np.array([[5.0, 6.0]], dtype=np.float32)
- result3 = self.get_module().predict(input3)
- output3 = np.array([[3.0, 4.0]], dtype=np.float32)
- assert np.allclose(result3, output3)
+ # ring buffer is not filled yet so data from first cycle will be returned.
+ input2 = np.array([[3.0, 4.0]], dtype=np.float32)
+ module.predict(input2)
+ # output = np.array([[1.0, 2.0]], dtype=np.float32)
+
+ # on 3rd cycle we overwrite oldest data and return data from 2nd cycle.
+ input3 = np.array([[5.0, 6.0]], dtype=np.float32)
+ module.predict(input3)
+ # output = np.array([[3.0, 4.0]], dtype=np.float32)
+
+ self.compare_backends(stateful_ringbuffer)
if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/scatter_update_test.py b/integrations/tensorflow/e2e/scatter_update_test.py
index ab5ab91..8b43e3a 100644
--- a/integrations/tensorflow/e2e/scatter_update_test.py
+++ b/integrations/tensorflow/e2e/scatter_update_test.py
@@ -49,28 +49,37 @@
@tf_test_utils.compile_module(ScatterUpdateModule)
-class ScatterUpdateTest(tf_test_utils.CompiledModuleTestCase):
+class ScatterUpdateTest(tf_test_utils.TracedModuleTestCase):
def test_scatter_update_1D(self):
- tensor = tf.ones([8], dtype=tf.int32)
- indices = tf.constant([[4], [5], [6]])
- updates = tf.constant([9, 10, 11])
- result = self.get_module().scatter_update_1D(tensor, indices, updates)
- result.assert_all_close()
+
+ def scatter_update_1D(module):
+ tensor = np.ones([8], dtype=np.int32)
+ indices = np.array([[4], [5], [6]], dtype=np.int32)
+ updates = np.array([9, 10, 11], dtype=np.int32)
+ module.scatter_update_1D(tensor, indices, updates)
+
+ self.compare_backends(scatter_update_1D)
def test_scatter_update_2D(self):
- tensor = tf.ones([4, 3], dtype=tf.int32)
- indices = tf.constant([[1, 0], [2, 1], [3, 2]])
- updates = tf.constant([2, 5, 8])
- result = self.get_module().scatter_update_2D(tensor, indices, updates)
- result.assert_all_close()
+
+ def scatter_update_2D(module):
+ tensor = np.ones([4, 3], dtype=np.int32)
+ indices = np.array([[1, 0], [2, 1], [3, 2]], dtype=np.int32)
+ updates = np.array([2, 5, 8], dtype=np.int32)
+ module.scatter_update_2D(tensor, indices, updates)
+
+ self.compare_backends(scatter_update_2D)
def test_scatter_update_2D_slice(self):
- tensor = tf.ones([4, 3], dtype=tf.int32)
- indices = tf.constant([[1]])
- updates = tf.constant([[2, 3, 4]])
- result = self.get_module().scatter_update_2D_slice(tensor, indices, updates)
- result.assert_all_close()
+
+ def scatter_update_2D_slice(module):
+ tensor = np.ones([4, 3], dtype=np.int32)
+ indices = np.array([[1]], dtype=np.int32)
+ updates = np.array([[2, 3, 4]], dtype=np.int32)
+ module.scatter_update_2D_slice(tensor, indices, updates)
+
+ self.compare_backends(scatter_update_2D_slice)
if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/simple_arithmetic_test.py b/integrations/tensorflow/e2e/simple_arithmetic_test.py
index d3ea327..aaec578 100644
--- a/integrations/tensorflow/e2e/simple_arithmetic_test.py
+++ b/integrations/tensorflow/e2e/simple_arithmetic_test.py
@@ -16,6 +16,7 @@
import numpy as np
from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
import tensorflow.compat.v2 as tf
@@ -37,21 +38,27 @@
@tf_test_utils.compile_module(SimpleArithmeticModule)
-class SimpleArithmeticTest(tf_test_utils.CompiledModuleTestCase):
+class SimpleArithmeticTest(tf_test_utils.TracedModuleTestCase):
def test_simple_mul(self):
- a = np.array([1., 2., 3., 4.], dtype=np.float32)
- b = np.array([400., 5., 6., 7.], dtype=np.float32)
- r = self.get_module().simple_mul(a, b)
- r.print().assert_all_close()
+
+ def simple_mul(module):
+ a = np.array([1., 2., 3., 4.], dtype=np.float32)
+ b = np.array([400., 5., 6., 7.], dtype=np.float32)
+ c = module.simple_mul(a, b)
+ module.simple_mul(a, c)
+
+ self.compare_backends(simple_mul)
def test_simple_matmul(self):
- np.random.seed(12345)
- # Note: scaling by a small value to increase numerical stability.
- a = np.random.random((128, 3072)).astype(np.float32) * 1e-3
- b = np.random.random((3072, 256)).astype(np.float32) * 1e-3
- r = self.get_module().simple_matmul(a, b)
- r.print().assert_all_close()
+
+ def simple_matmul(module):
+ # Note: scaling by a small value to increase numerical stability.
+ a = tf_utils.uniform((128, 3072)) * 1e-3
+ b = tf_utils.uniform((3072, 256)) * 1e-3
+ module.simple_matmul(a, b)
+
+ self.compare_backends(simple_matmul)
if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/simple_stateful_test.py b/integrations/tensorflow/e2e/simple_stateful_test.py
index 24dd23e..eff49a8 100644
--- a/integrations/tensorflow/e2e/simple_stateful_test.py
+++ b/integrations/tensorflow/e2e/simple_stateful_test.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -33,12 +34,15 @@
@tf_test_utils.compile_module(Stateful)
-class StatefulTest(tf_test_utils.CompiledModuleTestCase):
+class StatefulTest(tf_test_utils.TracedModuleTestCase):
def test_stateful(self):
- m = self.get_module()
- m.inc_by(tf.constant(1.))
- m.get_state().print().assert_all_close()
+
+ def get_state(module):
+ module.inc_by(np.array(1., dtype=np.float32))
+ module.get_state()
+
+ self.compare_backends(get_state)
if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/sliding_window_test.py b/integrations/tensorflow/e2e/sliding_window_test.py
index f206d86..513aa97 100644
--- a/integrations/tensorflow/e2e/sliding_window_test.py
+++ b/integrations/tensorflow/e2e/sliding_window_test.py
@@ -76,18 +76,20 @@
@tf_test_utils.compile_module(SlidingWindowModule, exported_names=["predict"])
-class SlidingWindowTest(tf_test_utils.CompiledModuleTestCase):
+class SlidingWindowTest(tf_test_utils.TracedModuleTestCase):
- def test_slidingwindow(self):
- input1 = np.array([[1.0, 2.0]], dtype=np.float32)
- result1 = self.get_module().predict(input1)
- output1 = np.array([[0.0, 0.0], [0.0, 0.0], [1.0, 2.0]], dtype=np.float32)
- assert np.allclose(result1, output1)
+ def test_sliding_window(self):
- input2 = np.array([[3.0, 4.0]], dtype=np.float32)
- result2 = self.get_module().predict(input2)
- output2 = np.array([[0.0, 0.0], [1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
- assert np.allclose(result2, output2)
+ def sliding_window(module):
+ input1 = np.array([[1.0, 2.0]], dtype=np.float32)
+ result1 = module.predict(input1)
+ # output1 = np.array([[0.0, 0.0], [0.0, 0.0], [1.0, 2.0]], dtype=np.float32)
+
+ input2 = np.array([[3.0, 4.0]], dtype=np.float32)
+ result2 = module.predict(input2)
+ # output2 = np.array([[0.0, 0.0], [1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
+
+ self.compare_backends(sliding_window)
if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/strings_test.py b/integrations/tensorflow/e2e/strings_test.py
index ce0787e..206b33c 100644
--- a/integrations/tensorflow/e2e/strings_test.py
+++ b/integrations/tensorflow/e2e/strings_test.py
@@ -41,20 +41,27 @@
@tf_test_utils.compile_module(StringsModule)
-class StringsTest(tf_test_utils.CompiledModuleTestCase):
+class StringsTest(tf_test_utils.TracedModuleTestCase):
def test_print_ids(self):
- input_ids = np.asarray(
- [[12, 10, 29, 28, 94, 15, 24, 27, 94, 25, 21, 10, 34],
- [13, 24, 16, 28, 94, 15, 24, 27, 94, 28, 29, 10, 34]])
- self.get_module().print_ids(input_ids)
+
+ def print_ids(module):
+ input_ids = np.asarray(
+ [[12, 10, 29, 28, 94, 15, 24, 27, 94, 25, 21, 10, 34],
+ [13, 24, 16, 28, 94, 15, 24, 27, 94, 28, 29, 10, 34]])
+ module.print_ids(input_ids)
+
+ self.compare_backends(print_ids)
def test_strings_to_ids(self):
- input_ids = np.asarray(
- [[12, 10, 29, 28, 94, 15, 24, 27, 94, 25, 21, 10, 34],
- [13, 24, 16, 28, 94, 15, 24, 27, 94, 28, 29, 10, 34]])
- result = self.get_module().strings_to_ids(input_ids)
- result.assert_all_equal()
+
+ def strings_to_ids(module):
+ input_ids = np.asarray(
+ [[12, 10, 29, 28, 94, 15, 24, 27, 94, 25, 21, 10, 34],
+ [13, 24, 16, 28, 94, 15, 24, 27, 94, 28, 29, 10, 34]])
+ module.strings_to_ids(input_ids)
+
+ self.compare_backends(strings_to_ids)
if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/tensorlist_test.py b/integrations/tensorflow/e2e/tensorlist_test.py
index 9b1330c..440bd43 100644
--- a/integrations/tensorflow/e2e/tensorlist_test.py
+++ b/integrations/tensorflow/e2e/tensorlist_test.py
@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import numpy as np
from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
import tensorflow.compat.v2 as tf
STATIC_SIZE = 20
@@ -65,36 +67,46 @@
@tf_test_utils.compile_module(TensorListModule)
-class TensorListTest(tf_test_utils.CompiledModuleTestCase):
+class TensorListTest(tf_test_utils.TracedModuleTestCase):
def test_identity_through_tensorlist(self):
- m = self.get_module()
- result = m.identity_through_tensorlist(tf.constant(42.))
- result.print().assert_all_close()
+
+ def identity_through_tensorlist(module):
+ module.identity_through_tensorlist(np.array(42., dtype=np.float32))
+
+ self.compare_backends(identity_through_tensorlist)
def test_add_through_tensorlist(self):
- m = self.get_module()
- result = m.add_through_tensorlist(tf.constant(42.), tf.constant(43.))
- result.print().assert_all_close()
+
+ def add_through_tensorlist(module):
+ module.add_through_tensorlist(
+ np.array(42., dtype=np.float32), np.array(43., dtype=np.float32))
+
+ self.compare_backends(add_through_tensorlist)
def test_slice_first_element_with_from_tensor(self):
- m = self.get_module()
- result = m.slice_first_element_with_from_tensor(
- tf.range(STATIC_SIZE, dtype=tf.float32))
- result.print().assert_all_close()
+
+ def slice_first_element_with_from_tensor(module):
+ module.slice_first_element_with_from_tensor(
+ np.arange(STATIC_SIZE, dtype=np.float32))
+
+ self.compare_backends(slice_first_element_with_from_tensor)
def test_slice_first_element_with_from_tensor_high_rank(self):
- m = self.get_module()
- result = m.slice_first_element_with_from_tensor_high_rank(
- tf.broadcast_to(
- tf.range(STATIC_SIZE, dtype=tf.float32),
- [STATIC_SIZE, STATIC_SIZE]))
- result.print().assert_all_close()
+
+ def slice_first_element_with_from_tensor_high_rank(module):
+ module.slice_first_element_with_from_tensor_high_rank(
+ tf_utils.ndarange([STATIC_SIZE, STATIC_SIZE]))
+
+ self.compare_backends(slice_first_element_with_from_tensor_high_rank)
def test_concat_with_tensorlist_stack(self):
- m = self.get_module()
- result = m.concat_with_tensorlist_stack(tf.constant(42.), tf.constant(43.))
- result.print().assert_all_close()
+
+ def concat_with_tensorlist_stack(module):
+ module.concat_with_tensorlist_stack(
+ np.array(42., dtype=np.float32), np.array(43., dtype=np.float32))
+
+ self.compare_backends(concat_with_tensorlist_stack)
if __name__ == "__main__":
diff --git a/iree/compiler/Conversion/BUILD b/iree/compiler/Conversion/BUILD
index 35f424e..8a297f8 100644
--- a/iree/compiler/Conversion/BUILD
+++ b/iree/compiler/Conversion/BUILD
@@ -24,6 +24,7 @@
"init_conversions.h",
],
deps = [
+ "//iree/compiler/Conversion/HLOToHLO",
"//iree/compiler/Conversion/HLOToLinalg",
"//iree/compiler/Conversion/LinalgToLLVM",
"//iree/compiler/Conversion/LinalgToSPIRV",
diff --git a/iree/compiler/Conversion/CMakeLists.txt b/iree/compiler/Conversion/CMakeLists.txt
index 6a022f2..5fa1478 100644
--- a/iree/compiler/Conversion/CMakeLists.txt
+++ b/iree/compiler/Conversion/CMakeLists.txt
@@ -20,6 +20,7 @@
HDRS
"init_conversions.h"
DEPS
+ iree::compiler::Conversion::HLOToHLO
iree::compiler::Conversion::HLOToLinalg
iree::compiler::Conversion::LinalgToLLVM
iree::compiler::Conversion::LinalgToSPIRV
diff --git a/iree/compiler/Conversion/HLOToHLO/BUILD b/iree/compiler/Conversion/HLOToHLO/BUILD
new file mode 100644
index 0000000..685e835
--- /dev/null
+++ b/iree/compiler/Conversion/HLOToHLO/BUILD
@@ -0,0 +1,38 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "HLOToHLO",
+ srcs = [
+ "DecomposeHLOClamp.cpp",
+ "Passes.cpp",
+ ],
+ hdrs = [
+ "Passes.h",
+ ],
+ deps = [
+ "@llvm-project//mlir:CFGTransforms",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Pass",
+ "@org_tensorflow//tensorflow/compiler/mlir/hlo",
+ "@org_tensorflow//tensorflow/compiler/mlir/hlo:legalize_gather_to_torch_index_select",
+ "@org_tensorflow//tensorflow/compiler/mlir/hlo:mhlo_to_mhlo_lowering_patterns",
+ ],
+)
diff --git a/iree/compiler/Conversion/HLOToHLO/CMakeLists.txt b/iree/compiler/Conversion/HLOToHLO/CMakeLists.txt
new file mode 100644
index 0000000..f94bbb5
--- /dev/null
+++ b/iree/compiler/Conversion/HLOToHLO/CMakeLists.txt
@@ -0,0 +1,31 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ HLOToHLO
+ HDRS
+ "Passes.h"
+ SRCS
+ "DecomposeHLOClamp.cpp"
+ "Passes.cpp"
+ DEPS
+ MLIRIR
+ MLIRPass
+ MLIRSCFToStandard
+ tensorflow::mlir_hlo
+ PUBLIC
+)
diff --git a/iree/compiler/Conversion/HLOToLinalg/DecomposeHLOClamp.cpp b/iree/compiler/Conversion/HLOToHLO/DecomposeHLOClamp.cpp
similarity index 100%
rename from iree/compiler/Conversion/HLOToLinalg/DecomposeHLOClamp.cpp
rename to iree/compiler/Conversion/HLOToHLO/DecomposeHLOClamp.cpp
diff --git a/iree/compiler/Conversion/HLOToHLO/Passes.cpp b/iree/compiler/Conversion/HLOToHLO/Passes.cpp
new file mode 100644
index 0000000..4bb2a8f
--- /dev/null
+++ b/iree/compiler/Conversion/HLOToHLO/Passes.cpp
@@ -0,0 +1,45 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Conversion/HLOToHLO/Passes.h"
+
+#include "mlir/IR/Matchers.h"
+#include "mlir/Pass/PassManager.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+struct ConvertHLOToCompatibleHLOPass
+ : public PassWrapper<ConvertHLOToCompatibleHLOPass, FunctionPass> {
+ void runOnFunction() override {
+ MLIRContext *context = &getContext();
+ OwningRewritePatternList greedyPatterns;
+ mhlo::PopulateComplexLoweringPatterns(context, &greedyPatterns);
+ mhlo::PopulateGatherToTorchIndexSelectPatterns(context, &greedyPatterns);
+
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), greedyPatterns))) {
+ return signalPassFailure();
+ }
+ }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>> createHLOToCompatibleHLOPass() {
+ return std::make_unique<ConvertHLOToCompatibleHLOPass>();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/HLOToHLO/Passes.h b/iree/compiler/Conversion/HLOToHLO/Passes.h
new file mode 100644
index 0000000..4dc36de
--- /dev/null
+++ b/iree/compiler/Conversion/HLOToHLO/Passes.h
@@ -0,0 +1,47 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- Passes.h - Pass to convert from XLA-HLO to IREE supported XLA-HLO --===//
+//
+// IREE specific passes used to sanitize XLA-HLO to IREE compatible XLA-HLO.
+// Some examples may be raising gather operations to torch index select or
+// decomposing complex operations input real valued equivalents.
+//
+//===----------------------------------------------------------------------===//
+#ifndef IREE_COMPILER_CONVERSION_HLOTOHLO_PASSES_H_
+#define IREE_COMPILER_CONVERSION_HLOTOHLO_PASSES_H_
+
+#include <memory>
+
+#include "mlir/IR/Function.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Creates a pass to decompose XLA-HLO clamp ops into primitive ops.
+std::unique_ptr<OperationPass<FuncOp>> createDecomposeHLOClampPass();
+
+/// Creates XLA-HLO to XLA-HLO transformation pass.
+std::unique_ptr<OperationPass<FuncOp>> createHLOToCompatibleHLOPass();
+
+/// Populates the patterns that convert from XLA-HLO to IREE compatible XLA-HLO.
+/// Imports patterns from the Tensorflow XLA passes.
+void populateHLOToCompatibleHLOPatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns);
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_CONVERSION_HLOTOHLO_PASSES_H_
diff --git a/iree/compiler/Conversion/HLOToLinalg/BUILD b/iree/compiler/Conversion/HLOToLinalg/BUILD
index 2e7447b..89e790d 100644
--- a/iree/compiler/Conversion/HLOToLinalg/BUILD
+++ b/iree/compiler/Conversion/HLOToLinalg/BUILD
@@ -21,7 +21,6 @@
cc_library(
name = "HLOToLinalg",
srcs = [
- "DecomposeHLOClamp.cpp",
"HLOToLinalgOnBuffers.cpp",
"HLOToLinalgOnTensors.cpp",
"Passes.cpp",
@@ -32,6 +31,7 @@
],
deps = [
"//iree/compiler/Conversion/CodegenUtils",
+ "//iree/compiler/Conversion/HLOToHLO",
"//iree/compiler/Dialect/HAL/IR",
"//iree/compiler/Dialect/HAL/IR:HALDialect",
"//iree/compiler/Dialect/IREE/IR",
diff --git a/iree/compiler/Conversion/HLOToLinalg/CMakeLists.txt b/iree/compiler/Conversion/HLOToLinalg/CMakeLists.txt
index bdfd4c9..26da6bd 100644
--- a/iree/compiler/Conversion/HLOToLinalg/CMakeLists.txt
+++ b/iree/compiler/Conversion/HLOToLinalg/CMakeLists.txt
@@ -20,7 +20,6 @@
HDRS
"Passes.h"
SRCS
- "DecomposeHLOClamp.cpp"
"HLOToLinalgOnBuffers.cpp"
"HLOToLinalgOnTensors.cpp"
"Passes.cpp"
@@ -36,6 +35,7 @@
MLIRSupport
MLIRTransforms
iree::compiler::Conversion::CodegenUtils
+ iree::compiler::Conversion::HLOToHLO
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
iree::compiler::Dialect::IREE::IR
diff --git a/iree/compiler/Conversion/HLOToLinalg/Passes.cpp b/iree/compiler/Conversion/HLOToLinalg/Passes.cpp
index 91432c3..e5638b0 100644
--- a/iree/compiler/Conversion/HLOToLinalg/Passes.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/Passes.cpp
@@ -14,6 +14,7 @@
#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
+#include "iree/compiler/Conversion/HLOToHLO/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
@@ -22,6 +23,7 @@
namespace iree_compiler {
void addHLOToLinalgOnBuffersPasses(OpPassManager &pm) {
+ pm.addPass(createHLOToCompatibleHLOPass());
pm.addPass(createHLOToLinalgOnTensorsPass());
pm.addPass(createLinalgFusionOfTensorOpsPass());
pm.addPass(createLinalgFoldUnitExtentDimsPass());
diff --git a/iree/compiler/Conversion/HLOToLinalg/Passes.h b/iree/compiler/Conversion/HLOToLinalg/Passes.h
index 437b0a6..445bd51 100644
--- a/iree/compiler/Conversion/HLOToLinalg/Passes.h
+++ b/iree/compiler/Conversion/HLOToLinalg/Passes.h
@@ -27,9 +27,6 @@
namespace mlir {
namespace iree_compiler {
-/// Crates a pass to decompose XLA-HLO clamp ops into primitive ops.
-std::unique_ptr<OperationPass<FuncOp>> createDecomposeHLOClampPass();
-
/// Creates XLA-HLO to Linalg on buffers transformation pass.
std::unique_ptr<OperationPass<FuncOp>> createHLOToLinalgOnBuffersPass();
diff --git a/iree/compiler/Conversion/LinalgToLLVM/BUILD b/iree/compiler/Conversion/LinalgToLLVM/BUILD
index c6cafc1..48d21e3 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/BUILD
+++ b/iree/compiler/Conversion/LinalgToLLVM/BUILD
@@ -30,6 +30,7 @@
],
deps = [
"//iree/compiler/Conversion/CodegenUtils",
+ "//iree/compiler/Conversion/HLOToHLO",
"//iree/compiler/Conversion/HLOToLinalg",
"//iree/compiler/Dialect/HAL/IR",
"//iree/compiler/Dialect/HAL/IR:HALDialect",
diff --git a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
index fddc144..09e06fa 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
@@ -39,6 +39,7 @@
MLIRVectorToLLVM
MLIRVectorToSCF
iree::compiler::Conversion::CodegenUtils
+ iree::compiler::Conversion::HLOToHLO
iree::compiler::Conversion::HLOToLinalg
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
diff --git a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
index 777fb02..fc476d0 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
@@ -144,8 +144,9 @@
public:
explicit ConvertFuncWithHALInterface(MLIRContext *context,
LLVMTypeConverter &typeconverter)
- : ConvertToLLVMPattern(FuncOp::getOperationName(), context,
- typeconverter) {}
+ : ConvertToLLVMPattern(
+ mlir::FuncOp::getOperationName(), context, typeconverter,
+ LowerToLLVMOptions::getDefaultOptions(), 65535 - 1) {}
LogicalResult matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
@@ -345,8 +346,15 @@
RemoveInterfaceOpPattern>(&getContext(), converter);
LLVMConversionTarget target(getContext());
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
- if (failed(applyPartialConversion(module, target, patterns)))
+ target.addIllegalOp<IREE::PlaceholderOp>();
+ target.addDynamicallyLegalOp<FuncOp>([](FuncOp funcOp) {
+ bool any = false;
+ funcOp.walk([&](IREE::PlaceholderOp placeholderOp) { any = true; });
+ return any ? false : true;
+ });
+ if (failed(applyPartialConversion(module, target, patterns))) {
signalPassFailure();
+ }
}
std::unique_ptr<OperationPass<ModuleOp>> createConvertToLLVMPass() {
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
index 8631cdf..c83a5d8 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
@@ -14,6 +14,7 @@
#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
+#include "iree/compiler/Conversion/HLOToHLO/Passes.h"
#include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
index fc447e6..7a80343 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/BUILD
+++ b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
@@ -41,6 +41,7 @@
],
deps = [
"//iree/compiler/Conversion/CodegenUtils",
+ "//iree/compiler/Conversion/HLOToHLO",
"//iree/compiler/Conversion/HLOToLinalg",
"//iree/compiler/Dialect/HAL/IR",
"//iree/compiler/Dialect/IREE/IR",
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
index 69c04d9..7bca4c8 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
@@ -56,6 +56,7 @@
MLIRTransforms
MLIRVector
iree::compiler::Conversion::CodegenUtils
+ iree::compiler::Conversion::HLOToHLO
iree::compiler::Conversion::HLOToLinalg
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::IREE::IR
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
index ed2ab3a..0991416 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
@@ -20,6 +20,7 @@
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
+#include "iree/compiler/Conversion/HLOToHLO/Passes.h"
#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
index b1bba5e..198a019 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
@@ -171,8 +171,7 @@
separableOp.index()));
StringRef newFnName = splitKernels.back();
builder.setInsertionPointToStart(moduleOp.getBody());
- auto newFn = builder.create<FuncOp>(loc, newFnName, oldFn.getType(),
- /*attrs=*/ArrayRef<NamedAttribute>());
+ auto newFn = builder.create<FuncOp>(loc, newFnName, oldFn.getType());
// Copy over all attributes except type and name.
for (const auto &namedAttr : oldFn.getAttrs()) {
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
index c16fd06..b8d7bfa 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
@@ -76,8 +76,8 @@
if (VectorToGPUPattern<StdOp>::cooperativeMatrixAnalysis
.usesCooperativeMatrixType(operation))
return failure();
- Value newOp = rewriter.create<StdOp>(
- operation.getLoc(), ValueRange(operands), ArrayRef<NamedAttribute>{});
+ Value newOp =
+ rewriter.create<StdOp>(operation.getLoc(), ValueRange(operands));
rewriter.replaceOp(operation, ValueRange(newOp));
return success();
}
diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h
index 259e3d5..7450f85 100644
--- a/iree/compiler/Conversion/init_conversions.h
+++ b/iree/compiler/Conversion/init_conversions.h
@@ -15,6 +15,7 @@
#ifndef IREE_COMPILER_CONVERSION_INIT_CONVERSIONS_H_
#define IREE_COMPILER_CONVERSION_INIT_CONVERSIONS_H_
+#include "iree/compiler/Conversion/HLOToHLO/Passes.h"
#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
#include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
@@ -26,6 +27,8 @@
// expects all the possible conversions to be made available to the context
// automatically.
+inline void registerHLOToHLOPasses() { createHLOToCompatibleHLOPass(); }
+
inline void registerHLOToLinalgPasses() {
createDecomposeHLOClampPass();
createHLOToLinalgOnBuffersPass();
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp
index 4c3f467..40053dd 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp
@@ -42,8 +42,7 @@
// provided symbols and it seems like it shouldn't be.
auto uniqueName = (Twine("__") + variableOp.getName() + "_initializer").str();
auto initializerFuncOp =
- rewriter.create<FuncOp>(variableOp.getLoc(), uniqueName, initializerType,
- ArrayRef<NamedAttribute>{});
+ rewriter.create<FuncOp>(variableOp.getLoc(), uniqueName, initializerType);
auto *entryBlock = initializerFuncOp.addEntryBlock();
rewriter.setInsertionPointToEnd(entryBlock);
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRTarget.cpp
index 96bb5ac..128abd4 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRTarget.cpp
@@ -52,6 +52,9 @@
// At this moment we are leaving MLIR LLVM dialect land translating module
// into target independent LLVMIR.
auto llvmModule = mlir::translateModuleToLLVMIR(targetOp.getInnerModule());
+ if (!llvmModule) {
+ return targetOp.emitError("Failed to translate executable to LLVM IR");
+ }
// Create invocation function an populate entry_points.
iree::LLVMIRExecutableDefT llvmIrExecutableDef;
diff --git a/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp b/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
index e605418..1c1c9e4 100644
--- a/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
@@ -123,8 +123,7 @@
descriptorSetLayoutCache_.try_emplace(bindingsAttr, variableOp);
auto initializerOp = moduleBuilder.create<FuncOp>(
- loc, initializerName, moduleBuilder.getFunctionType({}, {layoutType}),
- ArrayRef<NamedAttribute>{});
+ loc, initializerName, moduleBuilder.getFunctionType({}, {layoutType}));
SymbolTable::setSymbolVisibility(initializerOp,
SymbolTable::Visibility::Private);
auto *block = initializerOp.addEntryBlock();
@@ -168,8 +167,7 @@
executableLayoutCache_.try_emplace(setLayoutsArrayAttr, variableOp);
auto initializerOp = moduleBuilder.create<FuncOp>(
- loc, initializerName, moduleBuilder.getFunctionType({}, {layoutType}),
- ArrayRef<NamedAttribute>{});
+ loc, initializerName, moduleBuilder.getFunctionType({}, {layoutType}));
SymbolTable::setSymbolVisibility(initializerOp,
SymbolTable::Visibility::Private);
auto *block = initializerOp.addEntryBlock();
@@ -212,8 +210,7 @@
auto initializerOp = moduleBuilder.create<FuncOp>(
loc, initializerName,
- moduleBuilder.getFunctionType({}, {executableCacheType}),
- ArrayRef<NamedAttribute>{});
+ moduleBuilder.getFunctionType({}, {executableCacheType}));
SymbolTable::setSymbolVisibility(initializerOp,
SymbolTable::Visibility::Private);
auto *block = initializerOp.addEntryBlock();
diff --git a/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp b/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
index 83b1884..3e343eb 100644
--- a/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
@@ -72,8 +72,7 @@
"_device_match_id_" + std::to_string(matchKey.index());
auto initializerOp = moduleBuilder.create<FuncOp>(
fusedLoc, variableName + "_initializer",
- moduleBuilder.getFunctionType({}, {moduleBuilder.getI1Type()}),
- ArrayRef<NamedAttribute>{});
+ moduleBuilder.getFunctionType({}, {moduleBuilder.getI1Type()}));
SymbolTable::setSymbolVisibility(initializerOp,
SymbolTable::Visibility::Private);
auto variableOp = moduleBuilder.create<IREE::HAL::VariableOp>(
diff --git a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
index b3ceb17..9c7f0c2 100644
--- a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
+++ b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
@@ -1270,11 +1270,11 @@
if (op.getOperation()->getOperand(0).getType().template isa<RefType>()) {
condValue = rewriter.template createOrFold<CmpRefOp>(
op.getLoc(), ArrayRef<Type>{condType},
- op.getOperation()->getOperands(), ArrayRef<NamedAttribute>{});
+ op.getOperation()->getOperands());
} else {
condValue = rewriter.template createOrFold<CmpI32Op>(
op.getLoc(), ArrayRef<Type>{condType},
- op.getOperation()->getOperands(), ArrayRef<NamedAttribute>{});
+ op.getOperation()->getOperands());
}
condValue = rewriter.createOrFold<XorI32Op>(
op.getLoc(), condType, condValue,
diff --git a/iree/compiler/Dialect/VMLA/Transforms/BUILD b/iree/compiler/Dialect/VMLA/Transforms/BUILD
index e70cb08..d53f037 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/BUILD
+++ b/iree/compiler/Dialect/VMLA/Transforms/BUILD
@@ -48,5 +48,6 @@
"@llvm-project//mlir:Transforms",
"@org_tensorflow//tensorflow/compiler/mlir/hlo",
"@org_tensorflow//tensorflow/compiler/mlir/hlo:lhlo_fuse_linalg",
+ "@org_tensorflow//tensorflow/compiler/mlir/hlo:mhlo_to_mhlo_lowering_patterns",
],
)
diff --git a/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp b/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
index 15d91c6..f22c87f 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
+++ b/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
@@ -34,6 +34,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace iree_compiler {
@@ -281,6 +282,15 @@
public:
void runOnOperation() {
MLIRContext *context = &getContext();
+
+ // These patterns should be run greedily as they are not dialect
+ // conversions.
+ OwningRewritePatternList greedyPatterns;
+ mhlo::PopulateComplexLoweringPatterns(context, &greedyPatterns);
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), greedyPatterns))) {
+ return signalPassFailure();
+ }
+
OwningRewritePatternList patterns;
ConversionTarget target(*context);
target.addLegalDialect<StandardOpsDialect>();
diff --git a/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir b/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir
index d21bfda..3473e44 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir
+++ b/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir
@@ -31,3 +31,20 @@
%0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[5, 6]> : tensor<2xi64>} : (tensor<3xf32>) -> tensor<5x6x3xf32>
return %0 : tensor<5x6x3xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
+ // CHECK-NOT: "mhlo.complex"
+ %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex<f32>>
+
+ // CHECK-DAG: [[V1:%.+]] = mhlo.multiply %arg0, %arg0
+ // CHECK-DAG: [[V2:%.+]] = mhlo.multiply %arg1, %arg1
+ // CHECK-DAG: [[V3:%.+]] = mhlo.subtract [[V1]], [[V2]]
+ %1 = "mhlo.multiply"(%0, %0) : (tensor<3xcomplex<f32>>, tensor<3xcomplex<f32>>) -> tensor<3xcomplex<f32>>
+ %2 = "mhlo.real"(%1) : (tensor<3xcomplex<f32>>) -> tensor<3xf32>
+
+ // CHECK: return [[V3]]
+ return %2 : tensor<3xf32>
+}
diff --git a/iree/schemas/BUILD b/iree/schemas/BUILD
index 4d9d6d1..b1301e2 100644
--- a/iree/schemas/BUILD
+++ b/iree/schemas/BUILD
@@ -13,6 +13,7 @@
# limitations under the License.
load("//iree:build_defs.oss.bzl", "FLATBUFFER_SUPPORTS_REFLECTIONS", "iree_build_test", "iree_flatbuffer_cc_library")
+load("//build_tools/bazel:iree_flatcc.bzl", "iree_flatbuffer_c_library")
load("//build_tools/embed_data:build_defs.bzl", "cc_embed_data")
package(
@@ -33,7 +34,16 @@
"--gen-object-api",
]
-# TODO(benvanik): also expose as C using flatcc.
+iree_flatbuffer_c_library(
+ name = "bytecode_module_def_c_fbs",
+ srcs = ["bytecode_module_def.fbs"],
+ flatcc_args = [
+ "--reader",
+ "--builder",
+ "--verifier",
+ ],
+)
+
iree_flatbuffer_cc_library(
name = "bytecode_module_def_cc_fbs",
srcs = ["bytecode_module_def.fbs"],
diff --git a/iree/schemas/CMakeLists.txt b/iree/schemas/CMakeLists.txt
index 751ed40..313ffa1 100644
--- a/iree/schemas/CMakeLists.txt
+++ b/iree/schemas/CMakeLists.txt
@@ -14,6 +14,18 @@
iree_add_all_subdirs()
+flatbuffer_c_library(
+ NAME
+ bytecode_module_def_c_fbs
+ SRCS
+ "bytecode_module_def.fbs"
+ FLATCC_ARGS
+ "--reader"
+ "--builder"
+ "--verifier"
+ PUBLIC
+)
+
flatbuffer_cc_library(
NAME
bytecode_module_def_cc_fbs
diff --git a/iree/schemas/bytecode_module_def.fbs b/iree/schemas/bytecode_module_def.fbs
index 2394afb..f971093 100644
--- a/iree/schemas/bytecode_module_def.fbs
+++ b/iree/schemas/bytecode_module_def.fbs
@@ -93,7 +93,7 @@
compression_type:CompressionTypeDef;
// Contents in a format defined by CompressionTypeDef.
- data:[uint8] (force_align: 16);
+ data:[uint8];
}
// Read-write data segment.
@@ -117,6 +117,7 @@
// Offset and length within the larger bytecode data block.
bytecode_offset:int32;
bytecode_length:int32;
+ // TODO(benvanik): remove counts and embed directly in bytecode.
// Total number of i32 registers used by the function.
i32_register_count:int16;
// Total number of ref_ptr registers used by the function.
@@ -175,7 +176,7 @@
function_descriptors:[FunctionDescriptor];
// Bytecode contents. One large buffer containing all of the function op data.
- bytecode_data:[uint8] (force_align: 4);
+ bytecode_data:[uint8];
}
root_type BytecodeModuleDef;
diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD
index 744ff22..9ccb2d9 100644
--- a/iree/test/e2e/xla_ops/BUILD
+++ b/iree/test/e2e/xla_ops/BUILD
@@ -57,11 +57,8 @@
# TODO(#1696): Enable after standard dialect can support floor
# operation. Lowering from XLA -> linalg should be easy fix.
# "floor.mlir",
-
- # TODO(#1694): Enable after mhlo.gather can be lowered to linalg.
- # "gather.mlir",
- # "gather_concat.mlir",
- #
+ "gather.mlir",
+ "gather_concat.mlir",
"iota.mlir",
"log.mlir",
"maximum.mlir",
@@ -89,6 +86,8 @@
target_backend = "vulkan-spirv",
)
+# TODO(ataei): Enable dylib-llvm-aot tests.
+# See: https://github.com/google/iree/issues/2645
iree_check_single_backend_test_suite(
name = "check_llvm-ir_llvm",
srcs = [
@@ -107,6 +106,8 @@
"divide.mlir",
"dot.mlir",
"exponential.mlir",
+ "gather.mlir",
+ "gather_concat.mlir",
"iota.mlir",
"log.mlir",
"maximum.mlir",
diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt
index e4ae959..51a62cf 100644
--- a/iree/test/e2e/xla_ops/CMakeLists.txt
+++ b/iree/test/e2e/xla_ops/CMakeLists.txt
@@ -46,6 +46,8 @@
"divide.mlir"
"dot.mlir"
"exponential.mlir"
+ "gather.mlir"
+ "gather_concat.mlir"
"iota.mlir"
"log.mlir"
"maximum.mlir"
@@ -93,6 +95,8 @@
"divide.mlir"
"dot.mlir"
"exponential.mlir"
+ "gather.mlir"
+ "gather_concat.mlir"
"iota.mlir"
"log.mlir"
"maximum.mlir"
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index 46ea8a9..a3b5a9f 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -243,6 +243,7 @@
"@com_google_absl//absl/types:span",
"//iree/base:api",
"//iree/base:api_util",
+ "//iree/base:localfile",
"//iree/base:source_location",
"//iree/base:tracing",
"//iree/compiler/Dialect/Flow/Transforms",
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index a277877..b3a4fa9 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -330,6 +330,7 @@
iree::base::api
iree::base::api_util
iree::base::init
+ iree::base::localfile
iree::base::source_location
iree::base::status
iree::base::tracing
diff --git a/iree/vm/BUILD b/iree/vm/BUILD
index 93829d0..ceec083 100644
--- a/iree/vm/BUILD
+++ b/iree/vm/BUILD
@@ -59,11 +59,10 @@
":value",
"//iree/base:alignment",
"//iree/base:api",
- "//iree/base:flatbuffer_util",
+ "//iree/base:logging",
"//iree/base:target_platform",
- "//iree/schemas:bytecode_module_def_cc_fbs",
- "@com_github_google_flatbuffers//:flatbuffers",
- "@com_google_absl//absl/strings",
+ "//iree/schemas:bytecode_module_def_c_fbs",
+ "@com_github_dvidelabs_flatcc//:runtime",
],
)
diff --git a/iree/vm/CMakeLists.txt b/iree/vm/CMakeLists.txt
index ae91187..84b8bf5 100644
--- a/iree/vm/CMakeLists.txt
+++ b/iree/vm/CMakeLists.txt
@@ -64,13 +64,12 @@
::stack
::type_def
::value
- absl::strings
- flatbuffers
+ flatcc::runtime
iree::base::alignment
iree::base::api
- iree::base::flatbuffer_util
+ iree::base::logging
iree::base::target_platform
- iree::schemas::bytecode_module_def_cc_fbs
+ iree::schemas::bytecode_module_def_c_fbs
PUBLIC
)
diff --git a/iree/vm/bytecode_dispatch.c b/iree/vm/bytecode_dispatch.c
index eab5136..2565c5e 100644
--- a/iree/vm/bytecode_dispatch.c
+++ b/iree/vm/bytecode_dispatch.c
@@ -741,7 +741,7 @@
target_function.module = &module->interface;
target_function.linkage = IREE_VM_FUNCTION_LINKAGE_INTERNAL;
target_function.ordinal = function_ordinal;
- const iree_vm_function_descriptor_t* target_descriptor =
+ const iree_vm_FunctionDescriptor_t* target_descriptor =
&module->function_descriptor_table[function_ordinal];
target_function.i32_register_count =
target_descriptor->i32_register_count;
diff --git a/iree/vm/bytecode_module.cc b/iree/vm/bytecode_module.cc
index 9f56a35..cec96fd 100644
--- a/iree/vm/bytecode_module.cc
+++ b/iree/vm/bytecode_module.cc
@@ -14,23 +14,20 @@
#include "iree/vm/bytecode_module.h"
-#include <string.h>
-
-#include "absl/strings/match.h"
#include "iree/base/alignment.h"
#include "iree/base/api.h"
-#include "iree/base/flatbuffer_util.h"
+#include "iree/base/logging.h"
#include "iree/vm/bytecode_module_impl.h"
#include "iree/vm/ref.h"
#include "iree/vm/stack.h"
-// TODO(benvanik): replace with flatcc version so this file can be pure C.
-#include "flatbuffers/flatbuffers.h"
-#include "iree/schemas/bytecode_module_def_generated.h"
-
-#define IREE_VM_GET_MODULE_DEF(module) \
- ::flatbuffers::GetRoot<iree::vm::BytecodeModuleDef>( \
- module->flatbuffer_data.data)
+// Perform an strcmp between a flatbuffers string and an IREE string view.
+static bool iree_vm_flatbuffer_strcmp(flatbuffers_string_t lhs,
+ iree_string_view_t rhs) {
+ size_t lhs_size = flatbuffers_string_len(lhs);
+ int x = strncmp(lhs, rhs.data, lhs_size < rhs.size ? lhs_size : rhs.size);
+ return x != 0 ? x : lhs_size < rhs.size ? -1 : lhs_size > rhs.size;
+}
// Returns true if the given |type_def| is valid, meaning that the type it was
// resolved from is registered or known to the system as a builtin.
@@ -41,28 +38,36 @@
// Resolves a type through either builtin rules or the ref registered types.
static iree_vm_type_def_t iree_vm_bytecode_module_resolve_type(
- const iree::vm::TypeDef* type_def) {
- auto full_name = iree::WrapString(type_def->full_name());
+ iree_vm_TypeDef_table_t type_def) {
iree_vm_type_def_t result;
memset(&result, 0, sizeof(result));
- if (full_name == "i8") {
+ flatbuffers_string_t full_name = iree_vm_TypeDef_full_name(type_def);
+ if (!flatbuffers_string_len(full_name)) {
+ return result;
+ } else if (iree_vm_flatbuffer_strcmp(full_name,
+ iree_make_cstring_view("i8")) == 0) {
result.value_type = IREE_VM_VALUE_TYPE_I8;
- } else if (full_name == "i16") {
+ } else if (iree_vm_flatbuffer_strcmp(full_name,
+ iree_make_cstring_view("i16")) == 0) {
result.value_type = IREE_VM_VALUE_TYPE_I16;
- } else if (full_name == "i32") {
+ } else if (iree_vm_flatbuffer_strcmp(full_name,
+ iree_make_cstring_view("i32")) == 0) {
result.value_type = IREE_VM_VALUE_TYPE_I32;
- } else if (full_name == "i64") {
+ } else if (iree_vm_flatbuffer_strcmp(full_name,
+ iree_make_cstring_view("i64")) == 0) {
result.value_type = IREE_VM_VALUE_TYPE_I64;
- } else if (!full_name.empty() && full_name[0] == '!') {
- full_name.remove_prefix(1);
- if (absl::StartsWith(full_name, "vm.list<")) {
+ } else if (full_name[0] == '!') {
+ // Note that we drop the ! prefix:
+ iree_string_view_t type_name = iree_string_view_t{
+ full_name + 1, flatbuffers_string_len(full_name) - 1};
+ if (strncmp(type_name.data, "vm.list<", strlen("vm.list<")) == 0) {
// This is a !vm.list<...> type. We don't actually care about the type as
- // we allow list types to be widened.
- full_name.remove_suffix(full_name.size() - 7);
+ // we allow list types to be widened. Rewrite to just vm.list as that's
+ // all we have registered.
+ type_name = iree_make_cstring_view("vm.list");
}
const iree_vm_ref_type_descriptor_t* type_descriptor =
- iree_vm_ref_lookup_registered_type(
- iree_string_view_t{full_name.data(), full_name.size()});
+ iree_vm_ref_lookup_registered_type(type_name);
if (type_descriptor) {
result.ref_type = type_descriptor->type;
}
@@ -74,12 +79,13 @@
// |type_table| can be omitted to just perform verification that all types are
// registered.
static iree_status_t iree_vm_bytecode_module_resolve_types(
- const iree::vm::BytecodeModuleDef* module_def,
- iree_vm_type_def_t* type_table) {
- for (int i = 0; i < module_def->types()->size(); ++i) {
- type_table[i] =
- iree_vm_bytecode_module_resolve_type(module_def->types()->Get(i));
+ iree_vm_TypeDef_vec_t type_defs, iree_vm_type_def_t* type_table) {
+ for (size_t i = 0; i < iree_vm_TypeDef_vec_len(type_defs); ++i) {
+ iree_vm_TypeDef_table_t type_def = iree_vm_TypeDef_vec_at(type_defs, i);
+ type_table[i] = iree_vm_bytecode_module_resolve_type(type_def);
if (!iree_vm_type_def_is_valid(type_table[i])) {
+ LOG(ERROR) << "no type registered with name '"
+ << iree_vm_TypeDef_full_name(type_def) << "'";
return IREE_STATUS_NOT_FOUND;
}
}
@@ -91,121 +97,136 @@
// names on functions with internal linkage), however we shouldn't need to
// bounds check anything within the flatbuffer after this succeeds.
static iree_status_t iree_vm_bytecode_module_flatbuffer_verify(
- const iree::vm::BytecodeModuleDef* module_def) {
- if (!module_def->name() || module_def->name()->size() == 0) {
- LOG(ERROR) << "All modules must have a name.";
+ iree_const_byte_span_t flatbuffer_data) {
+ if (!flatbuffer_data.data || flatbuffer_data.data_length < 16) {
+ LOG(ERROR) << "Flatbuffer data is not present or less than 16 bytes";
return IREE_STATUS_INVALID_ARGUMENT;
}
- if (!module_def->types()) {
- LOG(ERROR) << "Type table is mandatory, though it could be empty (in empty "
- "modules).";
+ // Run flatcc generated verification. This ensures all pointers are in-bounds
+ // and that we can safely walk the file, but not that the actual contents of
+ // the flatbuffer meet our expectations.
+ int verify_ret = iree_vm_BytecodeModuleDef_verify_as_root(
+ flatbuffer_data.data, flatbuffer_data.data_length);
+ if (verify_ret != flatcc_verify_ok) {
+ LOG(ERROR) << flatcc_verify_error_string(verify_ret);
return IREE_STATUS_INVALID_ARGUMENT;
}
- if (!module_def->exported_functions() ||
- module_def->exported_functions()->size() == 0) {
- LOG(ERROR) << "At least one exported function is required.";
+ iree_vm_BytecodeModuleDef_table_t module_def =
+ iree_vm_BytecodeModuleDef_as_root(flatbuffer_data.data);
+
+ flatbuffers_string_t name = iree_vm_BytecodeModuleDef_name(module_def);
+ if (!flatbuffers_string_len(name)) {
+ LOG(ERROR) << "module name missing";
return IREE_STATUS_INVALID_ARGUMENT;
}
- if (!module_def->internal_functions() ||
- module_def->internal_functions()->size() == 0) {
- LOG(ERROR) << "At least one internal function is required.";
- return IREE_STATUS_INVALID_ARGUMENT;
- }
-
- if (!module_def->function_descriptors() ||
- module_def->function_descriptors()->size() !=
- module_def->internal_functions()->size()) {
- LOG(ERROR)
- << "All internal functions need a mapping into the bytecode data.";
- return IREE_STATUS_INVALID_ARGUMENT;
- }
-
- if (!module_def->bytecode_data()) {
- LOG(ERROR) << "Bytecode data is required if we have any functions.";
- return IREE_STATUS_INVALID_ARGUMENT;
- }
-
- for (int i = 0; i < module_def->types()->size(); ++i) {
- const auto* type_def = module_def->types()->Get(i);
+ iree_vm_TypeDef_vec_t types = iree_vm_BytecodeModuleDef_types(module_def);
+ for (size_t i = 0; i < iree_vm_TypeDef_vec_len(types); ++i) {
+ iree_vm_TypeDef_table_t type_def = iree_vm_TypeDef_vec_at(types, i);
if (!type_def) {
- LOG(ERROR) << "All types must be valid.";
- return IREE_STATUS_INVALID_ARGUMENT;
- } else if (!type_def->full_name() || type_def->full_name()->size() == 0) {
- LOG(ERROR) << "All types require a name.";
+ LOG(ERROR) << "type def missing body";
return IREE_STATUS_INVALID_ARGUMENT;
}
- if (!iree_vm_type_def_is_valid(
- iree_vm_bytecode_module_resolve_type(type_def))) {
- LOG(ERROR) << "No type registered with name '"
- << type_def->full_name()->c_str() << "'.";
+ flatbuffers_string_t full_name = iree_vm_TypeDef_full_name(type_def);
+ if (flatbuffers_string_len(full_name) <= 0) {
+ LOG(ERROR) << "type def missing full_name";
return IREE_STATUS_INVALID_ARGUMENT;
}
}
- if (module_def->imported_functions()) {
- for (int i = 0; i < module_def->imported_functions()->size(); ++i) {
- auto* import_def = module_def->imported_functions()->Get(i);
- if (!import_def) {
- LOG(ERROR) << "All imports must be valid.";
- return IREE_STATUS_INVALID_ARGUMENT;
- } else if (!import_def->full_name() ||
- import_def->full_name()->size() == 0) {
- LOG(ERROR) << "All imports require a name.";
- return IREE_STATUS_INVALID_ARGUMENT;
- } else if (!import_def->signature()) {
- LOG(ERROR) << "All imports require a signature.";
- return IREE_STATUS_INVALID_ARGUMENT;
- }
+ iree_vm_ImportFunctionDef_vec_t imported_functions =
+ iree_vm_BytecodeModuleDef_imported_functions(module_def);
+ iree_vm_ExportFunctionDef_vec_t exported_functions =
+ iree_vm_BytecodeModuleDef_exported_functions(module_def);
+ iree_vm_InternalFunctionDef_vec_t internal_functions =
+ iree_vm_BytecodeModuleDef_internal_functions(module_def);
+ iree_vm_FunctionDescriptor_vec_t function_descriptors =
+ iree_vm_BytecodeModuleDef_function_descriptors(module_def);
+
+ if (flatbuffers_vec_len(internal_functions) !=
+ flatbuffers_vec_len(function_descriptors)) {
+ LOG(ERROR)
+ << "mismatched internal_functions and function_descriptors vectors";
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ for (size_t i = 0; i < iree_vm_ImportFunctionDef_vec_len(imported_functions);
+ ++i) {
+ iree_vm_ImportFunctionDef_table_t import_def =
+ iree_vm_ImportFunctionDef_vec_at(imported_functions, i);
+ if (!import_def) {
+ LOG(ERROR) << "import def missing body";
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ flatbuffers_string_t full_name =
+ iree_vm_ImportFunctionDef_full_name(import_def);
+ if (!flatbuffers_string_len(full_name)) {
+ LOG(ERROR) << "import def missing full_name";
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ if (!iree_vm_ImportFunctionDef_signature(import_def)) {
+ LOG(ERROR) << "import def missing a function signature";
+ return IREE_STATUS_INVALID_ARGUMENT;
}
}
- for (int i = 0; i < module_def->exported_functions()->size(); ++i) {
- auto* export_def = module_def->exported_functions()->Get(i);
+ for (size_t i = 0; i < iree_vm_ExportFunctionDef_vec_len(exported_functions);
+ ++i) {
+ iree_vm_ExportFunctionDef_table_t export_def =
+ iree_vm_ExportFunctionDef_vec_at(exported_functions, i);
if (!export_def) {
- LOG(ERROR) << "All exports must be valid.";
+ LOG(ERROR) << "export def missing body";
return IREE_STATUS_INVALID_ARGUMENT;
- } else if (!export_def->local_name() ||
- export_def->local_name()->size() == 0) {
- LOG(ERROR) << "All exports require a name.";
+ }
+ flatbuffers_string_t local_name =
+ iree_vm_ExportFunctionDef_local_name(export_def);
+ if (!flatbuffers_string_len(local_name)) {
+ LOG(ERROR) << "export def missing local_name";
return IREE_STATUS_INVALID_ARGUMENT;
- } else if (!export_def->signature()) {
- LOG(ERROR) << "All exports require a signature.";
+ }
+ if (!iree_vm_ExportFunctionDef_signature(export_def)) {
+ LOG(ERROR) << "export def missing a function signature";
return IREE_STATUS_INVALID_ARGUMENT;
- } else if (export_def->internal_ordinal() < 0 ||
- export_def->internal_ordinal() >=
- module_def->internal_functions()->size()) {
- LOG(ERROR)
- << "Out-of-bounds reference to a function in the internal table.";
+ }
+ int32_t internal_ordinal =
+ iree_vm_ExportFunctionDef_internal_ordinal(export_def);
+ if (internal_ordinal < 0 ||
+ internal_ordinal >=
+ iree_vm_InternalFunctionDef_vec_len(internal_functions)) {
+ LOG(ERROR) << "export def internal_ordinal out of bounds";
return IREE_STATUS_INVALID_ARGUMENT;
}
}
- for (int i = 0; i < module_def->internal_functions()->size(); ++i) {
- auto* function_def = module_def->internal_functions()->Get(i);
+ flatbuffers_uint8_vec_t bytecode_data =
+ iree_vm_BytecodeModuleDef_bytecode_data(module_def);
+ for (size_t i = 0;
+ i < iree_vm_InternalFunctionDef_vec_len(internal_functions); ++i) {
+ iree_vm_InternalFunctionDef_table_t function_def =
+ iree_vm_InternalFunctionDef_vec_at(internal_functions, i);
if (!function_def) {
- LOG(ERROR) << "All functions must be valid.";
+ LOG(ERROR) << "function def missing body";
return IREE_STATUS_INVALID_ARGUMENT;
- } else if (!function_def->signature()) {
- LOG(ERROR) << "All functions require a signature.";
+ }
+ if (!iree_vm_InternalFunctionDef_signature(function_def)) {
+ LOG(ERROR) << "function def missing signature";
return IREE_STATUS_INVALID_ARGUMENT;
}
- const auto* function_descriptor =
- module_def->function_descriptors()->Get(i);
- if (!function_descriptor || function_descriptor->bytecode_offset() < 0 ||
- function_descriptor->bytecode_length() < 0 ||
- function_descriptor->bytecode_offset() +
- function_descriptor->bytecode_length() >
- module_def->bytecode_data()->size()) {
- LOG(ERROR) << "Bytecode span must be a valid range.";
+ iree_vm_FunctionDescriptor_struct_t function_descriptor =
+ iree_vm_FunctionDescriptor_vec_at(function_descriptors, i);
+ if (function_descriptor->bytecode_offset < 0 ||
+ function_descriptor->bytecode_offset +
+ function_descriptor->bytecode_length >
+ flatbuffers_uint8_vec_len(bytecode_data)) {
+ LOG(ERROR) << "function descriptor bytecode span out of range";
return IREE_STATUS_INVALID_ARGUMENT;
}
- if (function_descriptor->i32_register_count() > IREE_I32_REGISTER_COUNT ||
- function_descriptor->ref_register_count() > IREE_REF_REGISTER_COUNT) {
- LOG(ERROR) << "Register counts out of range.";
+ if (function_descriptor->i32_register_count > IREE_I32_REGISTER_COUNT ||
+ function_descriptor->ref_register_count > IREE_REF_REGISTER_COUNT) {
+ LOG(ERROR) << "function descriptor register out of range";
return IREE_STATUS_INVALID_ARGUMENT;
}
@@ -228,22 +249,21 @@
static iree_string_view_t iree_vm_bytecode_module_name(void* self) {
iree_vm_bytecode_module_t* module = (iree_vm_bytecode_module_t*)self;
- auto* module_def = IREE_VM_GET_MODULE_DEF(module);
- return iree_string_view_t{module_def->name()->data(),
- module_def->name()->size()};
+ flatbuffers_string_t name = iree_vm_BytecodeModuleDef_name(module->def);
+ return iree_string_view_t{name, flatbuffers_string_len(name)};
}
static iree_vm_module_signature_t iree_vm_bytecode_module_signature(
void* self) {
iree_vm_bytecode_module_t* module = (iree_vm_bytecode_module_t*)self;
- auto* module_def = IREE_VM_GET_MODULE_DEF(module);
iree_vm_module_signature_t signature;
- signature.import_function_count =
- module_def->imported_functions()
- ? module_def->imported_functions()->size()
- : 0;
- signature.export_function_count = module_def->exported_functions()->size();
- signature.internal_function_count = module_def->internal_functions()->size();
+ memset(&signature, 0, sizeof(signature));
+ signature.import_function_count = iree_vm_ImportFunctionDef_vec_len(
+ iree_vm_BytecodeModuleDef_imported_functions(module->def));
+ signature.export_function_count = iree_vm_ExportFunctionDef_vec_len(
+ iree_vm_BytecodeModuleDef_exported_functions(module->def));
+ signature.internal_function_count = iree_vm_InternalFunctionDef_vec_len(
+ iree_vm_BytecodeModuleDef_internal_functions(module->def));
return signature;
}
@@ -252,66 +272,75 @@
iree_vm_function_t* out_function, iree_string_view_t* out_name,
iree_vm_function_signature_t* out_signature) {
if (out_function) {
- memset(out_function, 0, sizeof(iree_vm_function_t));
+ memset(out_function, 0, sizeof(*out_function));
}
if (out_name) {
- out_name->data = NULL;
- out_name->size = 0;
+ memset(out_name, 0, sizeof(*out_name));
}
if (out_signature) {
- memset(out_signature, 0, sizeof(iree_vm_function_signature_t));
+ memset(out_signature, 0, sizeof(*out_signature));
}
iree_vm_bytecode_module_t* module = (iree_vm_bytecode_module_t*)self;
- auto* module_def = IREE_VM_GET_MODULE_DEF(module);
-
- const ::flatbuffers::String* name = nullptr;
- const iree::vm::FunctionSignatureDef* signature = nullptr;
+ flatbuffers_string_t name = NULL;
+ iree_vm_FunctionSignatureDef_table_t signature = NULL;
if (linkage == IREE_VM_FUNCTION_LINKAGE_IMPORT) {
- if (!module_def->imported_functions() || ordinal < 0 ||
- ordinal >= module_def->imported_functions()->size()) {
+ iree_vm_ImportFunctionDef_vec_t imported_functions =
+ iree_vm_BytecodeModuleDef_imported_functions(module->def);
+ if (ordinal < 0 ||
+ ordinal >= iree_vm_ImportFunctionDef_vec_len(imported_functions)) {
return IREE_STATUS_INVALID_ARGUMENT;
}
- auto* import_def = module_def->imported_functions()->Get(ordinal);
- name = import_def->full_name();
- signature = import_def->signature();
+ iree_vm_ImportFunctionDef_table_t import_def =
+ iree_vm_ImportFunctionDef_vec_at(imported_functions, ordinal);
+ name = iree_vm_ImportFunctionDef_full_name(import_def);
+ signature = iree_vm_ImportFunctionDef_signature(import_def);
if (out_function) {
out_function->module = &module->interface;
out_function->linkage = linkage;
out_function->ordinal = ordinal;
}
} else if (linkage == IREE_VM_FUNCTION_LINKAGE_EXPORT) {
- if (ordinal < 0 || ordinal >= module_def->exported_functions()->size()) {
+ iree_vm_ExportFunctionDef_vec_t exported_functions =
+ iree_vm_BytecodeModuleDef_exported_functions(module->def);
+ if (ordinal < 0 ||
+ ordinal >= iree_vm_ExportFunctionDef_vec_len(exported_functions)) {
return IREE_STATUS_INVALID_ARGUMENT;
}
- auto* export_def = module_def->exported_functions()->Get(ordinal);
- name = export_def->local_name();
- signature = export_def->signature();
+ iree_vm_ExportFunctionDef_table_t export_def =
+ iree_vm_ExportFunctionDef_vec_at(exported_functions, ordinal);
+ name = iree_vm_ExportFunctionDef_local_name(export_def);
+ signature = iree_vm_ExportFunctionDef_signature(export_def);
if (out_function) {
out_function->module = &module->interface;
out_function->linkage = IREE_VM_FUNCTION_LINKAGE_INTERNAL;
- out_function->ordinal = export_def->internal_ordinal();
+ out_function->ordinal =
+ iree_vm_ExportFunctionDef_internal_ordinal(export_def);
- const iree_vm_function_descriptor_t* function_descriptor =
- &module->function_descriptor_table[export_def->internal_ordinal()];
+ const iree_vm_FunctionDescriptor_t* function_descriptor =
+ &module->function_descriptor_table[out_function->ordinal];
out_function->i32_register_count =
function_descriptor->i32_register_count;
out_function->ref_register_count =
function_descriptor->ref_register_count;
}
} else {
- if (ordinal < 0 || ordinal >= module_def->internal_functions()->size()) {
+ iree_vm_InternalFunctionDef_vec_t internal_functions =
+ iree_vm_BytecodeModuleDef_internal_functions(module->def);
+ if (ordinal < 0 ||
+ ordinal >= iree_vm_InternalFunctionDef_vec_len(internal_functions)) {
return IREE_STATUS_INVALID_ARGUMENT;
}
- auto* function_def = module_def->internal_functions()->Get(ordinal);
- name = function_def->local_name();
- signature = function_def->signature();
+ iree_vm_InternalFunctionDef_table_t function_def =
+ iree_vm_InternalFunctionDef_vec_at(internal_functions, ordinal);
+ name = iree_vm_InternalFunctionDef_local_name(function_def);
+ signature = iree_vm_InternalFunctionDef_signature(function_def);
if (out_function) {
out_function->module = &module->interface;
out_function->linkage = IREE_VM_FUNCTION_LINKAGE_INTERNAL;
out_function->ordinal = ordinal;
- const iree_vm_function_descriptor_t* function_descriptor =
+ const iree_vm_FunctionDescriptor_t* function_descriptor =
&module->function_descriptor_table[ordinal];
out_function->i32_register_count =
function_descriptor->i32_register_count;
@@ -321,14 +350,14 @@
}
if (out_name && name) {
- out_name->data = name->c_str();
- out_name->size = name->size();
+ out_name->data = name;
+ out_name->size = flatbuffers_string_len(name);
}
if (out_signature && signature) {
- out_signature->argument_count =
- signature->argument_types() ? signature->argument_types()->size() : 0;
- out_signature->result_count =
- signature->result_types() ? signature->result_types()->size() : 0;
+ out_signature->argument_count = flatbuffers_int32_vec_len(
+ iree_vm_FunctionSignatureDef_argument_types(signature));
+ out_signature->result_count = flatbuffers_int32_vec_len(
+ iree_vm_FunctionSignatureDef_result_types(signature));
}
return IREE_STATUS_OK;
@@ -337,9 +366,6 @@
static iree_status_t iree_vm_bytecode_module_get_function_reflection_attr(
void* self, iree_vm_function_linkage_t linkage, int32_t ordinal,
int32_t index, iree_string_view_t* key, iree_string_view_t* value) {
- iree_vm_bytecode_module_t* module = (iree_vm_bytecode_module_t*)self;
- auto* module_def = IREE_VM_GET_MODULE_DEF(module);
-
if (linkage != IREE_VM_FUNCTION_LINKAGE_INTERNAL) {
iree_vm_function_t internal_function;
iree_vm_bytecode_module_get_function(self, linkage, ordinal,
@@ -348,83 +374,98 @@
ordinal = internal_function.ordinal;
}
- if (ordinal < 0 || ordinal >= module_def->internal_functions()->size()) {
+ iree_vm_bytecode_module_t* module = (iree_vm_bytecode_module_t*)self;
+ iree_vm_InternalFunctionDef_vec_t internal_functions =
+ iree_vm_BytecodeModuleDef_internal_functions(module->def);
+
+ if (ordinal < 0 ||
+ ordinal >= iree_vm_InternalFunctionDef_vec_len(internal_functions)) {
return IREE_STATUS_INVALID_ARGUMENT;
}
- auto* export_def = module_def->internal_functions()->Get(ordinal);
- const iree::vm::FunctionSignatureDef* signature = export_def->signature();
- if (!signature->reflection_attrs() || index < 0 ||
- index >= signature->reflection_attrs()->size()) {
+ iree_vm_InternalFunctionDef_table_t function_def =
+ iree_vm_InternalFunctionDef_vec_at(internal_functions, ordinal);
+ iree_vm_FunctionSignatureDef_table_t signature =
+ iree_vm_InternalFunctionDef_signature(function_def);
+ iree_vm_ReflectionAttrDef_vec_t reflection_attrs =
+ iree_vm_FunctionSignatureDef_reflection_attrs(signature);
+ if (index < 0 ||
+ index >= iree_vm_ReflectionAttrDef_vec_len(reflection_attrs)) {
return IREE_STATUS_NOT_FOUND;
}
- const ::iree::vm::ReflectionAttrDef* attr =
- signature->reflection_attrs()->Get(index);
- const ::flatbuffers::String* attr_key = attr->key();
- const ::flatbuffers::String* attr_value = attr->value();
- if (!attr_key || !attr_value) {
+ iree_vm_ReflectionAttrDef_table_t attr =
+ iree_vm_ReflectionAttrDef_vec_at(reflection_attrs, index);
+ flatbuffers_string_t attr_key = iree_vm_ReflectionAttrDef_key(attr);
+ flatbuffers_string_t attr_value = iree_vm_ReflectionAttrDef_value(attr);
+ if (!flatbuffers_string_len(attr_key) ||
+ !flatbuffers_string_len(attr_value)) {
// Because reflection metadata should not impose any overhead for the
// non reflection case, we do not eagerly validate in on load -- instead
// verify it structurally as needed.
return IREE_STATUS_FAILED_PRECONDITION;
}
- key->data = attr_key->c_str();
- key->size = attr_key->size();
- value->data = attr_value->c_str();
- value->size = attr_value->size();
+ key->data = attr_key;
+ key->size = flatbuffers_string_len(attr_key);
+ value->data = attr_value;
+ value->size = flatbuffers_string_len(attr_value);
return IREE_STATUS_OK;
}
-static bool iree_vm_bytecode_module_compare_str(const flatbuffers::String* lhs,
- iree_string_view_t rhs) {
- if (!lhs || lhs->size() != rhs.size) return false;
- return strncmp(lhs->c_str(), rhs.data, rhs.size) == 0;
-}
-
static iree_status_t iree_vm_bytecode_module_lookup_function(
void* self, iree_vm_function_linkage_t linkage, iree_string_view_t name,
iree_vm_function_t* out_function) {
if (!out_function) return IREE_STATUS_INVALID_ARGUMENT;
memset(out_function, 0, sizeof(iree_vm_function_t));
- if (!name.data || !name.size) return IREE_STATUS_INVALID_ARGUMENT;
-
- iree_vm_bytecode_module_t* module = (iree_vm_bytecode_module_t*)self;
- auto* module_def = IREE_VM_GET_MODULE_DEF(module);
+ if (iree_string_view_is_empty(name)) return IREE_STATUS_INVALID_ARGUMENT;
// NOTE: we could organize imports/exports alphabetically so we could bsearch.
+ iree_vm_bytecode_module_t* module = (iree_vm_bytecode_module_t*)self;
if (linkage == IREE_VM_FUNCTION_LINKAGE_IMPORT) {
- if (!module_def->imported_functions()) {
- return IREE_STATUS_NOT_FOUND;
- }
- for (int ordinal = 0; ordinal < module_def->imported_functions()->size();
+ iree_vm_ImportFunctionDef_vec_t imported_functions =
+ iree_vm_BytecodeModuleDef_imported_functions(module->def);
+ for (size_t ordinal = 0;
+ ordinal < iree_vm_ImportFunctionDef_vec_len(imported_functions);
++ordinal) {
- auto* import_def = module_def->imported_functions()->Get(ordinal);
- if (iree_vm_bytecode_module_compare_str(import_def->full_name(), name)) {
+ iree_vm_ImportFunctionDef_table_t import_def =
+ iree_vm_ImportFunctionDef_vec_at(imported_functions, ordinal);
+ if (iree_vm_flatbuffer_strcmp(
+ iree_vm_ImportFunctionDef_full_name(import_def), name) == 0) {
return iree_vm_bytecode_module_get_function(self, linkage, ordinal,
out_function, NULL, NULL);
}
}
return IREE_STATUS_NOT_FOUND;
} else if (linkage == IREE_VM_FUNCTION_LINKAGE_EXPORT) {
- for (int ordinal = 0; ordinal < module_def->exported_functions()->size();
+ iree_vm_ExportFunctionDef_vec_t exported_functions =
+ iree_vm_BytecodeModuleDef_exported_functions(module->def);
+ for (size_t ordinal = 0;
+ ordinal < iree_vm_InternalFunctionDef_vec_len(exported_functions);
++ordinal) {
- auto* export_def = module_def->exported_functions()->Get(ordinal);
- if (iree_vm_bytecode_module_compare_str(export_def->local_name(), name)) {
+ iree_vm_ExportFunctionDef_table_t export_def =
+ iree_vm_ExportFunctionDef_vec_at(exported_functions, ordinal);
+ if (iree_vm_flatbuffer_strcmp(
+ iree_vm_ExportFunctionDef_local_name(export_def), name) == 0) {
return iree_vm_bytecode_module_get_function(
self, IREE_VM_FUNCTION_LINKAGE_INTERNAL,
- export_def->internal_ordinal(), out_function, NULL, NULL);
+ iree_vm_ExportFunctionDef_internal_ordinal(export_def),
+ out_function, NULL, NULL);
}
}
return IREE_STATUS_NOT_FOUND;
} else {
- for (int ordinal = 0; ordinal < module_def->internal_functions()->size();
+ iree_vm_InternalFunctionDef_vec_t internal_functions =
+ iree_vm_BytecodeModuleDef_internal_functions(module->def);
+ for (size_t ordinal = 0;
+ ordinal < iree_vm_InternalFunctionDef_vec_len(internal_functions);
++ordinal) {
- auto* function_def = module_def->internal_functions()->Get(ordinal);
- if (iree_vm_bytecode_module_compare_str(function_def->local_name(),
- name)) {
+ iree_vm_InternalFunctionDef_table_t function_def =
+ iree_vm_InternalFunctionDef_vec_at(internal_functions, ordinal);
+ if (iree_vm_flatbuffer_strcmp(
+ iree_vm_InternalFunctionDef_local_name(function_def), name) ==
+ 0) {
return iree_vm_bytecode_module_get_function(
self, IREE_VM_FUNCTION_LINKAGE_INTERNAL, ordinal, out_function,
NULL, NULL);
@@ -438,20 +479,21 @@
// Returns the total size of the structure and all tables with padding applied.
// |state| may be null if only the structure size is required for allocation.
static iree_host_size_t iree_vm_bytecode_module_layout_state(
- const iree::vm::BytecodeModuleDef* module_def,
+ iree_vm_BytecodeModuleDef_table_t module_def,
iree_vm_bytecode_module_state_t* state) {
- int rwdata_storage_capacity =
- module_def->module_state()
- ? module_def->module_state()->global_bytes_capacity()
- : 0;
- int global_ref_count = module_def->module_state()
- ? module_def->module_state()->global_ref_count()
- : 0;
- int rodata_ref_count =
- module_def->rodata_segments() ? module_def->rodata_segments()->size() : 0;
- int import_function_count = module_def->imported_functions()
- ? module_def->imported_functions()->size()
- : 0;
+ iree_vm_ModuleStateDef_table_t module_state =
+ iree_vm_BytecodeModuleDef_module_state(module_def);
+ int rwdata_storage_capacity = 0;
+ int global_ref_count = 0;
+ if (module_state) {
+ rwdata_storage_capacity =
+ iree_vm_ModuleStateDef_global_bytes_capacity(module_state);
+ global_ref_count = iree_vm_ModuleStateDef_global_ref_count(module_state);
+ }
+ int rodata_ref_count = iree_vm_RodataSegmentDef_vec_len(
+ iree_vm_BytecodeModuleDef_rodata_segments(module_def));
+ int import_function_count = iree_vm_ImportFunctionDef_vec_len(
+ iree_vm_BytecodeModuleDef_imported_functions(module_def));
uint8_t* base_ptr = (uint8_t*)state;
iree_host_size_t offset =
@@ -491,7 +533,7 @@
*out_module_state = NULL;
iree_vm_bytecode_module_t* module = (iree_vm_bytecode_module_t*)self;
- auto* module_def = IREE_VM_GET_MODULE_DEF(module);
+ iree_vm_BytecodeModuleDef_table_t module_def = module->def;
// Compute the total size required (with padding) for the state structure.
iree_host_size_t total_state_struct_size =
@@ -507,13 +549,16 @@
iree_vm_bytecode_module_layout_state(module_def, state);
// Setup rodata segments to point directly at the flatbuffer memory.
+ iree_vm_RodataSegmentDef_vec_t rodata_segments =
+ iree_vm_BytecodeModuleDef_rodata_segments(module_def);
for (int i = 0; i < state->rodata_ref_count; ++i) {
- const iree::vm::RodataSegmentDef* segment =
- module_def->rodata_segments()->Get(i);
+ iree_vm_RodataSegmentDef_table_t segment =
+ iree_vm_RodataSegmentDef_vec_at(rodata_segments, i);
iree_vm_ro_byte_buffer_t* ref = &state->rodata_ref_table[i];
iree_atomic_store(&ref->ref_object.counter, 1);
- ref->data.data = segment->data()->Data();
- ref->data.data_length = segment->data()->size();
+ ref->data.data = iree_vm_RodataSegmentDef_data(segment);
+ ref->data.data_length =
+ flatbuffers_uint8_vec_len(iree_vm_RodataSegmentDef_data(segment));
}
*out_module_state = (iree_vm_module_state_t*)state;
@@ -594,31 +639,24 @@
iree_const_byte_span_t flatbuffer_data,
iree_allocator_t flatbuffer_allocator, iree_allocator_t allocator,
iree_vm_module_t** out_module) {
- if (!out_module) {
- LOG(ERROR) << "Output module argument not set";
- return IREE_STATUS_INVALID_ARGUMENT;
- }
+ if (!out_module) return IREE_STATUS_INVALID_ARGUMENT;
*out_module = NULL;
- if (!flatbuffer_data.data || flatbuffer_data.data_length < 16) {
- LOG(ERROR) << "Flatbuffer data is not present or less than 16 bytes";
- return IREE_STATUS_INVALID_ARGUMENT;
- } else if (!iree::vm::BytecodeModuleDefBufferHasIdentifier(
- flatbuffer_data.data)) {
- LOG(ERROR) << "Flatbuffer data does not have bytecode module identifier";
- return IREE_STATUS_INVALID_ARGUMENT;
- }
+ IREE_RETURN_IF_ERROR(
+ iree_vm_bytecode_module_flatbuffer_verify(flatbuffer_data));
- const iree::vm::BytecodeModuleDef* module_def =
- ::flatbuffers::GetRoot<iree::vm::BytecodeModuleDef>(flatbuffer_data.data);
+ iree_vm_BytecodeModuleDef_table_t module_def =
+ iree_vm_BytecodeModuleDef_as_root(flatbuffer_data.data);
if (!module_def) {
- LOG(ERROR) << "Failed getting root from flatbuffer data";
+ LOG(ERROR) << "failed getting root from flatbuffer data; expected "
+ "identifier " iree_vm_BytecodeModuleDef_file_identifier
+ " not found";
return IREE_STATUS_INVALID_ARGUMENT;
}
- IREE_RETURN_IF_ERROR(iree_vm_bytecode_module_flatbuffer_verify(module_def));
+ iree_vm_TypeDef_vec_t type_defs = iree_vm_BytecodeModuleDef_types(module_def);
size_t type_table_size =
- module_def->types()->size() * sizeof(iree_vm_type_def_t);
+ iree_vm_TypeDef_vec_len(type_defs) * sizeof(iree_vm_type_def_t);
iree_vm_bytecode_module_t* module = NULL;
IREE_RETURN_IF_ERROR(iree_allocator_malloc(
@@ -626,21 +664,30 @@
(void**)&module));
module->allocator = allocator;
+ iree_vm_FunctionDescriptor_vec_t function_descriptors =
+ iree_vm_BytecodeModuleDef_function_descriptors(module_def);
module->function_descriptor_count =
- module_def->function_descriptors()->size();
- module->function_descriptor_table =
- (const iree_vm_function_descriptor_t*)module_def->function_descriptors()
- ->data();
+ iree_vm_FunctionDescriptor_vec_len(function_descriptors);
+ module->function_descriptor_table = function_descriptors;
+
+ flatbuffers_uint8_vec_t bytecode_data =
+ iree_vm_BytecodeModuleDef_bytecode_data(module_def);
module->bytecode_data = iree_const_byte_span_t{
- module_def->bytecode_data()->Data(), module_def->bytecode_data()->size()};
+ bytecode_data, flatbuffers_uint8_vec_len(bytecode_data)};
module->flatbuffer_data = flatbuffer_data;
module->flatbuffer_allocator = flatbuffer_allocator;
+ module->def = module_def;
- module->type_count = module_def->types()->size();
+ module->type_count = iree_vm_TypeDef_vec_len(type_defs);
module->type_table = (iree_vm_type_def_t*)((uint8_t*)module +
sizeof(iree_vm_bytecode_module_t));
- iree_vm_bytecode_module_resolve_types(module_def, module->type_table);
+ iree_status_t resolve_status =
+ iree_vm_bytecode_module_resolve_types(type_defs, module->type_table);
+ if (!iree_status_is_ok(resolve_status)) {
+ iree_allocator_free(allocator, module);
+ return resolve_status;
+ }
iree_vm_module_initialize(&module->interface, module);
module->interface.destroy = iree_vm_bytecode_module_destroy;
diff --git a/iree/vm/bytecode_module_impl.h b/iree/vm/bytecode_module_impl.h
index da53486..0f4e818 100644
--- a/iree/vm/bytecode_module_impl.h
+++ b/iree/vm/bytecode_module_impl.h
@@ -16,6 +16,16 @@
#define IREE_VM_BYTECODE_MODULE_IMPL_H_
#include <stdint.h>
+#include <string.h>
+
+#ifdef IREE_PLATFORM_ANDROID
+#include <stdalign.h>
+#else
+// TODO(benvanik): figure out how to make MSVC happy with C11 stdalign.h.
+#ifdef __cplusplus
+#include <cstdalign>
+#endif // __cplusplus
+#endif
#include "iree/base/api.h"
#include "iree/vm/builtin_types.h"
@@ -25,18 +35,15 @@
#include "iree/vm/type_def.h"
#include "iree/vm/value.h"
+// NOTE: include order matters:
+#include "flatcc/reflection/flatbuffers_common_reader.h"
+#include "iree/schemas/bytecode_module_def_reader.h"
+#include "iree/schemas/bytecode_module_def_verifier.h"
+
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
-// Matches the FunctionDescriptor struct in the flatbuffer.
-typedef struct {
- int32_t bytecode_offset;
- int32_t bytecode_length;
- uint16_t i32_register_count;
- uint16_t ref_register_count;
-} iree_vm_function_descriptor_t;
-
// A loaded bytecode module.
typedef struct {
// Interface routing to the bytecode module functions.
@@ -48,7 +55,8 @@
// Mapped 1:1 with internal functions. Each defined bytecode span represents a
// range of bytes in |bytecode_data|.
int32_t function_descriptor_count;
- const iree_vm_function_descriptor_t* function_descriptor_table;
+ const iree_vm_FunctionDescriptor_t* function_descriptor_table;
+
// A pointer to the bytecode data embedded within the module.
iree_const_byte_span_t bytecode_data;
@@ -58,6 +66,7 @@
// Underlying FlatBuffer data and allocator (which may be null).
iree_const_byte_span_t flatbuffer_data;
iree_allocator_t flatbuffer_allocator;
+ iree_vm_BytecodeModuleDef_table_t def;
// Type table mapping module type IDs to registered VM types.
int32_t type_count;
diff --git a/packaging/python/README.md b/packaging/python/README.md
index 6fcae98..3d387a2 100644
--- a/packaging/python/README.md
+++ b/packaging/python/README.md
@@ -55,6 +55,11 @@
functional) version without TensorFlow kernels. This should not be done for
released binaries but can help while developing.
+Note that bazel does not always build properly named artifacts. See the tool
+`hack_python_package_from_runfiles.py` to extract and fixup artifacts from a
+bazel-bin directory. If using this mechanism, then the environment variable
+`PYIREE_PYTHON_ROOT` should be set to a suitable temp directory.
+
```shell
cd $IREE_SRC
bazel build -c opt \
diff --git a/packaging/python/common_setup.py b/packaging/python/common_setup.py
index 149dbdf..89d2639 100644
--- a/packaging/python/common_setup.py
+++ b/packaging/python/common_setup.py
@@ -28,6 +28,11 @@
def get_package_dir(prefix=("bindings", "python")):
+ explicit_root = os.environ.get("PYIREE_PYTHON_ROOT")
+ if explicit_root:
+ return explicit_root
+
+ # Use env variables based on build system type.
cmake_build_root = os.environ.get("PYIREE_CMAKE_BUILD_ROOT")
bazel_build_root = os.environ.get("PYIREE_BAZEL_BUILD_ROOT")
diff --git a/packaging/python/hack_python_package_from_runfiles.py b/packaging/python/hack_python_package_from_runfiles.py
new file mode 100644
index 0000000..61c2d6e
--- /dev/null
+++ b/packaging/python/hack_python_package_from_runfiles.py
@@ -0,0 +1,102 @@
+#!/usr/bin/python
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Given a runfiles directory from a bazel build, does surgery to extract
+# a usable python package directory. In addition to the bazel directory
+# structure being unnecessarily obtuse, it is also really hard to actually
+# name files correctly. This affects python extension modules which must be
+# named with a specific extension suffix. Bazel is extremely unflexible and
+# we patch around it with this script. For the record, there are various ways
+# to write custom rules to do this more natively, but it is all complicated
+# and needless complexity. We opt for a script that is at least readable by
+# mere mortals and in one place.
+# Usage:
+# ./this_script <dest_dir> <path to bazel-bin>
+
+import os
+import platform
+import shutil
+import sys
+import sysconfig
+
+FILE_NAME_MAP = {
+ "binding.so": "binding{}".format(sysconfig.get_config_var("EXT_SUFFIX")),
+ "binding.pyd": False,
+ "binding.dylib": False,
+}
+
+
+def get_exe_suffix():
+ if platform.system() == "Windows":
+ return ".exe"
+ else:
+ return ""
+
+
+def copy_prefix(dest_dir, runfiles_dir, prefix):
+ # And finally seek into the corresponding path in the runfiles dir.
+ # Aren't bazel paths fun???
+ # Note that the "iree_core" path segment corresponds to the workspace name.
+ pkg_dir = os.path.join(runfiles_dir, "iree_core", *prefix)
+ if not os.path.exists(pkg_dir):
+ return
+ dest_dir = os.path.join(dest_dir)
+ for root, dirs, files in os.walk(pkg_dir):
+ assert root.startswith(pkg_dir)
+ dest_prefix = root[len(pkg_dir):]
+ if dest_prefix.startswith(os.path.sep):
+ dest_prefix = dest_prefix[1:]
+ local_dest_dir = os.path.join(dest_dir, dest_prefix)
+ os.makedirs(local_dest_dir, exist_ok=True)
+ for file in files:
+ copy_file(os.path.join(root, file), local_dest_dir)
+
+
+def copy_file(src_file, dst_dir):
+ basename = os.path.basename(src_file)
+ dst_file = os.path.join(dst_dir, basename)
+ mapped_name = FILE_NAME_MAP.get(basename)
+ if mapped_name is False:
+ # Skip.
+ return
+ elif mapped_name is not None:
+ dst_file = os.path.join(dst_dir, mapped_name)
+ shutil.copyfile(src_file, dst_file, follow_symlinks=True)
+
+
+def main():
+ # Parse args.
+ dest_dir = sys.argv[1]
+ bazel_bin = sys.argv[2] if len(sys.argv) > 2 else os.path.join(
+ os.path.dirname(__file__), "..", "..", "bazel-bin")
+
+ # Find the path to the runfiles of the built target:
+ # //bindings/python/packaging:all_pyiree_packages
+ runfiles_dir = os.path.join(
+ bazel_bin, "packaging", "python",
+ "all_pyiree_packages%s.runfiles" % (get_exe_suffix(),))
+ if not os.path.isdir(runfiles_dir):
+ print("ERROR: Could not find build target 'all_pyiree_packages':",
+ runfiles_dir)
+ print("Make sure to build target", "//packaging/python:all_pyiree_packages")
+ sys.exit(1)
+
+ copy_prefix(dest_dir, runfiles_dir, ("bindings", "python"))
+ copy_prefix(dest_dir, runfiles_dir,
+ ("integrations", "tensorflow", "bindings", "python"))
+
+
+if __name__ == "__main__":
+ main()