Merge pull request #2504 from ScottTodd:main-to-google

PiperOrigin-RevId: 321003958
diff --git a/.bazelrc b/.bazelrc
index dfafd82..b53fe0a 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -43,6 +43,11 @@
 # either clang or gcc and are curated based on need.
 ###############################################################################
 
+# Treat warnings in-workspace as errors.
+build:generic_clang --per_file_copt=-external/.*-Werror
+# ...and silence them outside of the workspace.
+build:generic_clang --per_file_copt=external/.*@-w
+
 # Disable warnings we don't care about.
 build:generic_clang --copt=-Wno-unused-local-typedef
 build:generic_clang --copt=-Wno-unused-private-field
diff --git a/.github/workflows/google_to_main.yml b/.github/workflows/google_to_main.yml
new file mode 100644
index 0000000..dd85178
--- /dev/null
+++ b/.github/workflows/google_to_main.yml
@@ -0,0 +1,62 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Creates a PR to merge the `google` branch into the `main` branch.
+
+name: google -> main
+
+on:
+  schedule:
+    # Every 6 hours at 45 minutes past the hour (to be a bit behind the TF submodule update)
+    # 00:45, 06:45, 12:45, 18:45 UTC (04:45, 10:45, 16:45, 22:45 PST)
+    - cron: '45 */6 * * *'
+
+jobs:
+  google_to_main:
+    # Don't run this in everyone's forks.
+    if: github.repository == 'google/iree'
+    runs-on: ubuntu-18.04
+    steps:
+      - name: Checking out repository
+        uses: actions/checkout@v2
+        with:
+          ref: "google"
+      # We have to explicitly fetch the main branch as well
+      - name: Fetching Base Branch
+        run: git fetch --no-tags --prune --depth=1 origin main
+      - name: Checking for a diff
+        run: |
+          echo "::set-env name=has_diff::false"
+          git diff main --exit-code || echo "::set-env name=has_diff::true"
+      - name: Calculating PR body
+        if: env.has_diff == 'true'
+        run: |
+          echo "::set-env name=pr_body::$(git log main.. --decorate=no --pretty='format:* %h %<(80,trunc)%s')"
+      - name: Initializing submodules
+        if: env.has_diff == 'true'
+        run: ./scripts/git/submodule_versions.py init
+      - name: Creating Pull Request
+        if: env.has_diff == 'true'
+        uses: peter-evans/create-pull-request@v2
+        with:
+          # Personal token is required to trigger additional automation (e.g. presubmits).
+          token: ${{ secrets.GITHUB_WRITE_ACCESS_TOKEN }}
+          commit-message: "Merge google -> main"
+          title: "Merge google -> main"
+          body: "${{ env.pr_body }}"
+          committer: "Integrate Action <iree-github-actions-bot@google.com>"
+          # TODO(gcmn): Figure out a way to assign this to someone dynamically.
+          reviewers: gmngeoffrey
+          branch: "google-to-main"
+          base: "main"
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index 43c6a69..45cf1bd 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -65,3 +65,17 @@
         uses: ibiqlik/action-yamllint@v1
         with:
           strict: true
+
+  buildifier:
+    runs-on: ubuntu-18.04
+    steps:
+      - name: Checking out repository
+        uses: actions/checkout@v2
+      - name: Running Buildifier
+        # TODO(gcmn): Look into only running on changed files.
+        uses: thompsonja/bazel-buildifier@v0.2.1
+        with:
+          excludes: "*/third_party/*"
+          # For compatibility with Google's internal source repository, we still
+          # use this.
+          warnings: "-native-cc"
diff --git a/.github/workflows/update_tf.yml b/.github/workflows/update_tf.yml
index a3ad225..eb3b836 100644
--- a/.github/workflows/update_tf.yml
+++ b/.github/workflows/update_tf.yml
@@ -36,14 +36,19 @@
         run: ./scripts/git/submodule_versions.py init
       - name: Updating submodules
         run: ./scripts/git/update_tf_llvm_submodules.py --llvm_commit=KEEP --update_build_files=true
+      - name: Calculating TF SHA
+        run: echo "::set-env name=TF_SHA::$(git submodule status third_party/tensorflow | cut -c -12)"
       - name: Creating Pull Request
         uses: peter-evans/create-pull-request@v2
         with:
           # Personal token is required to trigger additional automation (e.g. presubmits).
           token: ${{ secrets.GITHUB_WRITE_ACCESS_TOKEN }}
           commit-message: "Update TF submodule and LLVM BUILD files"
-          title: "Update TF submodule and LLVM BUILD files"
-          body: "Automated submodule bump from .github/workflows/update_tf.yml"
+          title: "Integrate TF at https://github.com/tensorflow/tensorflow/commit/${{ env.TF_SHA }}"
+          body: |
+            "Updates TF to current HEAD and copies latest version of LLVM BUILD files.
+
+            Automated submodule bump from .github/workflows/update_tf.yml"
           committer: "Submodule Update Action <iree-github-actions-bot@google.com>"
           # TODO(gcmn): Figure out a way to assign this to someone dynamically.
           reviewers: gmngeoffrey
diff --git a/README.md b/README.md
index 06f88b8..e6de0e2 100644
--- a/README.md
+++ b/README.md
@@ -110,12 +110,13 @@
 
 ## Build Status
 
