Merge pull request #3594 from rsuderman:main-to-google

PiperOrigin-RevId: 338738884
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index d5d5bb0..7e91cc7 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -50,8 +50,8 @@
         uses: actions/checkout@v2
       - name: Setting up python
         uses: actions/setup-python@v2
-        # We have to explicitly fetch the base branch as well
       - name: Fetching Base Branch
+        # We have to explicitly fetch the base branch as well
         run: git fetch --no-tags --prune --depth=1 origin "${BASE_REF?}:${BASE_REF?}"
       - name: Install yapf
         run: python3 -m pip install yapf
@@ -65,6 +65,26 @@
           printf "You can fix the lint errors above by running\n"
           printf "  git diff -U0 "${BASE_REF?}" | python3 third_party/format_diff/format_diff.py yapf -i\n"
 
+  pytype:
+    runs-on: ubuntu-18.04
+    env:
+      BASE_REF: ${{ github.base_ref }}
+    steps:
+      - name: Checking out repository
+        uses: actions/checkout@v2
+      - name: Setting up python
+        uses: actions/setup-python@v2
+        with:
+          # Pytype does not support python3.9, which this action defaults to.
+          python-version: '3.8'
+      - name: Fetching Base Branch
+        # We have to explicitly fetch the base branch as well
+        run: git fetch --no-tags --prune --depth=1 origin "${BASE_REF?}:${BASE_REF?}"
+      - name: Install pytype
+        run: python3 -m pip install pytype
+      - name: Run pytype on changed files
+        run: ./build_tools/pytype/check_diff.sh "${BASE_REF?}"
+
   clang-format:
     runs-on: ubuntu-18.04
     env:
@@ -79,8 +99,8 @@
           chmod +x /tmp/git-clang-format
       - name: Checking out repository
         uses: actions/checkout@v2
-        # We have to explicitly fetch the base branch as well
       - name: Fetching Base Branch
+        # We have to explicitly fetch the base branch as well
         run: git fetch --no-tags --prune --depth=1 origin "${BASE_REF?}:${BASE_REF?}"
       - name: Running clang-format on changed source files
         run: |
@@ -102,8 +122,8 @@
     steps:
       - name: Checking out repository
         uses: actions/checkout@v2
-      # We have to explicitly fetch the base branch as well
       - name: Fetching Base Branch
+        # We have to explicitly fetch the base branch as well
         run: git fetch --no-tags --prune --depth=1 origin "${BASE_REF?}:${BASE_REF?}"
       - name: Checking tabs
         run: ./scripts/check_tabs.sh "${BASE_REF?}"
diff --git a/.gitignore b/.gitignore
index 96a5b66..b0a8cbc 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,6 +1,7 @@
 # Python
 *.pyc
 **/.ipynb_checkpoints/
+.pytype/
 
 # Visual Studio files
 .vs/
diff --git a/build_tools/pytype/check_diff.sh b/build_tools/pytype/check_diff.sh
new file mode 100755
index 0000000..2b9851a
--- /dev/null
+++ b/build_tools/pytype/check_diff.sh
@@ -0,0 +1,96 @@
+#!/bin/bash
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Uses git diff to run pytype on changed files.
+# Example Usage:
+#   Defaults to comparing against 'main'.
+#     ./build_tools/pytype/check_diff.sh
+#   A specific branch can be specified.
+#     ./build_tools/pytype/check_diff.sh  google
+#   Or all python files outside of './third_party/' can be checked.
+#     ./build_tools/pytype/check_diff.sh  all
+
+DIFF_TARGET="${1:-main}"
+echo "Running pycheck against '${DIFF_TARGET?}'"
+
+if [[ "${DIFF_TARGET?}" = "all" ]]; then
+  FILES=$(find -name "*\.py" -not -path "./third_party/*")
+else
+  FILES=$(git diff --name-only "${DIFF_TARGET?}" | grep '.*\.py')
+fi
+
+
+# We seperate the python files into multiple pytype calls because otherwise
+# Ninja gets confused. See https://github.com/google/pytype/issues/198
+BASE=$(echo "${FILES?}" | grep -vP '^(\./)?integrations/*')
+IREE_TF=$(echo "${FILES?}" | \
+          grep -P '^(\./)?integrations/tensorflow/bindings/python/pyiree/tf/.*')
+IREE_XLA=$(echo "${FILES?}" | \
+           grep -P '^(\./)?integrations/tensorflow/bindings/python/pyiree/xla/.*')
+COMPILER=$(echo "${FILES?}" | \
+           grep -P '^(\./)?integrations/tensorflow/compiler/.*')
+E2E=$(echo "${FILES?}" | grep -P '^(\./)?integrations/tensorflow/e2e/.*')
+
+function check_files() {
+  # $1: previous return code
+  # $2...: files to check
+  if [[ -z "${@:2}" ]]; then
+    echo "No files to check."
+    echo
+    return "${1?}"
+  fi
+
+  # We disable import-error because pytype doesn't have access to bazel.
+  # We disable pyi-error because of the way the bindings imports work.
+  echo "${@:2}" | \
+    xargs python3 -m pytype --disable=import-error,pyi-error -j $(nproc)
+  EXIT_CODE="$?"
+  echo
+  if [[ "${EXIT_CODE?}" -gt "${1?}" ]]; then
+    return "${EXIT_CODE?}"
+  else
+    return "${1?}"
+  fi
+}
+
+MAX_CODE=0
+
+echo "Checking .py files outside of integrations/"
+check_files "${MAX_CODE?}" "${BASE?}"
+MAX_CODE="$?"
+
+echo "Checking .py files in integrations/tensorflow/bindings/python/pyiree/tf/.*"
+check_files "${MAX_CODE?}" "${IREE_TF?}"
+MAX_CODE="$?"
+
+echo "Checking .py files in integrations/tensorflow/bindings/python/pyiree/xla/.*"
+check_files "${MAX_CODE?}" "${IREE_XLA?}"
+MAX_CODE="$?"
+
+echo "Checking .py files in integrations/tensorflow/compiler/.*"
+check_files "${MAX_CODE?}" "${COMPILER?}"
+MAX_CODE="$?"
+
+echo "Checking .py files in integrations/tensorflow/e2e/.*"
+check_files "${MAX_CODE?}" "${E2E?}"
+MAX_CODE="$?"
+
+
+if [[ "${MAX_CODE?}" -ne "0" ]]; then
+  echo "One or more pytype checks failed."
+  echo "You can view these errors locally by running"
+  echo "  ./build_tools/pytype/check_diff.sh ${DIFF_TARGET?}"
+  exit "${MAX_CODE?}"
+fi
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index b09f95b..111ad46 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -249,7 +249,7 @@
   """Stores the inputs and outputs of a series of calls to a module."""
 
   def __init__(self,
-               module: tf_utils.CompiledModule,
+               module: Union[tf_utils.CompiledModule, None],
                function: Union[Callable[["TracedModule"], None], None],
                _load_dict: Dict[str, Any] = None):
     """Extracts metadata from module and function and initializes.
@@ -563,7 +563,7 @@
     self._module = module
     self._trace = trace
 
-  def _trace_call(self, method: Callable[..., Any], method_name: str):
+  def _trace_call(self, method: tf_utils._FunctionWrapper, method_name: str):
     """Decorates a CompiledModule method to capture its inputs and outputs."""
 
     def call(*args, **kwargs):
@@ -611,8 +611,8 @@
 
 
 def compile_tf_module(
-    module_class: Type[tf.Module], exported_names: Sequence[str] = ()
-) -> Callable[[Any], Any]:
+    module_class: Type[tf.Module],
+    exported_names: Sequence[str] = ()) -> Modules:
   """Compiles module_class to each backend that we test.
 
   Args:
