Merge pull request #3594 from rsuderman:main-to-google
PiperOrigin-RevId: 338738884
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index d5d5bb0..7e91cc7 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -50,8 +50,8 @@
uses: actions/checkout@v2
- name: Setting up python
uses: actions/setup-python@v2
- # We have to explicitly fetch the base branch as well
- name: Fetching Base Branch
+ # We have to explicitly fetch the base branch as well
run: git fetch --no-tags --prune --depth=1 origin "${BASE_REF?}:${BASE_REF?}"
- name: Install yapf
run: python3 -m pip install yapf
@@ -65,6 +65,26 @@
printf "You can fix the lint errors above by running\n"
printf " git diff -U0 "${BASE_REF?}" | python3 third_party/format_diff/format_diff.py yapf -i\n"
+ pytype:
+ runs-on: ubuntu-18.04
+ env:
+ BASE_REF: ${{ github.base_ref }}
+ steps:
+ - name: Checking out repository
+ uses: actions/checkout@v2
+ - name: Setting up python
+ uses: actions/setup-python@v2
+ with:
+ # Pytype does not support python3.9, which this action defaults to.
+ python-version: '3.8'
+ - name: Fetching Base Branch
+ # We have to explicitly fetch the base branch as well
+ run: git fetch --no-tags --prune --depth=1 origin "${BASE_REF?}:${BASE_REF?}"
+ - name: Install pytype
+ run: python3 -m pip install pytype
+ - name: Run pytype on changed files
+ run: ./build_tools/pytype/check_diff.sh "${BASE_REF?}"
+
clang-format:
runs-on: ubuntu-18.04
env:
@@ -79,8 +99,8 @@
chmod +x /tmp/git-clang-format
- name: Checking out repository
uses: actions/checkout@v2
- # We have to explicitly fetch the base branch as well
- name: Fetching Base Branch
+ # We have to explicitly fetch the base branch as well
run: git fetch --no-tags --prune --depth=1 origin "${BASE_REF?}:${BASE_REF?}"
- name: Running clang-format on changed source files
run: |
@@ -102,8 +122,8 @@
steps:
- name: Checking out repository
uses: actions/checkout@v2
- # We have to explicitly fetch the base branch as well
- name: Fetching Base Branch
+ # We have to explicitly fetch the base branch as well
run: git fetch --no-tags --prune --depth=1 origin "${BASE_REF?}:${BASE_REF?}"
- name: Checking tabs
run: ./scripts/check_tabs.sh "${BASE_REF?}"
diff --git a/.gitignore b/.gitignore
index 96a5b66..b0a8cbc 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,6 +1,7 @@
# Python
*.pyc
**/.ipynb_checkpoints/
+.pytype/
# Visual Studio files
.vs/
diff --git a/build_tools/pytype/check_diff.sh b/build_tools/pytype/check_diff.sh
new file mode 100755
index 0000000..2b9851a
--- /dev/null
+++ b/build_tools/pytype/check_diff.sh
@@ -0,0 +1,96 @@
+#!/bin/bash
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Uses git diff to run pytype on changed files.
+# Example Usage:
+# Defaults to comparing against 'main'.
+# ./build_tools/pytype/check_diff.sh
+# A specific branch can be specified.
+# ./build_tools/pytype/check_diff.sh google
+# Or all python files outside of './third_party/' can be checked.
+# ./build_tools/pytype/check_diff.sh all
+
+DIFF_TARGET="${1:-main}"
+echo "Running pycheck against '${DIFF_TARGET?}'"
+
+if [[ "${DIFF_TARGET?}" = "all" ]]; then
+ FILES=$(find -name "*\.py" -not -path "./third_party/*")
+else
+ FILES=$(git diff --name-only "${DIFF_TARGET?}" | grep '.*\.py')
+fi
+
+
+# We seperate the python files into multiple pytype calls because otherwise
+# Ninja gets confused. See https://github.com/google/pytype/issues/198
+BASE=$(echo "${FILES?}" | grep -vP '^(\./)?integrations/*')
+IREE_TF=$(echo "${FILES?}" | \
+ grep -P '^(\./)?integrations/tensorflow/bindings/python/pyiree/tf/.*')
+IREE_XLA=$(echo "${FILES?}" | \
+ grep -P '^(\./)?integrations/tensorflow/bindings/python/pyiree/xla/.*')
+COMPILER=$(echo "${FILES?}" | \
+ grep -P '^(\./)?integrations/tensorflow/compiler/.*')
+E2E=$(echo "${FILES?}" | grep -P '^(\./)?integrations/tensorflow/e2e/.*')
+
+function check_files() {
+ # $1: previous return code
+ # $2...: files to check
+ if [[ -z "${@:2}" ]]; then
+ echo "No files to check."
+ echo
+ return "${1?}"
+ fi
+
+ # We disable import-error because pytype doesn't have access to bazel.
+ # We disable pyi-error because of the way the bindings imports work.
+ echo "${@:2}" | \
+ xargs python3 -m pytype --disable=import-error,pyi-error -j $(nproc)
+ EXIT_CODE="$?"
+ echo
+ if [[ "${EXIT_CODE?}" -gt "${1?}" ]]; then
+ return "${EXIT_CODE?}"
+ else
+ return "${1?}"
+ fi
+}
+
+MAX_CODE=0
+
+echo "Checking .py files outside of integrations/"
+check_files "${MAX_CODE?}" "${BASE?}"
+MAX_CODE="$?"
+
+echo "Checking .py files in integrations/tensorflow/bindings/python/pyiree/tf/.*"
+check_files "${MAX_CODE?}" "${IREE_TF?}"
+MAX_CODE="$?"
+
+echo "Checking .py files in integrations/tensorflow/bindings/python/pyiree/xla/.*"
+check_files "${MAX_CODE?}" "${IREE_XLA?}"
+MAX_CODE="$?"
+
+echo "Checking .py files in integrations/tensorflow/compiler/.*"
+check_files "${MAX_CODE?}" "${COMPILER?}"
+MAX_CODE="$?"
+
+echo "Checking .py files in integrations/tensorflow/e2e/.*"
+check_files "${MAX_CODE?}" "${E2E?}"
+MAX_CODE="$?"
+
+
+if [[ "${MAX_CODE?}" -ne "0" ]]; then
+ echo "One or more pytype checks failed."
+ echo "You can view these errors locally by running"
+ echo " ./build_tools/pytype/check_diff.sh ${DIFF_TARGET?}"
+ exit "${MAX_CODE?}"
+fi
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index b09f95b..111ad46 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -249,7 +249,7 @@
"""Stores the inputs and outputs of a series of calls to a module."""
def __init__(self,
- module: tf_utils.CompiledModule,
+ module: Union[tf_utils.CompiledModule, None],
function: Union[Callable[["TracedModule"], None], None],
_load_dict: Dict[str, Any] = None):
"""Extracts metadata from module and function and initializes.
@@ -563,7 +563,7 @@
self._module = module
self._trace = trace
- def _trace_call(self, method: Callable[..., Any], method_name: str):
+ def _trace_call(self, method: tf_utils._FunctionWrapper, method_name: str):
"""Decorates a CompiledModule method to capture its inputs and outputs."""
def call(*args, **kwargs):
@@ -611,8 +611,8 @@
def compile_tf_module(
- module_class: Type[tf.Module], exported_names: Sequence[str] = ()
-) -> Callable[[Any], Any]:
+ module_class: Type[tf.Module],
+ exported_names: Sequence[str] = ()) -> Modules:
"""Compiles module_class to each backend that we test.
Args:
@@ -648,11 +648,10 @@
return _global_modules
-def compile_tf_signature_def_saved_model(saved_model_dir: str,
- saved_model_tags: Set[str],
- module_name: str, exported_name: str,
- input_names: Sequence[str],
- output_names: Sequence[str]):
+def compile_tf_signature_def_saved_model(
+ saved_model_dir: str, saved_model_tags: Set[str], module_name: str,
+ exported_name: str, input_names: Sequence[str],
+ output_names: Sequence[str]) -> Modules:
"""Compiles a SignatureDef SavedModel to each backend that we test.
Args:
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index 99573c5..c0f8473 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -87,10 +87,10 @@
def _setup_mlir_crash_reproducer(
- function: Callable[[Any], Any],
+ function: Any, # pytype doesn't support arbitrary Callable[*args, **kwargs]
artifacts_dir: str,
backend_id: str,
-) -> Callable[[Any], Any]:
+) -> Any: # Callable[Any, Any]
"""Wraps `function` so that it a MLIR crash reproducer is saved if it crashes.
Writes to `artifacts_dir/reproducer__{backend}.mlir` in the case of a crash.
@@ -253,6 +253,16 @@
exported_name, artifacts_dir)
+class _FunctionWrapper(object):
+
+ def __call__(self, *args, **kwargs):
+ raise NotImplementedError()
+
+ def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
+ """Dummy function to match _IreeFunctionWrapper's API."""
+ return ("",), ("",)
+
+
class CompiledModule(object):
"""Base class for the TF and IREE compiled modules."""
@@ -260,7 +270,7 @@
self,
module_name: str,
backend_info: "BackendInfo",
- compiled_paths: Dict[str, str],
+ compiled_paths: Union[Dict[str, str], None],
):
"""Shared base constructor – not useful on its own.
@@ -344,6 +354,9 @@
"""
raise NotImplementedError()
+ def __getattr__(self, attr: str) -> _FunctionWrapper:
+ raise NotImplementedError()
+
def iree_serializable(self):
return False
@@ -351,13 +364,6 @@
return False
-class _FunctionWrapper(object):
-
- def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
- """Dummy function to match _IreeFunctionWrapper's API."""
- return (), ()
-
-
class _IreeFunctionWrapper(_FunctionWrapper):
"""Wraps an IREE function, making it callable."""
@@ -681,7 +687,7 @@
instance = module_class()
functions = []
for name in exported_names:
- functions.append(instance.__getattribute__(name).get_concrete_function())
+ functions.append(getattr(instance, name).get_concrete_function())
return functions, exported_names
@@ -787,7 +793,8 @@
def __init__(self, interpreter: tf.lite.Interpreter):
self._interpreter = interpreter
- def __call__(self, *args, **kwargs) -> Tuple[Any]:
+ def __call__(self, *args,
+ **kwargs) -> Union[Dict[str, Any], Tuple[Any], np.ndarray]:
if len(args) and len(kwargs):
raise ValueError("Passing both args and kwargs is not supported by "
"_TfLiteFunctionWrapper")
@@ -823,13 +830,12 @@
outputs.append(value)
# Process them to match the output of the tf.Module.
- if not is_dict:
- outputs = tuple(outputs)
- if len(outputs) == 1:
- outputs = outputs[0]
+ if is_dict:
+ return dict(outputs)
else:
- outputs = dict(outputs)
- return outputs
+ if len(outputs) == 1:
+ return outputs[0]
+ return tuple(outputs)
class TfLiteCompiledModule(CompiledModule):
diff --git a/integrations/tensorflow/compiler/BUILD b/integrations/tensorflow/compiler/BUILD
index 838e202..fba7d6d 100644
--- a/integrations/tensorflow/compiler/BUILD
+++ b/integrations/tensorflow/compiler/BUILD
@@ -21,6 +21,7 @@
cc_library(
name = "tensorflow",
srcs = [
+ "CheckNoTF.cpp",
"LegalizeTF.cpp",
"Passes.cpp",
"PropagateResourceCasts.cpp",
diff --git a/integrations/tensorflow/compiler/CheckNoTF.cpp b/integrations/tensorflow/compiler/CheckNoTF.cpp
new file mode 100644
index 0000000..246f28a
--- /dev/null
+++ b/integrations/tensorflow/compiler/CheckNoTF.cpp
@@ -0,0 +1,93 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "llvm/Support/FormatVariadic.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
+#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
+#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace {
+
+class CheckNoTensorflow : public PassWrapper<CheckNoTensorflow, FunctionPass> {
+ public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<chlo::HloClientDialect, mhlo::MhloDialect,
+ shape::ShapeDialect, StandardOpsDialect>();
+ }
+
+ CheckNoTensorflow() = default;
+ CheckNoTensorflow(const CheckNoTensorflow &) {}
+
+ /// Validates that no TensorFlow frontends ops are in the function.
+ void runOnFunction() override {
+ auto op = getFunction();
+ auto context = op.getContext();
+
+ Dialect *dialect = context->getLoadedDialect("tf");
+ DenseSet<Operation *> illegalOps;
+ op.walk([&](Operation *op) {
+ if (op->getDialect() == dialect) {
+ illegalOps.insert(op);
+ }
+ });
+
+ if (!illegalOps.empty()) {
+ emitLegalizationErrors(op, illegalOps);
+ return signalPassFailure();
+ }
+ }
+
+ // Emits debug information which includes the number of ops of each type which
+ // failed to legalize.
+ void emitLegalizationErrors(Operation *op,
+ const DenseSet<Operation *> &nonlegalizedOps) {
+ // Print op errors for each of the TensorFlow ops that still remain.
+ std::map<StringRef, int> opNameCounts;
+ for (Operation *nonlegalizedOp : nonlegalizedOps) {
+ StringRef opName = nonlegalizedOp->getName().getStringRef();
+ opNameCounts[opName]++;
+ nonlegalizedOp->emitOpError()
+ << ": unlegalized TensorFlow op still exists";
+ }
+
+ std::vector<std::string> errorMessages;
+ errorMessages.reserve(opNameCounts.size());
+ for (const auto &opInfo : opNameCounts) {
+ errorMessages.push_back(
+ llvm::formatv("\t{0} (count: {1})", opInfo.first, opInfo.second));
+ }
+ Location loc = op->getLoc();
+ emitError(loc) << "The following Tensorflow operations still remain: \n"
+ << llvm::join(errorMessages, "\n") << "\n";
+ }
+};
+
+static PassRegistration<CheckNoTensorflow> pass(
+ "iree-check-no-tf", "Check that no TensorFlow frontend ops remain");
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>> createCheckNoTF() {
+ return std::make_unique<CheckNoTensorflow>();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/integrations/tensorflow/compiler/Passes.cpp b/integrations/tensorflow/compiler/Passes.cpp
index b4b6c47..db2e9d9 100644
--- a/integrations/tensorflow/compiler/Passes.cpp
+++ b/integrations/tensorflow/compiler/Passes.cpp
@@ -64,6 +64,11 @@
// - It removes tf_saved_model.semantics from the module, which we can only
// do at the very end.
pm.addPass(createTFSavedModelLowerExportedFunctions());
+
+ ////////////////////////////////////////////////////////////////////////////
+ // Validate that all Tensorflow has been legalized away.
+ ////////////////////////////////////////////////////////////////////////////
+ pm.addPass(createCheckNoTF());
}
static mlir::PassPipelineRegistration<> pipeline(
diff --git a/integrations/tensorflow/compiler/Passes.h b/integrations/tensorflow/compiler/Passes.h
index e3016a3..8b3b219 100644
--- a/integrations/tensorflow/compiler/Passes.h
+++ b/integrations/tensorflow/compiler/Passes.h
@@ -40,6 +40,9 @@
// Push resource casts forward to better propagate resource related shapes.
std::unique_ptr<OperationPass<ModuleOp>> createPropagateResourceCasts();
+// Validates whether any Tensorflow operations remain.
+std::unique_ptr<OperationPass<FuncOp>> createCheckNoTF();
+
// Create a single pipeline that will run all the needed IREE-specific TF import
// passes in the right order.
void createIreeTfImportPipeline(OpPassManager &pm);
diff --git a/integrations/tensorflow/compiler/test/check-no-tf.mlir b/integrations/tensorflow/compiler/test/check-no-tf.mlir
new file mode 100644
index 0000000..a7c49e6
--- /dev/null
+++ b/integrations/tensorflow/compiler/test/check-no-tf.mlir
@@ -0,0 +1,28 @@
+// RUN: iree-tf-opt %s -iree-check-no-tf -split-input-file -verify-diagnostics
+
+// CHECK-LABEL: func @f
+func @f() -> (tensor<i32>) {
+ // CHECK: [[VAL0:%.+]] = mhlo.constant dense<3>
+ %0 = mhlo.constant dense<3> : tensor<i32>
+ return %0 : tensor<i32>
+}
+
+// -----
+
+// expected-error@+3 {{'tf.Const' op : unlegalized TensorFlow op still exists}}
+// expected-error@below {{The following Tensorflow operations still remain}}
+func @f() -> (tensor<i32>) {
+ %0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
+ return %0 : tensor<i32>
+}
+
+// -----
+
+// expected-error@+4 {{'tf.Const' op : unlegalized TensorFlow op still exists}}
+// expected-error@+4 {{'tf.Add' op : unlegalized TensorFlow op still exists}}
+// expected-error@below {{The following Tensorflow operations still remain}}
+func @f(%arg0 : tensor<i32>) -> (tensor<i32>) {
+ %0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
+ %1 = "tf.Add"(%arg0, %0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ return %1 : tensor<i32>
+}
diff --git a/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py b/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py
index 256ce89..aff5dd8 100644
--- a/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py
+++ b/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py
@@ -23,6 +23,7 @@
SAVED_MODEL_IMPORT_PASSES = [
"tf-executor-graph-pruning",
"tf-standard-pipeline",
+ "iree-xla-legalize-tf",
"iree-tf-import-pipeline",
"canonicalize",
]
@@ -114,8 +115,8 @@
# CHECK: attributes
# CHECK-SAME: iree.module.export
# CHECK-SAME: iree.reflection = {abi = "sip", abiv = 1 : i32, sip = "I1!R1!"}
-# CHECK-DAG: [[CONST_2xf32:%.+]] = "tf.Const"() {value = dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32>} : () -> tensor<2xf32>
-# CHECK-DAG: [[CONST_3xf32:%.+]] = "tf.Const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>} : () -> tensor<3xf32>
+# CHECK-DAG: [[CONST_2xf32:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00]>
+# CHECK-DAG: [[CONST_3xf32:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]>
# CHECK-DAG: flow.variable.store [[CONST_2xf32]], @v : tensor<2xf32>
# CHECK-DAG: flow.variable.store [[CONST_3xf32]], @v : tensor<3xf32>
# CHECK: FINISH_TEST
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index bc8a998..e919d3e 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -201,31 +201,11 @@
"tflite": ["mobile_bert_squad_test.py"],
"iree_vmla": ["mobile_bert_squad_test.py"],
"iree_llvmjit": ["mobile_bert_squad_test.py"],
- },
- reference_backend = "tf",
- tags = [
- "external",
- "guitar",
- "manual",
- "no-remote",
- "nokokoro",
- "notap",
- ],
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
-iree_e2e_test_suite(
- name = "mobile_bert_squad_tests_failing",
- size = "enormous",
- backends_to_srcs = {
"iree_vulkan": ["mobile_bert_squad_test.py"],
},
reference_backend = "tf",
tags = [
"external",
- "failing",
"guitar",
"manual",
"no-remote",
diff --git a/iree/base/BUILD b/iree/base/BUILD
index a1b46ba..1cc8c50 100644
--- a/iree/base/BUILD
+++ b/iree/base/BUILD
@@ -295,8 +295,10 @@
srcs = ["logging.cc"],
hdrs = ["logging.h"],
deps = [
+ ":tracing",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/strings:str_format",
],
)
diff --git a/iree/base/CMakeLists.txt b/iree/base/CMakeLists.txt
index 257618b..e6ec69a 100644
--- a/iree/base/CMakeLists.txt
+++ b/iree/base/CMakeLists.txt
@@ -345,8 +345,10 @@
SRCS
"logging.cc"
DEPS
+ ::tracing
absl::core_headers
absl::flags
+ absl::str_format
PUBLIC
)
diff --git a/iree/base/logging.cc b/iree/base/logging.cc
index 51ce60b..3b31a9a 100644
--- a/iree/base/logging.cc
+++ b/iree/base/logging.cc
@@ -17,6 +17,8 @@
#include <string>
#include "absl/flags/flag.h"
+#include "absl/strings/str_format.h"
+#include "iree/base/tracing.h"
ABSL_FLAG(int, iree_minloglevel, 0,
"Minimum logging level. 0 = INFO and above.");
@@ -88,6 +90,19 @@
// TODO(scotttodd): Include current system time
fprintf(stderr, "%c %s:%d] %s\n", "IWEF"[severity_], file_name_, line_,
str().c_str());
+
+#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_LOG_MESSAGES
+ constexpr int kLevelColors[4] = {
+ IREE_TRACING_MESSAGE_LEVEL_INFO, // INFO
+ IREE_TRACING_MESSAGE_LEVEL_WARNING, // WARNING
+ IREE_TRACING_MESSAGE_LEVEL_ERROR, // ERROR
+ IREE_TRACING_MESSAGE_LEVEL_ERROR, // FATAL
+ };
+ std::string message =
+ absl::StrFormat("%s:%d] %s\n", file_name_, line_, str().c_str());
+ IREE_TRACE_MESSAGE_DYNAMIC_COLORED(kLevelColors[severity_], message.c_str(),
+ message.size());
+#endif // IREE_TRACING_FEATURES& IREE_TRACING_FEATURE_LOG_MESSAGES
}
LogMessageFatal::LogMessageFatal(const char* file, int line)
diff --git a/iree/base/tracing.h b/iree/base/tracing.h
index a3226ad..98db68e 100644
--- a/iree/base/tracing.h
+++ b/iree/base/tracing.h
@@ -77,6 +77,10 @@
// Tracy UI.
#define IREE_TRACING_FEATURE_SLOW_LOCKS (1 << 5)
+// Forwards log messages to traces, which will be visible under "Messages" in
+// the Tracy UI.
+#define IREE_TRACING_FEATURE_LOG_MESSAGES (1 << 6)
+
#if !defined(IREE_TRACING_MAX_CALLSTACK_DEPTH)
// Tracing functions that capture stack traces will only capture up to N frames.
// The overhead for stack walking scales linearly with the number of frames
@@ -95,7 +99,7 @@
// overridden it with more specific settings.
//
// IREE_TRACING_MODE = 0: tracing disabled
-// IREE_TRACING_MODE = 1: instrumentation and basic statistics
+// IREE_TRACING_MODE = 1: instrumentation, log messages, and basic statistics
// IREE_TRACING_MODE = 2: same as 1 with added allocation tracking
// IREE_TRACING_MODE = 3: same as 2 with callstacks for allocations
// IREE_TRACING_MODE = 4: same as 3 with callstacks for all instrumentation
@@ -105,20 +109,23 @@
#undef IREE_TRACING_MAX_CALLSTACK_DEPTH
#define IREE_TRACING_MAX_CALLSTACK_DEPTH 0
#elif defined(IREE_TRACING_MODE) && IREE_TRACING_MODE == 2
-#define IREE_TRACING_FEATURES \
- (IREE_TRACING_FEATURE_INSTRUMENTATION | \
- IREE_TRACING_FEATURE_ALLOCATION_TRACKING)
-#elif defined(IREE_TRACING_MODE) && IREE_TRACING_MODE == 3
#define IREE_TRACING_FEATURES \
(IREE_TRACING_FEATURE_INSTRUMENTATION | \
IREE_TRACING_FEATURE_ALLOCATION_TRACKING | \
- IREE_TRACING_FEATURE_ALLOCATION_CALLSTACKS)
+ IREE_TRACING_FEATURE_LOG_MESSAGES)
+#elif defined(IREE_TRACING_MODE) && IREE_TRACING_MODE == 3
+#define IREE_TRACING_FEATURES \
+ (IREE_TRACING_FEATURE_INSTRUMENTATION | \
+ IREE_TRACING_FEATURE_ALLOCATION_TRACKING | \
+ IREE_TRACING_FEATURE_ALLOCATION_CALLSTACKS | \
+ IREE_TRACING_FEATURE_LOG_MESSAGES)
#elif defined(IREE_TRACING_MODE) && IREE_TRACING_MODE >= 4
#define IREE_TRACING_FEATURES \
(IREE_TRACING_FEATURE_INSTRUMENTATION | \
IREE_TRACING_FEATURE_INSTRUMENTATION_CALLSTACKS | \
IREE_TRACING_FEATURE_ALLOCATION_TRACKING | \
- IREE_TRACING_FEATURE_ALLOCATION_CALLSTACKS)
+ IREE_TRACING_FEATURE_ALLOCATION_CALLSTACKS | \
+ IREE_TRACING_FEATURE_LOG_MESSAGES)
#else
#define IREE_TRACING_FEATURES 0
#endif // IREE_TRACING_MODE
@@ -336,11 +343,21 @@
// The message text must be a compile-time string literal.
#define IREE_TRACE_MESSAGE(level, value_literal) \
___tracy_emit_messageLC(value_literal, IREE_TRACING_MESSAGE_LEVEL_##level, 0)
+// Logs a message with the given color to the trace.
+// Standard colors are defined as IREE_TRACING_MESSAGE_LEVEL_* values.
+// The message text must be a compile-time string literal.
+#define IREE_TRACE_MESSAGE_COLORED(color, value_literal) \
+ ___tracy_emit_messageLC(value_literal, color, 0)
// Logs a dynamically-allocated message at the given logging level to the trace.
// The string |value| will be copied into the trace buffer.
#define IREE_TRACE_MESSAGE_DYNAMIC(level, value, value_length) \
___tracy_emit_messageC(value, value_length, \
IREE_TRACING_MESSAGE_LEVEL_##level, 0)
+// Logs a dynamically-allocated message with the given color to the trace.
+// Standard colors are defined as IREE_TRACING_MESSAGE_LEVEL_* values.
+// The string |value| will be copied into the trace buffer.
+#define IREE_TRACE_MESSAGE_DYNAMIC_COLORED(color, value, value_length) \
+ ___tracy_emit_messageC(value, value_length, color, 0)
// Utilities:
#define IREE_TRACE_IMPL_GET_VARIADIC_HELPER_(_1, _2, _3, NAME, ...) NAME
@@ -370,7 +387,9 @@
#define IREE_TRACE_FRAME_MARK_BEGIN_NAMED(name_literal)
#define IREE_TRACE_FRAME_MARK_END_NAMED(name_literal)
#define IREE_TRACE_MESSAGE(level, value_literal)
+#define IREE_TRACE_MESSAGE_COLORED(color, value_literal)
#define IREE_TRACE_MESSAGE_DYNAMIC(level, value, value_length)
+#define IREE_TRACE_MESSAGE_DYNAMIC_COLORED(color, value, value_length)
#endif // IREE_TRACING_FEATURE_INSTRUMENTATION
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
index 7a788b1..93b8809 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
@@ -130,11 +130,8 @@
spirv::GlobalVariableOp getOrInsertResourceVariable(Location loc, Type type,
unsigned set,
unsigned binding,
- StringRef funcName,
Block &block) {
- auto name =
- llvm::formatv("__resource_var_{0}_{1}_{2}__", set, binding, funcName)
- .str();
+ auto name = llvm::formatv("__resource_var_{0}_{1}__", set, binding).str();
for (auto varOp : block.getOps<spirv::GlobalVariableOp>()) {
if (varOp.sym_name() == name) return varOp;
}
@@ -371,10 +368,9 @@
SymbolTable::lookupNearestSymbolFrom(
phOp, phOp.getAttrOfType<SymbolRefAttr>("binding")));
- StringRef funcName = phOp.getParentOfType<spirv::FuncOp>().getName();
- spirv::GlobalVariableOp varOp = getOrInsertResourceVariable(
- phOp.getLoc(), convertedType, bindingOp.set(), bindingOp.binding(),
- funcName, *moduleOp.getBody());
+ spirv::GlobalVariableOp varOp =
+ getOrInsertResourceVariable(phOp.getLoc(), convertedType, bindingOp.set(),
+ bindingOp.binding(), *moduleOp.getBody());
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(phOp, varOp);
return success();
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
index bb6faf8..d751065 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
@@ -63,6 +63,11 @@
llvm::cl::desc(
"Enable use of vectorization in SPIR-V code generation pipeline"),
llvm::cl::init(false)};
+ Option<bool> useVectorPass{
+ *this, "use-vector-pass",
+ llvm::cl::desc("Enable use of Linalg vectorization in SPIR-V code "
+ "generation pipeline"),
+ llvm::cl::init(false)};
};
static void addLinalgToSPIRVPasses(OpPassManager &pm,
@@ -94,7 +99,9 @@
//===--------------------------------------------------------------------===//
pm.addPass(createSplitDispatchFunctionPass());
pm.addPass(createLinalgTileAndFusePass(options));
- pm.addPass(createLoadStoreVectorizationPass());
+ if (options.useVectorPass) {
+ pm.addPass(createLoadStoreVectorizationPass());
+ }
pm.addPass(createCanonicalizerPass());
//===--------------------------------------------------------------------===//
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
index 502db6a..80805e2 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
@@ -29,6 +29,7 @@
SmallVector<int64_t, 3> tileSizes = {};
bool useWorkgroupMemory = false;
bool useVectorization = false;
+ bool useVectorPass = false;
};
/// Pass to initialize the function that computes the number of workgroups for
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
index 7721f41..eb2a32e 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
@@ -64,6 +64,7 @@
private:
void tileAndVectorizeLinalgCopy(FuncOp funcOp, MLIRContext *context);
+ void lowerVectorOps(FuncOp funcOp, MLIRContext *context);
};
// Common class for all vector to GPU patterns.
@@ -149,8 +150,7 @@
loc, rewriter.getIndexType(), rewriter.getStringAttr("x"));
Value index = rewriter.create<AddIOp>(loc, ThreadIndex, indices.back());
indices.back() = index;
- rewriter.create<StoreOp>(op.getLoc(), operands[0], operands[1], indices);
- rewriter.eraseOp(op);
+ rewriter.replaceOpWithNewOp<StoreOp>(op, operands[0], operands[1], indices);
return success();
}
};
@@ -204,11 +204,141 @@
applyPatternsAndFoldGreedily(funcOp, vectorizationPatterns);
}
+// Convert vector transfer_read to a load if possible. This is the case only if
+// the element type of the memref matches the element type we want to load.
+class VectorTransferReadToLoad
+ : public OpRewritePattern<vector::TransferReadOp> {
+ public:
+ using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getVectorType().getNumElements() != 1 ||
+ op.getMemRefType().getElementType() !=
+ op.getVectorType().getElementType()) {
+ return failure();
+ }
+ auto loc = op.getLoc();
+ Value newOp = rewriter.create<LoadOp>(loc, op.memref(), op.indices());
+ newOp =
+ rewriter.create<vector::BroadcastOp>(loc, op.getVectorType(), newOp);
+ rewriter.replaceOp(op, newOp);
+ return success();
+ }
+};
+
+// Convert vector transfer_write to a store if possible. This is the case only
+// if the element type of the memref matches the element type we want to store.
+class VectorTransferWriteToStore
+ : public OpRewritePattern<vector::TransferWriteOp> {
+ public:
+ using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferWriteOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getVectorType().getNumElements() != 1 ||
+ op.getMemRefType().getElementType() !=
+ op.getVectorType().getElementType()) {
+ return failure();
+ }
+ auto loc = op.getLoc();
+ SmallVector<int64_t, 2> zero(op.getVectorType().getRank(), 0);
+ Value scalarValue =
+ rewriter.create<vector::ExtractOp>(loc, op.vector(), zero);
+ rewriter.create<StoreOp>(loc, scalarValue, op.memref(), op.indices());
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+// Lower vector contract to a single scalar or vector mulf+addf. Insert casts to
+// convert from 2D vector to 1D vector or scalar.
+class VectorContractLowering : public OpRewritePattern<vector::ContractionOp> {
+ public:
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override {
+ auto iteratorTypes = op.iterator_types().getValue();
+ if (iteratorTypes.size() != 3 || !isParallelIterator(iteratorTypes[0]) ||
+ !isParallelIterator(iteratorTypes[1]) ||
+ !isReductionIterator(iteratorTypes[2]) ||
+ !isRowMajorMatmul(op.indexing_maps())) {
+ return failure();
+ }
+ if (op.getLhsType().getNumElements() != 1) return failure();
+ unsigned vecSize = op.getAccType().cast<VectorType>().getNumElements();
+ if (!(vecSize >= 1 && vecSize <= 4)) return failure();
+ auto loc = op.getLoc();
+ VectorType vecType = VectorType::get(
+ vecSize, op.getResultType().cast<VectorType>().getElementType());
+ std::array<int64_t, 2> zero = {0, 0};
+ Value lhs = rewriter.create<vector::ExtractOp>(loc, op.lhs(), zero);
+ Value rhs, acc;
+ if (vecSize == 1) {
+ rhs = rewriter.create<vector::ExtractOp>(loc, op.rhs(), zero);
+ acc = rewriter.create<vector::ExtractOp>(loc, op.acc(), zero);
+ } else {
+ lhs = rewriter.create<vector::BroadcastOp>(loc, vecType, lhs);
+ rhs = rewriter.create<vector::ShapeCastOp>(loc, vecType, op.rhs());
+ acc = rewriter.create<vector::ShapeCastOp>(loc, vecType, op.acc());
+ }
+ Value newOp = rewriter.create<MulFOp>(loc, lhs, rhs);
+ newOp = rewriter.create<AddFOp>(loc, newOp, acc);
+ if (vecSize == 1)
+ newOp =
+ rewriter.create<vector::BroadcastOp>(loc, op.getResultType(), newOp);
+ else
+ newOp =
+ rewriter.create<vector::ShapeCastOp>(loc, op.getResultType(), newOp);
+ rewriter.replaceOp(op, newOp);
+ return success();
+ }
+};
+
+// Lower ExtractStridedSliceOp to an ExtractOp instruction that can be natively
+// converted to SPIR-V. Add a BroadcastOp to keep the type consistent, we expect
+// the Broadcast to be removed by canonicalization.
+class ExtractStridedLowering
+ : public OpRewritePattern<vector::ExtractStridedSliceOp> {
+ public:
+ using OpRewritePattern<vector::ExtractStridedSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op,
+ PatternRewriter &rewriter) const override {
+ // Only handle cases extracting a degenerated vector so that we can generate
+ // an extractOp with scalar destination.
+ if (op.getResult().getType().cast<VectorType>().getNumElements() != 1)
+ return failure();
+ auto loc = op.getLoc();
+ SmallVector<int64_t, 4> offsets = llvm::to_vector<4>(
+ llvm::map_range(op.offsets().getAsRange<IntegerAttr>(),
+ [](IntegerAttr attr) { return attr.getInt(); }));
+ offsets.resize(op.getVectorType().getRank(), 0);
+ Value newOp = rewriter.create<vector::ExtractOp>(loc, op.vector(), offsets);
+ newOp = rewriter.create<vector::BroadcastOp>(loc, op.getResult().getType(),
+ newOp);
+ rewriter.replaceOp(op, newOp);
+ return success();
+ }
+};
+
+// Lower vector ops to instructions that can be later converted to SPIR-V.
+void ConvertVectorToGPUPass::lowerVectorOps(FuncOp funcOp,
+ MLIRContext *context) {
+ OwningRewritePatternList patterns;
+ patterns.insert<VectorContractLowering, VectorTransferReadToLoad,
+ VectorTransferWriteToStore, ExtractStridedLowering>(context);
+ applyPatternsAndFoldGreedily(funcOp, patterns);
+}
+
void ConvertVectorToGPUPass::runOnOperation() {
MLIRContext *context = &getContext();
FuncOp funcOp = getOperation();
tileAndVectorizeLinalgCopy(funcOp, context);
+ lowerVectorOps(funcOp, context);
+
auto &cooperativeMatrixAnalysis = getAnalysis<CooperativeMatrixAnalysis>();
OwningRewritePatternList patterns;
patterns.insert<UnaryAndBinaryOpPattern<AddFOp>, VectorTransferReadConversion,
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
index f6bbc40..ac83d70 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
@@ -23,12 +23,12 @@
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- // CHECK: spv.globalVariable @__resource_var_3_4_resource_variable__ bind(3, 4) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
- // CHECK: spv.globalVariable @__resource_var_1_2_resource_variable__ bind(1, 2) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+ // CHECK: spv.globalVariable @__resource_var_3_4__ bind(3, 4) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+ // CHECK: spv.globalVariable @__resource_var_1_2__ bind(1, 2) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
// CHECK: spv.func @resource_variable()
func @resource_variable() {
- // CHECK: spv._address_of @__resource_var_1_2_resource_variable__ : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
- // CHECK: spv._address_of @__resource_var_3_4_resource_variable__ : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+ // CHECK: spv._address_of @__resource_var_1_2__ : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+ // CHECK: spv._address_of @__resource_var_3_4__ : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
%0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4x4xf32>
%1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4x4xf32>
return
@@ -42,24 +42,6 @@
// -----
-module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>, vkspv.entry_point_schedule = ["ForwardPass_ex_dispatch_8_dispatch_0", "ForwardPass_ex_dispatch_8_dispatch_1"]} {
- // CHECK: spv.globalVariable @__resource_var_0_0_dispatch_0__ bind(0, 0) : !spv.ptr<!spv.struct<(!spv.array<1 x vector<4xf32>, stride=16> [0])>, StorageBuffer>
- // CHECK: spv.globalVariable @__resource_var_0_0_dispatch_1__ bind(0, 0) : !spv.ptr<!spv.struct<(!spv.array<4 x f32, stride=4> [0])>, StorageBuffer>
- func @dispatch_1() attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}, vkspv.num_workgroups_fn = @ForwardPass_ex_dispatch_8_dispatch_1__num_workgroups__} {
- %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<4xf32>
- return
- }
- func @dispatch_0() attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}, vkspv.num_workgroups_fn = @ForwardPass_ex_dispatch_8_dispatch_0__num_workgroups__} {
- %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<1xvector<4xf32>>
- return
- }
- hal.interface @legacy_io attributes {sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- }
-}
-
-// -----
-
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/vector_to_gpu.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/vector_to_gpu.mlir
index 30d94c5..99ec424 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/vector_to_gpu.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/vector_to_gpu.mlir
@@ -59,3 +59,75 @@
// CHECK: %[[LOAD:.+]] = vector.transfer_read %[[SVs]][%c0, %c0], %cst {{.*}} : memref<1x4xf32, {{.*}}>, vector<1x4xf32>
// CHECK: vector.transfer_write %[[LOAD]], %[[SVd]][%[[C0]], %[[C0]]] {{.*}} : vector<1x4xf32>, memref<1x4xf32
}
+
+// -----
+
+module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @transfer_ops(%arg0: memref<32x32xf32>, %arg1 : vector<1x1xf32>) -> vector<1x1xf32> attributes {spv.entry_point_abi = {local_size = dense<[128, 1, 1]> : vector<3xi32>}} {
+ %c0 = constant 0 : index
+ %cst = constant 0.0 : f32
+ %0 = vector.transfer_read %arg0[%c0, %c0], %cst : memref<32x32xf32>, vector<1x1xf32>
+ vector.transfer_write %arg1, %arg0[%c0, %c0] : vector<1x1xf32>, memref<32x32xf32>
+ return %0 : vector<1x1xf32>
+ }
+ // CHECK-LABEL: func @transfer_ops
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<32x32xf32>, %[[ARG1:.*]]: vector<1x1xf32>
+ // CHECK: %[[C0:.*]] = constant 0 : index
+ // CHECK: %[[LOAD:.*]] = load %[[ARG0]][%[[C0]], %[[C0]]] : memref<32x32xf32>
+ // CHECK: %[[B:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<1x1xf32>
+ // CHECK: %[[EXT:.*]] = vector.extract %[[ARG1]][0, 0] : vector<1x1xf32>
+ // CHECK: store %[[EXT]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<32x32xf32>
+ // CHECK: return %[[B]] : vector<1x1xf32>
+}
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @contract_ops(%arg0 : vector<1x1xf32>, %arg1 : vector<1x4xf32>,
+ %arg2 : vector<1x4xf32>, %arg3 : vector<1x1xf32>,
+ %arg4 : vector<1x1xf32>) -> (vector<1x1xf32>, vector<1x4xf32>) attributes {spv.entry_point_abi = {local_size = dense<[128, 1, 1]> : vector<3xi32>}} {
+ %0 = vector.contract {indexing_maps = [#map0, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"]} %arg0, %arg3, %arg4
+ : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32>
+ %1 = vector.contract {indexing_maps = [#map0, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"]} %arg0, %arg1, %arg2
+ : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+ return %0, %1 : vector<1x1xf32>, vector<1x4xf32>
+ }
+ // CHECK-LABEL: func @contract_ops
+ // CHECK-SAME: (%[[ARG0:.*]]: vector<1x1xf32>, %[[ARG1:.*]]: vector<1x4xf32>, %[[ARG2:.*]]: vector<1x4xf32>, %[[ARG3:.*]]: vector<1x1xf32>, %[[ARG4:.*]]: vector<1x1xf32>)
+ // CHECK: %[[A:.*]] = vector.extract %[[ARG0]][0, 0] : vector<1x1xf32>
+ // CHECK: %[[B:.*]] = vector.extract %[[ARG3]][0, 0] : vector<1x1xf32>
+ // CHECK: %[[C:.*]] = vector.extract %[[ARG4]][0, 0] : vector<1x1xf32>
+ // CHECK: %[[MUL:.*]] = mulf %[[A]], %[[B]] : f32
+ // CHECK: %[[ADD:.*]] = addf %[[MUL]], %[[C]] : f32
+ // CHECK: %[[R0:.*]] = vector.broadcast %[[ADD]] : f32 to vector<1x1xf32>
+ // CHECK: %[[A:.*]] = vector.extract %[[ARG0]][0, 0] : vector<1x1xf32>
+ // CHECK: %[[VA:.*]] = vector.broadcast %[[A]] : f32 to vector<4xf32>
+ // CHECK: %[[VB:.*]] = vector.shape_cast %[[ARG1]] : vector<1x4xf32> to vector<4xf32>
+ // CHECK: %[[VC:.*]] = vector.shape_cast %[[ARG2]] : vector<1x4xf32> to vector<4xf32>
+ // CHECK: %[[VMUL:.*]] = mulf %[[VA]], %[[VB]] : vector<4xf32>
+ // CHECK: %[[VADD:.*]] = addf %[[VMUL]], %[[VC]] : vector<4xf32>
+ // CHECK: %[[R1:.*]] = vector.shape_cast %[[VADD]] : vector<4xf32> to vector<1x4xf32>
+ // CHECK: return %[[R0]], %[[R1]] : vector<1x1xf32>, vector<1x4xf32>
+}
+
+// -----
+
+module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @extract(%arg0 : vector<1x4xf32>) -> vector<1x1xf32> attributes {spv.entry_point_abi = {local_size = dense<[128, 1, 1]> : vector<3xi32>}} {
+ %0 = vector.extract_strided_slice %arg0
+ {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]}
+ : vector<1x4xf32> to vector<1x1xf32>
+ return %0 : vector<1x1xf32>
+ }
+ // CHECK-LABEL: func @extract
+ // CHECK-SAME: (%[[ARG0:.*]]: vector<1x4xf32>
+ // CHECK: %[[A:.*]] = vector.extract %[[ARG0]][0, 2] : vector<1x4xf32>
+ // CHECK: %[[B:.*]] = vector.broadcast %[[A]] : f32 to vector<1x1xf32>
+ // CHECK: return %[[B]] : vector<1x1xf32>
+}
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
index ea4c70c..91de450c 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -48,6 +48,12 @@
// llvm::cl::OptionCategory halVulkanSPIRVOptionsCategory(
// "IREE Vulkan/SPIR-V backend options");
+ static llvm::cl::opt<bool> clUseVectorPass(
+ "iree-spirv-use-vector-pass",
+ llvm::cl::desc(
+ "Enable use of Linalg vectorization in SPIR-V code generation"),
+ llvm::cl::init(false));
+
static llvm::cl::opt<bool> clUseWorkgroupMemory(
"iree-spirv-use-workgroup-memory",
llvm::cl::desc(
@@ -81,6 +87,7 @@
targetOptions.codegenOptions.tileSizes.assign(clTileSizes.begin(),
clTileSizes.end());
targetOptions.codegenOptions.useWorkgroupMemory = clUseWorkgroupMemory;
+ targetOptions.codegenOptions.useVectorPass = clUseVectorPass;
if (!clVulkanTargetEnv.empty()) {
targetOptions.vulkanTargetEnv = clVulkanTargetEnv;
} else {
diff --git a/iree/test/e2e/vulkan_specific/BUILD b/iree/test/e2e/vulkan_specific/BUILD
index 9686098..f4741ac 100644
--- a/iree/test/e2e/vulkan_specific/BUILD
+++ b/iree/test/e2e/vulkan_specific/BUILD
@@ -45,3 +45,15 @@
driver = "vulkan",
target_backend = "vulkan-spirv",
)
+
+iree_check_single_backend_test_suite(
+ name = "check_vulkan-spirv_vulkan_vector",
+ srcs = [
+ "compare.mlir",
+ "log_plus_one.mlir",
+ "pw_add_multiwg.mlir",
+ ],
+ compiler_flags = ["-iree-spirv-use-vector-pass"],
+ driver = "vulkan",
+ target_backend = "vulkan-spirv",
+)
diff --git a/iree/test/e2e/vulkan_specific/CMakeLists.txt b/iree/test/e2e/vulkan_specific/CMakeLists.txt
index 4260b07..d5bd481 100644
--- a/iree/test/e2e/vulkan_specific/CMakeLists.txt
+++ b/iree/test/e2e/vulkan_specific/CMakeLists.txt
@@ -41,3 +41,18 @@
COMPILER_FLAGS
"-iree-spirv-use-workgroup-memory"
)
+
+iree_check_single_backend_test_suite(
+ NAME
+ check_vulkan-spirv_vulkan_vector
+ SRCS
+ "compare.mlir"
+ "log_plus_one.mlir"
+ "pw_add_multiwg.mlir"
+ TARGET_BACKEND
+ vulkan-spirv
+ DRIVER
+ vulkan
+ COMPILER_FLAGS
+ "-iree-spirv-use-vector-pass"
+)
diff --git a/iree/test/e2e/vulkan_specific/compare.mlir b/iree/test/e2e/vulkan_specific/compare.mlir
new file mode 100644
index 0000000..099670a
--- /dev/null
+++ b/iree/test/e2e/vulkan_specific/compare.mlir
@@ -0,0 +1,164 @@
+func @compare_tensor() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+ %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+ check.expect_eq_const(%output, dense<[0, 1, 0, 1]> : tensor<4xi8>) : tensor<4xi8>
+ return
+}
+
+func @compare_scalar() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1> : tensor<i32>
+ %rhs = iree.unfoldable_constant dense<5> : tensor<i32>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+ check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+ return
+}
+
+func @compare_i8() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1> : tensor<i8>
+ %rhs = iree.unfoldable_constant dense<5> : tensor<i8>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<i8>, tensor<i8>) -> tensor<i1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+ check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+ return
+}
+
+func @compare_i16() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1> : tensor<i16>
+ %rhs = iree.unfoldable_constant dense<5> : tensor<i16>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<i16>, tensor<i16>) -> tensor<i1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+ check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+ return
+}
+
+func @compare_i32() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1> : tensor<i32>
+ %rhs = iree.unfoldable_constant dense<5> : tensor<i32>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+ check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+ return
+}
+
+func @compare_i64() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1> : tensor<i64>
+ %rhs = iree.unfoldable_constant dense<5> : tensor<i64>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+ check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+ return
+}
+
+func @compare_f32() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1.0> : tensor<f32>
+ %rhs = iree.unfoldable_constant dense<5.0> : tensor<f32>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+ check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+ return
+}
+
+func @compare_f64() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1.0> : tensor<f64>
+ %rhs = iree.unfoldable_constant dense<5.0> : tensor<f64>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<f64>, tensor<f64>) -> tensor<i1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+ check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+ return
+}
+
+func @compare_tensor_odd_length() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<[1, 2, 7]> : tensor<3xi32>
+ %rhs = iree.unfoldable_constant dense<[5, 2, 3]> : tensor<3xi32>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<3xi8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<3xi8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<3xi1>, tensor<3xi8>, tensor<3xi8>) -> tensor<3xi8>
+ check.expect_eq_const(%output, dense<[0, 1, 0]> : tensor<3xi8>) : tensor<3xi8>
+ return
+}
+
+func @compare_eq() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+ %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+ check.expect_eq_const(%output, dense<[0, 1, 0, 1]> : tensor<4xi8>) : tensor<4xi8>
+ return
+}
+
+func @compare_ne() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+ %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+ check.expect_eq_const(%output, dense<[1, 0, 1, 0]> : tensor<4xi8>) : tensor<4xi8>
+ return
+}
+
+func @compare_lt() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+ %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+ check.expect_eq_const(%output, dense<[1, 0, 0, 0]> : tensor<4xi8>) : tensor<4xi8>
+ return
+}
+
+func @compare_le() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+ %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+ check.expect_eq_const(%output, dense<[1, 1, 0, 1]> : tensor<4xi8>) : tensor<4xi8>
+ return
+}
+
+func @compare_gt() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+ %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+ check.expect_eq_const(%output, dense<[0, 0, 1, 0]> : tensor<4xi8>) : tensor<4xi8>
+ return
+}
+
+func @compare_ge() attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+ %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+ %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+ %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+ %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+ check.expect_eq_const(%output, dense<[0, 1, 1, 1]> : tensor<4xi8>) : tensor<4xi8>
+ return
+}
diff --git a/iree/test/e2e/vulkan_specific/log_plus_one.mlir b/iree/test/e2e/vulkan_specific/log_plus_one.mlir
new file mode 100644
index 0000000..3bba4a9
--- /dev/null
+++ b/iree/test/e2e/vulkan_specific/log_plus_one.mlir
@@ -0,0 +1,6 @@
+func @log_plus_one() attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<[0.0, 0.5, 1.0, 5.0]> : tensor<4xf32>
+ %result = "mhlo.log_plus_one"(%input) : (tensor<4xf32>) -> tensor<4xf32>
+ check.expect_almost_eq_const(%result, dense<[0.0, 0.4054651, 0.6931472, 1.7917595]> : tensor<4xf32>) : tensor<4xf32>
+ return
+}
diff --git a/scripts/get_e2e_artifacts.py b/scripts/get_e2e_artifacts.py
index 5cb6f8e..7ed2691 100755
--- a/scripts/get_e2e_artifacts.py
+++ b/scripts/get_e2e_artifacts.py
@@ -15,8 +15,12 @@
# limitations under the License.
"""Runs all E2E TensorFlow tests and extracts their benchmarking artifacts.
-Example usage:
- python3 get_e2e_artifacts.py
+Example usages:
+ # Run all test suites and collect their artifacts:
+ python3 ./scripts/get_e2e_artifacts.py
+
+ # Run the e2e_tests test suite and collect its artifacts:
+ python3 ./scripts/get_e2e_artifacts.py --test_suites=e2e_tests
"""
import fileinput
@@ -39,8 +43,10 @@
'//integrations/tensorflow/e2e:mobile_bert_squad_tests',
'keras_tests':
'//integrations/tensorflow/e2e/keras:keras_tests',
- 'vision_external_tests':
- '//integrations/tensorflow/e2e/keras:vision_external_tests',
+ 'imagenet_external_tests':
+ '//integrations/tensorflow/e2e/keras:imagenet_external_tests',
+ 'slim_vision_tests':
+ '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests',
}
SUITES_HELP = [f'`{name}`' for name in SUITE_NAME_TO_TARGET]
SUITES_HELP = f'{", ".join(SUITES_HELP[:-1])} and {SUITES_HELP[-1]}'
@@ -118,6 +124,11 @@
paths_to_tests: Dict[str, str]):
"""Unzips all of the benchmarking artifacts for a given test and backend."""
outputs = os.path.join(test_path, 'test.outputs', 'outputs.zip')
+ if FLAGS.dry_run and not os.path.exists(outputs):
+ # The artifacts may or may not be present on disk during a dry run. If they
+ # are then we want to collision check them, but if they aren't that's fine.
+ return
+
archive = zipfile.ZipFile(outputs)
# Filter out directory names.
filenames = [name for name in archive.namelist() if name[-1] != os.sep]
@@ -139,17 +150,21 @@
# Convert test suite shorthands to full test suite targets.
test_suites = [SUITE_NAME_TO_TARGET[suite] for suite in FLAGS.test_suites]
+ if FLAGS.run_test_suites:
+ # Use bazel test to execute all of the test suites in parallel.
+ command = [
+ 'bazel', 'test', *test_suites, '--color=yes',
+ '--test_arg=--get_saved_model'
+ ]
+ print(f'Running: `{" ".join(command)}`')
+ if not FLAGS.dry_run:
+ subprocess.check_call(command)
+ print()
+
written_paths = set()
paths_to_tests = dict()
for test_suite in test_suites:
- if FLAGS.run_test_suites and not FLAGS.dry_run:
- subprocess.check_call([
- 'bazel', 'test', test_suite, '--color=yes',
- '--test_arg=--get_saved_model'
- ])
- print()
-
# Extract all of the artifacts for this test suite.
test_paths, test_names = get_test_paths_and_names(test_suite)
for i, (test_path, test_name) in enumerate(zip(test_paths, test_names)):