-CI System | Platform | Build System | Component       | Status
-:-------: | :------: | :----------: | :-------------: | :----:
-Kokoro    | Linux    | Bazel        | Core            | [![kokoro-status-linux-bazel-core](https://storage.googleapis.com/iree-oss-build-badges/linux/bazel/core/main_status.svg)](https://storage.googleapis.com/iree-oss-build-badges/linux/bazel/core/main_result.html)
-Kokoro    | Linux    | Bazel        | Bindings        | [![kokoro-status-linux-bazel-bindings](https://storage.googleapis.com/iree-oss-build-badges/linux/bazel/bindings/main_status.svg)](https://storage.googleapis.com/iree-oss-build-badges/linux/bazel/bindings/main_result.html)
-Kokoro    | Linux    | Bazel        | Integrations    | [![kokoro-status-linux-bazel-integrations](https://storage.googleapis.com/iree-oss-build-badges/linux/bazel/integrations/main_status.svg)](https://storage.googleapis.com/iree-oss-build-badges/linux/bazel/integrations/main_result.html)
-Kokoro    | Linux    | CMake        | Core + Bindings | [![kokoro-status-linux-cmake](https://storage.googleapis.com/iree-oss-build-badges/linux/cmake/main_status.svg)](https://storage.googleapis.com/iree-oss-build-badges/linux/cmake/main_result.html)
+CI System | Platform | Build System | Component            | Status
+:-------: | :------: | :----------: | :------------------: | :----:
+Kokoro    | Linux    | Bazel        | Core                 | [![kokoro-status-linux-bazel-core](https://storage.googleapis.com/iree-oss-build-badges/linux/bazel/core/main_status.svg)](https://storage.googleapis.com/iree-oss-build-badges/linux/bazel/core/main_result.html)
+Kokoro    | Linux    | Bazel        | Bindings             | [![kokoro-status-linux-bazel-bindings](https://storage.googleapis.com/iree-oss-build-badges/linux/bazel/bindings/main_status.svg)](https://storage.googleapis.com/iree-oss-build-badges/linux/bazel/bindings/main_result.html)
+Kokoro    | Linux    | Bazel        | Integrations         | [![kokoro-status-linux-bazel-integrations](https://storage.googleapis.com/iree-oss-build-badges/linux/bazel/integrations/main_status.svg)](https://storage.googleapis.com/iree-oss-build-badges/linux/bazel/integrations/main_result.html)
+Kokoro    | Linux    | CMake        | Core + Bindings      | [![kokoro-status-linux-cmake](https://storage.googleapis.com/iree-oss-build-badges/linux/cmake/main_status.svg)](https://storage.googleapis.com/iree-oss-build-badges/linux/cmake/main_result.html)
+Kokoro    | Android  | CMake        | Runtime (build only) | [![kokoro-status-android-cmake](https://storage.googleapis.com/iree-oss-build-badges/cmake/android/arm64-v8a/google_status.svg)](https://storage.googleapis.com/iree-oss-build-badges/cmake/android/arm64-v8a/google_result.html)
 
 ## License
 
diff --git a/WORKSPACE b/WORKSPACE
index 9663fd4..d744421 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -27,17 +27,19 @@
     sha256 = "cf3b76a90c86c0554c5b10f4b160f05af71d252026b71362c4674e2fb9936cf9",
     strip_prefix = "rules_cc-01d4a48911d5e7591ecb1c06d3b8af47fe872371",
     urls = [
-      "https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_cc/archive/01d4a48911d5e7591ecb1c06d3b8af47fe872371.zip",
-      "https://github.com/bazelbuild/rules_cc/archive/01d4a48911d5e7591ecb1c06d3b8af47fe872371.zip",
+        "https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_cc/archive/01d4a48911d5e7591ecb1c06d3b8af47fe872371.zip",
+        "https://github.com/bazelbuild/rules_cc/archive/01d4a48911d5e7591ecb1c06d3b8af47fe872371.zip",
     ],
 )
 
 http_archive(
     name = "rules_python",
-    url = "https://github.com/bazelbuild/rules_python/releases/download/0.0.1/rules_python-0.0.1.tar.gz",
     sha256 = "aa96a691d3a8177f3215b14b0edc9641787abaaa30363a080165d06ab65e1161",
+    url = "https://github.com/bazelbuild/rules_python/releases/download/0.0.1/rules_python-0.0.1.tar.gz",
 )
+
 load("@rules_python//python:repositories.bzl", "py_repositories")
+
 py_repositories()
 
 ###############################################################################
@@ -46,24 +48,24 @@
 # bazel toolchains rules for remote execution (https://releases.bazel.build/bazel-toolchains.html).
 http_archive(
     name = "bazel_toolchains",
+    # Workaround for b/150158570. This patch needs to be updated if the bazel_toolchains version is updated.
+    patches = ["//build_tools/bazel:bazel_toolchains.patch"],
     sha256 = "4d348abfaddbcee0c077fc51bb1177065c3663191588ab3d958f027cbfe1818b",
     strip_prefix = "bazel-toolchains-2.1.0",
     urls = [
         "https://github.com/bazelbuild/bazel-toolchains/releases/download/2.1.0/bazel-toolchains-2.1.0.tar.gz",
         "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/2.1.0.tar.gz",
     ],
-    # Workaround for b/150158570. This patch needs to be updated if the bazel_toolchains version is updated.
-    patches = ["//build_tools/bazel:bazel_toolchains.patch"],
 )
 
 load("@bazel_toolchains//rules:rbe_repo.bzl", "rbe_autoconfig")
 
 rbe_autoconfig(
     name = "rbe_default",
-    base_container_digest = 'sha256:ac36d37616b044ee77813fc7cd36607a6dc43c65357f3e2ca39f3ad723e426f6',
+    base_container_digest = "sha256:ac36d37616b044ee77813fc7cd36607a6dc43c65357f3e2ca39f3ad723e426f6",
     digest = "sha256:61fd698572dc8b5fc9db11cb4ba4f138a915517b80617259bcaef8e1e4ffd3fb",
     registry = "gcr.io",
-    repository = "iree-oss/rbe-toolchain"
+    repository = "iree-oss/rbe-toolchain",
 )
 
 ###############################################################################
@@ -90,23 +92,27 @@
 #   TensorFlow
 http_archive(
     name = "bazel_skylib",
+    sha256 = "97e70364e9249702246c0e9444bccdc4b847bed1eb03c5a3ece4f83dfe6abc44",
     urls = [
         "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz",
         "https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz",
     ],
-    sha256 = "97e70364e9249702246c0e9444bccdc4b847bed1eb03c5a3ece4f83dfe6abc44",
 )
+
 load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace")
+
 bazel_skylib_workspace()
 ###############################################################################
 
 ###############################################################################
 # llvm-project
 load("@iree_core//build_tools/bazel/third_party_import/llvm-project:configure.bzl", "llvm_configure")
-maybe(llvm_configure,
+
+maybe(
+    llvm_configure,
     name = "llvm-project",
-    workspace = "@iree_core//:WORKSPACE",
     path = "third_party/llvm-project",
+    workspace = "@iree_core//:WORKSPACE",
 )
 ###############################################################################
 
@@ -118,7 +124,8 @@
 # TODO(laurenzo): Come up with a way to make this optional. Also, see if we can
 # get the TensorFlow tf_repositories() rule to use maybe() so we can provide
 # local overrides safely.
-maybe(local_repository,
+maybe(
+    local_repository,
     name = "org_tensorflow",
     path = "third_party/tensorflow",
 )
@@ -133,7 +140,8 @@
 
 # TODO(laurenzo): Scoping to "iree" to avoid conflicts with other things that
 # take an opinion until we can isolate.
-maybe(python_configure,
+maybe(
+    python_configure,
     name = "iree_native_python",
 )
 ###############################################################################
@@ -141,100 +149,118 @@
 ###############################################################################
 # Find and configure the Vulkan SDK, if installed.
 load("//build_tools/third_party/vulkan_sdk:repo.bzl", "vulkan_sdk_setup")
-maybe(vulkan_sdk_setup,
+
+maybe(
+    vulkan_sdk_setup,
     name = "vulkan_sdk",
 )
 ###############################################################################
 
-maybe(local_repository,
-     name = "com_google_absl",
-     path = "third_party/abseil-cpp",
+maybe(
+    local_repository,
+    name = "com_google_absl",
+    path = "third_party/abseil-cpp",
 )
 
-maybe(local_repository,
-     name = "com_google_ruy",
-     path = "third_party/ruy",
+maybe(
+    local_repository,
+    name = "com_google_ruy",
+    path = "third_party/ruy",
 )
 
-maybe(local_repository,
-     name = "com_google_googletest",
-     path = "third_party/googletest",
+maybe(
+    local_repository,
+    name = "com_google_googletest",
+    path = "third_party/googletest",
 )
 
 # Note that TensorFlow provides this as "flatbuffers" which is wrong.
 # It is only used for TFLite and may cause ODR issues if not fixed.
-maybe(local_repository,
+maybe(
+    local_repository,
     name = "com_github_google_flatbuffers",
     path = "third_party/flatbuffers",
 )
 
 # TODO(scotttodd): TensorFlow is squatting on the vulkan_headers repo name, so
 # we use a temporary one until resolved. Theirs is set to an outdated version.
-maybe(new_local_repository,
+maybe(
+    new_local_repository,
     name = "iree_vulkan_headers",
-    path = "third_party/vulkan_headers",
     build_file = "build_tools/third_party/vulkan_headers/BUILD.overlay",
+    path = "third_party/vulkan_headers",
 )
 
-maybe(new_local_repository,
+maybe(
+    new_local_repository,
     name = "vulkan_memory_allocator",
-    path = "third_party/vulkan_memory_allocator",
     build_file = "build_tools/third_party/vulkan_memory_allocator/BUILD.overlay",
+    path = "third_party/vulkan_memory_allocator",
 )
 
-maybe(new_local_repository,
+maybe(
+    new_local_repository,
     name = "glslang",
-    path = "third_party/glslang",
     build_file = "build_tools/third_party/glslang/BUILD.overlay",
+    path = "third_party/glslang",
 )
 
-maybe(local_repository,
+maybe(
+    local_repository,
     name = "spirv_tools",
     path = "third_party/spirv_tools",
 )
 
-maybe(local_repository,
+maybe(
+    local_repository,
     name = "spirv_headers",
     path = "third_party/spirv_headers",
 )
 
 # TODO(laurenzo): TensorFlow is squatting on the pybind11 repo name, so
 # we use a temporary one until resolved. Theirs pulls in a bunch of stuff.
-maybe(new_local_repository,
+maybe(
+    new_local_repository,
     name = "iree_pybind11",
-    path = "third_party/pybind11",
     build_file = "build_tools/third_party/pybind11/BUILD.overlay",
+    path = "third_party/pybind11",
 )
 
-maybe(local_repository,
+maybe(
+    local_repository,
     name = "com_google_benchmark",
-    path = "third_party/benchmark")
-
-maybe(new_local_repository,
-    name = "sdl2",
-    path = "third_party/sdl2",
-    build_file = "build_tools/third_party/sdl2/BUILD.overlay",
+    path = "third_party/benchmark",
 )
 
-maybe(new_local_repository,
+maybe(
+    new_local_repository,
+    name = "sdl2",
+    build_file = "build_tools/third_party/sdl2/BUILD.overlay",
+    path = "third_party/sdl2",
+)
+
+maybe(
+    new_local_repository,
     name = "sdl2_config",
-    path = "build_tools/third_party/sdl2",
     build_file_content = """
 package(default_visibility = ["//visibility:public"])
 cc_library(name = "headers", srcs = glob(["*.h"]))
 """,
+    path = "build_tools/third_party/sdl2",
 )
 
-maybe(new_local_repository,
+maybe(
+    new_local_repository,
     name = "dear_imgui",
-    path = "third_party/dear_imgui",
     build_file = "build_tools/third_party/dear_imgui/BUILD.overlay",
+    path = "third_party/dear_imgui",
 )
 
-maybe(new_local_repository,
+maybe(
+    new_local_repository,
     name = "renderdoc_api",
-    path = "third_party/renderdoc_api",
     build_file = "build_tools/third_party/renderdoc_api/BUILD.overlay",
+    path = "third_party/renderdoc_api",
 )
 
 # Bootstrap TensorFlow deps last so that ours can take precedence.
diff --git a/build_tools/embed_data/BUILD b/build_tools/embed_data/BUILD
index 6ec80be..6e5482a 100644
--- a/build_tools/embed_data/BUILD
+++ b/build_tools/embed_data/BUILD
@@ -34,6 +34,7 @@
 
 cc_embed_data(
     name = "testembed1",
+    # do not sort
     srcs = [
         "file1.txt",
         "data/file2.txt",
diff --git a/docs/GetStarted/cmake_options_and_variables.md b/docs/GetStarted/cmake_options_and_variables.md
index cd179d4..8e121fc 100644
--- a/docs/GetStarted/cmake_options_and_variables.md
+++ b/docs/GetStarted/cmake_options_and_variables.md
@@ -143,8 +143,8 @@
 cross-compiling, it's the artifact names being load-bearing. The artifact names
 are used to express dependencies across CMake invocation boundary (remember that
 we cannot access targets defined in another CMake invocation); the
-package-prefixed CMake target names are just cutom targets depending on the host
-artfact.
+package-prefixed CMake target names are just custom targets depending on the
+host artifact.
 
 #### `IREE_HOST_BINARY_ROOT`:FILEPATH
 
diff --git a/experimental/ModelBuilder/VulkanWrapperPass.cpp b/experimental/ModelBuilder/VulkanWrapperPass.cpp
index 1e6469d..b20d2fc 100644
--- a/experimental/ModelBuilder/VulkanWrapperPass.cpp
+++ b/experimental/ModelBuilder/VulkanWrapperPass.cpp
@@ -161,8 +161,7 @@
   // Calculate the number of groups to dispatch based on the workload size
   // and the workgroup size picked by the tiling pass.
   for (int i = 0; i < 3; i++) {
-    auto dispatchSize =
-        std::max(int64_t(1), workloadSize[i] / workgroupSize[i]);
+    auto dispatchSize = std::max(int64_t(1), workloadSize[i]);
     Value numGroups = builder.create<ConstantIndexOp>(loc, dispatchSize);
     arguments.push_back(numGroups);
   }
diff --git a/experimental/ModelBuilder/test/BUILD b/experimental/ModelBuilder/test/BUILD
index 6a84782..a942673 100644
--- a/experimental/ModelBuilder/test/BUILD
+++ b/experimental/ModelBuilder/test/BUILD
@@ -183,6 +183,39 @@
 )
 
 cc_binary(
+    name = "bench-matmul-gpu",
+    srcs = ["BenchMatMulVectorGPU.cpp"],
+    tags = [
+        "noga",
+    ],
+    deps = [
+        "//experimental/ModelBuilder",
+        "//experimental/ModelBuilder:ModelRunner",
+        "//experimental/ModelBuilder:VulkanLaunchWrapper",
+        "//iree/base:initializer",
+        "//iree/compiler/Conversion/CodegenUtils",
+        "//iree/compiler/Conversion/LinalgToSPIRV",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:GPUToVulkanTransforms",
+        "@llvm-project//mlir:GPUTransforms",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:LLVMTransforms",
+        "@llvm-project//mlir:LinalgOps",
+        "@llvm-project//mlir:LinalgToLLVM",
+        "@llvm-project//mlir:LinalgTransforms",
+        "@llvm-project//mlir:Parser",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:SPIRVDialect",
+        "@llvm-project//mlir:StandardToSPIRVConversions",
+        "@llvm-project//mlir:TransformUtils",
+        "@llvm-project//mlir:VectorOps",
+        "@llvm-project//mlir:mlir_c_runner_utils",
+        # mlir_runner_utils with iostream needed for printMemRef atm
+        "@llvm-project//mlir:mlir_runner_utils",
+    ] + PLATFORM_VULKAN_DEPS + IREE_DRIVER_MODULES,
+)
+
+cc_binary(
     name = "test-simple-mlir",
     srcs = ["TestSimpleMLIR.cpp"],
     deps = [
diff --git a/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp b/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp
new file mode 100644
index 0000000..25346e3
--- /dev/null
+++ b/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp
@@ -0,0 +1,234 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include <string>
+
+#include "experimental/ModelBuilder/ModelBuilder.h"
+#include "experimental/ModelBuilder/ModelRunner.h"
+#include "experimental/ModelBuilder/VulkanWrapperPass.h"
+#include "iree/base/initializer.h"
+#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
+#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
+#include "mlir/Dialect/GPU/Passes.h"
+#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/SPIRV/Passes.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/TargetAndABI.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/ExecutionEngine/CRunnerUtils.h"
+#include "mlir/ExecutionEngine/RunnerUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/Parser.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;                    // NOLINT
+using namespace mlir::edsc;              // NOLINT
+using namespace mlir::edsc::intrinsics;  // NOLINT
+
+static llvm::cl::opt<std::string> vulkanWrapper(
+    "vulkan-wrapper", llvm::cl::desc("Vulkan wrapper library"),
+    llvm::cl::value_desc("filename"), llvm::cl::init("-"));
+
+static llvm::cl::opt<bool> correctness(
+    "correctness",
+    llvm::cl::desc(
+        "Compare the result to value calculated on CPU. We will use a smaller "
+        "matrix multiply in this case to avoid long runtime."),
+    llvm::cl::init(false));
+
+static void addLoweringPasses(mlir::PassManager &pm,
+                              llvm::ArrayRef<int64_t> numWorkgroups,
+                              llvm::ArrayRef<Type> args) {
+  pm.addPass(mlir::iree_compiler::createVectorToGPUPass());
+  pm.addPass(mlir::createLowerAffinePass());
+  pm.addPass(mlir::createLegalizeStdOpsForSPIRVLoweringPass());
+  pm.addPass(mlir::createCanonicalizerPass());
+  pm.addPass(mlir::createCSEPass());
+  pm.addPass(mlir::iree_compiler::createConvertToSPIRVPass());
+
+  auto &spirvModulePM = pm.nest<mlir::spirv::ModuleOp>();
+  spirvModulePM.addPass(mlir::createSetSpirvABIPass());
+  spirvModulePM.addPass(mlir::spirv::createLowerABIAttributesPass());
+  spirvModulePM.addPass(mlir::createCanonicalizerPass());
+  spirvModulePM.addPass(mlir::createCSEPass());
+  spirvModulePM.addPass(
+      mlir::spirv::createUpdateVersionCapabilityExtensionPass());
+
+  pm.addPass(mlir::createAddVulkanLaunchWrapperPass(numWorkgroups, args));
+  mlir::LowerToLLVMOptions llvmOptions = {
+      /*useBarePtrCallConv=*/false,
+      /*emitCWrappers=*/true,
+      /*indexBitwidth=*/mlir::kDeriveIndexBitwidthFromDataLayout};
+  pm.addPass(createLowerToLLVMPass(llvmOptions));
+  pm.addPass(mlir::createConvertVulkanLaunchFuncToVulkanCallsPass());
+}
+
+void matMul(int m, int n, int k, int tileM, int tileN, int tileK,
+            bool correctness) {
+  const int warpSize = 32;
+  const int resRows = m;
+  const int resColumns = n;
+  const int reductionSize = k;
+  StringLiteral funcName = "kernel_matmul";
+  MLIRContext context;
+  ModelBuilder modelBuilder;
+
+  auto typeA =
+      modelBuilder.getMemRefType({resRows, reductionSize}, modelBuilder.i8);
+  auto typeB =
+      modelBuilder.getMemRefType({reductionSize, resColumns}, modelBuilder.i8);
+  auto typeC = modelBuilder.getMemRefType({resRows, resColumns},
+                                          modelBuilder.getI32Type());
+  // 1. Build the kernel.
+  {
+    modelBuilder.addGPUAttr();
+    FuncOp kernelFunc = modelBuilder.makeFunction(
+        funcName, {}, {typeA, typeB, typeC}, MLIRFuncOpConfig());
+    // Right now we map one workgroup to one warp.
+    kernelFunc.setAttr(spirv::getEntryPointABIAttrName(),
+                       spirv::getEntryPointABIAttr({warpSize, 1, 1}, &context));
+    OpBuilder b(&kernelFunc.getBody());
+    ScopedContext scope(b, kernelFunc.getLoc());
+
+    auto A = kernelFunc.getArgument(0);
+    auto B = kernelFunc.getArgument(1);
+    auto C = kernelFunc.getArgument(2);
+
+    linalg_matmul(TypeRange{}, ValueRange{A, B, C});
+    std_ret();
+  }
+
+  // 2. Compile the function, pass in runtime support library to the execution
+  // engine for vector.print.
+  ModelRunner runner(modelBuilder.getModuleRef(),
+                     ModelRunner::Target::GPUTarget);
+  CompilationOptions options;
+  options.loweringPasses = [&](mlir::PassManager &pm) {
+    MatmulCodegenStrategy strategy;
+    // Use hardcoded value for cooperative matrix size. Those will be pulled
+    // from device properties eventually.
+    const int cooperativeMatrixM = 8;
+    const int cooperativeMatrixN = 8;
+    const int cooperativeMatrixK = 32;
+    // Swap the order of the parallel loops because PLoopToGPU pattern assigns
+    // dimension in reverse order of the loop.
+    // TODO(thomasraooux) LICM is disabled due to limitation in SPIR-V
+    strategy
+        .tile<linalg::MatmulOp>(
+            linalg::LinalgTilingOptions()
+                .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops)
+                .setTileSizes({tileM, tileN, tileK})
+                .setInterchange({1, 0, 2}))
+        .setHoistInvariantCode(false)
+        .vectorize<linalg::MatmulOp>()
+        .unrollVector<vector::ContractionOp>(
+            {cooperativeMatrixM, cooperativeMatrixN, cooperativeMatrixK});
+    modelBuilder.getModuleRef()->walk(
+        [&](FuncOp fn) { strategy.transform(fn); });
+    addLoweringPasses(pm, {resRows / tileM, resColumns / tileN, 1},
+                      {typeA, typeB, typeC});
+  };
+  runner.compile(options, {vulkanWrapper});
+
+  // 3. Allocate data within data structures that interoperate with the MLIR ABI
+  // conventions used by codegen.
+  auto oneInit = [](unsigned idx, uint8_t *ptr) { ptr[idx] = 2 * idx + 1; };
+  auto incInit = [](unsigned idx, uint8_t *ptr) { ptr[idx] = idx; };
+  auto zeroInit = [](unsigned idx, uint32_t *ptr) { ptr[idx] = 0; };
+  auto A = makeInitializedStridedMemRefDescriptor<uint8_t, 2>(
+      {resRows, reductionSize}, oneInit);
+  auto B = makeInitializedStridedMemRefDescriptor<uint8_t, 2>(
+      {reductionSize, resColumns}, incInit);
+  auto C = makeInitializedStridedMemRefDescriptor<uint32_t, 2>(
+      {resRows, resColumns}, zeroInit);
+  auto CPURes = makeInitializedStridedMemRefDescriptor<uint32_t, 2>(
+      {resRows, resColumns}, zeroInit);
+
+  // Is checking corretness compare to the value computed on CPU.
+  if (correctness) {
+    for (int i = 0; i < resRows; i++) {
+      for (int j = 0; j < resColumns; j++) {
+        uint32_t acc = (*C)[i][j];
+        for (int k = 0; k < reductionSize; k++) {
+          uint32_t a = (*A)[i][k];
+          uint32_t b = (*B)[k][j];
+          acc += a * b;
+        }
+        (*CPURes)[i][j] = acc;
+      }
+    }
+  }
+
+  // 4. Call the funcOp named `funcName`.
+  auto err = runner.invoke(std::string(funcName) + "_wrapper", A, B, C);
+  if (err) llvm_unreachable("Error running function.");
+
+  if (correctness) {
+    bool correct = true;
+    for (int i = 0; i < resRows; i++) {
+      for (int j = 0; j < resColumns; j++) {
+        if ((*CPURes)[i][j] != (*C)[i][j]) {
+          correct = false;
+          printf("mismatch at index(%i, %i) was expecting %i but got %i\n", i,
+                 j, (*CPURes)[i][j], (*C)[i][j]);
+        }
+      }
+    }
+    if (correct) printf("pass\n");
+  }
+}
+
+int main(int argc, char **argv) {
+  ModelBuilder::registerAllDialects();
+  iree::Initializer::RunInitializers();
+  // Allow LLVM setup through command line and parse the
+  // test specific option for a runtime support library.
+  llvm::InitLLVM y(argc, argv);
+  llvm::cl::ParseCommandLineOptions(argc, argv, "BenchMatMulVectorGPU\n");
+  int m = 4096;
+  int n = 4096;
+  int k = 4096;
+  if (correctness) {
+    m = 256;
+    n = 256;
+    k = 256;
+  }
+  printf("Matrix size: %ix%ix%i", m, n, k);
+  int tileK = 32;
+  for (int tileM = 8; tileM <= 128; tileM *= 2) {
+    for (int tileN = 8; tileN <= 128; tileN *= 2) {
+      printf("tileM=%i tileN=%i tileK=%i\n", tileM, tileN, tileK);
+      // For non-power of two tile sizes, round up the matrix size to
+      // be an even multiple of the tile size.
+      // TODO(thomasraoux): enable non power of two tiles once affine.min
+      // folding is fixed.
+      auto paddedM = (m + tileM - 1) / tileM * tileM;
+      auto paddedN = (n + tileN - 1) / tileN * tileN;
+      auto paddedK = (k + tileK - 1) / tileK * tileK;
+
+      matMul(paddedM, paddedN, paddedK, tileM, tileN, tileK, correctness);
+    }
+  }
+}
diff --git a/experimental/ModelBuilder/test/CMakeLists.txt b/experimental/ModelBuilder/test/CMakeLists.txt
index 94dbcb0..e775c8f 100644
--- a/experimental/ModelBuilder/test/CMakeLists.txt
+++ b/experimental/ModelBuilder/test/CMakeLists.txt
@@ -204,6 +204,43 @@
 
 iree_cc_binary(
   NAME
+    bench-matmul-gpu
+  OUT
+    bench-matmul-gpu
+  SRCS
+    "BenchMatMulVectorGPU.cpp"
+  DEPS
+    LLVMSupport
+    MLIRAllDialects
+    MLIRExecutionEngine
+    MLIRGPU
+    MLIRGPUToVulkanTransforms
+    MLIRIR
+    MLIRLinalgOps
+    MLIRLinalgToLLVM
+    MLIRLinalgTransforms
+    MLIRParser
+    MLIRPass
+    MLIRSPIRV
+    MLIRStandardToLLVM
+    MLIRStandardToSPIRVTransforms
+    MLIRTransformUtils
+    MLIRVector
+    MLIRmlir_runner_utils
+    experimental::ModelBuilder
+    experimental::ModelBuilder::ModelRunner
+    experimental::ModelBuilder::VulkanLaunchWrapper
+    iree::base::initializer
+    iree::compiler::Conversion::CodegenUtils
+    iree::compiler::Conversion::LinalgToSPIRV
+    iree::hal::llvmjit::llvmjit_driver_module
+    iree::hal::vmla::vmla_driver_module
+    iree::hal::vulkan::vulkan_driver_module
+    vulkan-runtime-wrappers
+)
+
+iree_cc_binary(
+  NAME
     test-simple-mlir
   OUT
     test-simple-mlir
diff --git a/experimental/ModelBuilder/test/TestMatMulVulkan.cpp b/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
index e16a8a2..b31c663 100644
--- a/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
+++ b/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
@@ -114,8 +114,16 @@
     spirvModulePM.addPass(
         mlir::spirv::createUpdateVersionCapabilityExtensionPass());
 
-    pm.addPass(
-        mlir::createAddVulkanLaunchWrapperPass({width, height, 1}, args));
+    int numWorkgroupX =
+        vWorkgroupSizes.empty()
+            ? 1
+            : (width + vWorkgroupSizes[0] - 1) / vWorkgroupSizes[0];
+    int numWorkgroupY =
+        vWorkgroupSizes.size() < 2
+            ? 1
+            : (height + vWorkgroupSizes[1] - 1) / vWorkgroupSizes[1];
+    pm.addPass(mlir::createAddVulkanLaunchWrapperPass(
+        {numWorkgroupX, numWorkgroupY, 1}, args));
     mlir::LowerToLLVMOptions llvmOptions = {
         /*useBarePtrCallConv =*/false,
         /*emitCWrappers = */ true,
diff --git a/experimental/ModelBuilder/test/TestVectorToGPU.cpp b/experimental/ModelBuilder/test/TestVectorToGPU.cpp
index 091967a..c7df737 100644
--- a/experimental/ModelBuilder/test/TestVectorToGPU.cpp
+++ b/experimental/ModelBuilder/test/TestVectorToGPU.cpp
@@ -57,14 +57,8 @@
     "vulkan-wrapper", llvm::cl::desc("Vulkan wrapper library"),
     llvm::cl::value_desc("filename"), llvm::cl::init("-"));
 
-static llvm::cl::opt<bool> useCooperativeMatrix(
-    "cooperative-matrix",
-    llvm::cl::desc("Run cooperative matrix tests, this requires hardware "
-                   "supporting cooperative matrix extension"),
-    llvm::cl::init(false));
-
 static void addLoweringPasses(mlir::PassManager &pm,
-                              llvm::ArrayRef<int64_t> workloadSize,
+                              llvm::ArrayRef<int64_t> workgroupSize,
                               llvm::ArrayRef<Type> args) {
   pm.addPass(mlir::iree_compiler::createVectorToGPUPass());
   pm.addPass(mlir::createLowerAffinePass());
@@ -81,7 +75,7 @@
   spirvModulePM.addPass(
       mlir::spirv::createUpdateVersionCapabilityExtensionPass());
 
-  pm.addPass(mlir::createAddVulkanLaunchWrapperPass(workloadSize, args));
+  pm.addPass(mlir::createAddVulkanLaunchWrapperPass(workgroupSize, args));
   mlir::LowerToLLVMOptions llvmOptions = {
       /*useBarePtrCallConv=*/false,
       /*emitCWrappers=*/true,
@@ -130,7 +124,7 @@
                      ModelRunner::Target::CPUTarget);
   CompilationOptions options;
   auto lowering = [&](mlir::PassManager &pm) {
-    addLoweringPasses(pm, {warpSize, 1, 1}, {typeA, typeB, typeC});
+    addLoweringPasses(pm, {1, 1, 1}, {typeA, typeB, typeC});
   };
   options.loweringPasses = lowering;
   runner.compile(options, {vulkanWrapper});
@@ -154,91 +148,6 @@
   ::impl::printMemRef(*C);
 }
 
-void testCooperativeMatMul() {
-  const int warpSize = 32;
-  // Pick twice the size of cooperative matrix to test that the matmul gets
-  // tiled correctly.
-  const int resRows = 8 * 2;
-  const int resColumns = 8 * 2;
-  const int reductionSize = 32 * 2;
-  StringLiteral funcName = "kernel_matmul";
-  MLIRContext context;
-  ModelBuilder modelBuilder;
-
-  auto typeA =
-      modelBuilder.getMemRefType({resRows, reductionSize}, modelBuilder.i8);
-  auto typeB =
-      modelBuilder.getMemRefType({reductionSize, resColumns}, modelBuilder.i8);
-  auto typeC = modelBuilder.getMemRefType({resRows, resColumns},
-                                          modelBuilder.getI32Type());
-  // 1. Build the kernel.
-  {
-    modelBuilder.addGPUAttr();
-    FuncOp kernelFunc = modelBuilder.makeFunction(
-        funcName, {}, {typeA, typeB, typeC}, MLIRFuncOpConfig());
-    // Right now we map one workgroup to one warp.
-    kernelFunc.setAttr(spirv::getEntryPointABIAttrName(),
-                       spirv::getEntryPointABIAttr({warpSize, 1, 1}, &context));
-    OpBuilder b(&kernelFunc.getBody());
-    ScopedContext scope(b, kernelFunc.getLoc());
-
-    auto A = kernelFunc.getArgument(0);
-    auto B = kernelFunc.getArgument(1);
-    auto C = kernelFunc.getArgument(2);
-
-    linalg_matmul(TypeRange{}, ValueRange{A, B, C});
-    std_ret();
-  }
-
-  // 2. Compile the function, pass in runtime support library to the execution
-  // engine for vector.print.
-  ModelRunner runner(modelBuilder.getModuleRef(),
-                     ModelRunner::Target::GPUTarget);
-  CompilationOptions options;
-  options.loweringPasses = [&](mlir::PassManager &pm) {
-    MatmulCodegenStrategy strategy;
-    // Use hardcoded value for cooperative matrix size. Those will be pulled
-    // from device properties eventually.
-    const int cooperativeMatrixM = 8;
-    const int cooperativeMatrixK = 8;
-    const int cooperativeMatrixN = 32;
-    // TODO(thomasraooux) LICM is disabled due to limitation in SPIR-V
-    strategy
-        .tile<linalg::MatmulOp>(
-            linalg::LinalgTilingOptions()
-                .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops)
-                .setTileSizes({cooperativeMatrixM, cooperativeMatrixK,
-                               cooperativeMatrixN}))
-        .setHoistInvariantCode(false)
-        .vectorize<linalg::MatmulOp>();
-    modelBuilder.getModuleRef()->walk(
-        [&](FuncOp fn) { strategy.transform(fn); });
-    addLoweringPasses(pm, {resRows, resColumns, 1}, {typeA, typeB, typeC});
-  };
-  runner.compile(options, {vulkanWrapper});
-
-  // 3. Allocate data within data structures that interoperate with the MLIR ABI
-  // conventions used by codegen.
-  auto oneInit = [](unsigned idx, uint8_t *ptr) { ptr[idx] = 2 * idx + 1; };
-  auto incInit = [](unsigned idx, uint8_t *ptr) { ptr[idx] = idx; };
-  auto zeroInit = [](unsigned idx, uint32_t *ptr) { ptr[idx] = 0; };
-  auto A = makeInitializedStridedMemRefDescriptor<uint8_t, 2>(
-      {resRows, reductionSize}, oneInit);
-  auto B = makeInitializedStridedMemRefDescriptor<uint8_t, 2>(
-      {reductionSize, resColumns}, incInit);
-  auto C = makeInitializedStridedMemRefDescriptor<uint32_t, 2>(
-      {resRows, resColumns}, zeroInit);
-
-  // 4. Call the funcOp named `funcName`.
-  auto err = runner.invoke(std::string(funcName) + "_wrapper", A, B, C);
-  if (err) llvm_unreachable("Error running function.");
-
-  // 5. Dump content of input and output buffer for testing with FileCheck.
-  ::impl::printMemRef(*A);
-  ::impl::printMemRef(*B);
-  ::impl::printMemRef(*C);
-}
-
 int main(int argc, char **argv) {
   ModelBuilder::registerAllDialects();
   iree::Initializer::RunInitializers();
@@ -261,7 +170,4 @@
   // CHECK: 59,  63,  67,  71,  75,  79,  83,  87,  91,  95,  99,  103,  107,
   // CHECK: 111,  115,  119,  123,  127]
   testVecAdd();
-
-  if (useCooperativeMatrix)
-    testCooperativeMatMul();
 }
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD b/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD
index c27d3d0..e9c7e49 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD
@@ -30,6 +30,7 @@
         "__init__.py",
         "tf_test_driver.py",
         "tf_test_utils.py",
+        "tf_utils.py",
     ],
     deps = INTREE_TENSORFLOW_PY_DEPS + [
         "//integrations/tensorflow/bindings/python:pathsetup",  # build_cleaner: keep
@@ -49,3 +50,15 @@
         "//integrations/tensorflow/bindings/python/pyiree/tf/support",
     ],
 )
+
+iree_py_test(
+    name = "tf_utils_test",
+    srcs = [
+        "tf_utils.py",
+        "tf_utils_test.py",
+    ],
+    python_version = "PY3",
+    deps = INTREE_TENSORFLOW_PY_DEPS + [
+        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+    ],
+)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 3b55853..22e0392 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -21,15 +21,14 @@
 
 import collections
 import os
-import random
 import re
-import tempfile
 
 from absl import flags
 from absl import logging
 import numpy as np
 from pyiree import rt
 from pyiree.tf import compiler
+from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
 
 flags.DEFINE_string("target_backends", None,
@@ -40,115 +39,20 @@
     "--test_tmpdir")
 FLAGS = flags.FLAGS
 
-ORIGINAL_SAVED_MODEL_PATH_ATTR = "_ORIGINAL_SAVED_MODEL_PATH"
 
-# Per test directory where debug artifacts are dumped.
-global_debug_dir = None
+def _setup_test_debug_dir(test_name):
+  global global_debug_dir
 
+  # Use test_tempdir (which defaults to '/tmp/absl_testing/') if FLAGS.debug_dir
+  # is not provided.
+  parent = FLAGS.debug_dir if FLAGS.debug_dir is not None else FLAGS.test_tmpdir
+  global_debug_dir = os.path.join(parent, test_name)
 
-def set_random_seed(seed=0):
-  """Set random seed for tf, np and random."""
-  tf.random.set_seed(seed)
-  random.seed(seed)
-  np.random.seed(seed)
-
-
-def save_and_compile_tf_module(tf_module, exported_names=(),
-                               target_backends=()):
-  """Saves and compiles a TensorFlow tf.Module.
-
-  Note that if the module has the special _ORIGINAL_SAVED_MODEL_PATH attribute,
-  then it will be compiled directly from that path instead of saved and then
-  loaded.
-
-  Args:
-    tf_module: A tf.Module.
-    exported_names: Iterable of dotted function names to consider for
-      compilation.
-    target_backends: Iterable of string backend names to compile for.
-
-  Returns:
-    An _IreeCompiledModule.
-  """
-
-  def compile_from_path(sm_path):
-    compiler_context = compiler.Context()
-    # Break up the compilation so we can save debug artifacts.
-    compiler_module = compiler.tf_load_saved_model(
-        sm_path,
-        exported_names=exported_names,
-        compiler_context=compiler_context,
-        pass_pipeline=())
-
-    # Save the input MLIR module.
-    flattened_target_backends = re.sub("[^0-9a-zA-Z_]+", "_",
-                                       "__".join(target_backends))
-    if global_debug_dir:
-      mlir_path = os.path.join(global_debug_dir,
-                               "raw__%s.mlir" % flattened_target_backends)
-      logging.info("Saving raw TF input MLIR to: %s", mlir_path)
-      with open(mlir_path, "w") as f:
-        f.write(compiler_module.to_asm())
-
-    # Now run the passes manually that tf_load_saved_model would usually do.
-    compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)
-
-    if global_debug_dir:
-      mlir_path = os.path.join(global_debug_dir,
-                               "input__%s.mlir" % flattened_target_backends)
-      logging.info("Saving IREE input MLIR to: %s", mlir_path)
-      with open(mlir_path, "w") as f:
-        f.write(compiler_module.to_asm())
-
-    compiled_module = compiler_module.compile(target_backends=target_backends)
-    if global_debug_dir:
-      compiled_path = os.path.join(
-          global_debug_dir, "compiled__%s.vmfb" % flattened_target_backends)
-      logging.info("Saving compiled IREE module to: %s", compiled_path)
-      with open(compiled_path, "wb") as f:
-        f.write(compiled_module)
-
-    return compiled_module
-
-  if hasattr(tf_module, ORIGINAL_SAVED_MODEL_PATH_ATTR):
-    # Compile directly from the original path.
-    sm_path = getattr(tf_module, ORIGINAL_SAVED_MODEL_PATH_ATTR)
-    logging.info(
-        "Compiling from original saved_model path (not round-tripping): %s",
-        sm_path)
-    return compile_from_path(sm_path)
-  else:
-    options = tf.saved_model.SaveOptions(save_debug_info=True)
-    if FLAGS.debug_dir is None:
-      # Round-trip through a temporary directory.
-      with tempfile.TemporaryDirectory() as sm_path:
-        tf.saved_model.save(tf_module, sm_path, options=options)
-        return compile_from_path(sm_path)
-    else:
-      # Use the supplied directory.
-      sm_path = os.path.join(FLAGS.debug_dir, "SavedModel")
-      tf.saved_model.save(tf_module, sm_path, options=options)
-      return compile_from_path(sm_path)
-
-
-def load_tf_module(path):
-  """Wrapper around tf.saved_model.load which preserves the path.
-
-  Args:
-    path: The path to load from.
-
-  Returns:
-    The loaded module with an extra property _ORIGINAL_SAVED_MODEL_PATH added.
-    This is used on subsequent compiles to load directly from the original
-    path, which gives us unmolested access to the original debug information,
-    which TensorFlow tends to lose on round-trip.
-  """
-  tf_module = tf.saved_model.load(path)
-  assert not hasattr(tf_module, ORIGINAL_SAVED_MODEL_PATH_ATTR), (
-      "Saved model (%s) already has attribute %s" %
-      (path, ORIGINAL_SAVED_MODEL_PATH_ATTR))
-  setattr(tf_module, ORIGINAL_SAVED_MODEL_PATH_ATTR, path)
-  return tf_module
+  # Create the directory.
+  try:
+    os.makedirs(global_debug_dir)
+  except IOError:
+    logging.exception("Error creating debug dir for: %s", global_debug_dir)
 
 
 class CompiledModule(object):
@@ -226,10 +130,11 @@
 
   def __init__(self, ctor, exported_names, backend):
     super().__init__(ctor, exported_names, backend)
-    self._iree_module_blob = save_and_compile_tf_module(
+    self._iree_module_blob = tf_utils.compile_tf_module(
         ctor(),
         exported_names=exported_names,
-        target_backends=backend.iree_compiler_targets)
+        target_backends=backend.iree_compiler_targets,
+        artifacts_dir=global_debug_dir)
     self._iree_module = rt.VmModule.from_flatbuffer(self._iree_module_blob)
 
   def instantiate(self):
@@ -438,98 +343,75 @@
       return self
 
     def save(self):
-      if FLAGS.debug_dir:
-        for i in range(len(self)):
-          result = self[i]  # output generated by a model
-          field = self._fields[i]  # backend name
-          fname = os.path.join(FLAGS.debug_dir, "output_{}".format(field))
-          with open(fname, "w") as file:
-            # content of txt file can be converted to py objects by eval(txt)
-            file.write(str(result))
+      for i in range(len(self)):
+        result = self[i]  # output generated by a model
+        field = self._fields[i]  # backend name
+        fname = os.path.join(global_debug_dir, "output_{}".format(field))
+        with open(fname, "w") as file:
+          # content of txt file can be converted to py objects by eval(txt)
+          file.write(str(result))
       return self
 
   return MultiResults
 
 
-def _instantiate_modules(compiled_modules_dict):
-  """Given a dict of modules, instantiates them.
+def _instantiate_backends(compiled_backends):
+  """Creates a VirtualBackend namedtuple class for a dict.
 
   Args:
-    compiled_modules_dict: Dictionary of
-        {module_name:{backend_name:CompiledModule}} that should be instantiated.
+    compiled_backends: Dictionary of backend_name:ModuleInstance.
 
   Returns:
-    namedtuple mapping module_key:VirtualBackendsClass for every module
-    in compiled_modules_dict. The VirtualBackendsClass is a dynamically
+    a VirtualBackendsClass instance. The VirtualBackendsClass is a dynamically
     generated namedtuple mapping backend_name:ModuleInstance, where the
     ModuleInstance allows attribute resolution of public functions on the
-    module. The VirtualBackendsClass also contributes some convenience
-    methods for selecting all or a subset of matching backend modules.
+    module. The VirtualBackendsClass also contributes some convenience methods
+    for selecting all or a subset of matching backend modules.
   """
+  tuple_class = collections.namedtuple("VirtualBackendsTuple",
+                                       compiled_backends.keys())
 
-  def instantiate_backends(module_dict):
-    """Creates a VirtualBackend namedtuple class for a dict.
+  class VirtualBackendsClass(tuple_class):
+    """Adds a __call__ method that creates a virtual module."""
 
-    Args:
-      module_dict: Dictionary of backend_name:ModuleInstance.
+    def multi(self, match_spec="."):
+      """Selects multiple backends that match a regular expression."""
+      return _VirtualModuleInstance(self._asdict(), match_spec)
 
-    Returns:
-      namedtuple subclass with a field for every backend and special
-      all and multi() helpers.
-    """
-    tuple_class = collections.namedtuple("VirtualBackendsTuple",
-                                         module_dict.keys())
+    @property
+    def all(self):
+      """Shorthand for multi() which selects all backends."""
+      return self.multi()
 
-    class VirtualBackendsClass(tuple_class):
-      """Adds a __call__ method that creates a virtual module."""
-
-      def multi(self, match_spec="."):
-        """Selects multiple backends that match a regular expression."""
-        return _VirtualModuleInstance(self._asdict(), match_spec)
-
-      @property
-      def all(self):
-        """Shorthand for multi() which selects all backends."""
-        return self.multi()
-
-    return VirtualBackendsClass(
-        *[m.instantiate() for m in module_dict.values()])
-
-  module_keys = [k for (k, _) in compiled_modules_dict.items()]
-  module_insts = [
-      instantiate_backends(module_dict)
-      for (_, module_dict) in compiled_modules_dict.items()
-  ]
-  tuple_class = collections.namedtuple("Modules", module_keys)
-  return tuple_class(*module_insts)
+  return VirtualBackendsClass(
+      *[m.instantiate() for m in compiled_backends.values()])
 
 
-def compile_modules(**kwargs):
-  """Decorator applied to a SavedModelTestCase subclass to compile modules.
+def compile_module(module_ctor, exported_names=()):
+  """SavedModelTestCase decorator that compiles a tf.Module.
+
+  A CompiledModule is created for each backend in --target_backends. They can
+  be accessed individually via self.compiled_modules.backend_name or as a union
+  via self.get_module().
 
   Args:
-    **kwargs: name/Module constructor mappings. Each such arg will be added to
-      the classes 'compiled_modules' field.
+    module_ctor: tf.Module subclass or function which returns a tf.Module
+      subclass instance.
+    exported_names: optional iterable of strings representing the exported names
+      to keep. Used primarily for Keras models (e.g. exported_names=["predict"])
 
   Returns:
     Class decorator function.
   """
 
   def decorator(cls):
-    """Decorator function."""
-    assert issubclass(cls, SavedModelTestCase), (
-        "The 'compile_modules' decorator must be applied to a "
-        "SavedModelTestCase derived class.")
-    if not cls._modules_to_compile:
-      cls._modules_to_compile = {}
-    for name, ctor in kwargs.items():
-      assert name not in cls._modules_to_compile, (
-          "@compile_modules called with duplicate module names '%s'" % (name,))
-      exported_names = ()
-      if isinstance(ctor, tuple):
-        ctor, exported_names = ctor
-      cls._modules_to_compile[name] = (ctor, exported_names)
-
+    """Decorator Function."""
+    if not issubclass(cls, SavedModelTestCase):
+      logging.exception(
+          "The 'compile_module' decorator must be applied to a "
+          "SavedModelTestCase derived class, which %s is not.", cls)
+    cls._module_ctor = module_ctor
+    cls._exported_names = exported_names
     return cls
 
   return decorator
@@ -617,13 +499,13 @@
 class SavedModelTestCase(tf.test.TestCase):
   """Tests against a SavedModel."""
 
-  # Will be initialized to a dict by the @compile_modules decorator.
-  # The dict maps module name to (ctor, exported_names, backend_names).
-  _modules_to_compile = None
+  # Will be initialized by the @compile_module decorator.
+  _module_ctor = None
+  _exported_names = ()
 
-  # Will be initialized in setUpClass to a dict of (name, CompiledModule)
-  # instances mirroring _modules_to_compile.
-  compiled_modules = None
+  # Will be initialized in setUpClass to a dict of
+  # {backend_name: CompiledModule}.
+  _compiled_backends_dict = None
 
   def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
@@ -632,42 +514,27 @@
   @classmethod
   def setUpClass(cls):
     super().setUpClass()
-    cls.compiled_modules = {}
-    if cls._modules_to_compile:
-      for name, (ctor, exported_names) in cls._modules_to_compile.items():
+    if cls._module_ctor is not None:
+      # Setup the debug directory for this test. Creates a global variable
+      # `global_debug_dir`.
+      _setup_test_debug_dir(test_name=cls.__name__)
 
-        # Setup the debug directory.
-        debug_parent_dir = FLAGS.debug_dir
-        if not debug_parent_dir:
-          debug_parent_dir = FLAGS.test_tmpdir
-        debug_parent_dir = os.path.join(debug_parent_dir, cls.__name__)
+      # Setup crash reproducer for the test.
+      crash_reproducer_path = os.path.join(global_debug_dir, "reproducer.mlir")
+      compiler.Context.default_crash_reproducer_path = crash_reproducer_path
 
-        try:
-          os.makedirs(debug_parent_dir)
-        except IOError:
-          logging.exception("Error creating crash reproducer dir for: %s",
-                            debug_parent_dir)
+      # Create a CompiledModule for each backend.
+      try:
+        backends = get_backends()
+        cls._compiled_backends_dict = {}
+        for backend in backends:
+          cls._compiled_backends_dict[backend.name] = CompiledModule.create(
+              cls._module_ctor, cls._exported_names, backend)
 
-        # Setup crash reproducer and global debug dir.
-        crash_reproducer_path = os.path.join(debug_parent_dir,
-                                             name + "_reproducer.mlir")
-        compiler.Context.default_crash_reproducer_path = crash_reproducer_path
-        global global_debug_dir
-        global_debug_dir = debug_parent_dir
-
-        try:
-          # Compile.
-          backends = get_backends()
-          cls.compiled_modules[name] = dict([
-              (backend.name, CompiledModule.create(ctor, exported_names,
-                                                   backend))
-              for backend in backends
-          ])
-        finally:
-          # Disable crash reproducer (to avoid inadvertently overwriting this
-          # path on a subsequent interaction).
-          compiler.Context.default_crash_reproducer_path = None
-          global_debug_dir = None
+      finally:
+        # Disable crash reproducer (to avoid inadvertently overwriting this
+        # path on a subsequent interaction).
+        compiler.Context.default_crash_reproducer_path = None
 
   @classmethod
   def tearDownClass(cls):
@@ -675,4 +542,7 @@
 
   def setUp(self):
     super().setUp()
-    self.modules = _instantiate_modules(self.compiled_modules)
+    self.compiled_modules = _instantiate_backends(self._compiled_backends_dict)
+
+  def get_module(self):
+    return self.compiled_modules.all
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
new file mode 100644
index 0000000..cdf26d9
--- /dev/null
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -0,0 +1,125 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities interop with TensorFlow."""
+
+import os
+import random
+import re
+import tempfile
+
+from absl import flags
+from absl import logging
+import numpy as np
+from pyiree.tf import compiler
+import tensorflow.compat.v2 as tf
+
+FLAGS = flags.FLAGS
+
+
+def set_random_seed(seed=0):
+  """Set random seed for tf, np and random."""
+  tf.random.set_seed(seed)
+  random.seed(seed)
+  np.random.seed(seed)
+
+
+def compile_tf_module(tf_module,
+                      target_backends=(),
+                      exported_names=(),
+                      artifacts_dir=None):
+  """Compiles a TensorFlow tf.Module and optionally saves compilation artifacts.
+
+  If artifacts_dir is provided then the following artifacts will be saved:
+    saved_model:
+      A TF SavedModel directory containing the files used translate the
+      tf.Module into an IREE module.
+    tf_input__backends.mlir:
+      MLIR for the module in TF's input dialect.
+    iree_input__backends.mlir:
+      The MLIR above translated to IREE via compiler.TF_IMPORT_PASS_PIPELINE.
+    compiled__backends.vmfb:
+      A VM FlatBuffer compiled to the target backends from the IREE MLIR above.
+  Here 'backends' is a '__' delimited list of iree backends (e.g. vmla__llvm_ir)
+
+  Args:
+    tf_module: A tf.Module.
+    target_backends: Iterable of string backend names to compile for.
+    exported_names: Iterable of dotted function names to consider for
+      compilation.
+    artifacts_dir: An optional string pointing to where compilation artifacts
+      should be saved.
+
+  Returns:
+    A compiled IREE module blob.
+  """
+
+  def _compile_from_path(sm_path):
+    """Helper function for compile_tf_module."""
+    # We break up the compilation here so we can save intermediary artifacts.
+    compiler_context = compiler.Context()
+
+    if artifacts_dir is not None:
+      normalized_backends = []
+      for backend in target_backends:
+        # Remove unusual characters and ensure names don't end or start in "_".
+        backend = re.sub("[^0-9a-zA-Z_]+", "_", backend)
+        normalized_backends.append(backend.strip("_"))
+      backends_string = "__".join(normalized_backends)
+
+    # Convert the tf_module into raw TF input MLIR.
+    compiler_module = compiler.tf_load_saved_model(
+        sm_path,
+        exported_names=exported_names,
+        compiler_context=compiler_context,
+        pass_pipeline=())
+
+    if artifacts_dir is not None:
+      tf_mlir_path = os.path.join(artifacts_dir,
+                                  f"tf_input__{backends_string}.mlir")
+      logging.info("Saving raw TF input MLIR to: %s", tf_mlir_path)
+      with open(tf_mlir_path, "w") as f:
+        f.write(compiler_module.to_asm())
+
+    # Now run the passes manually that tf_load_saved_model would usually do.
+    compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)
+
+    if artifacts_dir is not None:
+      iree_mlir_path = os.path.join(artifacts_dir,
+                                    f"iree_input__{backends_string}.mlir")
+      logging.info("Saving IREE input MLIR to: %s", iree_mlir_path)
+      with open(iree_mlir_path, "w") as f:
+        f.write(compiler_module.to_asm())
+
+    compiled_module = compiler_module.compile(target_backends=target_backends)
+    if artifacts_dir is not None:
+      compiled_path = os.path.join(artifacts_dir,
+                                   f"compiled__{backends_string}.vmfb")
+      logging.info("Saving compiled IREE module to: %s", compiled_path)
+      with open(compiled_path, "wb") as f:
+        f.write(compiled_module)
+
+    return compiled_module
+
+  options = tf.saved_model.SaveOptions(save_debug_info=True)
+  if artifacts_dir is not None:
+    # Save the saved model alongside the other compilation artifacts.
+    sm_path = os.path.join(artifacts_dir, "saved_model")
+    tf.saved_model.save(tf_module, sm_path, options=options)
+    return _compile_from_path(sm_path)
+  else:
+    # Round-trip the saved model through a temporary directory.
+    with tempfile.TemporaryDirectory() as sm_path:
+      tf.saved_model.save(tf_module, sm_path, options=options)
+      return _compile_from_path(sm_path)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
new file mode 100644
index 0000000..89a0011
--- /dev/null
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
@@ -0,0 +1,63 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for pyiree.tf.support.tf_utils."""
+
+import os
+import tempfile
+
+from absl.testing import parameterized
+from pyiree.tf.support import tf_utils
+import tensorflow as tf
+
+
+class ConstantModule(tf.Module):
+
+  @tf.function(input_signature=[])
+  def meaning(self):
+    return tf.constant([42.])
+
+
+class UtilsTests(tf.test.TestCase, parameterized.TestCase):
+
+  @parameterized.named_parameters([
+      {
+          'testcase_name': 'single_backend',
+          'target_backends': ['vmla'],
+      },
+      {
+          'testcase_name': 'multiple_backends',
+          'target_backends': ['vmla', 'llvm'],
+      },
+  ])
+  def test_artifact_saving(self, target_backends):
+    with tempfile.TemporaryDirectory() as artifacts_dir:
+      tf_module = ConstantModule()
+      iree_compiled_module = tf_utils.compile_tf_module(
+          tf_module,
+          target_backends=target_backends,
+          artifacts_dir=artifacts_dir)
+
+      artifacts_to_check = [
+          'saved_model',
+          f'tf_input__{"__".join(target_backends)}.mlir',
+          f'iree_input__{"__".join(target_backends)}.mlir',
+          f'compiled__{"__".join(target_backends)}.vmfb',
+      ]
+      for artifact in artifacts_to_check:
+        self.assertTrue(os.path.exists(os.path.join(artifacts_dir, artifact)))
+
+
+if __name__ == '__main__':
+  tf.test.main()
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index 43a215f..f5bd3b7 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -82,7 +82,6 @@
 # keep sorted
 VULKAN_FAILING = [
     "broadcasting_test.py",
-    "concat_test.py",  # TODO(b/160616675)
     "depth_conv_test.py",
     "dynamic_mlp_relu_test.py",
     "dynamic_mlp_test.py",
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index bcd7d88..deef755 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -23,6 +23,23 @@
 The test suites can be run excluding Vulkan by specifying
 `--test_tag_filters="-driver=vulkan"` in the `bazel test` invocation.
 
+## Compiling `tf.Module`s
+
+Compatible TensorFlow modules can be compiled to specific IREE backends using
+`tf_utils.compile_tf_module`. This also optionally saves compilation artifacts
+to a specified directory. These artifacts include: MLIR across various
+lowerings, a TensorFlow SavedModel, and the compiled VM FlatBuffer.
+
+When using Keras models or tf.Modules with functions that IREE can't compile,
+`exported_names` should be specified. For example:
+
+```python
+vmla_module_blob = tf_utils.compile_tf_module(
+    tf_module=SomeKerasModelModule(),
+    target_backends="vmla",
+    exported_names=['predict'])
+```
+
 ## Running tests
 
 For locally running tests and iterating on backend development, `bazel run` is
@@ -52,7 +69,7 @@
 you specify `tf` backend only, then we will also test `tf` vs `tf` to capture
 any model initialization/randomization issues (it is a special case for debug
 purpose). For reproducibility of the unit tests we set random seed of `tf` and
-`numpy` by calling `tf_test_utils.set_random_seed()` before model creation.
+`numpy` by calling `tf_utils.set_random_seed()` before model creation.
 
 ## Test Suites
 
diff --git a/integrations/tensorflow/e2e/batch_norm_test.py b/integrations/tensorflow/e2e/batch_norm_test.py
index b9fdc2a..f9f8d8c 100644
--- a/integrations/tensorflow/e2e/batch_norm_test.py
+++ b/integrations/tensorflow/e2e/batch_norm_test.py
@@ -38,7 +38,7 @@
         variance_epsilon=1e-4)
 
 
-@tf_test_utils.compile_modules(bn=BatchNormModule)
+@tf_test_utils.compile_module(BatchNormModule)
 class BatchNormTest(tf_test_utils.SavedModelTestCase):
 
   def test_batch_norm_inference(self):
@@ -49,8 +49,7 @@
     variance = np.random.random((16,)).astype(np.float32) * 1e-3
     offset = np.random.random((16,)).astype(np.float32) * 1e-3
     scale = np.random.random((16,)).astype(np.float32) * 1e-3
-    r = self.modules.bn.all.batch_norm_inference(x, mean, variance, offset,
-                                                 scale)
+    r = self.get_module().batch_norm_inference(x, mean, variance, offset, scale)
     r.print().assert_all_close()
 
 
diff --git a/integrations/tensorflow/e2e/broadcasting_test.py b/integrations/tensorflow/e2e/broadcasting_test.py
index 01c20a6..cde2fd6 100644
--- a/integrations/tensorflow/e2e/broadcasting_test.py
+++ b/integrations/tensorflow/e2e/broadcasting_test.py
@@ -28,23 +28,23 @@
     return lhs + rhs
 
 
-@tf_test_utils.compile_modules(m=BroadcastingModule)
+@tf_test_utils.compile_module(BroadcastingModule)
 class BroadcastingTest(tf_test_utils.SavedModelTestCase):
 
   def test_add_same_shape(self):
-    m = self.modules.m.all
+    m = self.get_module()
     dst = m.add(tf.random.uniform([4]), tf.random.uniform([4]))
     dst.print().assert_all_close()
 
 
 # TODO(silvasean): Make these work.
 #   def test_add_broadcast_lhs(self):
-#     m = self.modules.m.all
+#     m = self.get_module()
 #     dst = m.add(tf.random.uniform([1]), tf.random.uniform([4]))
 #     dst.print().assert_all_close()
 #
 #   def test_add_broadcast_rhs(self):
-#     m = self.modules.m.all
+#     m = self.get_module()
 #     dst = m.add(tf.random.uniform([4]), tf.random.uniform([1]))
 #     dst.print().assert_all_close()
 
diff --git a/integrations/tensorflow/e2e/concat_test.py b/integrations/tensorflow/e2e/concat_test.py
index 8b7a856..a9f9759 100644
--- a/integrations/tensorflow/e2e/concat_test.py
+++ b/integrations/tensorflow/e2e/concat_test.py
@@ -15,6 +15,7 @@
 """Test concat op."""
 
 from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
 
 
@@ -49,36 +50,36 @@
     return tf.concat([a, b], axis=2)
 
 
-@tf_test_utils.compile_modules(mat=ConcatOpsModule)
+@tf_test_utils.compile_module(ConcatOpsModule)
 class ConcatOpsTest(tf_test_utils.SavedModelTestCase):
 
   def test_concat_zero_dim(self):
-    tf_test_utils.set_random_seed()
-    m = self.modules.mat.all
+    tf_utils.set_random_seed()
+    m = self.get_module()
     a = tf.random.uniform([1, 5, 0], dtype=tf.float32)
     b = tf.random.uniform([1, 5, 1], dtype=tf.float32)
     dst = m.concat_zero_dim(a, b)
     dst.assert_all_close()
 
   def concat0axis(self):
-    tf_test_utils.set_random_seed()
-    m = self.modules.mat.all
+    tf_utils.set_random_seed()
+    m = self.get_module()
     a = tf.random.uniform([1, 5, 1], dtype=tf.float32)
     b = tf.random.uniform([1, 5, 1], dtype=tf.float32)
     dst = m.concat_zero_dim(a, b)
     dst.assert_all_close()
 
   def concat1axis(self):
-    tf_test_utils.set_random_seed()
-    m = self.modules.mat.all
+    tf_utils.set_random_seed()
+    m = self.get_module()
     a = tf.random.uniform([1, 5, 1], dtype=tf.float32)
     b = tf.random.uniform([1, 5, 1], dtype=tf.float32)
     dst = m.concat_zero_dim(a, b)
     dst.assert_all_close()
 
   def concat2axis(self):
-    tf_test_utils.set_random_seed()
-    m = self.modules.mat.all
+    tf_utils.set_random_seed()
+    m = self.get_module()
     a = tf.random.uniform([1, 5, 1], dtype=tf.float32)
     b = tf.random.uniform([1, 5, 1], dtype=tf.float32)
     dst = m.concat_zero_dim(a, b)
diff --git a/integrations/tensorflow/e2e/control_flow_test.py b/integrations/tensorflow/e2e/control_flow_test.py
index d579e9b..0223e8c 100644
--- a/integrations/tensorflow/e2e/control_flow_test.py
+++ b/integrations/tensorflow/e2e/control_flow_test.py
@@ -38,17 +38,17 @@
     return i
 
 
-@tf_test_utils.compile_modules(control_flow=ControlFlowModule)
+@tf_test_utils.compile_module(ControlFlowModule)
 class ControlFlowTest(tf_test_utils.SavedModelTestCase):
 
   def test_short_sequence(self):
     input_array = numpy.array(9., dtype=numpy.float32)
-    result = self.modules.control_flow.all.collatz(input_array)
+    result = self.get_module().collatz(input_array)
     result.print().assert_all_close()
 
   def test_long_sequence(self):
     input_array = numpy.array(178., dtype=numpy.float32)
-    result = self.modules.control_flow.all.collatz(input_array)
+    result = self.get_module().collatz(input_array)
     result.print().assert_all_close()
 
 
diff --git a/integrations/tensorflow/e2e/conv_test.py b/integrations/tensorflow/e2e/conv_test.py
index a4997e6..f72b11d 100644
--- a/integrations/tensorflow/e2e/conv_test.py
+++ b/integrations/tensorflow/e2e/conv_test.py
@@ -98,73 +98,73 @@
     return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
 
 
-@tf_test_utils.compile_modules(conv2d=Conv2dModule)
+@tf_test_utils.compile_module(Conv2dModule)
 class ConvTest(tf_test_utils.SavedModelTestCase):
 
   def test_id_batch_size_1(self):
     i = np.arange(20, dtype=np.float32).reshape([1, 4, 5, 1])
     k = np.ones([1, 1, 1, 1], dtype=np.float32)
-    r = self.modules.conv2d.all.conv2d_1451x1111_valid(i, k)
+    r = self.get_module().conv2d_1451x1111_valid(i, k)
     r.print().assert_all_close()
 
   def test_id_batch_size_2(self):
     i = np.arange(40, dtype=np.float32).reshape([2, 4, 5, 1])
     k = np.ones([1, 1, 1, 1], dtype=np.float32)
-    r = self.modules.conv2d.all.conv2d_2451x1111_valid(i, k)
+    r = self.get_module().conv2d_2451x1111_valid(i, k)
     r.print().assert_all_close()
 
   def test_asym_kernel(self):
     i = np.arange(20, dtype=np.float32).reshape([1, 4, 5, 1])
     k = np.array([[1, 4, 2], [-2, 0, 1]], dtype=np.float32).reshape(2, 3, 1, 1)
-    r = self.modules.conv2d.all.conv2d_1451x2311_valid(i, k)
+    r = self.get_module().conv2d_1451x2311_valid(i, k)
     r.print().assert_all_close()
 
   def test_padding(self):
     i = np.arange(20, dtype=np.float32).reshape([1, 4, 5, 1])
     k = np.array([[1, 4, 2], [-2, 0, 1]], dtype=np.float32).reshape(2, 3, 1, 1)
-    r = self.modules.conv2d.all.conv2d_1451x2311_same(i, k)
+    r = self.get_module().conv2d_1451x2311_same(i, k)
     r.print().assert_all_close()
 
   def test_batched_padding(self):
     i = np.arange(40, dtype=np.float32).reshape([2, 4, 5, 1])
     k = np.array([[1, 4, 2], [-2, 0, 1]], dtype=np.float32).reshape(2, 3, 1, 1)
-    r = self.modules.conv2d.all.conv2d_2451x2311_same(i, k)
+    r = self.get_module().conv2d_2451x2311_same(i, k)
     r.print().assert_all_close()
 
   def test_feature_reduce(self):
     i = np.arange(40, dtype=np.float32).reshape([1, 4, 5, 2])
     k = np.ones([3, 2, 2, 1], dtype=np.float32)
-    r = self.modules.conv2d.all.conv2d_1452x3221_same(i, k)
+    r = self.get_module().conv2d_1452x3221_same(i, k)
     r.print().assert_all_close()
 
   def test_feature_inflate(self):
     i = np.arange(20, dtype=np.float32).reshape([1, 4, 5, 1])
     k = np.arange(2, dtype=np.float32).reshape([1, 1, 1, 2])
-    r = self.modules.conv2d.all.conv2d_1451x1112_same(i, k)
+    r = self.get_module().conv2d_1451x1112_same(i, k)
     r.print().assert_all_close()
 
   def test_feature_mix(self):
     i = np.arange(40, dtype=np.float32).reshape([1, 4, 5, 2])
     k = np.arange(4, dtype=np.float32).reshape([1, 1, 2, 2])
-    r = self.modules.conv2d.all.conv2d_1452x1122_same(i, k)
+    r = self.get_module().conv2d_1452x1122_same(i, k)
     r.print().assert_all_close()
 
   def test_feature_padded(self):
     i = np.arange(40, dtype=np.float32).reshape([1, 4, 5, 2])
     k = np.arange(24, dtype=np.float32).reshape([2, 2, 2, 3])
-    r = self.modules.conv2d.all.conv2d_1452x2223_same(i, k)
+    r = self.get_module().conv2d_1452x2223_same(i, k)
     r.print().assert_all_close()
 
   def test_feature_unpadded(self):
     i = np.arange(40, dtype=np.float32).reshape([1, 4, 5, 2])
     k = np.arange(24, dtype=np.float32).reshape([2, 2, 2, 3])
-    r = self.modules.conv2d.all.conv2d_1452x2223_valid(i, k)
+    r = self.get_module().conv2d_1452x2223_valid(i, k)
     r.print().assert_all_close()
 
   def test_batched_feature_unpadded(self):
     i = np.arange(80, dtype=np.float32).reshape([2, 4, 5, 2])
     k = np.arange(24, dtype=np.float32).reshape([2, 2, 2, 3])
-    r = self.modules.conv2d.all.conv2d_2452x2223_valid(i, k)
+    r = self.get_module().conv2d_2452x2223_valid(i, k)
     r.print().assert_all_close()
 
 
diff --git a/integrations/tensorflow/e2e/depth_conv_test.py b/integrations/tensorflow/e2e/depth_conv_test.py
index 361b55f..cdf4d1e 100644
--- a/integrations/tensorflow/e2e/depth_conv_test.py
+++ b/integrations/tensorflow/e2e/depth_conv_test.py
@@ -38,19 +38,19 @@
         img, kernel, [1, 1, 1, 1], "SAME", name="result")
 
 
-@tf_test_utils.compile_modules(conv2d=Conv2dModule)
+@tf_test_utils.compile_module(Conv2dModule)
 class ConvTest(tf_test_utils.SavedModelTestCase):
 
   def test_batched_feature_unpadded(self):
     i = np.arange(80, dtype=np.float32).reshape([2, 4, 5, 2])
     k = np.arange(24, dtype=np.float32).reshape([2, 2, 2, 3])
-    r = self.modules.conv2d.all.conv2d_2452x2223_valid(i, k)
+    r = self.get_module().conv2d_2452x2223_valid(i, k)
     r.print().assert_all_close()
 
   def test_batched_feature_unpadded_smae(self):
     i = np.arange(80, dtype=np.float32).reshape([2, 4, 5, 2])
     k = np.arange(48, dtype=np.float32).reshape([2, 4, 2, 3])
-    r = self.modules.conv2d.all.conv2d_2452x2223_same(i, k)
+    r = self.get_module().conv2d_2452x2223_same(i, k)
     r.print().assert_all_close()
 
 
diff --git a/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py b/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
index 5f9c667..04de603 100644
--- a/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
+++ b/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
@@ -19,6 +19,7 @@
 
 import numpy as np
 from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
 
 HIDDEN_1_DIM = 256
@@ -35,7 +36,7 @@
                input_dim=28 * 28,
                classes=10):
     super().__init__()
-    tf_test_utils.set_random_seed()
+    tf_utils.set_random_seed()
     self.hidden_1_dim = hidden_1_dim
     self.hidden_2_dim = hidden_2_dim
     self.input_dim = input_dim
@@ -64,11 +65,11 @@
     return tf.nn.softmax(self.mlp(x))
 
 
-@tf_test_utils.compile_modules(mlp=(Mlp, ["predict"]))
+@tf_test_utils.compile_module(Mlp, exported_names=["predict"])
 class DynamicMlpTest(tf_test_utils.SavedModelTestCase):
 
   def test_dynamic_batch(self):
-    m = self.modules.mlp.all
+    m = self.get_module()
     np.random.seed(12345)
     x = np.random.random([3, 28 * 28]).astype(np.float32) * 1e-3
     m.predict(x).print().assert_all_close()
diff --git a/integrations/tensorflow/e2e/dynamic_mlp_test.py b/integrations/tensorflow/e2e/dynamic_mlp_test.py
index 17da1cd..66f3c06 100644
--- a/integrations/tensorflow/e2e/dynamic_mlp_test.py
+++ b/integrations/tensorflow/e2e/dynamic_mlp_test.py
@@ -15,6 +15,7 @@
 
 import numpy as np
 from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
 
 HIDDEN_1_DIM = 256
@@ -31,7 +32,7 @@
                input_dim=28 * 28,
                classes=10):
     super().__init__()
-    tf_test_utils.set_random_seed()
+    tf_utils.set_random_seed()
     self.hidden_1_dim = hidden_1_dim
     self.hidden_2_dim = hidden_2_dim
     self.input_dim = input_dim
@@ -60,11 +61,11 @@
     return tf.nn.softmax(self.mlp(x))
 
 
-@tf_test_utils.compile_modules(mlp=(Mlp, ["predict"]))
+@tf_test_utils.compile_module(Mlp, exported_names=["predict"])
 class DynamicMlpTest(tf_test_utils.SavedModelTestCase):
 
   def test_dynamic_batch(self):
-    m = self.modules.mlp.all
+    m = self.get_module()
     np.random.seed(12345)
     x = np.random.random([3, 28 * 28]).astype(np.float32) * 1e-3
     m.predict(x).print().assert_all_close()
diff --git a/integrations/tensorflow/e2e/explicit_backend_test.py b/integrations/tensorflow/e2e/explicit_backend_test.py
index 1cccd8a..903b34c 100644
--- a/integrations/tensorflow/e2e/explicit_backend_test.py
+++ b/integrations/tensorflow/e2e/explicit_backend_test.py
@@ -29,7 +29,7 @@
     return a * b
 
 
-@tf_test_utils.compile_modules(simple_arithmetic=SimpleArithmeticModule)
+@tf_test_utils.compile_module(SimpleArithmeticModule)
 class ExplicitBackendTest(tf_test_utils.SavedModelTestCase):
 
   def test_explicit(self):
@@ -39,9 +39,9 @@
     # Demonstrates simple, one by one invocation of functions against
     # different explicit backends. Individual backends can be accessed off of
     # the module by name ('tf', 'iree_vmla' below).
-    tf_c = self.modules.simple_arithmetic.tf.simple_mul(a, b)
+    tf_c = self.compiled_modules.tf.simple_mul(a, b)
     print("TF Result:", tf_c)
-    iree_c = self.modules.simple_arithmetic.iree_vmla.simple_mul(a, b)
+    iree_c = self.compiled_modules.iree_vmla.simple_mul(a, b)
     print("IREE Result:", iree_c)
     self.assertAllClose(tf_c, iree_c)
 
@@ -53,18 +53,18 @@
     # which takes a regex string matching backend names. This also returns a
     # MultiResults tuple with actual results keyed by backend name. These also
     # have convenience methods like print() and assert_all_close().
-    vmod = self.modules.simple_arithmetic.multi("tf|iree")
+    vmod = self.compiled_modules.multi("tf|iree")
     r = vmod.simple_mul(a, b)
     r.print().assert_all_close()
 
-  def test_all(self):
+  def test_get_module(self):
     a = np.array([1., 2., 3., 4.], dtype=np.float32)
     b = np.array([400., 5., 6., 7.], dtype=np.float32)
 
-    # Evaluating against all backends can be done with the special 'all'
-    # backend name. This also returns a MultiResults tuple with actual results
-    # keyed by backend name.
-    r = self.modules.simple_arithmetic.all.simple_mul(a, b)
+    # Evaluating against all backends can be done with self.get_module(). This
+    # also returns a MultiResults tuple with actual results keyed by backend
+    # name.
+    r = self.get_module().simple_mul(a, b)
     r.print().assert_all_close()
 
 
diff --git a/integrations/tensorflow/e2e/exported_names_test.py b/integrations/tensorflow/e2e/exported_names_test.py
deleted file mode 100644
index 2d7e447..0000000
--- a/integrations/tensorflow/e2e/exported_names_test.py
+++ /dev/null
@@ -1,46 +0,0 @@
-# Lint as: python3
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#      https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from pyiree.tf.support import tf_test_utils
-import tensorflow.compat.v2 as tf
-
-
-class DontExportEverything(tf.Module):
-
-  @tf.function(input_signature=[])
-  def exported_fn(self):
-    return tf.constant([42.])
-
-  # No input_signature, so it cannot be imported by the SavedModel importer.
-  # We need to ensure that
-  @tf.function
-  def unreachable_fn(self, x):
-    return x
-
-
-# To pass a set of exported names for the module, instead of passing just a
-# module ctor, instead pass a pair `(ctor, [list, of, exported, names])`.
-@tf_test_utils.compile_modules(
-    dont_export_everything=(DontExportEverything, ["exported_fn"]))
-class DontExportEverythingTest(tf_test_utils.SavedModelTestCase):
-
-  def test_dont_export_everything(self):
-    self.modules.dont_export_everything.all.exported_fn().assert_all_close()
-
-
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
-    tf.enable_v2_behavior()
-  tf.test.main()
diff --git a/integrations/tensorflow/e2e/fill_test.py b/integrations/tensorflow/e2e/fill_test.py
index 82b2af5..8ef96a9 100644
--- a/integrations/tensorflow/e2e/fill_test.py
+++ b/integrations/tensorflow/e2e/fill_test.py
@@ -30,14 +30,14 @@
     return tf.fill(dims, value)
 
 
-@tf_test_utils.compile_modules(fill=FillModule)
+@tf_test_utils.compile_module(FillModule)
 class FillTest(tf_test_utils.SavedModelTestCase):
 
   def test_fill(self):
     dims = np.array([2, 3], dtype=np.int32)
     value = np.array(9., dtype=np.float32)
 
-    result = self.modules.fill.all.fill(dims, value)
+    result = self.get_module().fill(dims, value)
     result.assert_all_close()
 
 
diff --git a/integrations/tensorflow/e2e/gather_test.py b/integrations/tensorflow/e2e/gather_test.py
index de532cd..8d2a0ba 100644
--- a/integrations/tensorflow/e2e/gather_test.py
+++ b/integrations/tensorflow/e2e/gather_test.py
@@ -48,31 +48,31 @@
     return tf.gather(params, indices, axis=2, batch_dims=1)
 
 
-@tf_test_utils.compile_modules(gather=GatherModule)
+@tf_test_utils.compile_module(GatherModule)
 class GatherTest(tf_test_utils.SavedModelTestCase):
 
   def test_gather_axis0_scalar(self):
     indices = np.array(2, dtype=np.int32)
     params = np.arange(32, dtype=np.float32).reshape(4, 8)
-    result = self.modules.gather.all.gather_axis0_scalar(params, indices)
+    result = self.get_module().gather_axis0_scalar(params, indices)
     result.print().assert_all_close()
 
   def test_gather_axis0_batch0(self):
     indices = np.array([2, 3], dtype=np.int32)
     params = np.arange(32, dtype=np.float32).reshape(4, 8)
-    result = self.modules.gather.all.gather_axis0_batch0(params, indices)
+    result = self.get_module().gather_axis0_batch0(params, indices)
     result.print().assert_all_close()
 
-  def test_gahter_axis1_batch0(self):
+  def test_gather_axis1_batch0(self):
     indices = np.array([2, 3], dtype=np.int32)
     params = np.arange(4 * 7 * 8, dtype=np.float32).reshape(4, 7, 8)
-    result = self.modules.gather.all.gather_axis1_batch0(params, indices)
+    result = self.get_module().gather_axis1_batch0(params, indices)
     result.print().assert_all_close()
 
-  def test_gahter_axis2_batch1(self):
+  def test_gather_axis2_batch1(self):
     indices = np.array([[2], [3], [0], [1]], dtype=np.int32)
     params = np.arange(4 * 7 * 8 * 2, dtype=np.float32).reshape(4, 7, 8, 2)
-    result = self.modules.gather.all.gather_axis2_batch1(params, indices)
+    result = self.get_module().gather_axis2_batch1(params, indices)
     result.print().assert_all_close()
 
 
diff --git a/integrations/tensorflow/e2e/keras/BUILD b/integrations/tensorflow/e2e/keras/BUILD
index 03a129d..ec592e7 100644
--- a/integrations/tensorflow/e2e/keras/BUILD
+++ b/integrations/tensorflow/e2e/keras/BUILD
@@ -22,6 +22,7 @@
     "//bindings/python:build_defs.oss.bzl",
     "INTREE_TENSORFLOW_PY_DEPS",
     "NUMPY_DEPS",
+    "iree_py_binary",
 )
 load(
     "//integrations/tensorflow/e2e/keras:iree_vision_test_suite.bzl",
@@ -66,7 +67,7 @@
 """
 
 [
-    py_binary(
+    iree_py_binary(
         name = src.replace(".py", "_manual"),
         srcs = [src],
         main = src,
@@ -219,7 +220,7 @@
 # 32x32. These models are not optimized for accuracy or latency (they are for
 # debugging only). They have the same neural net topology with keras vision
 # models trained on imagenet data sets
-py_binary(
+iree_py_binary(
     name = "train_vision_models_on_cifar",
     srcs = ["train_vision_models_on_cifar.py"],
     python_version = "PY3",
diff --git a/integrations/tensorflow/e2e/keras/lstm_static_test.py b/integrations/tensorflow/e2e/keras/lstm_static_test.py
index 12f56d1..0d34d97 100644
--- a/integrations/tensorflow/e2e/keras/lstm_static_test.py
+++ b/integrations/tensorflow/e2e/keras/lstm_static_test.py
@@ -18,6 +18,7 @@
 
 import numpy as np
 from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
 
 NUM_UNITS = 10
@@ -27,7 +28,7 @@
 
 
 def lstm_module():
-  tf_test_utils.set_random_seed()
+  tf_utils.set_random_seed()
   inputs = tf.keras.layers.Input(batch_size=NUM_BATCH, shape=INPUT_SHAPE[1:])
   outputs = tf.keras.layers.LSTM(units=NUM_UNITS, return_sequences=True)(inputs)
   model = tf.keras.Model(inputs, outputs)
@@ -39,11 +40,11 @@
   return module
 
 
-@tf_test_utils.compile_modules(lstm=(lstm_module, ["predict"]))
+@tf_test_utils.compile_module(lstm_module, exported_names=["predict"])
 class LstmTest(tf_test_utils.SavedModelTestCase):
 
   def test_lstm(self):
-    m = self.modules.lstm.all
+    m = self.get_module()
     m.predict(
         tf.constant(
             np.arange(NUM_BATCH * NUM_TIMESTEPS * NUM_UNITS,
diff --git a/integrations/tensorflow/e2e/keras/lstm_test.py b/integrations/tensorflow/e2e/keras/lstm_test.py
index 5def0e7..671c31b 100644
--- a/integrations/tensorflow/e2e/keras/lstm_test.py
+++ b/integrations/tensorflow/e2e/keras/lstm_test.py
@@ -15,6 +15,7 @@
 
 import numpy as np
 from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
 
 NUM_UNITS = 10
@@ -24,7 +25,7 @@
 
 
 def lstm_module():
-  tf_test_utils.set_random_seed()
+  tf_utils.set_random_seed()
   inputs = tf.keras.layers.Input(batch_size=None, shape=INPUT_SHAPE[1:])
   outputs = tf.keras.layers.LSTM(units=NUM_UNITS, return_sequences=True)(inputs)
   model = tf.keras.Model(inputs, outputs)
@@ -36,11 +37,11 @@
   return module
 
 
-@tf_test_utils.compile_modules(lstm=(lstm_module, ["predict"]))
+@tf_test_utils.compile_module(lstm_module, exported_names=["predict"])
 class LstmTest(tf_test_utils.SavedModelTestCase):
 
   def test_lstm(self):
-    m = self.modules.lstm.all
+    m = self.get_module()
     m.predict(
         tf.constant(
             np.arange(NUM_BATCH * NUM_TIMESTEPS * NUM_UNITS,
diff --git a/integrations/tensorflow/e2e/keras/train/model_train_test.py b/integrations/tensorflow/e2e/keras/train/model_train_test.py
index 68cfa73..6675956 100644
--- a/integrations/tensorflow/e2e/keras/train/model_train_test.py
+++ b/integrations/tensorflow/e2e/keras/train/model_train_test.py
@@ -14,13 +14,10 @@
 # limitations under the License.
 """Test keras Model training."""
 
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
 from absl import flags
 import numpy as np
 from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
 from sklearn.preprocessing import PolynomialFeatures
 import tensorflow as tf
 
@@ -51,7 +48,7 @@
       model for linear regression
     """
 
-    tf_test_utils.set_random_seed()
+    tf_utils.set_random_seed()
 
     # build a single layer model
     inputs = tf.keras.layers.Input((input_dim))
@@ -78,8 +75,8 @@
     return loss_value
 
 
-@tf_test_utils.compile_modules(
-    train_module=(ModelTrain.CreateModule, ["TrainStep"]))
+@tf_test_utils.compile_module(
+    ModelTrain.CreateModule, exported_names=["TrainStep"])
 class ModelTrainTest(tf_test_utils.SavedModelTestCase):
 
   def generate_regression_data(self, size=8):
@@ -103,7 +100,7 @@
 
     targets = np.expand_dims(targets, axis=1)
     # run one iteration of training step
-    result = self.modules.train_module.all.TrainStep(inputs, targets)
+    result = self.get_module().TrainStep(inputs, targets)
     result.print().assert_all_close()
 
 
diff --git a/integrations/tensorflow/e2e/keras/vision_model_test.py b/integrations/tensorflow/e2e/keras/vision_model_test.py
index 9164d9b..1804739 100644
--- a/integrations/tensorflow/e2e/keras/vision_model_test.py
+++ b/integrations/tensorflow/e2e/keras/vision_model_test.py
@@ -17,6 +17,7 @@
 from absl import flags
 import numpy as np
 from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
 
 FLAGS = flags.FLAGS
@@ -88,7 +89,7 @@
 
 def models():
   tf.keras.backend.set_learning_phase(False)
-  tf_test_utils.set_random_seed()
+  tf_utils.set_random_seed()
 
   input_shape = get_input_shape(FLAGS.data, FLAGS.model)
   # keras model receives images size as input,
@@ -127,7 +128,7 @@
   return module
 
 
-@tf_test_utils.compile_modules(applications=(models, ['predict']))
+@tf_test_utils.compile_module(models, exported_names=['predict'])
 class AppTest(tf_test_utils.SavedModelTestCase):
 
   def test_application(self):
@@ -135,8 +136,7 @@
     input_data = np.random.rand(np.prod(np.array(input_shape))).astype(
         np.float32)
     input_data = input_data.reshape(input_shape)
-    self.modules.applications.all.predict(input_data).print().assert_all_close(
-        atol=1e-6)
+    self.get_module().predict(input_data).print().assert_all_close(atol=1e-6)
 
 
 if __name__ == '__main__':
diff --git a/integrations/tensorflow/e2e/linspace_test.py b/integrations/tensorflow/e2e/linspace_test.py
index 682df56..d326db5 100644
--- a/integrations/tensorflow/e2e/linspace_test.py
+++ b/integrations/tensorflow/e2e/linspace_test.py
@@ -33,14 +33,14 @@
     return tf.linspace(start, stop, num)
 
 
-@tf_test_utils.compile_modules(linspace=LinSpaceModule)
+@tf_test_utils.compile_module(LinSpaceModule)
 class LinspaceTest(tf_test_utils.SavedModelTestCase):
 
   def test_linspace(self):
     start = np.array(10., dtype=np.float32)
     stop = np.array(12., dtype=np.float32)
 
-    result = self.modules.linspace.all.linspace(start, stop)
+    result = self.get_module().linspace(start, stop)
     result.assert_all_close()
 
 
diff --git a/integrations/tensorflow/e2e/mandelbrot_test.py b/integrations/tensorflow/e2e/mandelbrot_test.py
index 0fa7205..4886b7a 100644
--- a/integrations/tensorflow/e2e/mandelbrot_test.py
+++ b/integrations/tensorflow/e2e/mandelbrot_test.py
@@ -94,11 +94,11 @@
     return tf.reshape(in_the_set, shape=[view_pixels, view_pixels])
 
 
-@tf_test_utils.compile_modules(mandelbrot=MandelbrotModule)
+@tf_test_utils.compile_module(MandelbrotModule)
 class MandelbrotTest(tf_test_utils.SavedModelTestCase):
 
   def test_mandelbrot(self):
-    mandelbrot = self.modules.mandelbrot.all
+    mandelbrot = self.get_module()
 
     # Basic view of the entire set.
     pixels = mandelbrot.calculate(-0.7, 0.0, 3.0, 400, 100)
diff --git a/integrations/tensorflow/e2e/math_test.py b/integrations/tensorflow/e2e/math_test.py
index f5d3538..b27d1d1 100644
--- a/integrations/tensorflow/e2e/math_test.py
+++ b/integrations/tensorflow/e2e/math_test.py
@@ -38,27 +38,27 @@
     return tf.math.mod(x, 2.0)
 
 
-@tf_test_utils.compile_modules(math=MathModule)
+@tf_test_utils.compile_module(MathModule)
 class MathTest(tf_test_utils.SavedModelTestCase):
 
   def test_abs(self):
     a = np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32)
-    r = self.modules.math.all.abs(a)
+    r = self.get_module().abs(a)
     r.print().assert_all_close()
 
   def test_cos(self):
     a = np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32)
-    r = self.modules.math.all.cos(a)
+    r = self.get_module().cos(a)
     r.print().assert_all_close()
 
   def test_log(self):
     a = np.array([0.1, 0.2, 0.5, 1.0], dtype=np.float32)
-    r = self.modules.math.all.log(a)
+    r = self.get_module().log(a)
     r.print().assert_all_close()
 
   def test_mod(self):
     a = np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32)
-    r = self.modules.math.all.mod(a)
+    r = self.get_module().mod(a)
     r.print().assert_all_close()
 
 
diff --git a/integrations/tensorflow/e2e/matrix_ops_test.py b/integrations/tensorflow/e2e/matrix_ops_test.py
index a604fac..b29a198 100644
--- a/integrations/tensorflow/e2e/matrix_ops_test.py
+++ b/integrations/tensorflow/e2e/matrix_ops_test.py
@@ -70,58 +70,58 @@
     return tf.matmul(lhs, rhs)
 
 
-@tf_test_utils.compile_modules(mat=MatrixOpsModule)
+@tf_test_utils.compile_module(MatrixOpsModule)
 class MatrixOpsTest(tf_test_utils.SavedModelTestCase):
 
   def test_basic_matmul(self):
-    m = self.modules.mat.all
+    m = self.get_module()
     dst = m.basic_matmul(tf.random.uniform([4, 2]), tf.random.uniform([2, 4]))
     dst.assert_all_close()
 
   def test_matmul_lhs_batch(self):
-    m = self.modules.mat.all
+    m = self.get_module()
     dst = m.matmul_lhs_batch(
         tf.random.uniform([3, 4, 2]), tf.random.uniform([2, 4]))
     dst.assert_all_close()
 
   def test_matmul_rhs_batch(self):
-    m = self.modules.mat.all
+    m = self.get_module()
     dst = m.matmul_rhs_batch(
         tf.random.uniform([4, 2]), tf.random.uniform([3, 2, 4]))
     dst.assert_all_close()
 
   def test_matmul_broadcast_singleton_dimension(self):
-    m = self.modules.mat.all
+    m = self.get_module()
     dst = m.matmul_broadcast_singleton_dimension(
         tf.random.uniform([1, 4, 2]), tf.random.uniform([3, 2, 4]))
     dst.assert_all_close()
 
   def test_matmul_high_rank_batch(self):
-    m = self.modules.mat.all
+    m = self.get_module()
     dst = m.matmul_high_rank_batch(
         tf.random.uniform([1, 7, 4, 2]), tf.random.uniform([7, 1, 2, 4]))
     dst.assert_all_close()
 
   def test_matmul_dynamic_matching_batch(self):
-    m = self.modules.mat.all
+    m = self.get_module()
     dst = m.matmul_dynamic(
         tf.random.uniform([2, 2, 3]), tf.random.uniform([2, 3, 4]))
     dst.assert_all_close()
 
   def test_matmul_dynamic_broadcast_lhs(self):
-    m = self.modules.mat.all
+    m = self.get_module()
     dst = m.matmul_dynamic(
         tf.random.uniform([1, 2, 3]), tf.random.uniform([2, 3, 4]))
     dst.assert_all_close()
 
   def test_matmul_dynamic_broadcast_rhs(self):
-    m = self.modules.mat.all
+    m = self.get_module()
     dst = m.matmul_dynamic(
         tf.random.uniform([2, 2, 3]), tf.random.uniform([1, 3, 4]))
     dst.assert_all_close()
 
   def test_matmul_dynamic_rank_broadcasting(self):
-    m = self.modules.mat.all
+    m = self.get_module()
     dst = m.matmul_dynamic_lhs_batch(
         tf.random.uniform([7, 2, 3]), tf.random.uniform([3, 4]))
     dst.assert_all_close()
diff --git a/integrations/tensorflow/e2e/resource_ops_test.py b/integrations/tensorflow/e2e/resource_ops_test.py
index 342adff..1d703c0 100644
--- a/integrations/tensorflow/e2e/resource_ops_test.py
+++ b/integrations/tensorflow/e2e/resource_ops_test.py
@@ -28,12 +28,11 @@
     return self.counter.assign_add(value)
 
 
-@tf_test_utils.compile_modules(resource_ops=ResourcesOpsModule)
+@tf_test_utils.compile_module(ResourcesOpsModule)
 class ResourcesOpsTest(tf_test_utils.SavedModelTestCase):
 
   def test_add_assign(self):
-    result = self.modules.resource_ops.all.add_assign(
-        np.array(9., dtype=np.float32))
+    result = self.get_module().add_assign(np.array(9., dtype=np.float32))
     result.assert_all_close()
 
 
diff --git a/integrations/tensorflow/e2e/ring_buffer_test.py b/integrations/tensorflow/e2e/ring_buffer_test.py
index 32dac0d..ea48711 100644
--- a/integrations/tensorflow/e2e/ring_buffer_test.py
+++ b/integrations/tensorflow/e2e/ring_buffer_test.py
@@ -164,10 +164,10 @@
     return dict(list(base_config.items()) + list(config.items()))
 
 
-class StatefulRingBufferM(tf.Module):
+class StatefulRingBufferModule(tf.Module):
 
   def __init__(self):
-    super(StatefulRingBufferM, self).__init__()
+    super(StatefulRingBufferModule, self).__init__()
     state_shape = [BATCH_SIZE, TIME_SIZE, FEATURE_SIZE]
     self.rb = StatefulRingBuffer(state_shape=state_shape)
 
@@ -177,26 +177,27 @@
     return self.rb(x)
 
 
-@tf_test_utils.compile_modules(rb=(StatefulRingBufferM, ["predict"]))
+@tf_test_utils.compile_module(
+    StatefulRingBufferModule, exported_names=["predict"])
 class StatefulRingBufferTest(tf_test_utils.SavedModelTestCase):
 
-  def test_statefulringbuffer(self):
+  def test_stateful_ringbuffer(self):
     input1 = np.array([[1.0, 2.0]], dtype=np.float32)
-    result1 = self.modules.rb.all.predict(input1)
+    result1 = self.get_module().predict(input1)
     output1 = np.array([[1.0, 2.0]], dtype=np.float32)
     assert np.allclose(result1, output1)
 
     # ring buffer is not filled yet,
     # so data from first cycle will be returned
     input2 = np.array([[3.0, 4.0]], dtype=np.float32)
-    result2 = self.modules.rb.all.predict(input2)
+    result2 = self.get_module().predict(input2)
     output2 = np.array([[1.0, 2.0]], dtype=np.float32)
     assert np.allclose(result2, output2)
 
     # on 3rd cycle we overwrite oldest data
     # and return data from 2nd cycle
     input3 = np.array([[5.0, 6.0]], dtype=np.float32)
-    result3 = self.modules.rb.all.predict(input3)
+    result3 = self.get_module().predict(input3)
     output3 = np.array([[3.0, 4.0]], dtype=np.float32)
     assert np.allclose(result3, output3)
 
diff --git a/integrations/tensorflow/e2e/scatter_update_test.py b/integrations/tensorflow/e2e/scatter_update_test.py
index d10e26d..cdd3277 100644
--- a/integrations/tensorflow/e2e/scatter_update_test.py
+++ b/integrations/tensorflow/e2e/scatter_update_test.py
@@ -48,31 +48,28 @@
     return tf.tensor_scatter_nd_update(tensor, indices, updates)
 
 
-@tf_test_utils.compile_modules(scatter_update=ScatterUpdateModule)
+@tf_test_utils.compile_module(ScatterUpdateModule)
 class ScatterUpdateTest(tf_test_utils.SavedModelTestCase):
 
   def test_scatter_update_1D(self):
     tensor = tf.ones([8], dtype=tf.int32)
     indices = tf.constant([[4], [5], [6]])
     updates = tf.constant([9, 10, 11])
-    result = self.modules.scatter_update.all.scatter_update_1D(
-        tensor, indices, updates)
+    result = self.get_module().scatter_update_1D(tensor, indices, updates)
     result.assert_all_close()
 
   def test_scatter_update_2D(self):
     tensor = tf.ones([4, 3], dtype=tf.int32)
     indices = tf.constant([[1, 0], [2, 1], [3, 2]])
     updates = tf.constant([2, 5, 8])
-    result = self.modules.scatter_update.all.scatter_update_2D(
-        tensor, indices, updates)
+    result = self.get_module().scatter_update_2D(tensor, indices, updates)
     result.assert_all_close()
 
   def test_scatter_update_2D_slice(self):
     tensor = tf.ones([4, 3], dtype=tf.int32)
     indices = tf.constant([[1]])
     updates = tf.constant([[2, 3, 4]])
-    result = self.modules.scatter_update.all.scatter_update_2D_slice(
-        tensor, indices, updates)
+    result = self.get_module().scatter_update_2D_slice(tensor, indices, updates)
     result.assert_all_close()
 
 
diff --git a/integrations/tensorflow/e2e/simple_arithmetic_test.py b/integrations/tensorflow/e2e/simple_arithmetic_test.py
index eaa1b30..0c5941d 100644
--- a/integrations/tensorflow/e2e/simple_arithmetic_test.py
+++ b/integrations/tensorflow/e2e/simple_arithmetic_test.py
@@ -36,13 +36,13 @@
     return tf.matmul(a, b)
 
 
-@tf_test_utils.compile_modules(simple_arithmetic=SimpleArithmeticModule)
+@tf_test_utils.compile_module(SimpleArithmeticModule)
 class SimpleArithmeticTest(tf_test_utils.SavedModelTestCase):
 
   def test_simple_mul(self):
     a = np.array([1., 2., 3., 4.], dtype=np.float32)
     b = np.array([400., 5., 6., 7.], dtype=np.float32)
-    r = self.modules.simple_arithmetic.all.simple_mul(a, b)
+    r = self.get_module().simple_mul(a, b)
     r.print().assert_all_close()
 
   def test_simple_matmul(self):
@@ -50,7 +50,7 @@
     # Note: scaling by a small value to increase numerical stability.
     a = np.random.random((128, 3072)).astype(np.float32) * 1e-3
     b = np.random.random((3072, 256)).astype(np.float32) * 1e-3
-    r = self.modules.simple_arithmetic.all.simple_matmul(a, b)
+    r = self.get_module().simple_matmul(a, b)
     r.print().assert_all_close()
 
 
diff --git a/integrations/tensorflow/e2e/simple_stateful_test.py b/integrations/tensorflow/e2e/simple_stateful_test.py
index cc8bfd5..45eba4f 100644
--- a/integrations/tensorflow/e2e/simple_stateful_test.py
+++ b/integrations/tensorflow/e2e/simple_stateful_test.py
@@ -32,11 +32,11 @@
     return self.counter
 
 
-@tf_test_utils.compile_modules(stateful=Stateful)
+@tf_test_utils.compile_module(Stateful)
 class StatefulTest(tf_test_utils.SavedModelTestCase):
 
   def test_stateful(self):
-    m = self.modules.stateful.all
+    m = self.get_module()
     m.inc_by(tf.constant(1.))
     m.get_state().print().assert_all_close()
 
diff --git a/integrations/tensorflow/e2e/sliding_window_test.py b/integrations/tensorflow/e2e/sliding_window_test.py
index cae9b54..b663fc8 100644
--- a/integrations/tensorflow/e2e/sliding_window_test.py
+++ b/integrations/tensorflow/e2e/sliding_window_test.py
@@ -62,10 +62,10 @@
     return dict(list(base_config.items()) + list(config.items()))
 
 
-class SlidingWindowM(tf.Module):
+class SlidingWindowModule(tf.Module):
 
   def __init__(self):
-    super(SlidingWindowM, self).__init__()
+    super(SlidingWindowModule, self).__init__()
     state_shape = [BATCH_SIZE, TIME_SIZE, FEATURE_SIZE]
     self.sw = SlidingWindow(state_shape=state_shape)
 
@@ -75,17 +75,17 @@
     return self.sw(x)
 
 
-@tf_test_utils.compile_modules(sw=(SlidingWindowM, ["predict"]))
+@tf_test_utils.compile_module(SlidingWindowModule, exported_names=["predict"])
 class SlidingWindowTest(tf_test_utils.SavedModelTestCase):
 
   def test_slidingwindow(self):
     input1 = np.array([[1.0, 2.0]], dtype=np.float32)
-    result1 = self.modules.sw.all.predict(input1)
+    result1 = self.get_module().predict(input1)
     output1 = np.array([[0.0, 0.0], [0.0, 0.0], [1.0, 2.0]], dtype=np.float32)
     assert np.allclose(result1, output1)
 
     input2 = np.array([[3.0, 4.0]], dtype=np.float32)
-    result2 = self.modules.sw.all.predict(input2)
+    result2 = self.get_module().predict(input2)
     output2 = np.array([[0.0, 0.0], [1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
     assert np.allclose(result2, output2)
 
diff --git a/integrations/tensorflow/e2e/strings_test.py b/integrations/tensorflow/e2e/strings_test.py
index b4105a6..ac590ff 100644
--- a/integrations/tensorflow/e2e/strings_test.py
+++ b/integrations/tensorflow/e2e/strings_test.py
@@ -40,20 +40,20 @@
     return tf.strings.reduce_join(wps, 1)
 
 
-@tf_test_utils.compile_modules(strings=StringsModule)
+@tf_test_utils.compile_module(StringsModule)
 class StringsTest(tf_test_utils.SavedModelTestCase):
 
   def test_print_ids(self):
     input_ids = np.asarray(
         [[12, 10, 29, 28, 94, 15, 24, 27, 94, 25, 21, 10, 34],
          [13, 24, 16, 28, 94, 15, 24, 27, 94, 28, 29, 10, 34]])
-    self.modules.strings.all.print_ids(input_ids)
+    self.get_module().print_ids(input_ids)
 
   def test_strings_to_ids(self):
     input_ids = np.asarray(
         [[12, 10, 29, 28, 94, 15, 24, 27, 94, 25, 21, 10, 34],
          [13, 24, 16, 28, 94, 15, 24, 27, 94, 28, 29, 10, 34]])
-    result = self.modules.strings.all.strings_to_ids(input_ids)
+    result = self.get_module().strings_to_ids(input_ids)
     result.assert_all_equal()
 
 
diff --git a/integrations/tensorflow/e2e/tensorlist_test.py b/integrations/tensorflow/e2e/tensorlist_test.py
index 83ae28f..f8ea811 100644
--- a/integrations/tensorflow/e2e/tensorlist_test.py
+++ b/integrations/tensorflow/e2e/tensorlist_test.py
@@ -68,27 +68,27 @@
     return ta.stack()
 
 
-@tf_test_utils.compile_modules(tensorlist=TensorListModule)
+@tf_test_utils.compile_module(TensorListModule)
 class TensorListTest(tf_test_utils.SavedModelTestCase):
 
   def test_identity_through_tensorlist(self):
-    m = self.modules.tensorlist.all
+    m = self.get_module()
     result = m.identity_through_tensorlist(tf.constant(42.))
     result.print().assert_all_close()
 
   def test_add_through_tensorlist(self):
-    m = self.modules.tensorlist.all
+    m = self.get_module()
     result = m.add_through_tensorlist(tf.constant(42.), tf.constant(43.))
     result.print().assert_all_close()
 
   def test_slice_first_element_with_from_tensor(self):
-    m = self.modules.tensorlist.all
+    m = self.get_module()
     result = m.slice_first_element_with_from_tensor(
         tf.range(STATIC_SIZE, dtype=tf.float32))
     result.print().assert_all_close()
 
   def test_slice_first_element_with_from_tensor_high_rank(self):
-    m = self.modules.tensorlist.all
+    m = self.get_module()
     result = m.slice_first_element_with_from_tensor_high_rank(
         tf.broadcast_to(
             tf.range(STATIC_SIZE, dtype=tf.float32),
@@ -96,7 +96,7 @@
     result.print().assert_all_close()
 
   def test_concat_with_tensorlist_stack(self):
-    m = self.modules.tensorlist.all
+    m = self.get_module()
     result = m.concat_with_tensorlist_stack(tf.constant(42.), tf.constant(43.))
     result.print().assert_all_close()
 
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
index e80ce79..17e4ec7 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
@@ -335,11 +335,6 @@
   void runOnOperation() override;
   ConvertToSPIRVPass() {}
   ConvertToSPIRVPass(const ConvertToSPIRVPass &pass) {}
-  Option<bool> useCooperativeMatrix{
-      *this, "use-cooperative-matrix",
-      llvm::cl::desc("Experimental: Lower vector contract to cooperative "
-                     "matrix operations"),
-      llvm::cl::init(false)};
 };
 }  // namespace
 
@@ -419,12 +414,9 @@
   populateStandardToSPIRVPatterns(context, typeConverter, patterns);
   // Pull in builtin func to spv.func conversion.
   populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns);
-
-  if (useCooperativeMatrix) {
-    auto &cooperativeMatrixAnalysis = getAnalysis<CooperativeMatrixAnalysis>();
-    populateVectorToSPIRVPatterns(context, typeConverter, patterns,
-                                  cooperativeMatrixAnalysis);
-  }
+  auto &cooperativeMatrixAnalysis = getAnalysis<CooperativeMatrixAnalysis>();
+  populateVectorToSPIRVPatterns(context, typeConverter, patterns,
+                                cooperativeMatrixAnalysis);
   patterns.insert<HALInterfaceLoadConstantConverter, IREEPlaceholderConverter,
                   LinalgReshapeConverter>(context, typeConverter);
 
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
index 55a147b..046b10d 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -iree-codegen-convert-to-spirv=use-cooperative-matrix %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -iree-codegen-convert-to-spirv %s | IreeFileCheck %s
 
 module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
   // CHECK: spv.globalVariable @__push_constant_var__ : !spv.ptr<!spv.struct<!spv.array<5 x i32, stride=4> [0]>, PushConstant>
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/internal/LLVMAOTTargetLinker.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/internal/LLVMAOTTargetLinker.cpp
index b883ef7..e5108d6 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/internal/LLVMAOTTargetLinker.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/internal/LLVMAOTTargetLinker.cpp
@@ -14,6 +14,8 @@
 
 #include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTargetLinker.h"
 
+#include "iree/base/status.h"
+
 namespace mlir {
 namespace iree_compiler {
 namespace IREE {
@@ -27,7 +29,11 @@
   ASSIGN_OR_RETURN(sharedLibFile, iree::file_io::GetTempFile("dylibfile"));
   std::string linkingCmd =
       linkerToolPath + " -shared " + archiveFile + " -o " + sharedLibFile;
-  system(linkingCmd.c_str());
+  int systemRet = system(linkingCmd.c_str());
+  if (systemRet != 0) {
+    return iree::InternalErrorBuilder(IREE_LOC)
+           << linkingCmd << " failed with exit code " << systemRet;
+  }
   return iree::file_io::GetFileContents(sharedLibFile);
 }
 
diff --git a/iree/hal/vulkan/vma_allocator.cc b/iree/hal/vulkan/vma_allocator.cc
index f9895fb..d160254 100644
--- a/iree/hal/vulkan/vma_allocator.cc
+++ b/iree/hal/vulkan/vma_allocator.cc
@@ -135,6 +135,11 @@
     VmaAllocationCreateFlags flags) {
   IREE_TRACE_SCOPE0("VmaAllocator::AllocateInternal");
 
+  // Guard against the corner case where the requested buffer size is 0. The
+  // application is unlikely to do anything when requesting a 0-byte buffer; but
+  // it can happen in real world use cases. So we should at least not crash.
+  if (allocation_size == 0) allocation_size = 4;
+
   VkBufferCreateInfo buffer_create_info;
   buffer_create_info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
   buffer_create_info.pNext = nullptr;
diff --git a/iree/tools/run_mlir_main.cc b/iree/tools/run_mlir_main.cc
index f726d3e..9f26284 100644
--- a/iree/tools/run_mlir_main.cc
+++ b/iree/tools/run_mlir_main.cc
@@ -122,6 +122,13 @@
     llvm::cl::ZeroOrMore,
 };
 
+static llvm::cl::opt<std::string> input_values_file_flag{
+    "input-value-file",
+    llvm::cl::desc("Provides a file for input shapes and optional values (see "
+                   "run_module_main.cc for details)"),
+    llvm::cl::init(""),
+};
+
 static llvm::cl::opt<bool> run_flag{
     "run",
     llvm::cl::desc("Runs the module (vs. just compiling and verifing)"),
@@ -258,6 +265,23 @@
   return binary_contents;
 }
 
+// Returns a splitted input values from `filename` using newline as separater.
+StatusOr<std::vector<std::string>> GetInputValues(const std::string& filename) {
+  std::string error_message;
+  auto file = mlir::openInputFile(filename, &error_message);
+  if (!file) {
+    return NotFoundErrorBuilder(IREE_LOC) << "Unable to open input file '"
+                                          << filename << "': " << error_message;
+  }
+  llvm::SmallVector<llvm::StringRef, 8> source_buffers;
+  file->getBuffer().split(source_buffers, /*Separator=*/"\n", /*MaxSplit=*/-1,
+                          /*KeepEmpty=*/false);
+  std::vector<std::string> res;
+  res.reserve(source_buffers.size());
+  for (auto s : source_buffers) res.emplace_back(s);
+  return res;
+}
+
 // Evaluates a single function in its own fiber, printing the results to stdout.
 Status EvaluateFunction(iree_vm_context_t* context,
                         iree_hal_allocator_t* allocator,
@@ -267,11 +291,23 @@
 
   std::cout << "EXEC @" << export_name << std::endl;
   ASSIGN_OR_RETURN(auto input_descs, ParseInputSignature(function));
-  auto input_values_list = absl::MakeConstSpan(
-      input_values_flag.empty() ? nullptr : &input_values_flag.front(),
-      input_values_flag.size());
-  ASSIGN_OR_RETURN(auto* input_list, ParseToVariantList(input_descs, allocator,
-                                                        input_values_list));
+  iree_vm_variant_list_t* input_list;
+  if (!input_values_file_flag.empty()) {
+    if (!input_values_flag.empty()) {
+      return InvalidArgumentErrorBuilder(IREE_LOC)
+             << "Expected only one of input_values_file_flag and "
+                "input_values_flag is set";
+    }
+    ASSIGN_OR_RETURN(auto input_values, GetInputValues(input_values_file_flag));
+    ASSIGN_OR_RETURN(input_list,
+                     ParseToVariantList(input_descs, allocator, input_values));
+  } else {
+    auto input_values_list = absl::MakeConstSpan(
+        input_values_flag.empty() ? nullptr : &input_values_flag.front(),
+        input_values_flag.size());
+    ASSIGN_OR_RETURN(input_list, ParseToVariantList(input_descs, allocator,
+                                                    input_values_list));
+  }
 
   ASSIGN_OR_RETURN(auto output_descs, ParseOutputSignature(function));
   // Prepare outputs list to accept the results from the invocation.