@@ -648,11 +648,10 @@
   return _global_modules
 
 
-def compile_tf_signature_def_saved_model(saved_model_dir: str,
-                                         saved_model_tags: Set[str],
-                                         module_name: str, exported_name: str,
-                                         input_names: Sequence[str],
-                                         output_names: Sequence[str]):
+def compile_tf_signature_def_saved_model(
+    saved_model_dir: str, saved_model_tags: Set[str], module_name: str,
+    exported_name: str, input_names: Sequence[str],
+    output_names: Sequence[str]) -> Modules:
   """Compiles a SignatureDef SavedModel to each backend that we test.
 
   Args:
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index 99573c5..c0f8473 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -87,10 +87,10 @@
 
 
 def _setup_mlir_crash_reproducer(
-    function: Callable[[Any], Any],
+    function: Any,  # pytype doesn't support arbitrary Callable[*args, **kwargs]
     artifacts_dir: str,
     backend_id: str,
-) -> Callable[[Any], Any]:
+) -> Any:  # Callable[Any, Any]
   """Wraps `function` so that it a MLIR crash reproducer is saved if it crashes.
 
   Writes to `artifacts_dir/reproducer__{backend}.mlir` in the case of a crash.
@@ -253,6 +253,16 @@
                          exported_name, artifacts_dir)
 
 
+class _FunctionWrapper(object):
+
+  def __call__(self, *args, **kwargs):
+    raise NotImplementedError()
+
+  def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
+    """Dummy function to match _IreeFunctionWrapper's API."""
+    return ("",), ("",)
+
+
 class CompiledModule(object):
   """Base class for the TF and IREE compiled modules."""
 
@@ -260,7 +270,7 @@
       self,
       module_name: str,
       backend_info: "BackendInfo",
-      compiled_paths: Dict[str, str],
+      compiled_paths: Union[Dict[str, str], None],
   ):
     """Shared base constructor – not useful on its own.
 
@@ -344,6 +354,9 @@
     """
     raise NotImplementedError()
 
+  def __getattr__(self, attr: str) -> _FunctionWrapper:
+    raise NotImplementedError()
+
   def iree_serializable(self):
     return False
 
@@ -351,13 +364,6 @@
     return False
 
 
-class _FunctionWrapper(object):
-
-  def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
-    """Dummy function to match _IreeFunctionWrapper's API."""
-    return (), ()
-
-
 class _IreeFunctionWrapper(_FunctionWrapper):
   """Wraps an IREE function, making it callable."""
 
@@ -681,7 +687,7 @@
   instance = module_class()
   functions = []
   for name in exported_names:
-    functions.append(instance.__getattribute__(name).get_concrete_function())
+    functions.append(getattr(instance, name).get_concrete_function())
   return functions, exported_names
 
 
@@ -787,7 +793,8 @@
   def __init__(self, interpreter: tf.lite.Interpreter):
     self._interpreter = interpreter
 
-  def __call__(self, *args, **kwargs) -> Tuple[Any]:
+  def __call__(self, *args,
+               **kwargs) -> Union[Dict[str, Any], Tuple[Any], np.ndarray]:
     if len(args) and len(kwargs):
       raise ValueError("Passing both args and kwargs is not supported by "
                        "_TfLiteFunctionWrapper")
@@ -823,13 +830,12 @@
         outputs.append(value)
 
     # Process them to match the output of the tf.Module.
-    if not is_dict:
-      outputs = tuple(outputs)
-      if len(outputs) == 1:
-        outputs = outputs[0]
+    if is_dict:
+      return dict(outputs)
     else:
-      outputs = dict(outputs)
-    return outputs
+      if len(outputs) == 1:
+        return outputs[0]
+      return tuple(outputs)
 
 
 class TfLiteCompiledModule(CompiledModule):
diff --git a/integrations/tensorflow/compiler/BUILD b/integrations/tensorflow/compiler/BUILD
index 838e202..fba7d6d 100644
--- a/integrations/tensorflow/compiler/BUILD
+++ b/integrations/tensorflow/compiler/BUILD
@@ -21,6 +21,7 @@
 cc_library(
     name = "tensorflow",
     srcs = [
+        "CheckNoTF.cpp",
         "LegalizeTF.cpp",
         "Passes.cpp",
         "PropagateResourceCasts.cpp",
diff --git a/integrations/tensorflow/compiler/CheckNoTF.cpp b/integrations/tensorflow/compiler/CheckNoTF.cpp
new file mode 100644
index 0000000..246f28a
--- /dev/null
+++ b/integrations/tensorflow/compiler/CheckNoTF.cpp
@@ -0,0 +1,93 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "llvm/Support/FormatVariadic.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
+#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
+#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace {
+
+class CheckNoTensorflow : public PassWrapper<CheckNoTensorflow, FunctionPass> {
+ public:
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<chlo::HloClientDialect, mhlo::MhloDialect,
+                    shape::ShapeDialect, StandardOpsDialect>();
+  }
+
+  CheckNoTensorflow() = default;
+  CheckNoTensorflow(const CheckNoTensorflow &) {}
+
+  /// Validates that no TensorFlow frontends ops are in the function.
+  void runOnFunction() override {
+    auto op = getFunction();
+    auto context = op.getContext();
+
+    Dialect *dialect = context->getLoadedDialect("tf");
+    DenseSet<Operation *> illegalOps;
+    op.walk([&](Operation *op) {
+      if (op->getDialect() == dialect) {
+        illegalOps.insert(op);
+      }
+    });
+
+    if (!illegalOps.empty()) {
+      emitLegalizationErrors(op, illegalOps);
+      return signalPassFailure();
+    }
+  }
+
+  // Emits debug information which includes the number of ops of each type which
+  // failed to legalize.
+  void emitLegalizationErrors(Operation *op,
+                              const DenseSet<Operation *> &nonlegalizedOps) {
+    // Print op errors for each of the TensorFlow ops that still remain.
+    std::map<StringRef, int> opNameCounts;
+    for (Operation *nonlegalizedOp : nonlegalizedOps) {
+      StringRef opName = nonlegalizedOp->getName().getStringRef();
+      opNameCounts[opName]++;
+      nonlegalizedOp->emitOpError()
+          << ": unlegalized TensorFlow op still exists";
+    }
+
+    std::vector<std::string> errorMessages;
+    errorMessages.reserve(opNameCounts.size());
+    for (const auto &opInfo : opNameCounts) {
+      errorMessages.push_back(
+          llvm::formatv("\t{0} (count: {1})", opInfo.first, opInfo.second));
+    }
+    Location loc = op->getLoc();
+    emitError(loc) << "The following Tensorflow operations still remain: \n"
+                   << llvm::join(errorMessages, "\n") << "\n";
+  }
+};
+
+static PassRegistration<CheckNoTensorflow> pass(
+    "iree-check-no-tf", "Check that no TensorFlow frontend ops remain");
+}  // namespace
+
+std::unique_ptr<OperationPass<FuncOp>> createCheckNoTF() {
+  return std::make_unique<CheckNoTensorflow>();
+}
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/integrations/tensorflow/compiler/Passes.cpp b/integrations/tensorflow/compiler/Passes.cpp
index b4b6c47..db2e9d9 100644
--- a/integrations/tensorflow/compiler/Passes.cpp
+++ b/integrations/tensorflow/compiler/Passes.cpp
@@ -64,6 +64,11 @@
   // - It removes tf_saved_model.semantics from the module, which we can only
   //   do at the very end.
   pm.addPass(createTFSavedModelLowerExportedFunctions());
+
+  ////////////////////////////////////////////////////////////////////////////
+  // Validate that all Tensorflow has been legalized away.
+  ////////////////////////////////////////////////////////////////////////////
+  pm.addPass(createCheckNoTF());
 }
 
 static mlir::PassPipelineRegistration<> pipeline(
diff --git a/integrations/tensorflow/compiler/Passes.h b/integrations/tensorflow/compiler/Passes.h
index e3016a3..8b3b219 100644
--- a/integrations/tensorflow/compiler/Passes.h
+++ b/integrations/tensorflow/compiler/Passes.h
@@ -40,6 +40,9 @@
 // Push resource casts forward to better propagate resource related shapes.
 std::unique_ptr<OperationPass<ModuleOp>> createPropagateResourceCasts();
 
+// Validates whether any Tensorflow operations remain.
+std::unique_ptr<OperationPass<FuncOp>> createCheckNoTF();
+
 // Create a single pipeline that will run all the needed IREE-specific TF import
 // passes in the right order.
 void createIreeTfImportPipeline(OpPassManager &pm);
diff --git a/integrations/tensorflow/compiler/test/check-no-tf.mlir b/integrations/tensorflow/compiler/test/check-no-tf.mlir
new file mode 100644
index 0000000..a7c49e6
--- /dev/null
+++ b/integrations/tensorflow/compiler/test/check-no-tf.mlir
@@ -0,0 +1,28 @@
+// RUN: iree-tf-opt %s -iree-check-no-tf -split-input-file -verify-diagnostics
+
+// CHECK-LABEL: func @f
+func @f() -> (tensor<i32>) {
+  // CHECK: [[VAL0:%.+]] = mhlo.constant dense<3>
+  %0 = mhlo.constant dense<3> : tensor<i32>
+  return %0 : tensor<i32>
+}
+
+// -----
+
+// expected-error@+3 {{'tf.Const' op : unlegalized TensorFlow op still exists}}
+// expected-error@below {{The following Tensorflow operations still remain}}
+func @f() -> (tensor<i32>) {
+  %0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
+  return %0 : tensor<i32>
+}
+
+// -----
+
+// expected-error@+4 {{'tf.Const' op : unlegalized TensorFlow op still exists}}
+// expected-error@+4 {{'tf.Add' op : unlegalized TensorFlow op still exists}}
+// expected-error@below {{The following Tensorflow operations still remain}}
+func @f(%arg0 : tensor<i32>) -> (tensor<i32>) {
+  %0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
+  %1 = "tf.Add"(%arg0, %0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  return %1 : tensor<i32>
+}
diff --git a/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py b/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py
index 256ce89..aff5dd8 100644
--- a/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py
+++ b/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py
@@ -23,6 +23,7 @@
 SAVED_MODEL_IMPORT_PASSES = [
     "tf-executor-graph-pruning",
     "tf-standard-pipeline",
+    "iree-xla-legalize-tf",
     "iree-tf-import-pipeline",
     "canonicalize",
 ]
@@ -114,8 +115,8 @@
 # CHECK: attributes
 # CHECK-SAME: iree.module.export
 # CHECK-SAME: iree.reflection = {abi = "sip", abiv = 1 : i32, sip = "I1!R1!"}
-# CHECK-DAG:   [[CONST_2xf32:%.+]] = "tf.Const"() {value = dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32>} : () -> tensor<2xf32>
-# CHECK-DAG:   [[CONST_3xf32:%.+]] = "tf.Const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>} : () -> tensor<3xf32>
+# CHECK-DAG:   [[CONST_2xf32:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00]>
+# CHECK-DAG:   [[CONST_3xf32:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]>
 # CHECK-DAG:   flow.variable.store [[CONST_2xf32]], @v : tensor<2xf32>
 # CHECK-DAG:   flow.variable.store [[CONST_3xf32]], @v : tensor<3xf32>
 # CHECK: FINISH_TEST
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index bc8a998..e919d3e 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -201,31 +201,11 @@
         "tflite": ["mobile_bert_squad_test.py"],
         "iree_vmla": ["mobile_bert_squad_test.py"],
         "iree_llvmjit": ["mobile_bert_squad_test.py"],
-    },
-    reference_backend = "tf",
-    tags = [
-        "external",
-        "guitar",
-        "manual",
-        "no-remote",
-        "nokokoro",
-        "notap",
-    ],
-    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
-        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
-    ],
-)
-
-iree_e2e_test_suite(
-    name = "mobile_bert_squad_tests_failing",
-    size = "enormous",
-    backends_to_srcs = {
         "iree_vulkan": ["mobile_bert_squad_test.py"],
     },
     reference_backend = "tf",
     tags = [
         "external",
-        "failing",
         "guitar",
         "manual",
         "no-remote",
diff --git a/iree/base/BUILD b/iree/base/BUILD
index a1b46ba..1cc8c50 100644
--- a/iree/base/BUILD
+++ b/iree/base/BUILD
@@ -295,8 +295,10 @@
     srcs = ["logging.cc"],
     hdrs = ["logging.h"],
     deps = [
+        ":tracing",
         "@com_google_absl//absl/base:core_headers",
         "@com_google_absl//absl/flags:flag",
+        "@com_google_absl//absl/strings:str_format",
     ],
 )
 
diff --git a/iree/base/CMakeLists.txt b/iree/base/CMakeLists.txt
index 257618b..e6ec69a 100644
--- a/iree/base/CMakeLists.txt
+++ b/iree/base/CMakeLists.txt
@@ -345,8 +345,10 @@
   SRCS
     "logging.cc"
   DEPS
+    ::tracing
     absl::core_headers
     absl::flags
+    absl::str_format
   PUBLIC
 )
 
diff --git a/iree/base/logging.cc b/iree/base/logging.cc
index 51ce60b..3b31a9a 100644
--- a/iree/base/logging.cc
+++ b/iree/base/logging.cc
@@ -17,6 +17,8 @@
 #include <string>
 
 #include "absl/flags/flag.h"
+#include "absl/strings/str_format.h"
+#include "iree/base/tracing.h"
 
 ABSL_FLAG(int, iree_minloglevel, 0,
           "Minimum logging level. 0 = INFO and above.");
@@ -88,6 +90,19 @@
   // TODO(scotttodd): Include current system time
   fprintf(stderr, "%c %s:%d] %s\n", "IWEF"[severity_], file_name_, line_,
           str().c_str());
+
+#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_LOG_MESSAGES
+  constexpr int kLevelColors[4] = {
+      IREE_TRACING_MESSAGE_LEVEL_INFO,     // INFO
+      IREE_TRACING_MESSAGE_LEVEL_WARNING,  // WARNING
+      IREE_TRACING_MESSAGE_LEVEL_ERROR,    // ERROR
+      IREE_TRACING_MESSAGE_LEVEL_ERROR,    // FATAL
+  };
+  std::string message =
+      absl::StrFormat("%s:%d] %s\n", file_name_, line_, str().c_str());
+  IREE_TRACE_MESSAGE_DYNAMIC_COLORED(kLevelColors[severity_], message.c_str(),
+                                     message.size());
+#endif  // IREE_TRACING_FEATURES& IREE_TRACING_FEATURE_LOG_MESSAGES
 }
 
 LogMessageFatal::LogMessageFatal(const char* file, int line)
diff --git a/iree/base/tracing.h b/iree/base/tracing.h
index a3226ad..98db68e 100644
--- a/iree/base/tracing.h
+++ b/iree/base/tracing.h
@@ -77,6 +77,10 @@
 // Tracy UI.
 #define IREE_TRACING_FEATURE_SLOW_LOCKS (1 << 5)
 
+// Forwards log messages to traces, which will be visible under "Messages" in
+// the Tracy UI.
+#define IREE_TRACING_FEATURE_LOG_MESSAGES (1 << 6)
+
 #if !defined(IREE_TRACING_MAX_CALLSTACK_DEPTH)
 // Tracing functions that capture stack traces will only capture up to N frames.
 // The overhead for stack walking scales linearly with the number of frames
@@ -95,7 +99,7 @@
 // overridden it with more specific settings.
 //
 // IREE_TRACING_MODE = 0: tracing disabled
-// IREE_TRACING_MODE = 1: instrumentation and basic statistics
+// IREE_TRACING_MODE = 1: instrumentation, log messages, and basic statistics
 // IREE_TRACING_MODE = 2: same as 1 with added allocation tracking
 // IREE_TRACING_MODE = 3: same as 2 with callstacks for allocations
 // IREE_TRACING_MODE = 4: same as 3 with callstacks for all instrumentation
@@ -105,20 +109,23 @@
 #undef IREE_TRACING_MAX_CALLSTACK_DEPTH
 #define IREE_TRACING_MAX_CALLSTACK_DEPTH 0
 #elif defined(IREE_TRACING_MODE) && IREE_TRACING_MODE == 2
-#define IREE_TRACING_FEATURES             \
-  (IREE_TRACING_FEATURE_INSTRUMENTATION | \
-   IREE_TRACING_FEATURE_ALLOCATION_TRACKING)
-#elif defined(IREE_TRACING_MODE) && IREE_TRACING_MODE == 3
 #define IREE_TRACING_FEATURES                 \
   (IREE_TRACING_FEATURE_INSTRUMENTATION |     \
    IREE_TRACING_FEATURE_ALLOCATION_TRACKING | \
-   IREE_TRACING_FEATURE_ALLOCATION_CALLSTACKS)
+   IREE_TRACING_FEATURE_LOG_MESSAGES)
+#elif defined(IREE_TRACING_MODE) && IREE_TRACING_MODE == 3
+#define IREE_TRACING_FEATURES                   \
+  (IREE_TRACING_FEATURE_INSTRUMENTATION |       \
+   IREE_TRACING_FEATURE_ALLOCATION_TRACKING |   \
+   IREE_TRACING_FEATURE_ALLOCATION_CALLSTACKS | \
+   IREE_TRACING_FEATURE_LOG_MESSAGES)
 #elif defined(IREE_TRACING_MODE) && IREE_TRACING_MODE >= 4
 #define IREE_TRACING_FEATURES                        \
   (IREE_TRACING_FEATURE_INSTRUMENTATION |            \
    IREE_TRACING_FEATURE_INSTRUMENTATION_CALLSTACKS | \
    IREE_TRACING_FEATURE_ALLOCATION_TRACKING |        \
-   IREE_TRACING_FEATURE_ALLOCATION_CALLSTACKS)
+   IREE_TRACING_FEATURE_ALLOCATION_CALLSTACKS |      \
+   IREE_TRACING_FEATURE_LOG_MESSAGES)
 #else
 #define IREE_TRACING_FEATURES 0
 #endif  // IREE_TRACING_MODE
@@ -336,11 +343,21 @@
 // The message text must be a compile-time string literal.
 #define IREE_TRACE_MESSAGE(level, value_literal) \
   ___tracy_emit_messageLC(value_literal, IREE_TRACING_MESSAGE_LEVEL_##level, 0)
+// Logs a message with the given color to the trace.
+// Standard colors are defined as IREE_TRACING_MESSAGE_LEVEL_* values.
+// The message text must be a compile-time string literal.
+#define IREE_TRACE_MESSAGE_COLORED(color, value_literal) \
+  ___tracy_emit_messageLC(value_literal, color, 0)
 // Logs a dynamically-allocated message at the given logging level to the trace.
 // The string |value| will be copied into the trace buffer.
 #define IREE_TRACE_MESSAGE_DYNAMIC(level, value, value_length) \
   ___tracy_emit_messageC(value, value_length,                  \
                          IREE_TRACING_MESSAGE_LEVEL_##level, 0)
+// Logs a dynamically-allocated message with the given color to the trace.
+// Standard colors are defined as IREE_TRACING_MESSAGE_LEVEL_* values.
+// The string |value| will be copied into the trace buffer.
+#define IREE_TRACE_MESSAGE_DYNAMIC_COLORED(color, value, value_length) \
+  ___tracy_emit_messageC(value, value_length, color, 0)
 
 // Utilities:
 #define IREE_TRACE_IMPL_GET_VARIADIC_HELPER_(_1, _2, _3, NAME, ...) NAME
@@ -370,7 +387,9 @@
 #define IREE_TRACE_FRAME_MARK_BEGIN_NAMED(name_literal)
 #define IREE_TRACE_FRAME_MARK_END_NAMED(name_literal)
 #define IREE_TRACE_MESSAGE(level, value_literal)
+#define IREE_TRACE_MESSAGE_COLORED(color, value_literal)
 #define IREE_TRACE_MESSAGE_DYNAMIC(level, value, value_length)
+#define IREE_TRACE_MESSAGE_DYNAMIC_COLORED(color, value, value_length)
 #endif  // IREE_TRACING_FEATURE_INSTRUMENTATION
 
 //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
index 7a788b1..93b8809 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
@@ -130,11 +130,8 @@
 spirv::GlobalVariableOp getOrInsertResourceVariable(Location loc, Type type,
                                                     unsigned set,
                                                     unsigned binding,
-                                                    StringRef funcName,
                                                     Block &block) {
-  auto name =
-      llvm::formatv("__resource_var_{0}_{1}_{2}__", set, binding, funcName)
-          .str();
+  auto name = llvm::formatv("__resource_var_{0}_{1}__", set, binding).str();
   for (auto varOp : block.getOps<spirv::GlobalVariableOp>()) {
     if (varOp.sym_name() == name) return varOp;
   }
@@ -371,10 +368,9 @@
       SymbolTable::lookupNearestSymbolFrom(
           phOp, phOp.getAttrOfType<SymbolRefAttr>("binding")));
 
-  StringRef funcName = phOp.getParentOfType<spirv::FuncOp>().getName();
-  spirv::GlobalVariableOp varOp = getOrInsertResourceVariable(
-      phOp.getLoc(), convertedType, bindingOp.set(), bindingOp.binding(),
-      funcName, *moduleOp.getBody());
+  spirv::GlobalVariableOp varOp =
+      getOrInsertResourceVariable(phOp.getLoc(), convertedType, bindingOp.set(),
+                                  bindingOp.binding(), *moduleOp.getBody());
 
   rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(phOp, varOp);
   return success();
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
index bb6faf8..d751065 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
@@ -63,6 +63,11 @@
       llvm::cl::desc(
           "Enable use of vectorization in SPIR-V code generation pipeline"),
       llvm::cl::init(false)};
+  Option<bool> useVectorPass{
+      *this, "use-vector-pass",
+      llvm::cl::desc("Enable use of Linalg vectorization in SPIR-V code "
+                     "generation pipeline"),
+      llvm::cl::init(false)};
 };
 
 static void addLinalgToSPIRVPasses(OpPassManager &pm,
@@ -94,7 +99,9 @@
   //===--------------------------------------------------------------------===//
   pm.addPass(createSplitDispatchFunctionPass());
   pm.addPass(createLinalgTileAndFusePass(options));
-  pm.addPass(createLoadStoreVectorizationPass());
+  if (options.useVectorPass) {
+    pm.addPass(createLoadStoreVectorizationPass());
+  }
   pm.addPass(createCanonicalizerPass());
 
   //===--------------------------------------------------------------------===//
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
index 502db6a..80805e2 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
@@ -29,6 +29,7 @@
   SmallVector<int64_t, 3> tileSizes = {};
   bool useWorkgroupMemory = false;
   bool useVectorization = false;
+  bool useVectorPass = false;
 };
 
 /// Pass to initialize the function that computes the number of workgroups for
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
index 7721f41..eb2a32e 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
@@ -64,6 +64,7 @@
 
  private:
   void tileAndVectorizeLinalgCopy(FuncOp funcOp, MLIRContext *context);
+  void lowerVectorOps(FuncOp funcOp, MLIRContext *context);
 };
 
 // Common class for all vector to GPU patterns.
@@ -149,8 +150,7 @@
         loc, rewriter.getIndexType(), rewriter.getStringAttr("x"));
     Value index = rewriter.create<AddIOp>(loc, ThreadIndex, indices.back());
     indices.back() = index;
-    rewriter.create<StoreOp>(op.getLoc(), operands[0], operands[1], indices);
-    rewriter.eraseOp(op);
+    rewriter.replaceOpWithNewOp<StoreOp>(op, operands[0], operands[1], indices);
     return success();
   }
 };
@@ -204,11 +204,141 @@
   applyPatternsAndFoldGreedily(funcOp, vectorizationPatterns);
 }
 
+// Convert vector transfer_read to a load if possible. This is the case only if
+// the element type of the memref matches the element type we want to load.
+class VectorTransferReadToLoad
+    : public OpRewritePattern<vector::TransferReadOp> {
+ public:
+  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.getVectorType().getNumElements() != 1 ||
+        op.getMemRefType().getElementType() !=
+            op.getVectorType().getElementType()) {
+      return failure();
+    }
+    auto loc = op.getLoc();
+    Value newOp = rewriter.create<LoadOp>(loc, op.memref(), op.indices());
+    newOp =
+        rewriter.create<vector::BroadcastOp>(loc, op.getVectorType(), newOp);
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
+};
+
+// Convert vector transfer_write to a store if possible. This is the case only
+// if the element type of the memref matches the element type we want to store.
+class VectorTransferWriteToStore
+    : public OpRewritePattern<vector::TransferWriteOp> {
+ public:
+  using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.getVectorType().getNumElements() != 1 ||
+        op.getMemRefType().getElementType() !=
+            op.getVectorType().getElementType()) {
+      return failure();
+    }
+    auto loc = op.getLoc();
+    SmallVector<int64_t, 2> zero(op.getVectorType().getRank(), 0);
+    Value scalarValue =
+        rewriter.create<vector::ExtractOp>(loc, op.vector(), zero);
+    rewriter.create<StoreOp>(loc, scalarValue, op.memref(), op.indices());
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+// Lower vector contract to a single scalar or vector mulf+addf. Insert casts to
+// convert from 2D vector to 1D vector or scalar.
+class VectorContractLowering : public OpRewritePattern<vector::ContractionOp> {
+ public:
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override {
+    auto iteratorTypes = op.iterator_types().getValue();
+    if (iteratorTypes.size() != 3 || !isParallelIterator(iteratorTypes[0]) ||
+        !isParallelIterator(iteratorTypes[1]) ||
+        !isReductionIterator(iteratorTypes[2]) ||
+        !isRowMajorMatmul(op.indexing_maps())) {
+      return failure();
+    }
+    if (op.getLhsType().getNumElements() != 1) return failure();
+    unsigned vecSize = op.getAccType().cast<VectorType>().getNumElements();
+    if (!(vecSize >= 1 && vecSize <= 4)) return failure();
+    auto loc = op.getLoc();
+    VectorType vecType = VectorType::get(
+        vecSize, op.getResultType().cast<VectorType>().getElementType());
+    std::array<int64_t, 2> zero = {0, 0};
+    Value lhs = rewriter.create<vector::ExtractOp>(loc, op.lhs(), zero);
+    Value rhs, acc;
+    if (vecSize == 1) {
+      rhs = rewriter.create<vector::ExtractOp>(loc, op.rhs(), zero);
+      acc = rewriter.create<vector::ExtractOp>(loc, op.acc(), zero);
+    } else {
+      lhs = rewriter.create<vector::BroadcastOp>(loc, vecType, lhs);
+      rhs = rewriter.create<vector::ShapeCastOp>(loc, vecType, op.rhs());
+      acc = rewriter.create<vector::ShapeCastOp>(loc, vecType, op.acc());
+    }
+    Value newOp = rewriter.create<MulFOp>(loc, lhs, rhs);
+    newOp = rewriter.create<AddFOp>(loc, newOp, acc);
+    if (vecSize == 1)
+      newOp =
+          rewriter.create<vector::BroadcastOp>(loc, op.getResultType(), newOp);
+    else
+      newOp =
+          rewriter.create<vector::ShapeCastOp>(loc, op.getResultType(), newOp);
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
+};
+
+// Lower ExtractStridedSliceOp to an ExtractOp instruction that can be natively
+// converted to SPIR-V. Add a BroadcastOp to keep the type consistent, we expect
+// the Broadcast to be removed by canonicalization.
+class ExtractStridedLowering
+    : public OpRewritePattern<vector::ExtractStridedSliceOp> {
+ public:
+  using OpRewritePattern<vector::ExtractStridedSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op,
+                                PatternRewriter &rewriter) const override {
+    // Only handle cases extracting a degenerated vector so that we can generate
+    // an extractOp with scalar destination.
+    if (op.getResult().getType().cast<VectorType>().getNumElements() != 1)
+      return failure();
+    auto loc = op.getLoc();
+    SmallVector<int64_t, 4> offsets = llvm::to_vector<4>(
+        llvm::map_range(op.offsets().getAsRange<IntegerAttr>(),
+                        [](IntegerAttr attr) { return attr.getInt(); }));
+    offsets.resize(op.getVectorType().getRank(), 0);
+    Value newOp = rewriter.create<vector::ExtractOp>(loc, op.vector(), offsets);
+    newOp = rewriter.create<vector::BroadcastOp>(loc, op.getResult().getType(),
+                                                 newOp);
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
+};
+
+// Lower vector ops to instructions that can be later converted to SPIR-V.
+void ConvertVectorToGPUPass::lowerVectorOps(FuncOp funcOp,
+                                            MLIRContext *context) {
+  OwningRewritePatternList patterns;
+  patterns.insert<VectorContractLowering, VectorTransferReadToLoad,
+                  VectorTransferWriteToStore, ExtractStridedLowering>(context);
+  applyPatternsAndFoldGreedily(funcOp, patterns);
+}
+
 void ConvertVectorToGPUPass::runOnOperation() {
   MLIRContext *context = &getContext();
   FuncOp funcOp = getOperation();
   tileAndVectorizeLinalgCopy(funcOp, context);
 
+  lowerVectorOps(funcOp, context);
+
   auto &cooperativeMatrixAnalysis = getAnalysis<CooperativeMatrixAnalysis>();
   OwningRewritePatternList patterns;
   patterns.insert<UnaryAndBinaryOpPattern<AddFOp>, VectorTransferReadConversion,
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
index f6bbc40..ac83d70 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
@@ -23,12 +23,12 @@
 
 module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
 
-  // CHECK: spv.globalVariable @__resource_var_3_4_resource_variable__ bind(3, 4) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
-  // CHECK: spv.globalVariable @__resource_var_1_2_resource_variable__ bind(1, 2) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+  // CHECK: spv.globalVariable @__resource_var_3_4__ bind(3, 4) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+  // CHECK: spv.globalVariable @__resource_var_1_2__ bind(1, 2) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
   // CHECK: spv.func @resource_variable()
   func @resource_variable() {
-    // CHECK: spv._address_of @__resource_var_1_2_resource_variable__ : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
-    // CHECK: spv._address_of @__resource_var_3_4_resource_variable__ : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+    // CHECK: spv._address_of @__resource_var_1_2__ : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+    // CHECK: spv._address_of @__resource_var_3_4__ : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
     %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4x4xf32>
     %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4x4xf32>
     return
@@ -42,24 +42,6 @@
 
 // -----
 
-module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>, vkspv.entry_point_schedule = ["ForwardPass_ex_dispatch_8_dispatch_0", "ForwardPass_ex_dispatch_8_dispatch_1"]} {
-  // CHECK: spv.globalVariable @__resource_var_0_0_dispatch_0__ bind(0, 0) : !spv.ptr<!spv.struct<(!spv.array<1 x vector<4xf32>, stride=16> [0])>, StorageBuffer>
-  // CHECK: spv.globalVariable @__resource_var_0_0_dispatch_1__ bind(0, 0) : !spv.ptr<!spv.struct<(!spv.array<4 x f32, stride=4> [0])>, StorageBuffer>
-  func @dispatch_1() attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}, vkspv.num_workgroups_fn = @ForwardPass_ex_dispatch_8_dispatch_1__num_workgroups__} {
-    %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<4xf32>
-    return
-  }
-  func @dispatch_0() attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}, vkspv.num_workgroups_fn = @ForwardPass_ex_dispatch_8_dispatch_0__num_workgroups__} {
-    %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<1xvector<4xf32>>
-    return
-  }
-  hal.interface @legacy_io attributes {sym_visibility = "private"} {
-    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-  }
-}
-
-// -----
-
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/vector_to_gpu.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/vector_to_gpu.mlir
index 30d94c5..99ec424 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/vector_to_gpu.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/vector_to_gpu.mlir
@@ -59,3 +59,75 @@
     // CHECK:   %[[LOAD:.+]] = vector.transfer_read %[[SVs]][%c0, %c0], %cst {{.*}} : memref<1x4xf32, {{.*}}>, vector<1x4xf32>
     // CHECK:   vector.transfer_write %[[LOAD]], %[[SVd]][%[[C0]], %[[C0]]] {{.*}} : vector<1x4xf32>, memref<1x4xf32
 }
+
+// -----
+
+module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+  func @transfer_ops(%arg0: memref<32x32xf32>, %arg1 : vector<1x1xf32>) -> vector<1x1xf32> attributes {spv.entry_point_abi = {local_size = dense<[128, 1, 1]> : vector<3xi32>}} {
+  %c0 = constant 0 : index
+  %cst = constant 0.0 : f32
+  %0 = vector.transfer_read %arg0[%c0, %c0], %cst : memref<32x32xf32>, vector<1x1xf32>
+  vector.transfer_write %arg1, %arg0[%c0, %c0] : vector<1x1xf32>, memref<32x32xf32>
+  return %0 : vector<1x1xf32>
+  }
+  // CHECK-LABEL: func @transfer_ops
+  //  CHECK-SAME: (%[[ARG0:.*]]: memref<32x32xf32>, %[[ARG1:.*]]: vector<1x1xf32>
+  //       CHECK:   %[[C0:.*]] = constant 0 : index
+  //       CHECK:   %[[LOAD:.*]] = load %[[ARG0]][%[[C0]], %[[C0]]] : memref<32x32xf32>
+  //       CHECK:   %[[B:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<1x1xf32>
+  //       CHECK:   %[[EXT:.*]] = vector.extract %[[ARG1]][0, 0] : vector<1x1xf32>
+  //       CHECK:   store %[[EXT]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<32x32xf32>
+  //       CHECK:   return %[[B]] : vector<1x1xf32>
+}
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+  func @contract_ops(%arg0 : vector<1x1xf32>, %arg1 : vector<1x4xf32>,
+                    %arg2 : vector<1x4xf32>, %arg3 : vector<1x1xf32>,
+                    %arg4 : vector<1x1xf32>) -> (vector<1x1xf32>, vector<1x4xf32>) attributes {spv.entry_point_abi = {local_size = dense<[128, 1, 1]> : vector<3xi32>}} {
+  %0 = vector.contract {indexing_maps = [#map0, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"]} %arg0, %arg3, %arg4
+                      : vector<1x1xf32>, vector<1x1xf32> into vector<1x1xf32>
+  %1 = vector.contract {indexing_maps = [#map0, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"]} %arg0, %arg1, %arg2
+                      : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+  return %0, %1 : vector<1x1xf32>, vector<1x4xf32>
+  }
+  // CHECK-LABEL: func @contract_ops
+  //  CHECK-SAME: (%[[ARG0:.*]]: vector<1x1xf32>, %[[ARG1:.*]]: vector<1x4xf32>, %[[ARG2:.*]]: vector<1x4xf32>, %[[ARG3:.*]]: vector<1x1xf32>, %[[ARG4:.*]]: vector<1x1xf32>)
+  //       CHECK:   %[[A:.*]] = vector.extract %[[ARG0]][0, 0] : vector<1x1xf32>
+  //       CHECK:   %[[B:.*]] = vector.extract %[[ARG3]][0, 0] : vector<1x1xf32>
+  //       CHECK:   %[[C:.*]] = vector.extract %[[ARG4]][0, 0] : vector<1x1xf32>
+  //       CHECK:   %[[MUL:.*]] = mulf %[[A]], %[[B]] : f32
+  //       CHECK:   %[[ADD:.*]] = addf %[[MUL]], %[[C]] : f32
+  //       CHECK:   %[[R0:.*]] = vector.broadcast %[[ADD]] : f32 to vector<1x1xf32>
+  //       CHECK:   %[[A:.*]] = vector.extract %[[ARG0]][0, 0] : vector<1x1xf32>
+  //       CHECK:   %[[VA:.*]] = vector.broadcast %[[A]] : f32 to vector<4xf32>
+  //       CHECK:   %[[VB:.*]] = vector.shape_cast %[[ARG1]] : vector<1x4xf32> to vector<4xf32>
+  //       CHECK:   %[[VC:.*]] = vector.shape_cast %[[ARG2]] : vector<1x4xf32> to vector<4xf32>
+  //       CHECK:   %[[VMUL:.*]] = mulf %[[VA]], %[[VB]] : vector<4xf32>
+  //       CHECK:   %[[VADD:.*]] = addf %[[VMUL]], %[[VC]] : vector<4xf32>
+  //       CHECK:   %[[R1:.*]] = vector.shape_cast %[[VADD]] : vector<4xf32> to vector<1x4xf32>
+  //       CHECK:   return %[[R0]], %[[R1]] : vector<1x1xf32>, vector<1x4xf32>
+}
+
+// -----
+
+module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+  func @extract(%arg0 : vector<1x4xf32>) -> vector<1x1xf32> attributes {spv.entry_point_abi = {local_size = dense<[128, 1, 1]> : vector<3xi32>}} {
+    %0 = vector.extract_strided_slice %arg0
+      {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]}
+        : vector<1x4xf32> to vector<1x1xf32>
+    return %0 : vector<1x1xf32>
+  }
+  // CHECK-LABEL: func @extract
+  //  CHECK-SAME: (%[[ARG0:.*]]: vector<1x4xf32>
+  //       CHECK:   %[[A:.*]] = vector.extract %[[ARG0]][0, 2] : vector<1x4xf32>
+  //       CHECK:   %[[B:.*]] = vector.broadcast %[[A]] : f32 to vector<1x1xf32>
+  //       CHECK:   return %[[B]] : vector<1x1xf32>
+}
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
index ea4c70c..91de450c 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -48,6 +48,12 @@
   // llvm::cl::OptionCategory halVulkanSPIRVOptionsCategory(
   //     "IREE Vulkan/SPIR-V backend options");
 
+  static llvm::cl::opt<bool> clUseVectorPass(
+      "iree-spirv-use-vector-pass",
+      llvm::cl::desc(
+          "Enable use of Linalg vectorization in SPIR-V code generation"),
+      llvm::cl::init(false));
+
   static llvm::cl::opt<bool> clUseWorkgroupMemory(
       "iree-spirv-use-workgroup-memory",
       llvm::cl::desc(
@@ -81,6 +87,7 @@
   targetOptions.codegenOptions.tileSizes.assign(clTileSizes.begin(),
                                                 clTileSizes.end());
   targetOptions.codegenOptions.useWorkgroupMemory = clUseWorkgroupMemory;
+  targetOptions.codegenOptions.useVectorPass = clUseVectorPass;
   if (!clVulkanTargetEnv.empty()) {
     targetOptions.vulkanTargetEnv = clVulkanTargetEnv;
   } else {
diff --git a/iree/test/e2e/vulkan_specific/BUILD b/iree/test/e2e/vulkan_specific/BUILD
index 9686098..f4741ac 100644
--- a/iree/test/e2e/vulkan_specific/BUILD
+++ b/iree/test/e2e/vulkan_specific/BUILD
@@ -45,3 +45,15 @@
     driver = "vulkan",
     target_backend = "vulkan-spirv",
 )
+
+iree_check_single_backend_test_suite(
+    name = "check_vulkan-spirv_vulkan_vector",
+    srcs = [
+        "compare.mlir",
+        "log_plus_one.mlir",
+        "pw_add_multiwg.mlir",
+    ],
+    compiler_flags = ["-iree-spirv-use-vector-pass"],
+    driver = "vulkan",
+    target_backend = "vulkan-spirv",
+)
diff --git a/iree/test/e2e/vulkan_specific/CMakeLists.txt b/iree/test/e2e/vulkan_specific/CMakeLists.txt
index 4260b07..d5bd481 100644
--- a/iree/test/e2e/vulkan_specific/CMakeLists.txt
+++ b/iree/test/e2e/vulkan_specific/CMakeLists.txt
@@ -41,3 +41,18 @@
   COMPILER_FLAGS
     "-iree-spirv-use-workgroup-memory"
 )
+
+iree_check_single_backend_test_suite(
+  NAME
+    check_vulkan-spirv_vulkan_vector
+  SRCS
+    "compare.mlir"
+    "log_plus_one.mlir"
+    "pw_add_multiwg.mlir"
+  TARGET_BACKEND
+    vulkan-spirv
+  DRIVER
+    vulkan
+  COMPILER_FLAGS
+    "-iree-spirv-use-vector-pass"
+)
diff --git a/iree/test/e2e/vulkan_specific/compare.mlir b/iree/test/e2e/vulkan_specific/compare.mlir
new file mode 100644
index 0000000..099670a
--- /dev/null
+++ b/iree/test/e2e/vulkan_specific/compare.mlir
@@ -0,0 +1,164 @@
+func @compare_tensor() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+  %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+  check.expect_eq_const(%output, dense<[0, 1, 0, 1]> : tensor<4xi8>) : tensor<4xi8>
+  return
+}
+
+func @compare_scalar() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<1> : tensor<i32>
+  %rhs = iree.unfoldable_constant dense<5> : tensor<i32>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+  check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+  return
+}
+
+func @compare_i8() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<1> : tensor<i8>
+  %rhs = iree.unfoldable_constant dense<5> : tensor<i8>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<i8>, tensor<i8>) -> tensor<i1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+  check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+  return
+}
+
+func @compare_i16() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<1> : tensor<i16>
+  %rhs = iree.unfoldable_constant dense<5> : tensor<i16>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<i16>, tensor<i16>) -> tensor<i1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+  check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+  return
+}
+
+func @compare_i32() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<1> : tensor<i32>
+  %rhs = iree.unfoldable_constant dense<5> : tensor<i32>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+  check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+  return
+}
+
+func @compare_i64() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<1> : tensor<i64>
+  %rhs = iree.unfoldable_constant dense<5> : tensor<i64>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+  check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+  return
+}
+
+func @compare_f32() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<1.0> : tensor<f32>
+  %rhs = iree.unfoldable_constant dense<5.0> : tensor<f32>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+  check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+  return
+}
+
+func @compare_f64() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<1.0> : tensor<f64>
+  %rhs = iree.unfoldable_constant dense<5.0> : tensor<f64>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<f64>, tensor<f64>) -> tensor<i1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+  check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+  return
+}
+
+func @compare_tensor_odd_length() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<[1, 2, 7]> : tensor<3xi32>
+  %rhs = iree.unfoldable_constant dense<[5, 2, 3]> : tensor<3xi32>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<3xi8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<3xi8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<3xi1>, tensor<3xi8>, tensor<3xi8>) -> tensor<3xi8>
+  check.expect_eq_const(%output, dense<[0, 1, 0]> : tensor<3xi8>) : tensor<3xi8>
+  return
+}
+
+func @compare_eq() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+  %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+  check.expect_eq_const(%output, dense<[0, 1, 0, 1]> : tensor<4xi8>) : tensor<4xi8>
+  return
+}
+
+func @compare_ne() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+  %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+  check.expect_eq_const(%output, dense<[1, 0, 1, 0]> : tensor<4xi8>) : tensor<4xi8>
+  return
+}
+
+func @compare_lt() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+  %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+  check.expect_eq_const(%output, dense<[1, 0, 0, 0]> : tensor<4xi8>) : tensor<4xi8>
+  return
+}
+
+func @compare_le() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+  %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+  check.expect_eq_const(%output, dense<[1, 1, 0, 1]> : tensor<4xi8>) : tensor<4xi8>
+  return
+}
+
+func @compare_gt() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+  %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+  check.expect_eq_const(%output, dense<[0, 0, 1, 0]> : tensor<4xi8>) : tensor<4xi8>
+  return
+}
+
+func @compare_ge() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+  %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+  check.expect_eq_const(%output, dense<[0, 1, 1, 1]> : tensor<4xi8>) : tensor<4xi8>
+  return
+}
diff --git a/iree/test/e2e/vulkan_specific/log_plus_one.mlir b/iree/test/e2e/vulkan_specific/log_plus_one.mlir
new file mode 100644
index 0000000..3bba4a9
--- /dev/null
+++ b/iree/test/e2e/vulkan_specific/log_plus_one.mlir
@@ -0,0 +1,6 @@
+func @log_plus_one() attributes { iree.module.export } {
+  %input = iree.unfoldable_constant dense<[0.0, 0.5, 1.0, 5.0]> : tensor<4xf32>
+  %result = "mhlo.log_plus_one"(%input) : (tensor<4xf32>) -> tensor<4xf32>
+  check.expect_almost_eq_const(%result, dense<[0.0, 0.4054651, 0.6931472, 1.7917595]> : tensor<4xf32>) : tensor<4xf32>
+  return
+}
diff --git a/scripts/get_e2e_artifacts.py b/scripts/get_e2e_artifacts.py
index 5cb6f8e..7ed2691 100755
--- a/scripts/get_e2e_artifacts.py
+++ b/scripts/get_e2e_artifacts.py
@@ -15,8 +15,12 @@
 # limitations under the License.
 """Runs all E2E TensorFlow tests and extracts their benchmarking artifacts.
 
-Example usage:
-  python3 get_e2e_artifacts.py
+Example usages:
+  # Run all test suites and collect their artifacts:
+  python3 ./scripts/get_e2e_artifacts.py
+
+  # Run the e2e_tests test suite and collect its artifacts:
+  python3 ./scripts/get_e2e_artifacts.py --test_suites=e2e_tests
 """
 
 import fileinput
@@ -39,8 +43,10 @@
         '//integrations/tensorflow/e2e:mobile_bert_squad_tests',
     'keras_tests':
         '//integrations/tensorflow/e2e/keras:keras_tests',
-    'vision_external_tests':
-        '//integrations/tensorflow/e2e/keras:vision_external_tests',
+    'imagenet_external_tests':
+        '//integrations/tensorflow/e2e/keras:imagenet_external_tests',
+    'slim_vision_tests':
+        '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests',
 }
 SUITES_HELP = [f'`{name}`' for name in SUITE_NAME_TO_TARGET]
 SUITES_HELP = f'{", ".join(SUITES_HELP[:-1])} and {SUITES_HELP[-1]}'
@@ -118,6 +124,11 @@
                       paths_to_tests: Dict[str, str]):
   """Unzips all of the benchmarking artifacts for a given test and backend."""
   outputs = os.path.join(test_path, 'test.outputs', 'outputs.zip')
+  if FLAGS.dry_run and not os.path.exists(outputs):
+    # The artifacts may or may not be present on disk during a dry run. If they
+    # are then we want to collision check them, but if they aren't that's fine.
+    return
+
   archive = zipfile.ZipFile(outputs)
   # Filter out directory names.
   filenames = [name for name in archive.namelist() if name[-1] != os.sep]
@@ -139,17 +150,21 @@
   # Convert test suite shorthands to full test suite targets.
   test_suites = [SUITE_NAME_TO_TARGET[suite] for suite in FLAGS.test_suites]
 
+  if FLAGS.run_test_suites:
+    # Use bazel test to execute all of the test suites in parallel.
+    command = [
+        'bazel', 'test', *test_suites, '--color=yes',
+        '--test_arg=--get_saved_model'
+    ]
+    print(f'Running: `{" ".join(command)}`')
+    if not FLAGS.dry_run:
+      subprocess.check_call(command)
+    print()
+
   written_paths = set()
   paths_to_tests = dict()
 
   for test_suite in test_suites:
-    if FLAGS.run_test_suites and not FLAGS.dry_run:
-      subprocess.check_call([
-          'bazel', 'test', test_suite, '--color=yes',
-          '--test_arg=--get_saved_model'
-      ])
-      print()
-
     # Extract all of the artifacts for this test suite.
     test_paths, test_names = get_test_paths_and_names(test_suite)
     for i, (test_path, test_name) in enumerate(zip(test_paths, test_names)):