Merge main -> google

* 6c1869bc Add license header to linker scripts (#2724)
* 6d32fe81 Updates all e2e tests to use TracedModuleTestCase (#2736)
* 7fea2dd1 Revert "[android] Ignore adb push status for the executable (#2627)" (#2735)
* da48d03c Adds a class for tracing module calls and example tests. (#2660)
* d66150d8 Enables `conv_test` on vulkan. (#2723)
* a3b73bf6 Add .python-version to .gitignore (#2727)
* 6d2bfd35 Enables passing MobileNet and MobileNetV2 vision tests (#2725)
* 763dc098 Fix ModelBuilder build (#2726)
* 528a86a7 Update the codegen pipeline example snippets to better demonstrate (#2709)
* 91a05ff5 Merge google -> main (#2719)
* f0307a87 Revert "Merge google -> main" (#2718)
* 77d2f8b4 Merge google -> main (#2711)
* aafabbcb Add linker scripts to cmake pyiree builds to hide symbols. (#2707)
* 1918d5c7 Updated gather to run on LLVM and SPIRV (#2701)
* ec38aa00 Add a VM to EmitC conversion and CModule Target (#2536)
* 32fe2c9b Revert "Add a VM to EmitC conversion and CModule Target (#2536)" (#2703)
* e7f90b61 Revert addition of DYLIB-LLVM-AOT as a default backend (#2702)
* 48efada8 Dump input values to a file. (#2683)
* 0fd8a550 Mark shell scripts as executable (#2699)
* 966fe78f Add dep "//iree/base:localfile" to iree-run-mlir (#2687)
* 64e29879 Merge google -> main (#2698)
* ced7477f Update scripts for packaging wheels from bazel.
* 1c05e0ae Disable dylib driver for android build (#2692)
* e00a85f8 Disable a few gcc warnings that are chatty.
* 74af1f01 Add a config for gcc. (#2696)
* eec62a1e Remove dylib-llvm-aot tests (#2694)
* 11bd7a00 Fix stdalign include in Android (#2688)
* b9026bd6 Enable dylib-llvm-aot backend e2e/xla_op tests (#2639)
* 8ed9f73f Added Complex operation lowerings to pre-target conversion (#2662)
* 8224247a Working around duplicate func->llvm.func rewrites. (#2685)
* f2d6649c Bump pybind11 to head. (#2592)
* 9ac5241d [NFC] Remove empty ArrayRef<NamedAttribute> arg in op creation (#2686)
* 36abdfde [ModelBuilder] Revert spurious path change (#2684)
* c96bbb1d Merge pull request #2568 from google/benvanik-flatcc-vm
* f80864f8 [ModelBuilder] Add MemRefUtil helpers for padding. (#2577)
* 1425737d Fix flatcc target name and binary path (#2599)
* 8c18b8c3 Refactoring bytecode_module to use flatcc instead of flatbuffers C++. This imp..
* 3762c132 Linking in the flatcc bytecode_module_def files. This ensures they are correct..
* eb53de3e Building bytecode_module_def with flatcc.

PiperOrigin-RevId: 324296127
diff --git a/.gitignore b/.gitignore
index 7db7ef9..96a5b66 100644
--- a/.gitignore
+++ b/.gitignore
@@ -39,3 +39,6 @@
 compile_commands.json
 .cache/clangd
 .clangd/
+
+# Pyenv files
+.python-version
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 48c4948..b761338 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -95,7 +95,8 @@
   set( IREE_HAL_DRIVERS_TO_BUILD ${IREE_ALL_HAL_DRIVERS} )
   # For cross compilation towords Android, we don't want LLVM JIT HAL driver.
   if(ANDROID)
-    list(REMOVE_ITEM IREE_HAL_DRIVERS_TO_BUILD LLVM)
+    # TODO(ataei): Enable dylib/dylib-llvm-aot for android.
+    list(REMOVE_ITEM IREE_HAL_DRIVERS_TO_BUILD LLVM DyLib)
   endif()
 endif()
 message(STATUS "Building HAL drivers ${IREE_HAL_DRIVERS_TO_BUILD}")
@@ -115,7 +116,7 @@
 
 # List of all target backends to be built by default:
 set(IREE_ALL_TARGET_BACKENDS
-  # TODO(scotttodd): LLVMAOT
+  # TODO(#2645): Add DYLIB-LLVM-AOT when it doesn't require an env var
   LLVM-IR
   Vulkan-SPIRV
   VMLA
@@ -356,7 +357,7 @@
   # We need flatc to generate some source code. When cross-compiling, we need
   # to make sure the flatc binary is configured under host environment.
   iree_declare_host_excutable(flatc BUILDONLY)
-  iree_declare_host_excutable(flatcc BUILDONLY)
+  iree_declare_host_excutable(flatcc_cli BUILDONLY)
 
   # Set the FLATBUFFERS_FLATC_EXECUTABLE. It controls where to find the flatc
   # binary in BuildFlatBuffers().
@@ -371,14 +372,17 @@
     DEPENDS iree_host_build_flatc
     COMMENT "Installing host flatc..."
   )
-  add_custom_target(iree_host_flatcc
+  add_custom_target(iree_host_flatcc_cli
     COMMAND
       "${CMAKE_COMMAND}" -E copy_if_different
-        "${IREE_HOST_BINARY_ROOT}/third_party/flatcc/flatcc${IREE_HOST_EXECUTABLE_SUFFIX}"
-        "${IREE_HOST_BINARY_ROOT}/bin"
-    DEPENDS iree_host_build_flatcc
+        "${PROJECT_SOURCE_DIR}/third_party/flatcc/bin/flatcc${IREE_HOST_EXECUTABLE_SUFFIX}"
+        "${IREE_HOST_BINARY_ROOT}/bin/flatcc_cli"
+    DEPENDS iree_host_build_flatcc_cli
     COMMENT "Installing host flatcc..."
   )
+else()
+  # TODO: unify flatc and flatcc handling to the same mechanism.
+  add_executable(iree_host_flatcc_cli ALIAS flatcc_cli)
 endif()
 
 if(${IREE_BUILD_COMPILER})
diff --git a/SUBMODULE_VERSIONS b/SUBMODULE_VERSIONS
index 58303cb..e719d5a 100644
--- a/SUBMODULE_VERSIONS
+++ b/SUBMODULE_VERSIONS
@@ -7,7 +7,7 @@
 cd4e8d7f6f5ef108919f9f53db35ac73d1edea3d third_party/llvm-project
 17b12a4481daa150e2d1ea3ada086b551b856707 third_party/marl
 80885f899e12d55a45561ef758eea47bb340dbf1 third_party/mlir-emitc
-80d452484c5409444b0ec19383faa84bb7a4d351 third_party/pybind11
+d8c7ee00a687ac369e62e2032514a93a9b413502 third_party/pybind11
 9f53ba413e6fc879236dcaa3e008915973d67a4f third_party/ruy
 a1390ed39ec77ecfb574bc6fcd5bfc5e3adbdea9 third_party/sdl2
 f8bf11a0253a32375c32cad92c841237b96696c0 third_party/spirv_headers
diff --git a/bindings/python/pyiree/compiler/CMakeLists.txt b/bindings/python/pyiree/compiler/CMakeLists.txt
index b7bc2cc..929e621 100644
--- a/bindings/python/pyiree/compiler/CMakeLists.txt
+++ b/bindings/python/pyiree/compiler/CMakeLists.txt
@@ -26,6 +26,8 @@
     PyExtCompiler
   MODULE_NAME
     binding
+  UNIX_LINKER_SCRIPT
+    "unix_version.lds"
   SRCS
     "initialize_module.cc"
   PYEXT_DEPS
diff --git a/bindings/python/pyiree/compiler/unix_version.lds b/bindings/python/pyiree/compiler/unix_version.lds
new file mode 100644
index 0000000..68ef766
--- /dev/null
+++ b/bindings/python/pyiree/compiler/unix_version.lds
@@ -0,0 +1,19 @@
+/* Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+{
+  global: PyInit_binding;
+  local: *;
+};
diff --git a/bindings/python/pyiree/rt/CMakeLists.txt b/bindings/python/pyiree/rt/CMakeLists.txt
index 1ec5eaf..225dccd 100644
--- a/bindings/python/pyiree/rt/CMakeLists.txt
+++ b/bindings/python/pyiree/rt/CMakeLists.txt
@@ -50,6 +50,8 @@
   MODULE_NAME binding
   SRCS
     "initialize_module.cc"
+  UNIX_LINKER_SCRIPT
+    "unix_version.lds"
   PYEXT_DEPS
     ::PyExtRtLib
     bindings::python::pyiree::common::PyextCommonLib
diff --git a/bindings/python/pyiree/rt/unix_version.lds b/bindings/python/pyiree/rt/unix_version.lds
new file mode 100644
index 0000000..68ef766
--- /dev/null
+++ b/bindings/python/pyiree/rt/unix_version.lds
@@ -0,0 +1,19 @@
+/* Copyright 2020 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+{
+  global: PyInit_binding;
+  local: *;
+};
diff --git a/build_tools/bazel/build_bindings.sh b/build_tools/bazel/build_bindings.sh
index b4a417e..4ca7ddb 100755
--- a/build_tools/bazel/build_bindings.sh
+++ b/build_tools/bazel/build_bindings.sh
@@ -43,7 +43,7 @@
 declare -a test_env_args=(
   --test_env=IREE_LLVMJIT_DISABLE=$IREE_LLVMJIT_DISABLE
   --test_env=IREE_VULKAN_DISABLE=$IREE_VULKAN_DISABLE
-  --test_env=IREE_LLVMAOT_LINKER_PATH
+  --action_env=IREE_LLVMAOT_LINKER_PATH=$IREE_LLVMAOT_LINKER_PATH
 )
 
 declare -a default_build_tag_filters=("-nokokoro")
diff --git a/build_tools/bazel/build_core.sh b/build_tools/bazel/build_core.sh
index 89d94e2..bd04942 100755
--- a/build_tools/bazel/build_core.sh
+++ b/build_tools/bazel/build_core.sh
@@ -42,6 +42,7 @@
 declare -a test_env_args=(
   --test_env=IREE_LLVMJIT_DISABLE=$IREE_LLVMJIT_DISABLE
   --test_env=IREE_VULKAN_DISABLE=$IREE_VULKAN_DISABLE
+  --action_env=IREE_LLVMAOT_LINKER_PATH=$IREE_LLVMAOT_LINKER_PATH
 )
 
 declare -a default_build_tag_filters=("-nokokoro")
diff --git a/build_tools/bazel/build_tensorflow.sh b/build_tools/bazel/build_tensorflow.sh
index be82227..3634c4d 100755
--- a/build_tools/bazel/build_tensorflow.sh
+++ b/build_tools/bazel/build_tensorflow.sh
@@ -42,6 +42,7 @@
 declare -a test_env_args=(
   --test_env=IREE_LLVMJIT_DISABLE=$IREE_LLVMJIT_DISABLE
   --test_env=IREE_VULKAN_DISABLE=$IREE_VULKAN_DISABLE
+  --action_env=IREE_LLVMAOT_LINKER_PATH=$IREE_LLVMAOT_LINKER_PATH
 )
 # Pass in VK_ICD_FILENAMES if exists so that the Vulkan loader can find the
 # Vulkan implementation.
diff --git a/build_tools/bazel/iree.bazelrc b/build_tools/bazel/iree.bazelrc
index dba02ed..220b868 100644
--- a/build_tools/bazel/iree.bazelrc
+++ b/build_tools/bazel/iree.bazelrc
@@ -36,6 +36,23 @@
 build --define open_source_build=true
 
 ###############################################################################
+# Options for "generic_gcc" builds
+###############################################################################
+
+# C++14 standard version is required.
+build:generic_gcc --cxxopt=-std=c++14 --host_cxxopt=-std=c++14
+
+# Default to adding back asserts in optimized builds.
+# This is a good compromise between runtime and debugability.
+build:generic_gcc --copt=-UNDEBUG
+
+# Disable warnings we don't care about or that generally have a low signal/noise
+# ratio.
+build:generic_gcc --copt=-Wno-unused-but-set-parameter
+build:generic_gcc --copt=-Wno-comment
+build:generic_gcc --copt=-Wno-attributes
+
+###############################################################################
 # Options for "generic_clang" builds: these options should generally apply to
 # either clang or gcc and are curated based on need.
 ###############################################################################
diff --git a/build_tools/cmake/flatbuffer_c_library.cmake b/build_tools/cmake/flatbuffer_c_library.cmake
index 916cfd9..eb63567 100644
--- a/build_tools/cmake/flatbuffer_c_library.cmake
+++ b/build_tools/cmake/flatbuffer_c_library.cmake
@@ -100,7 +100,7 @@
   endforeach()
   list(TRANSFORM _OUTS PREPEND "${CMAKE_CURRENT_BINARY_DIR}/")
 
-  iree_get_executable_path(_FLATCC_BIN flatcc)
+  iree_get_executable_path(_FLATCC_BIN flatcc_cli)
   add_custom_command(
     OUTPUT
       ${_OUTS}
@@ -115,7 +115,7 @@
     MAIN_DEPENDENCY
       ${_RULE_SRCS}
     DEPENDS
-      ${_FLATCC_BIN}
+      iree_host_flatcc_cli
       ${_RULE_SRCS}
     COMMAND_EXPAND_LISTS
   )
diff --git a/build_tools/cmake/iree_multipy.cmake b/build_tools/cmake/iree_multipy.cmake
index a7fd1f1..706f6a6 100644
--- a/build_tools/cmake/iree_multipy.cmake
+++ b/build_tools/cmake/iree_multipy.cmake
@@ -98,7 +98,7 @@
 function(iree_pyext_module)
   cmake_parse_arguments(ARG
     ""
-    "NAME;MODULE_NAME"
+    "NAME;MODULE_NAME;UNIX_LINKER_SCRIPT"
     "SRCS;COPTS;DEPS;PYEXT_DEPS"
     ${ARGN})
   _setup_iree_pyext_names()
@@ -127,6 +127,14 @@
         SUFFIX "${IREE_MULTIPY_${V}_SUFFIX}${IREE_MULTIPY_${V}_EXTENSION}"
     )
 
+    # Link flags.
+    if(UNIX AND NOT APPLE)  # Apple does not support linker scripts.
+      if(ARG_UNIX_LINKER_SCRIPT)
+        set_target_properties(${VER_NAME} PROPERTIES LINK_FLAGS
+          "-Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/${ARG_UNIX_LINKER_SCRIPT}")
+      endif()
+    endif()
+
     iree_pyext_pybind11_options(${VER_NAME})
     target_include_directories(${VER_NAME}
       PUBLIC
diff --git a/build_tools/cmake/run_android_test.sh b/build_tools/cmake/run_android_test.sh
index 0c7510d..292890c 100755
--- a/build_tools/cmake/run_android_test.sh
+++ b/build_tools/cmake/run_android_test.sh
@@ -35,7 +35,7 @@
 set -x
 set -e
 
-adb push $TEST_EXECUTABLE $TEST_ANDROID_ABS_DIR/$(basename $TEST_EXECUTABLE) 1>/dev/null
+adb push $TEST_EXECUTABLE $TEST_ANDROID_ABS_DIR/$(basename $TEST_EXECUTABLE)
 
 if [ -n "$TEST_DATA" ]; then
   adb push $TEST_DATA $TEST_ANDROID_ABS_DIR/$(basename $TEST_DATA)
diff --git a/build_tools/embed_data/generate_cc_embed_data.cc b/build_tools/embed_data/generate_cc_embed_data.cc
index 21275cc..fed4f13 100644
--- a/build_tools/embed_data/generate_cc_embed_data.cc
+++ b/build_tools/embed_data/generate_cc_embed_data.cc
@@ -113,11 +113,8 @@
   f << "#include <cstddef>\n";
   GenerateTocStruct(f);
   GenerateNamespaceOpen(f);
-  f << "static const struct ::iree::FileToc toc[] = {\n";
-  assert(input_files.size() == toc_files.size());
   for (size_t i = 0, e = input_files.size(); i < e; ++i) {
-    f << "  {";
-    f << "\"" << absl::CEscape(toc_files[i]) << "\",\n";
+    f << "alignas(alignof(void*)) static char const file_" << i << "[] = {\n";
     std::string contents;
     if (!SlurpFile(input_files[i], &contents)) {
       std::cerr << "Error reading file " << input_files[i] << "\n";
@@ -130,7 +127,16 @@
       f << "\"" << absl::CHexEscape(line) << "\"\n";
       remaining_contents = remaining_contents.substr(line.size());
     }
-    f << "\"\\0\", " << contents.size() << "},\n";
+    f << "};\n";
+  }
+  f << "static const struct ::iree::FileToc toc[] = {\n";
+  assert(input_files.size() == toc_files.size());
+  for (size_t i = 0, e = input_files.size(); i < e; ++i) {
+    f << "  {\n";
+    f << "    \"" << absl::CEscape(toc_files[i]) << "\",\n";
+    f << "    file_" << i << ",\n";
+    f << "    sizeof(file_" << i << ") - 1\n";
+    f << "  },\n";
   }
   f << "  {nullptr, nullptr, 0},\n";
   f << "};\n";
diff --git a/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/integrations/build.sh b/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/integrations/build.sh
old mode 100644
new mode 100755
diff --git a/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/integrations/build_kokoro.sh b/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/integrations/build_kokoro.sh
old mode 100644
new mode 100755
diff --git a/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-turing/integrations/build.sh b/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-turing/integrations/build.sh
old mode 100644
new mode 100755
diff --git a/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-turing/integrations/build_kokoro.sh b/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-turing/integrations/build_kokoro.sh
old mode 100644
new mode 100755
diff --git a/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-swiftshader/build.sh b/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-swiftshader/build.sh
old mode 100644
new mode 100755
diff --git a/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-swiftshader/build_kokoro.sh b/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-swiftshader/build_kokoro.sh
old mode 100644
new mode 100755
diff --git a/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/hlo/CMakeLists.txt b/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/hlo/CMakeLists.txt
index 0900667..d8ddd33 100644
--- a/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/hlo/CMakeLists.txt
+++ b/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/hlo/CMakeLists.txt
@@ -40,7 +40,9 @@
     "lib/Dialect/mhlo/IR/lhlo_ops.cc"
     "lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc"
     "lib/Dialect/mhlo/transforms/legalize_control_flow.cc"
+    "lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc"
     "lib/Dialect/mhlo/transforms/legalize_to_standard.cc"
+    "lib/Dialect/mhlo/transforms/lower_complex.cc"
     "lib/Dialect/mhlo/transforms/lower_general_dot.cc"
     "lib/Dialect/mhlo/transforms/materialize_broadcasts.cc"
     "lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc"
@@ -89,6 +91,7 @@
     tensorflow_mlir_hlo_hlo_ops_pattern_gen
     tensorflow_mlir_hlo_infer_fusibility_op_interface_gen
     tensorflow_mlir_hlo_legalize_to_standard_patterns_gen
+    tensorflow_mlir_hlo_lower_to_complex_inc_gen
     tensorflow_mlir_hlo_lhlo_ops_gen
     tensorflow_mlir_hlo_xla_canonicalize_gen
   PUBLIC
@@ -222,3 +225,18 @@
   OUTS
     -gen-rewriters lib/Dialect/mhlo/transforms/generated_legalize_to_standard.inc
 )
+
+external_tablegen_library(
+  PACKAGE
+    tensorflow
+  NAME
+    mlir_hlo_lower_to_complex_inc_gen
+  TBLGEN
+    MLIR
+  ROOT
+    ${TF_MLIR_HLO_SRC_ROOT}
+  SRCS
+    "lib/Dialect/mhlo/transforms/lower_complex_patterns.td"
+  OUTS
+    -gen-rewriters lib/Dialect/mhlo/transforms/generated_lower_complex.inc
+)
diff --git a/docs/design_docs/codegen_passes.md b/docs/design_docs/codegen_passes.md
index 83e37fc..4a1335c 100644
--- a/docs/design_docs/codegen_passes.md
+++ b/docs/design_docs/codegen_passes.md
@@ -27,13 +27,13 @@
     func @main_ex_dispatch() {
       %c0 = constant 0 : index
       %0 = hal.interface.load.tensor @legacy_io::@arg0,
-             offset = %c0 : tensor<4x5xf32>
+             offset = %c0 : tensor<32x24xf32>
       %1 = hal.interface.load.tensor @legacy_io::@arg1,
-             offset = %c0 : tensor<5x10xf32>
+             offset = %c0 : tensor<24x16xf32>
       %2 = "mhlo.dot"(%0, %1) {precision_config = ["DEFAULT", "DEFAULT"]} :
-             (tensor<4x5xf32>, tensor<5x10xf32>) -> tensor<4x10xf32>
+             (tensor<32x24xf32>, tensor<24x16xf32>) -> tensor<32x16xf32>
       hal.interface.store.tensor %2, @legacy_io::@ret0,
-        offset = %c0 : tensor<4x10xf32>
+        offset = %c0 : tensor<32x16xf32>
       return
     }
     hal.interface @legacy_io attributes {sym_visibility = "private"} {
@@ -57,17 +57,18 @@
     func @main_ex_dispatch() {
       %c0 = constant 0 : index
       %0 = hal.interface.load.tensor @legacy_io::@arg0,
-             offset = %c0 : tensor<10x5xf32>
+             offset = %c0 : tensor<10x15xf32>
       %1 = hal.interface.load.tensor @legacy_io::@arg1,
-             offset = %c0 : tensor<10x5xf32>
+             offset = %c0 : tensor<10x15xf32>
       %2 = hal.interface.load.tensor @legacy_io::@arg2,
-             offset = %c0 : tensor<10x5xf32>
+             offset = %c0 : tensor<15xf32>
       %3 = "mhlo.add"(%0, %1) :
-         (tensor<10x5xf32>, tensor<10x5xf32>) -> tensor<10x5xf32>
-      %4 = "mhlo.multiply"(%3, %2) :
-         (tensor<10x5xf32>, tensor<10x5xf32>) -> tensor<10x5xf32>
-      hal.interface.store.tensor %4, @legacy_io::@ret0,
-        offset = %c0 : tensor<10x5xf32>
+         (tensor<10x15xf32>, tensor<10x15xf32>) -> tensor<10x15xf32>
+      %4 = "mhlo.broadcast"(%2) : (tensor<15xf32>) -> tensor<10x15xf32>
+      %5 = "mhlo.multiply"(%3, %4) :
+         (tensor<10x15xf32>, tensor<10x15xf32>) -> tensor<10x15xf32>
+      hal.interface.store.tensor %5, @legacy_io::@ret0,
+        offset = %c0 : tensor<10x15xf32>
       return
     }
     hal.interface @legacy_io attributes {sym_visibility = "private"} {
@@ -130,11 +131,13 @@
 
 The first step is to convert MHLO operations to Linalg on tensors. This is done
 using the [HLOToLinalgPass][HLOToLinalgPass] from Tensorflow. An example of the
-conversion is shown below, where each of the `mhlo.add` and `mhlo.multiply`
-operations are converted to `linalg.generic` operations on tensors.
+conversion is shown below, where the `mhlo.add`, `mhlo.broadcast` and
+`mhlo.multiply` operations are converted to `linalg.generic` operations on
+tensors.
 
 ```mlir
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1)>
 %3 = linalg.generic
        {args_in = 2 : i64, args_out = 1 : i64,
         indexing_maps = [#map0, #map0, #map0],
@@ -142,15 +145,22 @@
      ^bb0(%arg0: f32, %arg1: f32):  // no predecessors
        %5 = addf %arg0, %arg1 : f32
        linalg.yield %5 : f32
-     } : tensor<10x5xf32>, tensor<10x5xf32> -> tensor<10x5xf32>
+     } : tensor<10x15xf32>, tensor<10x15xf32> -> tensor<10x15xf32>
 %4 = linalg.generic
+       {args_in = 1 : i64, args_out = 1 : i64,
+        indexing_maps = [#map1, #map0],
+        iterator_types = ["parallel", "parallel"]} %2 {
+     ^bb0(%arg0: f32):  // no predecessors
+       linalg.yield %arg0 : f32
+     }: tensor<15xf32> -> tensor<10x15xf32>
+%5 = linalg.generic
        {args_in = 2 : i64, args_out = 1 : i64,
         indexing_maps = [#map0, #map0, #map0],
-        iterator_types = ["parallel", "parallel"]} %3, %2 {
+        iterator_types = ["parallel", "parallel"]} %3, %4 {
      ^bb0(%arg0: f32, %arg1: f32):  // no predecessors
        %5 = mulf %arg0, %arg1 : f32
        linalg.yield %5 : f32
-     }: tensor<10x5xf32>, tensor<10x5xf32> -> tensor<10x5xf32>
+     }: tensor<10x15xf32>, tensor<10x15xf32> -> tensor<10x15xf32>
 ```
 
 <a name="snippet3"></a> Snippet 3 : MHLO to Linalg conversion for
@@ -190,15 +200,16 @@
 
 ```mlir
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1)>
 %3 = linalg.generic
        {args_in = 3 : i64, args_out = 1 : i64,
-        indexing_maps = [#map0, #map0, #map0, #map0],
+        indexing_maps = [#map0, #map0, #map1, #map0],
         iterator_types = ["parallel", "parallel"]} %0, %1, %2 {
      ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):  // no predecessors
        %4 = addf %arg0, %arg1 : f32
        %5 = mulf %4, %arg2 : f32
        linalg.yield %5 : f32
-     }: tensor<?x5xf32>, tensor<?x5xf32>, tensor<?x5xf32> -> tensor<?x5xf32>
+     }: tensor<10x15xf32>, tensor<10x15xf32>, tensor<15xf32> -> tensor<10x15xf32>
 ```
 
 <a name="snippet4"></a> Snippet 4: Fusion of Linalg operation on tensors for
@@ -265,15 +276,15 @@
 ```mlir
 func @main_ex_dispatch() {
   %0 = iree.placeholder for "interface buffer"
-         {binding = @legacy_io::@ret0} : memref<4x10xf32>
+         {binding = @legacy_io::@ret0} : memref<32x16xf32>
   %c0 = constant 0 : index
   %1 = iree.placeholder for "interface buffer"
-         {binding = @legacy_io::@arg0} : memref<4x5xf32>
+         {binding = @legacy_io::@arg0} : memref<32x24xf32>
   %2 = iree.placeholder for "interface buffer"
-         {binding = @legacy_io::@arg1} : memref<5x10xf32>
+         {binding = @legacy_io::@arg1} : memref<24x16xf32>
   %cst = constant 0.000000e+00 : f32
   linalg.matmul(%1, %2, %0) :
-    memref<4x5xf32>, memref<5x10xf32>, memref<4x10xf32>
+    memref<32x24xf32>, memref<24x16xf32>, memref<32x16xf32>
   return
 }
 ```
@@ -283,25 +294,26 @@
 
 ```mlir
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1)>
 func @main_ex_dispatch() {
   %0 = iree.placeholder for "interface buffer"
-         {binding = @legacy_io::@ret0} : memref<10x5xf32>
+         {binding = @legacy_io::@ret0} : memref<10x15xf32>
   %c0 = constant 0 : index
   %1 = iree.placeholder for "interface buffer"
-         {binding = @legacy_io::@arg0} : memref<10x5xf32>
+         {binding = @legacy_io::@arg0} : memref<10x15xf32>
   %2 = iree.placeholder for "interface buffer"
-         {binding = @legacy_io::@arg1} : memref<10x5xf32>
+         {binding = @legacy_io::@arg1} : memref<10x15xf32>
   %3 = iree.placeholder for "interface buffer"
-         {binding = @legacy_io::@arg2} : memref<10x5xf32>
+         {binding = @legacy_io::@arg2} : memref<15xf32>
   linalg.generic
     {args_in = 3 : i64, args_out = 1 : i64,
-     indexing_maps = [#map0, #map0, #map0],
+     indexing_maps = [#map0, #map0, #map1, #map0],
      iterator_types = ["parallel", "parallel"]} %1, %2, %3, %0 {
   ^bb0(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32):  // no predecessors
     %4 = addf %arg0, %arg1 : f32
     %5 = mulf %4, %arg2 : f32
     linalg.yield %5 : f32
-  }: memref<10x5xf32>, memref<10x5xf32>, memref<10x5xf32>, memref<10x5xf32>
+  }: memref<10x15xf32>, memref<10x15xf32>, memref<15xf32>, memref<10x15xf32>
   return
 }
 ```
@@ -350,18 +362,20 @@
   attributes {
     spv.entry_point_abi = {local_size = dense<[8, 8, 1]> : vector<3xi32>}} {
   %cst = constant 0.000000e+00 : f32
+  %c32 = constant 32 : index
+  %c24 = constant 24 : index
+  %c16 = constant 16 : index
   %c0 = constant 0 : index
   %c4 = constant 4 : index
-  %c10 = constant 10 : index
   %0 = iree.placeholder for "interface buffer"
-         {binding = @legacy_io::@ret0} : memref<4x10xf32>
+         {binding = @legacy_io::@ret0} : memref<32x16xf32>
   %1 = iree.placeholder for "interface buffer"
-         {binding = @legacy_io::@arg0} : memref<4x5xf32>
+         {binding = @legacy_io::@arg0} : memref<32x24xf32>
   %2 = iree.placeholder for "interface buffer"
-         {binding = @legacy_io::@arg1} : memref<5x10xf32>
-  linalg.fill(%0, %cst) : memref<4x10xf32>, f32
-  scf.parallel (%arg0, %arg1) = (%c0, %c0) to (%c4, %c10) step (%c8, %c8) {
-    scf.for %arg2 = %c0 to %c5 step %c4 {
+         {binding = @legacy_io::@arg1} : memref<24x16xf32>
+  linalg.fill(%0, %cst) : memref<32x16xf32>, f32
+  scf.parallel (%arg0, %arg1) = (%c0, %c0) to (%c32, %c16) step (%c8, %c8) {
+    scf.for %arg2 = %c0 to %24 step %c4 {
       ...
       %5 = subview %1[%arg0, %arg2]...
       ...
@@ -440,19 +454,21 @@
 func @matmul_tile()
   attributes {
     spv.entry_point_abi = {local_size = dense<[8, 8, 1]> : vector<3xi32>}} {
-  %c96 = constant 96 : index
+  %c32 = constant 32 : index
+  %c24 = constant 24 : index
+  %c16 = constant 16 : index
   %c4 = constant 4 : index
   %c8 = constant 8 : index
   %c0 = constant 0 : index
   %c1 = constant 1 : index
   %0 = iree.placeholder for "interface buffer"
-         {binding = @legacy_io::@arg0} : memref<96x96xf32>
+         {binding = @legacy_io::@arg0} : memref<32x24xf32>
   %1 = iree.placeholder for "interface buffer"
-         {binding = @legacy_io::@arg1} : memref<96x96xf32>
+         {binding = @legacy_io::@arg1} : memref<24x16xf32>
   %2 = iree.placeholder for "interface buffer"
-         {binding = @legacy_io::@ret0} : memref<96x96xf32>
-  scf.parallel (%arg0, %arg1) = (%c0, %c0) to (%c96, %c96) step (%c8, %c8) {
-    scf.for %arg2 = %c0 to %c96 step %c4 {
+         {binding = @legacy_io::@ret0} : memref<32x16xf32>
+  scf.parallel (%arg0, %arg1) = (%c0, %c0) to (%c32, %c16) step (%c8, %c8) {
+    scf.for %arg2 = %c0 to %c24 step %c4 {
       ...
       %5 = subview %0[%arg0, %arg2]...
       ...
@@ -518,38 +534,40 @@
 func @main_ex_dispatch_0_dispatch_1()
   attributes {
     spv.entry_point_abi = {local_size = dense<[8, 8, 1]> : vector<3xi32>}} {
-  %c5 = constant 5 : index
+  %c24 = constant 24 : index
   %c8 = constant 8 : index
   %c4 = constant 4 : index
   %c0 = constant 0 : index
   %c1 = constant 1 : index
   %0 = iree.placeholder for "interface buffer"
-         {binding = @legacy_io::@ret0} : memref<4x10xf32>
+         {binding = @legacy_io::@ret0} : memref<32x16xf32>
   %1 = iree.placeholder for "interface buffer"
-         {binding = @legacy_io::@arg0} : memref<4x5xf32>
+         {binding = @legacy_io::@arg0} : memref<32x24xf32>
   %2 = iree.placeholder for "interface buffer"
-         {binding = @legacy_io::@arg1} : memref<5x10xf32>
+         {binding = @legacy_io::@arg1} : memref<24x16xf32>
   %3 = "gpu.block_id"() {dimension = "x"} : () -> index
-  %4 = muli %3, %c8 : index
-  scf.for %arg0 = %c0 to %c5 step %c4 {
+  %4 = "gpu.block_id"() {dimension = "y"} : () -> index
+  %5 = muli %4, %c8 : index
+  %6 = muli %3, %c8 : index
+  scf.for %arg0 = %c0 to %c24 step %c4 {
     ...
-    %9 = subview %1[0, %arg0]
+    %15 = subview %1[%5, %arg0]
     ...
-    %14 = subview %2[%arg0, %4]
-    %15 = subview %0[0, %4]
-    %16 = "gpu.thread_id"() {dimension = "x"} : () -> index
-    %17 = "gpu.thread_id"() {dimension = "y"} : () -> index
-    %18 = cmpi "slt", %17, %c4 : index
-    %19 = cmpi "slt", %16, %13 : index
-    %20 = and %18, %19 : i1
-    scf.if %20 {
-      scf.for %arg1 = %c0 to %8 step %c1 {
-        %21 = load %9[%17, %arg1] : memref<4x?xf32, #map0>
-        %22 = load %14[%arg1, %16] : memref<?x?xf32, #map1>
-        %23 = load %15[%17, %16] : memref<4x?xf32, #map1>
-        %24 = mulf %21, %22 : f32
-        %25 = addf %23, %24 : f32
-        store %25, %15[%17, %16] : memref<4x?xf32, #map1>
+    %20 = subview %2[%arg0, %6]
+    %21 = subview %0[%5, %6]
+    %22 = "gpu.thread_id"() {dimension = "x"} : () -> index
+    %23 = "gpu.thread_id"() {dimension = "y"} : () -> index
+    %24 = cmpi "slt", %23, %10 : index
+    %25 = cmpi "slt", %22, %19 : index
+    %26 = and %24, %25 : i1
+    scf.if %26 {
+      scf.for %arg1 = %c0 to %14 step %c1 {
+        %27 = load %15[%23, %arg1] : memref<?x?xf32, #map0>
+        %28 = load %20[%arg1, %22] : memref<?x?xf32, #map1>
+        %29 = load %21[%23, %22] : memref<?x?xf32, #map1>
+        %30 = mulf %21, %22 : f32
+        %31 = addf %23, %24 : f32
+        store %25, %15[%23, %22] : memref<4x?xf32, #map1>
       }
     }
   }
@@ -574,13 +592,13 @@
   %c50 = constant 50 : index
   %c5 = constant 5 : index
   %0 = iree.placeholder for "interface buffer"
-         {binding = @legacy_io::@ret0} : memref<10x5xf32>
+         {binding = @legacy_io::@ret0} : memref<10x15xf32>
   %1 = iree.placeholder for "interface buffer"
-         {binding = @legacy_io::@arg0} : memref<10x5xf32>
+         {binding = @legacy_io::@arg0} : memref<10x15xf32>
   %2 = iree.placeholder for "interface buffer"
-         {binding = @legacy_io::@arg1} : memref<10x5xf32>
+         {binding = @legacy_io::@arg1} : memref<10x15xf32>
   %3 = iree.placeholder for "interface buffer"
-         {binding = @legacy_io::@arg2} : memref<10x5xf32>
+         {binding = @legacy_io::@arg2} : memref<15xf32>
   %4 = "gpu.block_id"() {dimension = "x"} : () -> index
   %5 = "gpu.block_dim"() {dimension = "x"} : () -> index
   %6 = "gpu.thread_id"() {dimension = "x"} : () -> index
@@ -590,12 +608,12 @@
   scf.if %9 {
     %10 = divi_signed %8, %c5 : index
     %11 = remi_signed %8, %c5 : index
-    %12 = load %1[%10, %11] : memref<10x5xf32>
-    %13 = load %2[%10, %11] : memref<10x5xf32>
-    %14 = load %3[%10, %11] : memref<10x5xf32>
+    %12 = load %1[%10, %11] : memref<10x15xf32>
+    %13 = load %2[%10, %11] : memref<10x15xf32>
+    %14 = load %3[%11] : memref<15xf32>
     %15 = addf %12, %13 : f32
     %16 = mulf %15, %14 : f32
-    store %16, %0[%10, %11] : memref<10x5xf32>
+    store %16, %0[%10, %11] : memref<10x15xf32>
   }
   return
 }
diff --git a/experimental/ModelBuilder/MemRefUtils.h b/experimental/ModelBuilder/MemRefUtils.h
index ebd6d1b..4d1e67b 100644
--- a/experimental/ModelBuilder/MemRefUtils.h
+++ b/experimental/ModelBuilder/MemRefUtils.h
@@ -74,21 +74,55 @@
 
 // Mallocs a StridedMemRefDescriptor<T, N>* that matches the MLIR ABI.
 // This is an implementation detail that is kept in sync with MLIR codegen
+// conventions.  Additionally takes a `shapeAlloc` array which
+// is used instead of `shape` to allocate "more aligned" data and compute the
+// corresponding strides.
+template <typename T, int N>
+typename std::enable_if<(N >= 1), StridedMemRefType<T, N> *>::type
+makeStridedMemRefDescriptor(void *ptr, void *alignedPtr,
+                            const std::array<int64_t, N> &shape,
+                            const std::array<int64_t, N> &shapeAlloc,
+                            AllocFunType allocFun = &::malloc) {
+  StridedMemRefType<T, N> *descriptor = static_cast<StridedMemRefType<T, N> *>(
+      allocFun(sizeof(StridedMemRefType<T, N>)));
+  descriptor->basePtr = static_cast<T *>(ptr);
+  descriptor->data = static_cast<T *>(alignedPtr);
+  descriptor->offset = 0;
+  std::copy(shape.begin(), shape.end(), descriptor->sizes);
+  auto strides = makeStrides<N>(shapeAlloc);
+  std::copy(strides.begin(), strides.end(), descriptor->strides);
+  return descriptor;
+}
+
+// Mallocs a StridedMemRefDescriptor<T, 0>* that matches the MLIR ABI.
+// This is an implementation detail that is kept in sync with MLIR codegen
+// conventions.  Additionally takes a `shapeAlloc` array which
+// is used instead of `shape` to allocate "more aligned" data and compute the
+// corresponding strides.
+template <typename T, int N>
+typename std::enable_if<(N == 0), StridedMemRefType<T, 0> *>::type
+makeStridedMemRefDescriptor(void *ptr, void *alignedPtr,
+                            const std::array<int64_t, N> &shape = {},
+                            const std::array<int64_t, N> &shapeAlloc = {},
+                            AllocFunType allocFun = &::malloc) {
+  StridedMemRefType<T, 0> *descriptor = static_cast<StridedMemRefType<T, 0> *>(
+      allocFun(sizeof(StridedMemRefType<T, 0>)));
+  descriptor->basePtr = static_cast<T *>(ptr);
+  descriptor->data = static_cast<T *>(alignedPtr);
+  descriptor->offset = 0;
+  return descriptor;
+}
+
+// Mallocs a StridedMemRefDescriptor<T, N>* that matches the MLIR ABI.
+// This is an implementation detail that is kept in sync with MLIR codegen
 // conventions.
 template <typename T, int N>
 typename std::enable_if<(N >= 1), StridedMemRefType<T, N> *>::type
 makeStridedMemRefDescriptor(void *ptr, void *alignedPtr,
                             const std::array<int64_t, N> &shape,
-                            AllocFunType alloc = &::malloc) {
-  StridedMemRefType<T, N> *descriptor = static_cast<StridedMemRefType<T, N> *>(
-      alloc(sizeof(StridedMemRefType<T, N>)));
-  descriptor->basePtr = static_cast<T *>(ptr);
-  descriptor->data = static_cast<T *>(alignedPtr);
-  descriptor->offset = 0;
-  std::copy(shape.begin(), shape.end(), descriptor->sizes);
-  auto strides = makeStrides<N>(shape);
-  std::copy(strides.begin(), strides.end(), descriptor->strides);
-  return descriptor;
+                            AllocFunType allocFun = &::malloc) {
+  return makeStridedMemRefDescriptor<T, N>(ptr, alignedPtr, shape, shape,
+                                           allocFun);
 }
 
 // Mallocs a StridedMemRefDescriptor<T, 0>* (i.e. a pointer to scalar) that
@@ -98,13 +132,9 @@
 typename std::enable_if<(N == 0), StridedMemRefType<T, 0> *>::type
 makeStridedMemRefDescriptor(void *ptr, void *alignedPtr,
                             const std::array<int64_t, N> &shape = {},
-                            AllocFunType alloc = &::malloc) {
-  StridedMemRefType<T, 0> *descriptor = static_cast<StridedMemRefType<T, 0> *>(
-      alloc(sizeof(StridedMemRefType<T, 0>)));
-  descriptor->basePtr = static_cast<T *>(ptr);
-  descriptor->data = static_cast<T *>(alignedPtr);
-  descriptor->offset = 0;
-  return descriptor;
+                            AllocFunType allocFun = &::malloc) {
+  return makeStridedMemRefDescriptor<T, N>(ptr, alignedPtr, shape, shape,
+                                           allocFun);
 }
 
 // Mallocs an UnrankedMemRefType<T>* that contains a ranked
@@ -113,9 +143,9 @@
 template <typename T, int N>
 ::UnrankedMemRefType<T> *allocUnrankedDescriptor(
     void *data, void *alignedData, const std::array<int64_t, N> &shape,
-    AllocFunType alloc = &::malloc) {
+    AllocFunType allocFun = &::malloc) {
   ::UnrankedMemRefType<T> *res = static_cast<::UnrankedMemRefType<T> *>(
-      alloc(sizeof(::UnrankedMemRefType<T>)));
+      allocFun(sizeof(::UnrankedMemRefType<T>)));
   res->rank = N;
   res->descriptor = makeStridedMemRefDescriptor<T, N>(data, alignedData, shape);
   return res;
@@ -157,14 +187,14 @@
 // and greater than the size of T. By default the alignment is sizeof(T).
 template <typename T>
 std::pair<void *, void *> allocAligned(
-    size_t nElements, AllocFunType alloc = &::malloc,
+    size_t nElements, AllocFunType allocFun = &::malloc,
     llvm::Optional<uint64_t> alignment = llvm::Optional<uint64_t>()) {
   assert(sizeof(T) < (1ul << 32) && "Elemental type overflows");
   auto size = nElements * sizeof(T);
   auto desiredAlignment = alignment.getValueOr(pow2msb(sizeof(T)));
   assert((desiredAlignment & (desiredAlignment - 1)) == 0);
   assert(desiredAlignment >= sizeof(T));
-  void *data = alloc(size + desiredAlignment);
+  void *data = allocFun(size + desiredAlignment);
   uintptr_t addr = reinterpret_cast<uintptr_t>(data);
   uintptr_t rem = addr % desiredAlignment;
   void *alignedData =
@@ -194,24 +224,48 @@
 }
 
 // Entry point to allocate a dense buffer with a given `shape` and initializer
-// of type PointwiseInitializer. Can optionally take specific `alloc` and `free`
+// of type PointwiseInitializer. Additionally takes a `shapeAlloc` array which
+// is used instead of `shape` to allocate "more aligned" data and compute the
+// corresponding strides.
+// Can optionally take specific alloc and free functions.
+//
+// Example:
+// When called with `shape = [128, 127]` and `shapeAlloc = [128, 128]`, this
+// allocates a memref with `128*128*sizeof(T)` bytes, `sizes = [128, 127]` and
+// `strides = [128, 1]`.
+template <typename T, int N, typename FreeFunType = decltype(&::free)>
+std::unique_ptr<StridedMemRefType<T, N>, FreeFunType>
+makeInitializedStridedMemRefDescriptor(
+    const std::array<int64_t, N> &shape,
+    const std::array<int64_t, N> &shapeAlloc, LinearInitializer<T> init,
+    llvm::Optional<uint64_t> alignment = llvm::Optional<uint64_t>(),
+    AllocFunType allocFun = &::malloc, FreeFunType freeFun = &::free) {
+  for (unsigned i = 0; i < N; ++i)
+    assert(shape[i] <= shapeAlloc[i] &&
+           "shapeAlloc must be greater than or equal to shape");
+  int64_t nElements = 1;
+  for (int64_t s : shapeAlloc) nElements *= s;
+  auto allocated = allocAligned<T>(nElements, allocFun, alignment);
+  auto *data = static_cast<T *>(allocated.first);
+  auto *alignedData = static_cast<T *>(allocated.second);
+  for (unsigned i = 0; i < nElements; ++i) init(i, alignedData);
+  return std::unique_ptr<StridedMemRefType<T, N>, FreeFunType>(
+      detail::makeStridedMemRefDescriptor<T, N>(data, alignedData, shape,
+                                                shapeAlloc, allocFun),
+      freeFun);
+}
+
+// Entry point to allocate a dense buffer with a given `shape` and initializer
+// of type PointwiseInitializer. Can optionally take specific alloc and free
 // functions.
 template <typename T, int N, typename FreeFunType = decltype(&::free)>
 std::unique_ptr<StridedMemRefType<T, N>, FreeFunType>
 makeInitializedStridedMemRefDescriptor(
     const std::array<int64_t, N> &shape, LinearInitializer<T> init,
     llvm::Optional<uint64_t> alignment = llvm::Optional<uint64_t>(),
-    AllocFunType alloc = &::malloc, FreeFunType freeFun = &::free) {
-  int64_t nElements = 1;
-  for (int64_t s : shape) nElements *= s;
-  auto allocated = allocAligned<T>(nElements, alloc, alignment);
-  auto *data = static_cast<T *>(allocated.first);
-  auto *alignedData = static_cast<T *>(allocated.second);
-  for (unsigned i = 0; i < nElements; ++i) init(i, alignedData);
-  return std::unique_ptr<StridedMemRefType<T, N>, FreeFunType>(
-      detail::makeStridedMemRefDescriptor<T, N>(data, alignedData, shape,
-                                                alloc),
-      freeFun);
+    AllocFunType allocFun = &::malloc, FreeFunType freeFun = &::free) {
+  return makeInitializedStridedMemRefDescriptor<T, N>(
+      shape, shape, init, alignment, allocFun, freeFun);
 }
 
 }  // namespace mlir
diff --git a/experimental/ModelBuilder/ModelBuilder.cpp b/experimental/ModelBuilder/ModelBuilder.cpp
index a4a9062..706f419 100644
--- a/experimental/ModelBuilder/ModelBuilder.cpp
+++ b/experimental/ModelBuilder/ModelBuilder.cpp
@@ -235,8 +235,7 @@
     func = builder.create<FuncOp>(
         module.getLoc(), functionName,
         FunctionType::get(SmallVector<Type, 4>(values.getTypes()), returnTypes,
-                          builder.getContext()),
-        ArrayRef<NamedAttribute>{});
+                          builder.getContext()));
   }
   return std_call(builder.getSymbolRefAttr(func), returnTypes, values);
 }
diff --git a/experimental/ModelBuilder/VulkanWrapperPass.cpp b/experimental/ModelBuilder/VulkanWrapperPass.cpp
index b20d2fc..c9ba950 100644
--- a/experimental/ModelBuilder/VulkanWrapperPass.cpp
+++ b/experimental/ModelBuilder/VulkanWrapperPass.cpp
@@ -95,10 +95,9 @@
   vulkanLaunchTypes.insert(vulkanLaunchTypes.end(), args.begin(), args.end());
 
   // Declare vulkan launch function.
-  builder.create<FuncOp>(
-      loc, kVulkanLaunch,
-      FunctionType::get(vulkanLaunchTypes, ArrayRef<Type>{}, loc->getContext()),
-      ArrayRef<NamedAttribute>{});
+  builder.create<FuncOp>(loc, kVulkanLaunch,
+                         FunctionType::get(vulkanLaunchTypes, ArrayRef<Type>{},
+                                           loc->getContext()));
   return success();
 }
 
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index f2c6151..750aaba 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -18,9 +18,15 @@
 # pylint: disable=protected-access
 # pylint: disable=unsupported-assignment-operation
 
-import collections
+# This file uses the following abbreviations:
+#   ref: reference – for the reference CompiledModule
+#   tar: target - for one of the target CompiledModules
+
+import copy
+import inspect
 import os
-import re
+import sys
+import tempfile
 
 from absl import flags
 from absl import logging
@@ -29,244 +35,329 @@
 from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
 
+flags.DEFINE_string("reference_backend", "tf",
+                    "The backend to treat as a source of truth.")
 flags.DEFINE_string("target_backends", None,
                     "Explicit comma-delimited list of target backends.")
 flags.DEFINE_string(
-    "debug_dir", None,
-    "Specifies a directory to dump debug artifacts to. Defaults to "
-    "--test_tmpdir")
+    "artifacts_dir", None,
+    "Specifies a directory to dump compilation artifacts and traces to. "
+    "Defaults to the OS's tempdir.")
+flags.DEFINE_bool(
+    "summarize", True,
+    "Summarize the inputs and outputs of each module trace logged to disk.")
 FLAGS = flags.FLAGS
+NUMPY_LINEWIDTH = 120
 
 
-def _setup_test_debug_dir(test_name):
-  global global_debug_dir
+def _setup_artifacts_dir(module_name):
+  parent_dir = FLAGS.artifacts_dir
+  if parent_dir is None:
+    parent_dir = os.path.join(tempfile.gettempdir(), "iree", "modules")
+  artifacts_dir = os.path.join(parent_dir, module_name)
+  logging.info("Saving compilation artifacts and traces to '%s'", artifacts_dir)
 
-  # Use test_tempdir (which defaults to '/tmp/absl_testing/') if FLAGS.debug_dir
-  # is not provided.
-  parent = FLAGS.debug_dir if FLAGS.debug_dir is not None else FLAGS.test_tmpdir
-  global_debug_dir = os.path.join(parent, test_name)
-
-  # Create the directory.
+  # If the artifacts already exist then we overwrite/update them.
   try:
-    os.makedirs(global_debug_dir)
+    # Use try/except instead of os.path.exists to address a race condition
+    # between multiple tests targets.
+    os.makedirs(artifacts_dir)
   except IOError:
-    logging.exception("Error creating debug dir for: %s", global_debug_dir)
+    pass
+  return artifacts_dir
 
 
-class _VirtualModuleInstance(object):
-  """Wraps a namedtuple of modules and represents a union of them."""
+def _parse_target_backends(target_backends):
+  """Decodes a comma-delimited string of backends into BackendInfo objects."""
+  backends = []
+  for backend_name in target_backends.split(","):
+    if backend_name not in tf_utils.BackendInfo.ALL.keys():
+      raise ValueError(
+          "Invalid backend specification string '{}', unexpected name '{}';"
+          " valid names are '{}'".format(target_backends, backend_name,
+                                         tf_utils.BackendInfo.ALL.keys()))
+    backends.append(tf_utils.BackendInfo.ALL[backend_name])
+  return backends
 
-  def __init__(self, named_modules, match_spec):
-    self._named_modules = named_modules
-    self._match_spec = match_spec
+
+def get_target_backends():
+  """Gets the BackendInfo instances to compare with the reference backend.
+
+  By default all backends in BackendInfo will be used. Specific backends to
+  run on can be specified using the `--target_backends` flag.
+
+  Returns:
+    Sequence of BackendInfo that should be used.
+  """
+  if FLAGS.target_backends is not None:
+    logging.info("Using backends from command line: %s", FLAGS.target_backends)
+    backends = _parse_target_backends(FLAGS.target_backends)
+  else:
+    # If no backends are specified, use them all.
+    backends = list(tf_utils.BackendInfo.ALL.values())
+  return backends
+
+
+def _indent(input_str, indentation=2):
+  """Indents a string by the specified number of spaces, defaulting to 2."""
+  spaces = " " * indentation
+  lines = input_str.split("\n")
+  # Prepend spaces to each non-empty line.
+  lines = [f"{spaces}{line}" if len(line) else line for line in lines]
+  return "\n".join(lines)
+
+
+class ModuleCall:
+
+  def __init__(self, method_name, inputs, outputs, rtol=1e-6, atol=1e-6):
+    """Records the details of a call to a CompiledModule."""
+    self.method = method_name
+
+    # Deepcopy to safegard against mutation.
+    self.inputs = copy.deepcopy(inputs)
+    if outputs is not None:
+      outputs = copy.deepcopy(outputs)
+    else:
+      outputs = tuple()
+    self.outputs = outputs if isinstance(outputs, tuple) else (outputs,)
+
+    self.rtol = rtol
+    self.atol = atol
+
+  def get_tolerances(self):
+    """Gets the floating point tolerances associated with this call."""
+    return self.rtol, self.atol
+
+  def __str__(self):
+    prior_printoptions = np.get_printoptions()
+    np.set_printoptions(linewidth=NUMPY_LINEWIDTH)
+
+    header = f"Method: {self.method}"
+    inputs = "\n".join(_indent(str(value)) for value in self.inputs)
+    outputs = "\n".join(_indent(str(value)) for value in self.outputs)
+    tolerances = _indent(f"rtol={self.rtol}, atol={self.atol}")
+    body = f"Inputs:\n{inputs}\nOutputs:\n{outputs}\nTolerances:\n{tolerances}"
+    result = f"{header}\n{_indent(body)}"
+
+    np.set_printoptions(**prior_printoptions)
+    return result
+
+
+class Trace:
+  """Stores the inputs and outputs of a series of calls to a module."""
+
+  def __init__(self, module, function):
+    """Extracts metadata from module and function and initializes.
+
+    Example usage:
+      def forward_pass(...):
+        ...
+      module = IreeCompiledModule(...)
+      trace = Trace(module, forward_pass)
+      forward_pass(TracedModule(module, trace))
+
+    Args:
+      module: the module who's outputs this trace will record.
+      function: the function that module will be traced on.
+    """
+    # Extract metadata from module and function.
+    self.module_name = module.module_name
+    self.backend = module.backend
+    self.function_name = function.__name__
+    self.function_sourcefile = inspect.getsourcefile(function)
+    source, start_line = inspect.getsourcelines(function)
+    self.function_line_numbers = (start_line, start_line + len(source))
+    self.function_source = "".join(source)
+
+    self.calls = []
+
+  def __str__(self):
+    header = (f"Trace of {self.module_name} compiled to '{self.backend}' "
+              f"on function '{self.function_name}':")
+    # Give each call a number so it's easier to compare between multiple traces.
+    calls = [f"{i + 1}. {str(call)}" for i, call in enumerate(self.calls)]
+    calls = _indent("\n".join(calls))
+    return f"{header}\n{calls}"
+
+  def __iter__(self):
+    for call in self.calls:
+      yield call
+
+  @staticmethod
+  def compare_traces(ref_trace, tar_trace):
+    traces_match = True
+
+    # Check that all method invocations match.
+    ref_methods = [(call.method, call.rtol, call.atol) for call in ref_trace]
+    tar_methods = [(call.method, call.rtol, call.atol) for call in tar_trace]
+    if ref_methods != tar_methods:
+      # Raise a ValueError instead of returning False since this is an
+      # unexpected error.
+      raise ValueError(
+          "The reference and target traces have different call structures:\n"
+          f"Reference: {ref_methods}\nTarget:    {tar_methods}")
+
+    for ref_call, tar_call in zip(ref_trace, tar_trace):
+      logging.info("Comparing calls to '%s'", ref_call.method)
+      rtol, atol = ref_call.get_tolerances()
+
+      inputs_match = Trace._check_same(ref_call.inputs, tar_call.inputs, rtol,
+                                       atol)
+      if not inputs_match:
+        logging.error("Inputs did not match.")
+      outputs_match = Trace._check_same(ref_call.outputs, tar_call.outputs,
+                                        rtol, atol)
+      if not outputs_match:
+        logging.error("Outputs did not match.")
+      calls_match = inputs_match and outputs_match
+
+      if not calls_match:
+        logging.error("Comparision between '%s' and '%s' failed on method '%s'",
+                      ref_trace.backend, tar_trace.backend, ref_call.method)
+        logging.error("Reference call '%s':\n%s", ref_trace.backend, ref_call)
+        logging.error("Target call '%s':\n%s", tar_trace.backend, tar_call)
+
+      traces_match = traces_match and calls_match
+    return traces_match
+
+  @staticmethod
+  def _check_same(ref, tar, rtol, atol):
+    """Checks that ref and tar have identical datastructures and values."""
+    # Check for matching types.
+    if not isinstance(tar, type(ref)):
+      logging.error(
+          "Expected ref and tar to have the same type but got '%s' and '%s'",
+          type(ref), type(tar))
+      return False
+
+    if ref is None:
+      # Nothing to compare (e.g. the called method had no outputs).
+      return True
+
+    # Recursive check for dicts.
+    if isinstance(ref, dict):
+      if ref.keys() != tar.keys():
+        logging.error(
+            "Expected ref and tar to have the same keys, but got '%s' and '%s'",
+            ref.keys(), tar.keys())
+        return False
+      # Check that all of the dictionaries' values are the same.
+      for key in ref:
+        if not Trace._check_same(ref[key], tar[key], rtol, atol):
+          return False
+
+    # Recursive check for iterables.
+    elif isinstance(ref, list) or isinstance(ref, tuple):
+      if len(ref) != len(tar):
+        logging.error(
+            "Expected ref and tar to have the same length, but got %s and %s",
+            len(ref), len(tar))
+        return False
+      # Check that all of the iterables' values are the same.
+      for i in range(len(ref)):
+        if not Trace._check_same(ref[i], tar[i], rtol, atol):
+          return False
+
+    # Base check for numpy arrays.
+    elif isinstance(ref, np.ndarray):
+      if ref.dtype != tar.dtype:
+        logging.error(
+            "Expected ref and tar to have the same dtype, but got %s  and %s",
+            ref.dtype, tar.dtype)
+        return False
+      if np.issubdtype(ref.dtype, np.floating):
+        return np.allclose(ref, tar, rtol=rtol, atol=atol)
+      else:
+        return np.array_equal(ref, tar)
+
+    # Base check for native number types.
+    elif isinstance(ref, (int, float)):
+      return ref == tar
+
+    # If outputs end up here then an extra branch for that type should be added.
+    else:
+      raise TypeError(f"Encountered results with unexpected type {type(ref)}")
+    return True
+
+  def _get_trace_dir(self, artifacts_dir):
+    trace_dir = os.path.join(artifacts_dir, "traces")
+    if not os.path.exists(trace_dir):
+      os.makedirs(trace_dir)
+    return trace_dir
+
+  def save_plaintext(self, artifacts_dir, summarize=True):
+    """Saves a human-readable string representation of this trace to disk.
+
+    Args:
+      artifacts_dir: the base directory to save the trace in.
+      summarize: a bool controlling whether numpy should summarize the inputs
+        and outputs if they're large. Setting this to False is very slow for
+        large outputs.
+    """
+    prior_printoptions = np.get_printoptions()
+    np.set_printoptions(
+        linewidth=NUMPY_LINEWIDTH,
+        threshold=None if summarize else sys.maxsize,
+        edgeitems=10)  # Can show more items since they won't clutter the logs.
+
+    trace_dir = self._get_trace_dir(artifacts_dir)
+    path = os.path.join(trace_dir, f"{self.function_name}__{self.backend}.txt")
+    with open(path, "w") as f:
+      f.write(str(self))
+      f.write("\n")
+
+    np.set_printoptions(**prior_printoptions)
+
+
+class TracedModule:
+
+  def __init__(self, module, trace):
+    """Wraps a CompiledModule so that all inputs and outputs are traced.
+
+    The TracedModule returned will have an API almost identical to that of the
+    passed CompiledModule. The only changes is that if the keywords `rtol` or
+    `atol` are passed to one of the CompiledModule's methods, then they will be
+    used to set the tolerance for comparing that call to the same call in
+    another trace. So for example, calling `traced_module.add(a, b rtol=1e-8)`
+    would be the same as calling `module.add(a, b)`.
+
+    Args:
+      module: the CompiledModule to trace.
+      trace: the Trace to record calls to this module with.
+    """
+    self._module = module
+    self._trace = trace
+
+  def _trace_call(self, method, method_name):
+    """Decorates a CompiledModule method to capture its inputs and outputs."""
+
+    def call(*args, **kwargs):
+      # Pop manually specified tolerances from the kwargs (if any).
+      tolerances = {}
+      tolerances["rtol"] = kwargs.pop("rtol", None)
+      tolerances["atol"] = kwargs.pop("atol", None)
+      # Only pass these to ModuleCall if they were specified by the user.
+      tolerances = {k: v for k, v in tolerances.items() if v is not None}
+
+      # Run the method and record the details of the call.
+      outputs = method(*args, **kwargs)
+      self._trace.calls.append(
+          ModuleCall(method_name, args, outputs, **tolerances))
+      return outputs
+
+    return call
 
   def __getattr__(self, attr):
-    match_modules = {
-        k: v
-        for k, v in self._named_modules.items()
-        if re.search(self._match_spec, k)
-    }
-    if not match_modules:
-      raise AttributeError(
-          "Module match spec '%s' did not match anything. (Have %r)" %
-          (self._match_spec, self._named_modules.keys()))
-    # Resolve functions on each.
-    match_functions = {}
-    for backend, module in match_modules.items():
-      try:
-        match_functions[backend] = getattr(module, attr)
-      except:
-        raise AttributeError(
-            "Could not resolve function '%s' on backend module '%s'" %
-            (attr, backend))
-    return _VirtualFunctionWrapper(match_functions)
-
-
-class _VirtualFunctionWrapper(object):
-  """Wrapper around a virtual dict of functions."""
-
-  def __init__(self, backend_function_dict):
-    self._backend_function_dict = backend_function_dict
-
-  def __call__(self, *args, **kwargs):
-    all_results = {
-        backend: f(*args, **kwargs)
-        for backend, f in self._backend_function_dict.items()
-    }
-    # Turn it into a named tuple so we get nice class-like access to it.
-    results_tuple_class = collections.namedtuple("Results", all_results.keys())
-    return _make_multi_result_class(results_tuple_class)(*all_results.values())
-
-
-def _recursive_check_same(result_ref, result_tgt, rtol=1e-6, atol=1e-6):
-  same = True
-  if not isinstance(result_tgt, type(result_ref)):
-    raise ValueError("Types of the outputs must be the same, but have '{}' and "
-                     "'{}'".format(type(result_ref), type(result_tgt)))
-  if isinstance(result_ref, dict):
-    if result_ref.keys() != result_tgt.keys():
-      raise ValueError("Outputs must have the same structure, but have '{}' and"
-                       " '{}'".format(result_ref.keys(), result_tgt.keys()))
-    for key in result_ref.keys():
-      same = same and _recursive_check_same(result_ref[key], result_tgt[key],
-                                            rtol, atol)
-      if not same:
-        return False  # no need to go further they are different
-  elif isinstance(result_ref, list):
-    if len(result_ref) != len(result_tgt):
-      raise ValueError("Outputs must have the same structure, but have '{}' and"
-                       " '{}'".format(result_ref, result_tgt))
-    for i in range(len(result_ref)):
-      same = same and _recursive_check_same(result_ref[i], result_tgt[i], rtol,
-                                            atol)
-      if not same:
-        return False  # no need to go further they are different
-  elif isinstance(result_ref, np.ndarray):
-    if isinstance(result_ref.flat[0], np.floating):
-      return np.allclose(result_ref, result_tgt, rtol=rtol, atol=atol)
+    # Try to resolve it as an attr on self._module.
+    if not hasattr(self._module, attr):
+      raise AttributeError(f"The compiled module does not have attr '{attr}'")
+    module_attr = getattr(self._module, attr)
+    if not hasattr(module_attr, "__call__"):
+      # e.g. traced_module.backend
+      return module_attr
     else:
-      return np.array_equal(result_ref, result_tgt)
-  else:
-    # this one need more checks
-    return result_ref == result_tgt
-  return same
-
-
-def _collect_disagreements_recursively(mr, rtol=1e-6, atol=1e-6):
-  """Compare result structs recursively and search for disagreements.
-
-  Args:
-    mr: A MultiResults namedtuple where each entry corresponds to a backend set
-      of results.
-    rtol: The relative tolerance parameter.
-    atol: The absolute tolerance parameter.
-
-  Returns:
-    An equivalent MultiResults where each entry is an array of result names
-    that disagree.
-  """
-  has_disagreement = False
-  disagreement_list = [list() for _ in mr]
-  for i in range(len(mr)):
-    result_ref = mr[i]
-    for j in range(len(mr)):
-      if i < j:
-        continue  # Don't check self and reverse comparisons
-      result_tgt = mr[j]
-      if not _recursive_check_same(result_ref, result_tgt, rtol, atol):
-        has_disagreement = True
-        disagreement_list[i].append(mr._fields[j])
-  disagreements_tuple = collections.namedtuple("Disagreements", mr._fields)
-  return has_disagreement, disagreements_tuple(*disagreement_list)
-
-
-def _collect_disagreements(mr, predicate):
-  """Verifies that result structs.
-
-  Args:
-    mr: A MultiResults namedtuple where each entry corresponds to a backend set
-      of results.
-    predicate: A predicate function which takes (a, b) and returns whether they
-      should be considered equivalent.
-
-  Returns:
-    An equivalent MultiResults where each entry is an array of result names
-    that disagree.
-  """
-  has_disagreement = False
-  disagreement_list = [list() for _ in mr]
-  for i in range(len(mr)):
-    result_ref = mr[i]
-    for j in range(len(mr)):
-      if i == j:
-        continue  # Don't check self.
-      result_tgt = mr[j]
-      if not predicate(result_ref, result_tgt):
-        has_disagreement = True
-        disagreement_list[i].append(mr._fields[j])
-  disagreements_tuple = collections.namedtuple("Disagreements", mr._fields)
-  return has_disagreement, disagreements_tuple(*disagreement_list)
-
-
-def _make_multi_result_class(named_tuple_class):
-  """Makes a class that wraps a mapping of backend results."""
-
-  class MultiResults(named_tuple_class):
-    """Wraps a mapping of results."""
-
-    def assert_all_close(self, rtol=1e-6, atol=1e-6):
-      predicate = (lambda a, b: np.allclose(a, b, rtol=rtol, atol=atol))
-      has_disagreement, disagreements = _collect_disagreements(self, predicate)
-      assert not has_disagreement, ("Multiple backends disagree (%r):\n%r" %
-                                    (disagreements, self))
-      return self
-
-    def assert_all_equal(self):
-      predicate = np.array_equal
-      has_disagreement, disagreements = _collect_disagreements(self, predicate)
-      assert not has_disagreement, ("Multiple backends disagree (%r):\n%r" %
-                                    (disagreements, self))
-      return self
-
-    def assert_all_close_and_equal(self, rtol=1e-6, atol=1e-6):
-      # it is a special case when output can be a nestet map of dict(), list()
-      # with different types: int, float, string
-      # in this case int and string must be equal and for float we use rtol,atol
-      has_disagreement, disagreements = _collect_disagreements_recursively(
-          self, rtol, atol)
-      assert not has_disagreement, ("Multiple backends disagree (%r):\n%r" %
-                                    (disagreements, self))
-      return self
-
-    def print(self):
-      print(self)
-      return self
-
-    def save(self):
-      for i in range(len(self)):
-        result = self[i]  # output generated by a model
-        field = self._fields[i]  # backend name
-        fname = os.path.join(global_debug_dir, "output_{}".format(field))
-        with open(fname, "w") as file:
-          # content of txt file can be converted to py objects by eval(txt)
-          file.write(str(result))
-      return self
-
-  return MultiResults
-
-
-def _instantiate_backends(compiled_backends):
-  """Creates a VirtualBackend namedtuple class for a dict.
-
-  Args:
-    compiled_backends: Dictionary of backend_name:ModuleInstance.
-
-  Returns:
-    a VirtualBackendsClass instance. The VirtualBackendsClass is a dynamically
-    generated namedtuple mapping backend_name:ModuleInstance, where the
-    ModuleInstance allows attribute resolution of public functions on the
-    module. The VirtualBackendsClass also contributes some convenience methods
-    for selecting all or a subset of matching backend modules.
-  """
-  tuple_class = collections.namedtuple("VirtualBackendsTuple",
-                                       compiled_backends.keys())
-
-  class VirtualBackendsClass(tuple_class):
-    """Adds a __call__ method that creates a virtual module."""
-
-    def multi(self, match_spec="."):
-      """Selects multiple backends that match a regular expression."""
-      return _VirtualModuleInstance(self._asdict(), match_spec)
-
-    @property
-    def all(self):
-      """Shorthand for multi() which selects all backends."""
-      return self.multi()
-
-  reinitialized_modules = [
-      module.create_reinitialized() for module in compiled_backends.values()
-  ]
-  return VirtualBackendsClass(*reinitialized_modules)
+      # e.g. traced_module.simple_mul(a, b)
+      return self._trace_call(module_attr, method_name=attr)
 
 
 def compile_module(module_class, exported_names=()):
@@ -288,10 +379,10 @@
 
   def decorator(cls):
     """Decorator Function."""
-    if not issubclass(cls, CompiledModuleTestCase):
+    if not issubclass(cls, TracedModuleTestCase):
       logging.exception(
           "The 'compile_module' decorator must be applied to a "
-          "CompiledModuleTestCase derived class, which %s is not.", cls)
+          "TracedModuleTestCase derived class, which %s is not.", cls)
     cls._module_class = module_class
     cls._exported_names = exported_names
     return cls
@@ -299,90 +390,108 @@
   return decorator
 
 
-def _parse_target_backends(target_backends):
-  """Decodes a comma-delimited string of backends into BackendInfo objects."""
-  backends = []
-  for backend_name in target_backends.split(","):
-    if backend_name not in tf_utils.BackendInfo.ALL.keys():
-      raise ValueError(
-          "Invalid backend specification string '{}', unexpected name '{}';"
-          " valid names are '{}'".format(target_backends, backend_name,
-                                         tf_utils.BackendInfo.ALL.keys()))
-    backends.append(tf_utils.BackendInfo.ALL[backend_name])
-  return backends
-
-
-def get_backends():
-  """Gets the BackendInfo instances to test.
-
-  By default all backends in BackendInfo will be used. Specific backends to
-  run on can be specified using the `--target_backends` flag. If only "tf" is
-  provided then it will be compared against itself.
-
-  Returns:
-    Sequence of BackendInfo that should be used.
-  """
-  if FLAGS.target_backends is not None:
-    logging.info("Using backends from command line: %s", FLAGS.target_backends)
-    backends = _parse_target_backends(FLAGS.target_backends)
-    # If tf is the only backend then we will test it itself by adding tf_also.
-    if len(backends) == 1 and "tf" == backends[0].name:
-      backends.append(tf_utils.BackendInfo.ALL["tf_also"])
-  else:
-    # If no backends are specified, use them all.
-    backends = list(tf_utils.BackendInfo.ALL.values())
-  return backends
-
-
-class CompiledModuleTestCase(tf.test.TestCase):
+class TracedModuleTestCase(tf.test.TestCase):
   """Compiles a tf.Module to multiple backends to test their correctness."""
-
   # Will be initialized by the @compile_module decorator.
   _module_class = None
   _exported_names = ()
 
-  # Will be initialized in setUpClass to a dict of
-  # {backend_name: CompiledModule}.
-  _compiled_backends_dict = None
+  # Will be initialized in setUpClass.
+  _ref_module = None
+  _tar_modules = None
+
+  @classmethod
+  def _compile(cls, backend_info):
+    return backend_info.CompiledModule(cls._module_class, backend_info,
+                                       cls._exported_names, cls._artifacts_dir)
 
   @classmethod
   def setUpClass(cls):
+    # Ran before any of the unit tests.
     super().setUpClass()
     if cls._module_class is None:
       raise AttributeError(
           "setUpClass was called but no module was specified. Specify a module "
           "to compile via the @tf_test_utils.compile_module decorator.")
 
-    # Setup the debug directory for this test. Creates a global variable
-    # `global_debug_dir`.
-    _setup_test_debug_dir(test_name=cls.__name__)
+    # Setup the directory for saving compilation artifacts and traces.
+    cls._artifacts_dir = _setup_artifacts_dir(cls._module_class.__name__)
 
     # Setup crash reproducer for the test.
-    crash_reproducer_path = os.path.join(global_debug_dir, "reproducer.mlir")
+    crash_reproducer_path = os.path.join(cls._artifacts_dir, "reproducer.mlir")
     compiler.Context.default_crash_reproducer_path = crash_reproducer_path
 
-    # Create a CompiledModule for each backend.
+    # Create a CompiledModule for the reference backend and each target backend.
     try:
-      backends = get_backends()
-      cls._compiled_backends_dict = {}
-      for backend_info in backends:
-        compiled_backend = backend_info.CompiledModule(cls._module_class,
-                                                       backend_info,
-                                                       cls._exported_names,
-                                                       global_debug_dir)
-        cls._compiled_backends_dict[backend_info.name] = compiled_backend
+      ref_backend_info = tf_utils.BackendInfo.ALL[FLAGS.reference_backend]
+      cls._ref_module = cls._compile(ref_backend_info)
+
+      tar_backend_infos = get_target_backends()
+      cls._tar_modules = [
+          cls._compile(backend_info) for backend_info in tar_backend_infos
+      ]
     finally:
+      # TODO(meadowlark): Move this into tf_util.compile_tf_module to prevent
+      # overwritting `reproducer.mlir`.
       # Disable crash reproducer (to avoid inadvertently overwriting this
-      # path on a subsequent interaction).
+      # path if there are multiple TestCases in the same file).
       compiler.Context.default_crash_reproducer_path = None
 
+  def setUp(self):
+    # Ran before each unit test.
+    super().setUp()
+    self._ref_module.create_reinitialized()
+    self._tar_modules = [
+        module.create_reinitialized() for module in self._tar_modules
+    ]
+
+  def compare_backends(self, trace_function):
+    """Run the reference and target backends on trace_function and compare them.
+
+    Random seeds for tensorflow, numpy and python are set before each invocation
+    of trace_function.
+
+    Args:
+      trace_function: a function accepting a TracedModule as its argument.
+    """
+    # Create Traces for each backend.
+    ref_trace = Trace(self._ref_module, trace_function)
+    tar_traces = [Trace(module, trace_function) for module in self._tar_modules]
+
+    # Run the traces through trace_function with their associated modules.
+    tf_utils.set_random_seed()
+    trace_function(TracedModule(self._ref_module, ref_trace))
+    for module, trace in zip(self._tar_modules, tar_traces):
+      tf_utils.set_random_seed()
+      trace_function(TracedModule(module, trace))
+
+    # Compare each target trace of trace_function with the reference trace.
+    failed_backend_indices = []
+    for i, tar_trace in enumerate(tar_traces):
+      logging.info("Comparing the reference backend '%s' with '%s'",
+                   ref_trace.backend, tar_trace.backend)
+      traces_match = Trace.compare_traces(ref_trace, tar_trace)
+      if not traces_match:
+        failed_backend_indices.append(i)
+
+    # Save the results to disk before validating.
+    ref_trace.save_plaintext(self._artifacts_dir, FLAGS.summarize)
+    for tar_trace in tar_traces:
+      tar_trace.save_plaintext(self._artifacts_dir, FLAGS.summarize)
+
+    # Validate results.
+    if failed_backend_indices:
+      # Extract info for logging.
+      failed_backends = [tar_traces[i].backend for i in failed_backend_indices]
+      failure_info = (
+          "Comparision between the reference backend and the following targets "
+          f"failed: {failed_backends}. The errors above show the inputs and "
+          "outputs the non-matching calls.")
+
+      # This condition is always True, but is useful for context in the logs.
+      self.assertEmpty(failed_backends, failure_info)
+
   @classmethod
   def tearDownClass(cls):
+    # Ran after all unit tests are completed.
     super().tearDownClass()
-
-  def setUp(self):
-    super().setUp()
-    self.compiled_modules = _instantiate_backends(self._compiled_backends_dict)
-
-  def get_module(self):
-    return self.compiled_modules.all
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
index 20ba522..f21521a 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
@@ -17,9 +17,32 @@
 from absl.testing import parameterized
 import numpy as np
 from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
 import tensorflow as tf
 
 
+class StatefulCountingModule(tf.Module):
+
+  def __init__(self):
+    self.count = tf.Variable([0.])
+
+  @tf.function(input_signature=[])
+  def increment(self):
+    self.count.assign_add(tf.constant([1.]))
+
+  @tf.function(input_signature=[])
+  def get_count(self):
+    return self.count
+
+  @tf.function(input_signature=[tf.TensorSpec([1])])
+  def increment_by(self, value):
+    self.count.assign_add(value)
+
+  @tf.function(input_signature=[])
+  def decrement(self):
+    self.count.assign_sub(tf.constant([1.]))
+
+
 class UtilsTests(tf.test.TestCase, parameterized.TestCase):
 
   @parameterized.named_parameters([
@@ -28,31 +51,31 @@
           'array_c': np.array([0, 1, 2]),
           'array_d': np.array(['0', '1', '2']),
           'array_e': np.array([0.0, 0.1, 0.2]),
-          'tgt_same': True,
+          'tar_same': True,
       },
       {
           'testcase_name': 'wrong int',
           'array_c': np.array([1, 1, 2]),
           'array_d': np.array(['0', '1', '2']),
           'array_e': np.array([0.0, 0.1, 0.2]),
-          'tgt_same': False,
+          'tar_same': False,
       },
       {
           'testcase_name': 'wrong string',
           'array_c': np.array([0, 1, 2]),
           'array_d': np.array(['a', '1', '2']),
           'array_e': np.array([0.0, 0.1, 0.2]),
-          'tgt_same': False,
+          'tar_same': False,
       },
       {
           'testcase_name': 'wrong float',
           'array_c': np.array([0, 1, 2]),
           'array_d': np.array(['0', '1', '2']),
           'array_e': np.array([1.0, 0.1, 0.2]),
-          'tgt_same': False,
+          'tar_same': False,
       },
   ])
-  def test_recursive_check_same(self, array_c, array_d, array_e, tgt_same):
+  def test_recursive_check_same(self, array_c, array_d, array_e, tar_same):
 
     ref = {
         'a':
@@ -65,7 +88,7 @@
             'e': np.array([0.0, 0.1, 0.2])
         }],
     }
-    tgt = {
+    tar = {
         'a': 1,
         'b': [{
             'c': array_c
@@ -75,8 +98,74 @@
             'e': array_e
         }],
     }
-    same = tf_test_utils._recursive_check_same(ref, tgt)
-    self.assertEqual(tgt_same, same)
+    same = tf_test_utils.Trace._check_same(ref, tar, rtol=1e-6, atol=1e-6)
+    self.assertEqual(tar_same, same)
+
+  def test_trace_inputs_and_outputs(self):
+
+    def trace_function(module):
+      # No inputs or outpus
+      module.increment()
+      # Only inputs
+      module.increment_by(np.array([81.], dtype=np.float32))
+      # Only outputs
+      module.get_count()
+
+    module = tf_utils.TfCompiledModule(StatefulCountingModule,
+                                       tf_utils.BackendInfo.ALL['tf'])
+    trace = tf_test_utils.Trace(module, trace_function)
+    trace_function(tf_test_utils.TracedModule(module, trace))
+
+    self.assertIsInstance(trace.calls[0].inputs, tuple)
+    self.assertEmpty(trace.calls[0].inputs)
+    self.assertIsInstance(trace.calls[0].outputs, tuple)
+    self.assertEmpty(trace.calls[0].outputs)
+
+    self.assertAllClose(trace.calls[1].inputs[0], [81.])
+    self.assertAllClose(trace.calls[2].outputs[0], [82.])
+
+  def test_nonmatching_methods(self):
+
+    def tf_function(module):
+      module.increment()
+      module.increment()
+
+    def vmla_function(module):
+      module.increment()
+      module.decrement()
+
+    tf_module = tf_utils.TfCompiledModule(StatefulCountingModule,
+                                          tf_utils.BackendInfo.ALL['tf'])
+    tf_trace = tf_test_utils.Trace(tf_module, tf_function)
+    tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))
+
+    vmla_module = tf_utils.IreeCompiledModule(
+        StatefulCountingModule, tf_utils.BackendInfo.ALL['iree_vmla'])
+    vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
+    vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))
+
+    with self.assertRaises(ValueError):
+      tf_test_utils.Trace.compare_traces(tf_trace, vmla_trace)
+
+  def test_nonmatching_inputs(self):
+
+    def tf_function(module):
+      module.increment_by(np.array([42.], dtype=np.float32))
+
+    def vmla_function(module):
+      module.increment_by(np.array([22.], dtype=np.float32))
+
+    tf_module = tf_utils.TfCompiledModule(StatefulCountingModule,
+                                          tf_utils.BackendInfo.ALL['tf'])
+    tf_trace = tf_test_utils.Trace(tf_module, tf_function)
+    tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))
+
+    vmla_module = tf_utils.IreeCompiledModule(
+        StatefulCountingModule, tf_utils.BackendInfo.ALL['iree_vmla'])
+    vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
+    vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))
+
+    self.assertFalse(tf_test_utils.Trace.compare_traces(tf_trace, vmla_trace))
 
 
 if __name__ == '__main__':
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index 46a3785..f67119e 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -29,6 +29,8 @@
 from pyiree.tf import compiler
 import tensorflow.compat.v2 as tf
 
+flags.DEFINE_bool("keep_saved_model", False,
+                  "Keep the SavedModel used by compile_tf_module on disk.")
 FLAGS = flags.FLAGS
 
 
@@ -39,6 +41,14 @@
   np.random.seed(seed)
 
 
+def uniform(shape, dtype=np.float32):
+  return np.random.uniform(size=shape).astype(dtype)
+
+
+def ndarange(shape, dtype=np.float32):
+  return np.arange(np.prod(shape), dtype=dtype).reshape(shape)
+
+
 def backends_to_str(target_backends):
   """Creates a flattened and normalized string representing target_backends."""
   normalized_backends = []
@@ -49,6 +59,36 @@
   return "__".join(normalized_backends)
 
 
+def to_mlir_type(dtype):
+  """Returns a string that denotes the type `dtype` in MLIR style."""
+  bits = dtype.itemsize * 8
+  if np.issubdtype(dtype, np.integer):
+    return f"i{bits}"
+  elif np.issubdtype(dtype, np.floating):
+    return f"f{bits}"
+  else:
+    raise TypeError(f"Expected integer or floating type, but got {dtype}")
+
+
+def save_input_values(inputs, artifacts_dir=None):
+  """Saves input values with IREE tools format if `artifacts_dir` is set."""
+  result = []
+  for array in inputs:
+    shape = [str(dim) for dim in list(array.shape)]
+    shape.append(to_mlir_type(array.dtype))
+    shape = "x".join(shape)
+    values = " ".join([str(x) for x in array.flatten()])
+    result.append(f"{shape}={values}")
+  result = "\n".join(result)
+  if artifacts_dir is not None:
+    inputs_path = os.path.join(artifacts_dir, "inputs.txt")
+    logging.info("Saving IREE input values to: %s", inputs_path)
+    with open(inputs_path, "w") as f:
+      f.write(result)
+      f.write("\n")
+  return result
+
+
 def compile_tf_module(tf_module,
                       target_backends=(),
                       exported_names=(),
@@ -120,9 +160,14 @@
     return compiled_module
 
   options = tf.saved_model.SaveOptions(save_debug_info=True)
-  if artifacts_dir is not None:
+  if artifacts_dir is not None and FLAGS.keep_saved_model:
     # Save the saved model alongside the other compilation artifacts.
-    sm_path = os.path.join(artifacts_dir, "saved_model")
+
+    # Create a saved model for these target backends to avoid a race condition
+    # when running a test suite.
+    # TODO(meadowlark): Remove this once we have a TfLiteCompiledModule.
+    sm_path = os.path.join(artifacts_dir,
+                           f"saved_model__{backends_to_str(target_backends)}")
     tf.saved_model.save(tf_module, sm_path, options=options)
     return _compile_from_path(sm_path)
   else:
@@ -142,6 +187,10 @@
     self._exported_names = exported_names
     self._artifacts_dir = artifacts_dir
 
+    # Public attributes:
+    self.backend = self._backend_info.name
+    self.module_name = self._module_class.__name__
+
   def create_reinitialized(self):
     """Duplicates this module with its initial state without recompiling."""
     raise NotImplementedError()
@@ -153,7 +202,7 @@
   def __init__(self,
                module_class,
                backend_info,
-               exported_names=[],
+               exported_names=(),
                artifacts_dir=None,
                _create_reinitialized_args=None):
     """Compile a tf.Module to the target backend in backend_info.
@@ -166,10 +215,12 @@
         module_class's functions to compile. If exported_names is empty all
         functions will be compiled.
       artifacts_dir: an optional path to save compilation artifacts to.
+      _create_reinitialized_args: used internally.
     """
     super().__init__(module_class, backend_info, exported_names, artifacts_dir)
 
     if _create_reinitialized_args is None:
+      set_random_seed()
       self._module_blob = compile_tf_module(
           tf_module=module_class(),
           target_backends=backend_info.iree_compiler_targets,
@@ -222,7 +273,7 @@
   def __init__(self,
                module_class,
                backend_info,
-               exported_names=[],
+               exported_names=(),
                artifacts_dir=None):
     """Wrap a tf.Module in a TFCompiledModule facade.
 
@@ -236,6 +287,7 @@
         effect for this subclass as nothing is compiled.
     """
     super().__init__(module_class, backend_info, exported_names, artifacts_dir)
+    set_random_seed()
     self._tf_module = module_class()
 
   def create_reinitialized(self):
@@ -245,7 +297,7 @@
 
   def __getattr__(self, attr):
     # Try to resolve it as a function.
-    exported = len(self._exported_names) == 0 or attr in self._exported_names
+    exported = not self._exported_names or attr in self._exported_names
     if not hasattr(self._tf_module, attr) or not exported:
       raise AttributeError(f"The TensorFlow module does not have attr '{attr}'")
     f = getattr(self._tf_module, attr)
@@ -261,6 +313,13 @@
   def __init__(self, f):
     self._f = f
 
+  def _convert_to_numpy(self, tensor):
+    result = tensor.numpy()
+    if np.isscalar(result):
+      # convert_to_tensor isn't reversible via .numpy()
+      result = np.array(result)
+    return result
+
   def __call__(self, *args, **kwargs):
     # TensorFlow will auto-convert all inbound args.
     results = self._f(*args, **kwargs)
@@ -270,7 +329,7 @@
     if not isinstance(results, tuple):
       results = (results,)
     return tf.nest.map_structure(
-        lambda t: t.numpy() if isinstance(t, tf.Tensor) else t,
+        lambda t: self._convert_to_numpy(t) if isinstance(t, tf.Tensor) else t,
         *results,
         check_types=False)
 
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
index b1d9adb..cde0a08 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
@@ -21,6 +21,7 @@
 from absl.testing import parameterized
 from pyiree.tf.support import tf_utils
 import tensorflow as tf
+import numpy as np
 
 
 class ConstantModule(tf.Module):
@@ -65,7 +66,6 @@
           artifacts_dir=artifacts_dir)
 
       artifacts_to_check = [
-          'saved_model',
           'tf_input.mlir',
           'iree_input.mlir',
           f'compiled__{tf_utils.backends_to_str(target_backends)}.vmfb',
@@ -100,6 +100,18 @@
     # Test independent state.
     self.assertEqual([1.], module.get_count())
 
+  def test_to_mlir_type(self):
+    self.assertEqual('i8', tf_utils.to_mlir_type(np.dtype('int8')))
+    self.assertEqual('i32', tf_utils.to_mlir_type(np.dtype('int32')))
+    self.assertEqual('f32', tf_utils.to_mlir_type(np.dtype('float32')))
+    self.assertEqual('f64', tf_utils.to_mlir_type(np.dtype('float64')))
+
+  def test_save_input_values(self):
+    inputs = [np.array([1, 2], dtype=np.int32)]
+    self.assertEqual('2xi32=1 2', tf_utils.save_input_values(inputs))
+    inputs = [np.array([1, 2], dtype=np.float32)]
+    self.assertEqual('2xf32=1.0 2.0', tf_utils.save_input_values(inputs))
+
 
 if __name__ == '__main__':
   tf.test.main()
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index d63b839..a3346e4 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -23,7 +23,6 @@
     "INTREE_TENSORFLOW_PY_DEPS",
     "NUMPY_DEPS",
     "iree_py_binary",
-    "iree_py_test",
 )
 load(
     "//integrations/tensorflow/e2e:iree_e2e_test_suite.bzl",
@@ -54,7 +53,6 @@
 # backends.
 # keep sorted
 SPECIAL_CASES = [
-    "explicit_backend_test.py",
     "linspace_test.py",
 ]
 
@@ -84,7 +82,6 @@
 VULKAN_FAILING = [
     "broadcasting_test.py",
     "control_flow_test.py",
-    "conv_test.py",
     "depth_conv_test.py",
     "dynamic_mlp_relu_test.py",
     "dynamic_mlp_test.py",
@@ -93,7 +90,7 @@
     "matrix_ops_test.py",
     "ring_buffer_test.py",  # TODO(b/148747011)
     "scatter_update_test.py",
-    "sliding_window_test.py",  # TODO(#2659)
+    "sliding_window_test.py",  # TODO(#2659) Failing on nvidia, passing on swiftshader.
     "strings_test.py",
 ]
 
@@ -187,20 +184,3 @@
         "//integrations/tensorflow/bindings/python/pyiree/tf/support",
     ],
 )
-
-# This tests explicitly writing which backends to use in Python,
-# so overriding the backends can cause it to break.
-iree_py_test(
-    name = "explicit_backend_test",
-    srcs = ["explicit_backend_test.py"],
-    main = "explicit_backend_test.py",
-    python_version = "PY3",
-    tags = [
-        "driver=llvmjit",
-        "driver=vmla",
-        "driver=vulkan",
-    ],
-    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
-        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
-    ],
-)
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index 04604d8..a213bda 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -21,7 +21,8 @@
 `iree_vulkan` from the list of backends to run the tests on).
 
 The test suites can be run excluding Vulkan by specifying
-`--test_tag_filters="-driver=vulkan"` in the `bazel test` invocation.
+`--test_tag_filters="-driver=vulkan"` in the `bazel test` invocation, or by
+adding `test --test_tag_filters="-driver=vulkan"` to your `user.bazelrc`.
 
 ## Compiling `tf.Module`s
 
@@ -44,6 +45,9 @@
 vmla_module.predict(...)
 ```
 
+By default the TensorFlow SavedModels will not be kept. This can be overridden
+via the `--keep_saved_model` flag.
+
 ## Running tests
 
 For locally running tests and iterating on backend development, `bazel run` is
@@ -53,11 +57,15 @@
 # Run math_test on all backends.
 bazel run :math_test_manual
 
-# Run math_test on the VMLA backend only.
+# Run math_test comparing TensorFlow to itself (e.g. to debug randomization).
+bazel run :math_test_manual -- target_backends=tf
+
+# Run math_test comparing the VMLA backend and TensorFlow.
 bazel run :math_test_manual -- --target_backends=iree_vmla
 
-# Same as above, but add `tf` backend to cross-check numerical correctness.
-bazel run :math_test_manual -- --target_backends=tf,iree_vmla
+# Run math_test comparing the VMLA backend to itself multiple times.
+bazel run :math_test_manual -- \
+  --reference_backend=iree_vmla --target_backends=iree_vmla,iree_vmla
 
 # Run math_test and output on failure.
 bazel test :math_test_manual --test_output=errors
@@ -66,14 +74,43 @@
 bazel run :math_test_manual -- --test_output=streamed
 ```
 
-If you specify the same backend multiple times, for example
-`--target_backends=iree_vmla,iree_vmla`. The same backends are grouped and in
-this example `iree_vmla` will run once. If you specify `tf,iree_vmla` as
-backends, then we will test both backends and compare them with each other. If
-you specify `tf` backend only, then we will also test `tf` vs `tf` to capture
-any model initialization/randomization issues (it is a special case for debug
-purpose). For reproducibility of the unit tests we set random seed of `tf` and
-`numpy` by calling `tf_utils.set_random_seed()` before model creation.
+For reproducibility of the unit tests `CompiledModule()` sets the random seeds
+of `tf`, `numpy` and `python` by calling `tf_utils.set_random_seed()` before
+model creation.
+
+## Writing Tests
+
+Our tests use a class `TracedModule` to capture and store all of the inputs and
+outputs of a `CompiledModule` in a `Trace`. Each unittest on a `TestCase` uses
+the `compare_backends` method. This method runs the function it is passed with a
+`TracedModule` once for each reference and target backend. The inputs and
+outputs to these modules are then checked for correctness, using the reference
+backend as a source of truth. For example:
+
+```python
+# Compile a `tf.Module` named `SimpleArithmeticModule` into a `CompiledModule`.
+@tf_test_utils.compile_module(SimpleArithmeticModule)
+# Inherit from `TracedModuleTestCase`.
+class SimpleArithmeticTest(tf_test_utils.TracedModuleTestCase):
+
+  # Unit test.
+  def test_simple_mul(self):
+
+    # Trace function.
+    def simple_mul(module):
+      # A random seed is automatically set before each call to `simple_mul`.
+      a = tf_utils.uniform([4])
+      b = np.array([400., 5., 6., 7.], dtype=np.float32)
+      # The inputs `a` and `b` are recorded along with the output `c`
+      c = module.simple_mul(a, b)
+      # The inputs `a` and `b` are recorded along with the (unnamed) output
+      # module.simple_mul returns.
+      module.simple_mul(a, c)
+
+    # Calls `simple_mul` once for each backend, recording the inputs and outputs
+    # to `module` and then comparing them.
+    self.compare_backends(simple_mul)
+```
 
 ## Test Suites
 
diff --git a/integrations/tensorflow/e2e/batch_norm_test.py b/integrations/tensorflow/e2e/batch_norm_test.py
index 75de16d..e16436d 100644
--- a/integrations/tensorflow/e2e/batch_norm_test.py
+++ b/integrations/tensorflow/e2e/batch_norm_test.py
@@ -16,6 +16,7 @@
 
 import numpy as np
 from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
 
 
@@ -39,18 +40,20 @@
 
 
 @tf_test_utils.compile_module(BatchNormModule)
-class BatchNormTest(tf_test_utils.CompiledModuleTestCase):
+class BatchNormTest(tf_test_utils.TracedModuleTestCase):
 
   def test_batch_norm_inference(self):
-    np.random.seed(12345)
-    # Note: scaling by a small value to increase numerical stability.
-    x = np.random.random((4, 16)).astype(np.float32) * 1e-3
-    mean = np.random.random((16,)).astype(np.float32) * 1e-3
-    variance = np.random.random((16,)).astype(np.float32) * 1e-3
-    offset = np.random.random((16,)).astype(np.float32) * 1e-3
-    scale = np.random.random((16,)).astype(np.float32) * 1e-3
-    r = self.get_module().batch_norm_inference(x, mean, variance, offset, scale)
-    r.print().assert_all_close()
+
+    def batch_norm_inference(module):
+      # Note: scaling by a small value to increase numerical stability.
+      x = tf_utils.uniform((4, 16)) * 1e-3
+      mean = tf_utils.uniform((16,)) * 1e-3
+      variance = tf_utils.uniform((16,)) * 1e-3
+      offset = tf_utils.uniform((16,)) * 1e-3
+      scale = tf_utils.uniform((16,)) * 1e-3
+      module.batch_norm_inference(x, mean, variance, offset, scale)
+
+    self.compare_backends(batch_norm_inference)
 
 
 if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/broadcasting_test.py b/integrations/tensorflow/e2e/broadcasting_test.py
index 74880bd..c4f8e38 100644
--- a/integrations/tensorflow/e2e/broadcasting_test.py
+++ b/integrations/tensorflow/e2e/broadcasting_test.py
@@ -14,7 +14,9 @@
 # limitations under the License.
 """Test broadcasting support."""
 
+import numpy as np
 from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
 
 
@@ -29,24 +31,35 @@
 
 
 @tf_test_utils.compile_module(BroadcastingModule)
-class BroadcastingTest(tf_test_utils.CompiledModuleTestCase):
+class BroadcastingTest(tf_test_utils.TracedModuleTestCase):
 
   def test_add_same_shape(self):
-    m = self.get_module()
-    dst = m.add(tf.random.uniform([4]), tf.random.uniform([4]))
-    dst.print().assert_all_close()
 
+    def add_same_shape(module):
+      lhs = tf_utils.uniform([4])
+      rhs = tf_utils.uniform([4])
+      module.add(lhs, rhs)
 
-# TODO(silvasean): Make these work.
-#   def test_add_broadcast_lhs(self):
-#     m = self.get_module()
-#     dst = m.add(tf.random.uniform([1]), tf.random.uniform([4]))
-#     dst.print().assert_all_close()
-#
-#   def test_add_broadcast_rhs(self):
-#     m = self.get_module()
-#     dst = m.add(tf.random.uniform([4]), tf.random.uniform([1]))
-#     dst.print().assert_all_close()
+    self.compare_backends(add_same_shape)
+
+  def test_add_broadcast_lhs(self):
+
+    def add_broadcast_lhs(module):
+      lhs = tf_utils.uniform([1])
+      rhs = tf_utils.uniform([4])
+      module.add(lhs, rhs)
+
+    self.compare_backends(add_broadcast_lhs)
+
+  def test_add_broadcast_rhs(self):
+
+    def add_broadcast_rhs(module):
+      lhs = tf_utils.uniform([4])
+      rhs = tf_utils.uniform([1])
+      module.add(lhs, rhs)
+
+    self.compare_backends(add_broadcast_rhs)
+
 
 if __name__ == "__main__":
   if hasattr(tf, "enable_v2_behavior"):
diff --git a/integrations/tensorflow/e2e/complex_test.py b/integrations/tensorflow/e2e/complex_test.py
new file mode 100644
index 0000000..102396b
--- /dev/null
+++ b/integrations/tensorflow/e2e/complex_test.py
@@ -0,0 +1,51 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+import tensorflow.compat.v2 as tf
+
+
+class ComplexModule(tf.Module):
+
+  def __init__(self):
+    pass
+
+  @tf.function(input_signature=[
+      tf.TensorSpec([2], tf.float32),
+      tf.TensorSpec([2], tf.float32)
+  ])
+  def complex_exp(self, real, imag):
+    tensor = tf.complex(real, imag)
+    exp = tf.exp(tensor)
+    return tf.math.real(exp)
+
+
+@tf_test_utils.compile_module(ComplexModule)
+class ComplexTest(tf_test_utils.TracedModuleTestCase):
+
+  def test_complex(self):
+
+    def complex_exp(module):
+      real = np.array([2., 3.], dtype=np.float32)
+      imag = np.array([-1., 0.4], dtype=np.float32)
+      module.complex_exp(real, imag)
+
+    self.compare_backends(complex_exp)
+
+
+if __name__ == "__main__":
+  if hasattr(tf, "enable_v2_behavior"):
+    tf.enable_v2_behavior()
+  tf.test.main()
diff --git a/integrations/tensorflow/e2e/concat_test.py b/integrations/tensorflow/e2e/concat_test.py
index b7a348c..187c6cb 100644
--- a/integrations/tensorflow/e2e/concat_test.py
+++ b/integrations/tensorflow/e2e/concat_test.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 """Test concat op."""
 
+import numpy as np
 from pyiree.tf.support import tf_test_utils
 from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
@@ -51,39 +52,43 @@
 
 
 @tf_test_utils.compile_module(ConcatOpsModule)
-class ConcatOpsTest(tf_test_utils.CompiledModuleTestCase):
+class ConcatOpsTest(tf_test_utils.TracedModuleTestCase):
 
   def test_concat_zero_dim(self):
-    tf_utils.set_random_seed()
-    m = self.get_module()
-    a = tf.random.uniform([1, 5, 0], dtype=tf.float32)
-    b = tf.random.uniform([1, 5, 1], dtype=tf.float32)
-    dst = m.concat_zero_dim(a, b)
-    dst.assert_all_close()
 
-  def concat0axis(self):
-    tf_utils.set_random_seed()
-    m = self.get_module()
-    a = tf.random.uniform([1, 5, 1], dtype=tf.float32)
-    b = tf.random.uniform([1, 5, 1], dtype=tf.float32)
-    dst = m.concat_zero_dim(a, b)
-    dst.assert_all_close()
+    def concat_zero_dim(module):
+      a = tf_utils.uniform([1, 5, 0])
+      b = tf_utils.uniform([1, 5, 1])
+      module.concat_zero_dim(a, b)
 
-  def concat1axis(self):
-    tf_utils.set_random_seed()
-    m = self.get_module()
-    a = tf.random.uniform([1, 5, 1], dtype=tf.float32)
-    b = tf.random.uniform([1, 5, 1], dtype=tf.float32)
-    dst = m.concat_zero_dim(a, b)
-    dst.assert_all_close()
+    self.compare_backends(concat_zero_dim)
 
-  def concat2axis(self):
-    tf_utils.set_random_seed()
-    m = self.get_module()
-    a = tf.random.uniform([1, 5, 1], dtype=tf.float32)
-    b = tf.random.uniform([1, 5, 1], dtype=tf.float32)
-    dst = m.concat_zero_dim(a, b)
-    dst.assert_all_close()
+  def test_concat0axis(self):
+
+    def concat0axis(module):
+      a = tf_utils.uniform([1, 5, 1])
+      b = tf_utils.uniform([1, 5, 1])
+      module.concat0axis(a, b)
+
+    self.compare_backends(concat0axis)
+
+  def test_concat1axis(self):
+
+    def concat1axis(module):
+      a = tf_utils.uniform([1, 5, 1])
+      b = tf_utils.uniform([1, 5, 1])
+      module.concat1axis(a, b)
+
+    self.compare_backends(concat1axis)
+
+  def test_concat2axis(self):
+
+    def concat2axis(module):
+      a = tf_utils.uniform([1, 5, 1])
+      b = tf_utils.uniform([1, 5, 1])
+      module.concat2axis(a, b)
+
+    self.compare_backends(concat2axis)
 
 
 if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/control_flow_test.py b/integrations/tensorflow/e2e/control_flow_test.py
index 0c25fd6..bc0a328 100644
--- a/integrations/tensorflow/e2e/control_flow_test.py
+++ b/integrations/tensorflow/e2e/control_flow_test.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import numpy
+import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
 
@@ -35,17 +35,23 @@
 
 
 @tf_test_utils.compile_module(ControlFlowModule)
-class ControlFlowTest(tf_test_utils.CompiledModuleTestCase):
+class ControlFlowTest(tf_test_utils.TracedModuleTestCase):
 
   def test_short_sequence(self):
-    input_array = numpy.array(9., dtype=numpy.float32)
-    result = self.get_module().collatz(input_array)
-    result.print().assert_all_close()
+
+    def short_sequence(module):
+      input_array = np.array(9., dtype=np.float32)
+      module.collatz(input_array)
+
+    self.compare_backends(short_sequence)
 
   def test_long_sequence(self):
-    input_array = numpy.array(178., dtype=numpy.float32)
-    result = self.get_module().collatz(input_array)
-    result.print().assert_all_close()
+
+    def long_sequence(module):
+      input_array = np.array(178., dtype=np.float32)
+      module.collatz(input_array)
+
+    self.compare_backends(long_sequence)
 
 
 if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/conv_test.py b/integrations/tensorflow/e2e/conv_test.py
index 61c46cf..9346a52 100644
--- a/integrations/tensorflow/e2e/conv_test.py
+++ b/integrations/tensorflow/e2e/conv_test.py
@@ -15,6 +15,7 @@
 
 import numpy as np
 from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
 
 
@@ -99,73 +100,109 @@
 
 
 @tf_test_utils.compile_module(Conv2dModule)
-class ConvTest(tf_test_utils.CompiledModuleTestCase):
+class ConvTest(tf_test_utils.TracedModuleTestCase):
 
   def test_id_batch_size_1(self):
-    i = np.arange(20, dtype=np.float32).reshape([1, 4, 5, 1])
-    k = np.ones([1, 1, 1, 1], dtype=np.float32)
-    r = self.get_module().conv2d_1451x1111_valid(i, k)
-    r.print().assert_all_close()
+
+    def id_batch_size_1(module):
+      i = tf_utils.ndarange([1, 4, 5, 1])
+      k = np.ones([1, 1, 1, 1], dtype=np.float32)
+      module.conv2d_1451x1111_valid(i, k)
+
+    self.compare_backends(id_batch_size_1)
 
   def test_id_batch_size_2(self):
-    i = np.arange(40, dtype=np.float32).reshape([2, 4, 5, 1])
-    k = np.ones([1, 1, 1, 1], dtype=np.float32)
-    r = self.get_module().conv2d_2451x1111_valid(i, k)
-    r.print().assert_all_close()
 
-  def test_asym_kernel(self):
-    i = np.arange(20, dtype=np.float32).reshape([1, 4, 5, 1])
-    k = np.array([[1, 4, 2], [-2, 0, 1]], dtype=np.float32).reshape(2, 3, 1, 1)
-    r = self.get_module().conv2d_1451x2311_valid(i, k)
-    r.print().assert_all_close()
+    def id_batch_size_2(module):
+      i = tf_utils.ndarange([2, 4, 5, 1])
+      k = np.ones([1, 1, 1, 1], dtype=np.float32)
+      module.conv2d_2451x1111_valid(i, k)
+
+    self.compare_backends(id_batch_size_2)
+
+  def test_asymmetric_kernel(self):
+
+    def asymmetric_kernel(module):
+      i = tf_utils.ndarange([1, 4, 5, 1])
+      k = np.array([[1, 4, 2], [-2, 0, 1]],
+                   dtype=np.float32).reshape(2, 3, 1, 1)
+      module.conv2d_1451x2311_valid(i, k)
+
+    self.compare_backends(asymmetric_kernel)
 
   def test_padding(self):
-    i = np.arange(20, dtype=np.float32).reshape([1, 4, 5, 1])
-    k = np.array([[1, 4, 2], [-2, 0, 1]], dtype=np.float32).reshape(2, 3, 1, 1)
-    r = self.get_module().conv2d_1451x2311_same(i, k)
-    r.print().assert_all_close()
+
+    def padding(module):
+      i = tf_utils.ndarange([1, 4, 5, 1])
+      k = np.array([[1, 4, 2], [-2, 0, 1]],
+                   dtype=np.float32).reshape(2, 3, 1, 1)
+      module.conv2d_1451x2311_same(i, k)
+
+    self.compare_backends(padding)
 
   def test_batched_padding(self):
-    i = np.arange(40, dtype=np.float32).reshape([2, 4, 5, 1])
-    k = np.array([[1, 4, 2], [-2, 0, 1]], dtype=np.float32).reshape(2, 3, 1, 1)
-    r = self.get_module().conv2d_2451x2311_same(i, k)
-    r.print().assert_all_close()
+
+    def batched_padding(module):
+      i = tf_utils.ndarange([2, 4, 5, 1])
+      k = np.array([[1, 4, 2], [-2, 0, 1]],
+                   dtype=np.float32).reshape(2, 3, 1, 1)
+      module.conv2d_2451x2311_same(i, k)
+
+    self.compare_backends(batched_padding)
 
   def test_feature_reduce(self):
-    i = np.arange(40, dtype=np.float32).reshape([1, 4, 5, 2])
-    k = np.ones([3, 2, 2, 1], dtype=np.float32)
-    r = self.get_module().conv2d_1452x3221_same(i, k)
-    r.print().assert_all_close()
+
+    def feature_reduce(module):
+      i = tf_utils.ndarange([1, 4, 5, 2])
+      k = np.ones([3, 2, 2, 1], dtype=np.float32)
+      module.conv2d_1452x3221_same(i, k)
+
+    self.compare_backends(feature_reduce)
 
   def test_feature_inflate(self):
-    i = np.arange(20, dtype=np.float32).reshape([1, 4, 5, 1])
-    k = np.arange(2, dtype=np.float32).reshape([1, 1, 1, 2])
-    r = self.get_module().conv2d_1451x1112_same(i, k)
-    r.print().assert_all_close()
+
+    def feature_inflate(module):
+      i = tf_utils.ndarange([1, 4, 5, 1])
+      k = tf_utils.ndarange([1, 1, 1, 2])
+      module.conv2d_1451x1112_same(i, k)
+
+    self.compare_backends(feature_inflate)
 
   def test_feature_mix(self):
-    i = np.arange(40, dtype=np.float32).reshape([1, 4, 5, 2])
-    k = np.arange(4, dtype=np.float32).reshape([1, 1, 2, 2])
-    r = self.get_module().conv2d_1452x1122_same(i, k)
-    r.print().assert_all_close()
+
+    def feature_mix(module):
+      i = tf_utils.ndarange([1, 4, 5, 2])
+      k = tf_utils.ndarange([1, 1, 2, 2])
+      module.conv2d_1452x1122_same(i, k)
+
+    self.compare_backends(feature_mix)
 
   def test_feature_padded(self):
-    i = np.arange(40, dtype=np.float32).reshape([1, 4, 5, 2])
-    k = np.arange(24, dtype=np.float32).reshape([2, 2, 2, 3])
-    r = self.get_module().conv2d_1452x2223_same(i, k)
-    r.print().assert_all_close()
+
+    def feature_padded(module):
+      i = tf_utils.ndarange([1, 4, 5, 2])
+      k = tf_utils.ndarange([2, 2, 2, 3])
+      module.conv2d_1452x2223_same(i, k)
+
+    self.compare_backends(feature_padded)
 
   def test_feature_unpadded(self):
-    i = np.arange(40, dtype=np.float32).reshape([1, 4, 5, 2])
-    k = np.arange(24, dtype=np.float32).reshape([2, 2, 2, 3])
-    r = self.get_module().conv2d_1452x2223_valid(i, k)
-    r.print().assert_all_close()
+
+    def feature_unpadded(module):
+      i = tf_utils.ndarange([1, 4, 5, 2])
+      k = tf_utils.ndarange([2, 2, 2, 3])
+      module.conv2d_1452x2223_valid(i, k)
+
+    self.compare_backends(feature_unpadded)
 
   def test_batched_feature_unpadded(self):
-    i = np.arange(80, dtype=np.float32).reshape([2, 4, 5, 2])
-    k = np.arange(24, dtype=np.float32).reshape([2, 2, 2, 3])
-    r = self.get_module().conv2d_2452x2223_valid(i, k)
-    r.print().assert_all_close()
+
+    def batched_feature_unpadded(module):
+      i = tf_utils.ndarange([2, 4, 5, 2])
+      k = tf_utils.ndarange([2, 2, 2, 3])
+      module.conv2d_2452x2223_valid(i, k)
+
+    self.compare_backends(batched_feature_unpadded)
 
 
 if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/depth_conv_test.py b/integrations/tensorflow/e2e/depth_conv_test.py
index 1e8a002..9d88bee 100644
--- a/integrations/tensorflow/e2e/depth_conv_test.py
+++ b/integrations/tensorflow/e2e/depth_conv_test.py
@@ -15,6 +15,7 @@
 
 import numpy as np
 from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
 
 
@@ -39,19 +40,25 @@
 
 
 @tf_test_utils.compile_module(Conv2dModule)
-class ConvTest(tf_test_utils.CompiledModuleTestCase):
+class ConvTest(tf_test_utils.TracedModuleTestCase):
 
   def test_batched_feature_unpadded(self):
-    i = np.arange(80, dtype=np.float32).reshape([2, 4, 5, 2])
-    k = np.arange(24, dtype=np.float32).reshape([2, 2, 2, 3])
-    r = self.get_module().conv2d_2452x2223_valid(i, k)
-    r.print().assert_all_close()
 
-  def test_batched_feature_unpadded_smae(self):
-    i = np.arange(80, dtype=np.float32).reshape([2, 4, 5, 2])
-    k = np.arange(48, dtype=np.float32).reshape([2, 4, 2, 3])
-    r = self.get_module().conv2d_2452x2223_same(i, k)
-    r.print().assert_all_close()
+    def batched_feature_unpadded(module):
+      i = tf_utils.ndarange([2, 4, 5, 2])
+      k = tf_utils.ndarange([2, 2, 2, 3])
+      module.conv2d_2452x2223_valid(i, k)
+
+    self.compare_backends(batched_feature_unpadded)
+
+  def test_batched_feature_unpadded_same(self):
+
+    def batched_feature_unpadded_same(module):
+      i = tf_utils.ndarange([2, 4, 5, 2])
+      k = tf_utils.ndarange([2, 4, 2, 3])
+      module.conv2d_2452x2223_same(i, k)
+
+    self.compare_backends(batched_feature_unpadded_same)
 
 
 if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py b/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
index 64c51e9..eec6815 100644
--- a/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
+++ b/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
@@ -28,7 +28,7 @@
 CLASSES = 10
 
 
-class Mlp(tf.Module):
+class MlpRelu(tf.Module):
 
   def __init__(self,
                hidden_1_dim=256,
@@ -65,14 +65,16 @@
     return tf.nn.softmax(self.mlp(x))
 
 
-@tf_test_utils.compile_module(Mlp, exported_names=["predict"])
-class DynamicMlpTest(tf_test_utils.CompiledModuleTestCase):
+@tf_test_utils.compile_module(MlpRelu, exported_names=["predict"])
+class DynamicMlpReluTest(tf_test_utils.TracedModuleTestCase):
 
   def test_dynamic_batch(self):
-    m = self.get_module()
-    np.random.seed(12345)
-    x = np.random.random([3, 28 * 28]).astype(np.float32) * 1e-3
-    m.predict(x).print().assert_all_close()
+
+    def dynamic_batch(module):
+      x = tf_utils.uniform([3, 28 * 28]) * 1e-3
+      module.predict(x)
+
+    self.compare_backends(dynamic_batch)
 
 
 if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/dynamic_mlp_test.py b/integrations/tensorflow/e2e/dynamic_mlp_test.py
index 72d7f1f..0b70e84 100644
--- a/integrations/tensorflow/e2e/dynamic_mlp_test.py
+++ b/integrations/tensorflow/e2e/dynamic_mlp_test.py
@@ -62,13 +62,15 @@
 
 
 @tf_test_utils.compile_module(Mlp, exported_names=["predict"])
-class DynamicMlpTest(tf_test_utils.CompiledModuleTestCase):
+class DynamicMlpTest(tf_test_utils.TracedModuleTestCase):
 
   def test_dynamic_batch(self):
-    m = self.get_module()
-    np.random.seed(12345)
-    x = np.random.random([3, 28 * 28]).astype(np.float32) * 1e-3
-    m.predict(x).print().assert_all_close()
+
+    def dynamic_batch(module):
+      x = tf_utils.uniform([3, 28 * 28]) * 1e-3
+      module.predict(x)
+
+    self.compare_backends(dynamic_batch)
 
 
 if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/explicit_backend_test.py b/integrations/tensorflow/e2e/explicit_backend_test.py
deleted file mode 100644
index bdcdd79..0000000
--- a/integrations/tensorflow/e2e/explicit_backend_test.py
+++ /dev/null
@@ -1,74 +0,0 @@
-# Lint as: python3
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#      https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Tests explicitly specifying a backend in Python."""
-
-import numpy as np
-from pyiree.tf.support import tf_test_utils
-import tensorflow.compat.v2 as tf
-
-
-class SimpleArithmeticModule(tf.Module):
-
-  @tf.function(input_signature=[
-      tf.TensorSpec([4], tf.float32),
-      tf.TensorSpec([4], tf.float32)
-  ])
-  def simple_mul(self, a, b):
-    return a * b
-
-
-@tf_test_utils.compile_module(SimpleArithmeticModule)
-class ExplicitBackendTest(tf_test_utils.CompiledModuleTestCase):
-
-  def test_explicit(self):
-    a = np.array([1., 2., 3., 4.], dtype=np.float32)
-    b = np.array([400., 5., 6., 7.], dtype=np.float32)
-
-    # Demonstrates simple, one by one invocation of functions against
-    # different explicit backends. Individual backends can be accessed off of
-    # the module by name ('tf', 'iree_vmla' below).
-    tf_c = self.compiled_modules.tf.simple_mul(a, b)
-    print("TF Result:", tf_c)
-    iree_c = self.compiled_modules.iree_vmla.simple_mul(a, b)
-    print("IREE Result:", iree_c)
-    self.assertAllClose(tf_c, iree_c)
-
-  def test_multi(self):
-    a = np.array([1., 2., 3., 4.], dtype=np.float32)
-    b = np.array([400., 5., 6., 7.], dtype=np.float32)
-
-    # Evaluating against multiple backends can be done with the multi() method,
-    # which takes a regex string matching backend names. This also returns a
-    # MultiResults tuple with actual results keyed by backend name. These also
-    # have convenience methods like print() and assert_all_close().
-    vmod = self.compiled_modules.multi("tf|iree")
-    r = vmod.simple_mul(a, b)
-    r.print().assert_all_close()
-
-  def test_get_module(self):
-    a = np.array([1., 2., 3., 4.], dtype=np.float32)
-    b = np.array([400., 5., 6., 7.], dtype=np.float32)
-
-    # Evaluating against all backends can be done with self.get_module(). This
-    # also returns a MultiResults tuple with actual results keyed by backend
-    # name.
-    r = self.get_module().simple_mul(a, b)
-    r.print().assert_all_close()
-
-
-if __name__ == "__main__":
-  if hasattr(tf, "enable_v2_behavior"):
-    tf.enable_v2_behavior()
-  tf.test.main()
diff --git a/integrations/tensorflow/e2e/fill_test.py b/integrations/tensorflow/e2e/fill_test.py
index 8d912a4..050eb47 100644
--- a/integrations/tensorflow/e2e/fill_test.py
+++ b/integrations/tensorflow/e2e/fill_test.py
@@ -31,14 +31,16 @@
 
 
 @tf_test_utils.compile_module(FillModule)
-class FillTest(tf_test_utils.CompiledModuleTestCase):
+class FillTest(tf_test_utils.TracedModuleTestCase):
 
   def test_fill(self):
-    dims = np.array([2, 3], dtype=np.int32)
-    value = np.array(9., dtype=np.float32)
 
-    result = self.get_module().fill(dims, value)
-    result.assert_all_close()
+    def fill(module):
+      dims = np.array([2, 3], dtype=np.int32)
+      value = np.array(9., dtype=np.float32)
+      module.fill(dims, value)
+
+    self.compare_backends(fill)
 
 
 if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/gather_test.py b/integrations/tensorflow/e2e/gather_test.py
index 67f5acf..7ae7af6 100644
--- a/integrations/tensorflow/e2e/gather_test.py
+++ b/integrations/tensorflow/e2e/gather_test.py
@@ -14,6 +14,7 @@
 
 import numpy as np
 from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
 
 
@@ -49,31 +50,43 @@
 
 
 @tf_test_utils.compile_module(GatherModule)
-class GatherTest(tf_test_utils.CompiledModuleTestCase):
+class GatherTest(tf_test_utils.TracedModuleTestCase):
 
   def test_gather_axis0_scalar(self):
-    indices = np.array(2, dtype=np.int32)
-    params = np.arange(32, dtype=np.float32).reshape(4, 8)
-    result = self.get_module().gather_axis0_scalar(params, indices)
-    result.print().assert_all_close()
+
+    def gather_axis0_scalar(module):
+      indices = np.array(2, dtype=np.int32)
+      params = tf_utils.ndarange([4, 8])
+      module.gather_axis0_scalar(params, indices)
+
+    self.compare_backends(gather_axis0_scalar)
 
   def test_gather_axis0_batch0(self):
-    indices = np.array([2, 3], dtype=np.int32)
-    params = np.arange(32, dtype=np.float32).reshape(4, 8)
-    result = self.get_module().gather_axis0_batch0(params, indices)
-    result.print().assert_all_close()
+
+    def gather_axis0_batch0(module):
+      indices = np.array([2, 3], dtype=np.int32)
+      params = tf_utils.ndarange([4, 8])
+      module.gather_axis0_batch0(params, indices)
+
+    self.compare_backends(gather_axis0_batch0)
 
   def test_gather_axis1_batch0(self):
-    indices = np.array([2, 3], dtype=np.int32)
-    params = np.arange(4 * 7 * 8, dtype=np.float32).reshape(4, 7, 8)
-    result = self.get_module().gather_axis1_batch0(params, indices)
-    result.print().assert_all_close()
+
+    def gather_axis1_batch0(module):
+      indices = np.array([2, 3], dtype=np.int32)
+      params = tf_utils.ndarange([4, 7, 8])
+      module.gather_axis1_batch0(params, indices)
+
+    self.compare_backends(gather_axis1_batch0)
 
   def test_gather_axis2_batch1(self):
-    indices = np.array([[2], [3], [0], [1]], dtype=np.int32)
-    params = np.arange(4 * 7 * 8 * 2, dtype=np.float32).reshape(4, 7, 8, 2)
-    result = self.get_module().gather_axis2_batch1(params, indices)
-    result.print().assert_all_close()
+
+    def gather_axis2_batch1(module):
+      indices = np.array([[2], [3], [0], [1]], dtype=np.int32)
+      params = tf_utils.ndarange([4, 7, 8, 2])
+      module.gather_axis2_batch1(params, indices)
+
+    self.compare_backends(gather_axis2_batch1)
 
 
 if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl b/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
index bd80fbc..60bbff4 100644
--- a/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
+++ b/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
@@ -55,7 +55,8 @@
                 backend,
             )
             args = [
-                "--target_backends={},{}".format(reference_backend, backend),
+                "--reference_backend={}".format(reference_backend),
+                "--target_backends={}".format(backend),
             ]
 
             # TODO(GH-2175): Simplify this after backend names are standardized.
diff --git a/integrations/tensorflow/e2e/keras/BUILD b/integrations/tensorflow/e2e/keras/BUILD
index 16ddde9..7d1f1d3 100644
--- a/integrations/tensorflow/e2e/keras/BUILD
+++ b/integrations/tensorflow/e2e/keras/BUILD
@@ -185,11 +185,21 @@
         # interpreter and vulkan backends for these tests.
         {
             "models": [
-                "MobileNet",
                 "MobileNetV2",
             ],
             "datasets": [
                 "cifar10",
+            ],
+            "backends": [
+                "iree_vulkan",
+            ],
+        },
+        {
+            "models": [
+                "MobileNet",
+                "MobileNetV2",
+            ],
+            "datasets": [
                 "imagenet",
             ],
             "backends": [
diff --git a/integrations/tensorflow/e2e/keras/iree_vision_test_suite.bzl b/integrations/tensorflow/e2e/keras/iree_vision_test_suite.bzl
index a8d4f29..d286948 100644
--- a/integrations/tensorflow/e2e/keras/iree_vision_test_suite.bzl
+++ b/integrations/tensorflow/e2e/keras/iree_vision_test_suite.bzl
@@ -107,7 +107,8 @@
                     "--model={}".format(model),
                     "--data={}".format(dataset),
                     "--include_top=1",
-                    "--target_backends={},{}".format(reference_backend, backend),
+                    "--reference_backend={}".format(reference_backend),
+                    "--target_backends={}".format(backend),
                 ]
                 if external_weights:
                     args.append("--url={}".format(external_weights))
diff --git a/integrations/tensorflow/e2e/keras/lstm_static_test.py b/integrations/tensorflow/e2e/keras/lstm_static_test.py
index fb7a58c..64db229 100644
--- a/integrations/tensorflow/e2e/keras/lstm_static_test.py
+++ b/integrations/tensorflow/e2e/keras/lstm_static_test.py
@@ -21,9 +21,9 @@
 from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
 
-NUM_UNITS = 10
-NUM_TIMESTEPS = 24
 NUM_BATCH = 7
+NUM_TIMESTEPS = 24
+NUM_UNITS = 10
 INPUT_SHAPE = [NUM_BATCH, NUM_TIMESTEPS, NUM_UNITS]
 
 
@@ -43,17 +43,15 @@
 
 
 @tf_test_utils.compile_module(LstmStatic, exported_names=["predict"])
-class LstmTest(tf_test_utils.CompiledModuleTestCase):
+class LstmStaticTest(tf_test_utils.TracedModuleTestCase):
 
   def test_lstm(self):
-    m = self.get_module()
-    m.predict(
-        tf.constant(
-            np.arange(NUM_BATCH * NUM_TIMESTEPS * NUM_UNITS,
-                      dtype=np.float32).reshape(
-                          [NUM_BATCH, NUM_TIMESTEPS, NUM_UNITS]),
-            shape=[NUM_BATCH, NUM_TIMESTEPS,
-                   NUM_UNITS])).print().assert_all_close(1e-5, 1e-5)
+
+    def predict(module):
+      inputs = tf_utils.ndarange(INPUT_SHAPE)
+      module.predict(inputs, rtol=1e-5, atol=1e-5)
+
+    self.compare_backends(predict)
 
 
 if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/keras/lstm_test.py b/integrations/tensorflow/e2e/keras/lstm_test.py
index 9409d04..8232d19 100644
--- a/integrations/tensorflow/e2e/keras/lstm_test.py
+++ b/integrations/tensorflow/e2e/keras/lstm_test.py
@@ -18,10 +18,11 @@
 from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
 
-NUM_UNITS = 10
-NUM_TIMESTEPS = 24
 NUM_BATCH = 7
-INPUT_SHAPE = [None, None, NUM_UNITS]
+NUM_TIMESTEPS = 24
+NUM_UNITS = 10
+DYNAMIC_SHAPE = [None, None, NUM_UNITS]
+INPUT_SHAPE = [NUM_BATCH, NUM_TIMESTEPS, NUM_UNITS]
 
 
 class Lstm(tf.Module):
@@ -29,28 +30,26 @@
   def __init__(self):
     super(Lstm, self).__init__()
     tf_utils.set_random_seed()
-    inputs = tf.keras.layers.Input(batch_size=None, shape=INPUT_SHAPE[1:])
+    inputs = tf.keras.layers.Input(batch_size=None, shape=DYNAMIC_SHAPE[1:])
     outputs = tf.keras.layers.LSTM(
         units=NUM_UNITS, return_sequences=True)(
             inputs)
     self.m = tf.keras.Model(inputs, outputs)
     self.predict = tf.function(
-        input_signature=[tf.TensorSpec(INPUT_SHAPE, tf.float32)])(
+        input_signature=[tf.TensorSpec(DYNAMIC_SHAPE, tf.float32)])(
             self.m.call)
 
 
 @tf_test_utils.compile_module(Lstm, exported_names=["predict"])
-class LstmTest(tf_test_utils.CompiledModuleTestCase):
+class LstmTest(tf_test_utils.TracedModuleTestCase):
 
   def test_lstm(self):
-    m = self.get_module()
-    m.predict(
-        tf.constant(
-            np.arange(NUM_BATCH * NUM_TIMESTEPS * NUM_UNITS,
-                      dtype=np.float32).reshape(
-                          [NUM_BATCH, NUM_TIMESTEPS, NUM_UNITS]),
-            shape=[NUM_BATCH, NUM_TIMESTEPS,
-                   NUM_UNITS])).print().assert_all_close()
+
+    def predict(module):
+      inputs = tf_utils.ndarange(INPUT_SHAPE)
+      module.predict(inputs)
+
+    self.compare_backends(predict)
 
 
 if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/keras/train/model_train_test.py b/integrations/tensorflow/e2e/keras/train/model_train_test.py
index e30bd57..75286ad 100644
--- a/integrations/tensorflow/e2e/keras/train/model_train_test.py
+++ b/integrations/tensorflow/e2e/keras/train/model_train_test.py
@@ -65,7 +65,7 @@
       tf.TensorSpec(_INPUT_DATA_SHAPE, tf.float32),
       tf.TensorSpec(_OUTPUT_DATA_SHAPE, tf.float32)
   ])
-  def TrainStep(self, inputs, targets):
+  def train_step(self, inputs, targets):
     with tf.GradientTape() as tape:
       predictions = self.model(inputs, training=True)
       loss_value = self.loss(predictions, targets)
@@ -76,8 +76,8 @@
 
 
 @tf_test_utils.compile_module(
-    ModelTrain.CreateModule, exported_names=["TrainStep"])
-class ModelTrainTest(tf_test_utils.CompiledModuleTestCase):
+    ModelTrain.CreateModule, exported_names=["train_step"])
+class ModelTrainTest(tf_test_utils.TracedModuleTestCase):
 
   def generate_regression_data(self, size=8):
     x = np.arange(size) - size // 2
@@ -86,22 +86,25 @@
 
   def test_model_train(self):
 
-    # generate input and output data for regression problem
+    # Generate input and output data for regression problem.
     inputs, targets = self.generate_regression_data()
 
-    # normalize data
+    # Normalize data.
     inputs = inputs / max(inputs)
     targets = targets / max(targets)
 
-    # generate polynomial features
+    # Generate polynomial features.
     inputs = np.expand_dims(inputs, axis=1)
     polynomial = PolynomialFeatures(_DEGREE)  # returns: [1, a, b, a^2, ab, b^2]
     inputs = polynomial.fit_transform(inputs)
 
     targets = np.expand_dims(targets, axis=1)
-    # run one iteration of training step
-    result = self.get_module().TrainStep(inputs, targets)
-    result.print().assert_all_close()
+
+    def train_step(module):
+      # Run one iteration of training step.
+      module.train_step(inputs, targets)
+
+    self.compare_backends(train_step)
 
 
 if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/keras/vision_model_test.py b/integrations/tensorflow/e2e/keras/vision_model_test.py
index 54374ff..7c4b90e 100644
--- a/integrations/tensorflow/e2e/keras/vision_model_test.py
+++ b/integrations/tensorflow/e2e/keras/vision_model_test.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Test all applications models in Keras."""
+
 import os
 
 from absl import app
@@ -133,11 +134,14 @@
 
 
 @tf_test_utils.compile_module(VisionModule, exported_names=['predict'])
-class AppTest(tf_test_utils.CompiledModuleTestCase):
+class AppTest(tf_test_utils.TracedModuleTestCase):
 
   def test_application(self):
-    input_data = np.random.rand(*get_input_shape()).astype(np.float32)
-    self.get_module().predict(input_data).print().assert_all_close(atol=1e-6)
+
+    def predict(module):
+      module.predict(tf_utils.uniform(get_input_shape()))
+
+    self.compare_backends(predict)
 
 
 def main(argv):
@@ -148,7 +152,7 @@
   if FLAGS.model not in APP_MODELS:
     raise ValueError(f'Unsupported model: {FLAGS.model}')
   # Override VisionModule's __name__ to be more specific.
-  VisionModule.__name__ = FLAGS.model
+  VisionModule.__name__ = os.path.join(FLAGS.model, FLAGS.data)
 
   tf.test.main()
 
diff --git a/integrations/tensorflow/e2e/linspace_test.py b/integrations/tensorflow/e2e/linspace_test.py
index aa49f5b..b535021 100644
--- a/integrations/tensorflow/e2e/linspace_test.py
+++ b/integrations/tensorflow/e2e/linspace_test.py
@@ -34,14 +34,16 @@
 
 
 @tf_test_utils.compile_module(LinSpaceModule)
-class LinspaceTest(tf_test_utils.CompiledModuleTestCase):
+class LinspaceTest(tf_test_utils.TracedModuleTestCase):
 
   def test_linspace(self):
-    start = np.array(10., dtype=np.float32)
-    stop = np.array(12., dtype=np.float32)
 
-    result = self.get_module().linspace(start, stop)
-    result.assert_all_close()
+    def linspace(module):
+      start = np.array(10., dtype=np.float32)
+      stop = np.array(12., dtype=np.float32)
+      module.linspace(start, stop)
+
+    self.compare_backends(linspace)
 
 
 if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/mandelbrot_test.py b/integrations/tensorflow/e2e/mandelbrot_test.py
index 2b3a8d9..b5a3929 100644
--- a/integrations/tensorflow/e2e/mandelbrot_test.py
+++ b/integrations/tensorflow/e2e/mandelbrot_test.py
@@ -91,19 +91,17 @@
 
 
 @tf_test_utils.compile_module(MandelbrotModule)
-class MandelbrotTest(tf_test_utils.CompiledModuleTestCase):
+class MandelbrotTest(tf_test_utils.TracedModuleTestCase):
 
   def test_mandelbrot(self):
-    mandelbrot = self.get_module()
 
-    # Basic view of the entire set.
-    pixels = mandelbrot.calculate(-0.7, 0.0, 3.0, 400, 100)
-    pixels.assert_all_close()
+    def mandelbrot(module):
+      # Basic view of the entire set.
+      module.calculate(-0.7, 0.0, 3.0, 400, 100)
+      # This is a much more detailed view, so more iterations are needed.
+      module.calculate(-0.7436447860, 0.1318252536, 0.0000029336, 400, 3000)
 
-    # This is a much more detailed view, so more iterations are needed.
-    pixels = mandelbrot.calculate(-0.7436447860, 0.1318252536, 0.0000029336,
-                                  400, 3000)
-    pixels.assert_all_close()
+    self.compare_backends(mandelbrot)
 
 
 if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/math_test.py b/integrations/tensorflow/e2e/math_test.py
index a33ac7c..3f1c8be 100644
--- a/integrations/tensorflow/e2e/math_test.py
+++ b/integrations/tensorflow/e2e/math_test.py
@@ -39,27 +39,35 @@
 
 
 @tf_test_utils.compile_module(MathModule)
-class MathTest(tf_test_utils.CompiledModuleTestCase):
+class MathTest(tf_test_utils.TracedModuleTestCase):
 
   def test_abs(self):
-    a = np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32)
-    r = self.get_module().abs(a)
-    r.print().assert_all_close()
+
+    def abs(module):
+      module.abs(np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32))
+
+    self.compare_backends(abs)
 
   def test_cos(self):
-    a = np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32)
-    r = self.get_module().cos(a)
-    r.print().assert_all_close()
+
+    def cos(module):
+      module.cos(np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32))
+
+    self.compare_backends(cos)
 
   def test_log(self):
-    a = np.array([0.1, 0.2, 0.5, 1.0], dtype=np.float32)
-    r = self.get_module().log(a)
-    r.print().assert_all_close()
+
+    def log(module):
+      module.log(np.array([0.1, 0.2, 0.5, 1.0], dtype=np.float32))
+
+    self.compare_backends(log)
 
   def test_mod(self):
-    a = np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32)
-    r = self.get_module().mod(a)
-    r.print().assert_all_close()
+
+    def mod(module):
+      module.mod(np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32))
+
+    self.compare_backends(mod)
 
 
 if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/matrix_ops_test.py b/integrations/tensorflow/e2e/matrix_ops_test.py
index d04ce3a..c814397 100644
--- a/integrations/tensorflow/e2e/matrix_ops_test.py
+++ b/integrations/tensorflow/e2e/matrix_ops_test.py
@@ -15,6 +15,7 @@
 """Test matrix ops."""
 
 from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
 
 
@@ -71,60 +72,78 @@
 
 
 @tf_test_utils.compile_module(MatrixOpsModule)
-class MatrixOpsTest(tf_test_utils.CompiledModuleTestCase):
+class MatrixOpsTest(tf_test_utils.TracedModuleTestCase):
 
   def test_basic_matmul(self):
-    m = self.get_module()
-    dst = m.basic_matmul(tf.random.uniform([4, 2]), tf.random.uniform([2, 4]))
-    dst.assert_all_close()
+
+    def basic_matmul(module):
+      module.basic_matmul(tf_utils.uniform([4, 2]), tf_utils.uniform([2, 4]))
+
+    self.compare_backends(basic_matmul)
 
   def test_matmul_lhs_batch(self):
-    m = self.get_module()
-    dst = m.matmul_lhs_batch(
-        tf.random.uniform([3, 4, 2]), tf.random.uniform([2, 4]))
-    dst.assert_all_close()
+
+    def matmul_lhs_batch(module):
+      module.matmul_lhs_batch(
+          tf_utils.uniform([3, 4, 2]), tf_utils.uniform([2, 4]))
+
+    self.compare_backends(matmul_lhs_batch)
 
   def test_matmul_rhs_batch(self):
-    m = self.get_module()
-    dst = m.matmul_rhs_batch(
-        tf.random.uniform([4, 2]), tf.random.uniform([3, 2, 4]))
-    dst.assert_all_close()
+
+    def matmul_rhs_batch(module):
+      module.matmul_rhs_batch(
+          tf_utils.uniform([4, 2]), tf_utils.uniform([3, 2, 4]))
+
+    self.compare_backends(matmul_rhs_batch)
 
   def test_matmul_broadcast_singleton_dimension(self):
-    m = self.get_module()
-    dst = m.matmul_broadcast_singleton_dimension(
-        tf.random.uniform([1, 4, 2]), tf.random.uniform([3, 2, 4]))
-    dst.assert_all_close()
+
+    def matmul_broadcast_singleton_dimension(module):
+      module.matmul_broadcast_singleton_dimension(
+          tf_utils.uniform([1, 4, 2]), tf_utils.uniform([3, 2, 4]))
+
+    self.compare_backends(matmul_broadcast_singleton_dimension)
 
   def test_matmul_high_rank_batch(self):
-    m = self.get_module()
-    dst = m.matmul_high_rank_batch(
-        tf.random.uniform([1, 7, 4, 2]), tf.random.uniform([7, 1, 2, 4]))
-    dst.assert_all_close()
+
+    def matmul_high_rank_batch(module):
+      module.matmul_high_rank_batch(
+          tf_utils.uniform([1, 7, 4, 2]), tf_utils.uniform([7, 1, 2, 4]))
+
+    self.compare_backends(matmul_high_rank_batch)
 
   def test_matmul_dynamic_matching_batch(self):
-    m = self.get_module()
-    dst = m.matmul_dynamic(
-        tf.random.uniform([2, 2, 3]), tf.random.uniform([2, 3, 4]))
-    dst.assert_all_close()
+
+    def matmul_dynamic_matching_batch(module):
+      module.matmul_dynamic(
+          tf_utils.uniform([2, 2, 3]), tf_utils.uniform([2, 3, 4]))
+
+    self.compare_backends(matmul_dynamic_matching_batch)
 
   def test_matmul_dynamic_broadcast_lhs(self):
-    m = self.get_module()
-    dst = m.matmul_dynamic(
-        tf.random.uniform([1, 2, 3]), tf.random.uniform([2, 3, 4]))
-    dst.assert_all_close()
+
+    def matmul_dynamic_broadcast_lhs(module):
+      module.matmul_dynamic(
+          tf_utils.uniform([1, 2, 3]), tf_utils.uniform([2, 3, 4]))
+
+    self.compare_backends(matmul_dynamic_broadcast_lhs)
 
   def test_matmul_dynamic_broadcast_rhs(self):
-    m = self.get_module()
-    dst = m.matmul_dynamic(
-        tf.random.uniform([2, 2, 3]), tf.random.uniform([1, 3, 4]))
-    dst.assert_all_close()
+
+    def matmul_dynamic_broadcast_rhs(module):
+      module.matmul_dynamic(
+          tf_utils.uniform([2, 2, 3]), tf_utils.uniform([1, 3, 4]))
+
+    self.compare_backends(matmul_dynamic_broadcast_rhs)
 
   def test_matmul_dynamic_rank_broadcasting(self):
-    m = self.get_module()
-    dst = m.matmul_dynamic_lhs_batch(
-        tf.random.uniform([7, 2, 3]), tf.random.uniform([3, 4]))
-    dst.assert_all_close()
+
+    def matmul_dynamic_rank_broadcasting(module):
+      module.matmul_dynamic_lhs_batch(
+          tf_utils.uniform([7, 2, 3]), tf_utils.uniform([3, 4]))
+
+    self.compare_backends(matmul_dynamic_rank_broadcasting)
 
 
 if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/resource_ops_test.py b/integrations/tensorflow/e2e/resource_ops_test.py
index 8daa6cf..dd5ad6d 100644
--- a/integrations/tensorflow/e2e/resource_ops_test.py
+++ b/integrations/tensorflow/e2e/resource_ops_test.py
@@ -29,11 +29,14 @@
 
 
 @tf_test_utils.compile_module(ResourcesOpsModule)
-class ResourcesOpsTest(tf_test_utils.CompiledModuleTestCase):
+class ResourcesOpsTest(tf_test_utils.TracedModuleTestCase):
 
   def test_add_assign(self):
-    result = self.get_module().add_assign(np.array(9., dtype=np.float32))
-    result.assert_all_close()
+
+    def add_assign(module):
+      module.add_assign(np.array(9., dtype=np.float32))
+
+    self.compare_backends(add_assign)
 
 
 if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/ring_buffer_test.py b/integrations/tensorflow/e2e/ring_buffer_test.py
index 3af1502..8e437ea 100644
--- a/integrations/tensorflow/e2e/ring_buffer_test.py
+++ b/integrations/tensorflow/e2e/ring_buffer_test.py
@@ -179,27 +179,26 @@
 
 @tf_test_utils.compile_module(
     StatefulRingBufferModule, exported_names=["predict"])
-class StatefulRingBufferTest(tf_test_utils.CompiledModuleTestCase):
+class StatefulRingBufferTest(tf_test_utils.TracedModuleTestCase):
 
   def test_stateful_ringbuffer(self):
-    input1 = np.array([[1.0, 2.0]], dtype=np.float32)
-    result1 = self.get_module().predict(input1)
-    output1 = np.array([[1.0, 2.0]], dtype=np.float32)
-    assert np.allclose(result1, output1)
 
-    # ring buffer is not filled yet,
-    # so data from first cycle will be returned
-    input2 = np.array([[3.0, 4.0]], dtype=np.float32)
-    result2 = self.get_module().predict(input2)
-    output2 = np.array([[1.0, 2.0]], dtype=np.float32)
-    assert np.allclose(result2, output2)
+    def stateful_ringbuffer(module):
+      input1 = np.array([[1.0, 2.0]], dtype=np.float32)
+      module.predict(input1)
+      # output = np.array([[1.0, 2.0]], dtype=np.float32)
 
-    # on 3rd cycle we overwrite oldest data
-    # and return data from 2nd cycle
-    input3 = np.array([[5.0, 6.0]], dtype=np.float32)
-    result3 = self.get_module().predict(input3)
-    output3 = np.array([[3.0, 4.0]], dtype=np.float32)
-    assert np.allclose(result3, output3)
+      # ring buffer is not filled yet so data from first cycle will be returned.
+      input2 = np.array([[3.0, 4.0]], dtype=np.float32)
+      module.predict(input2)
+      # output = np.array([[1.0, 2.0]], dtype=np.float32)
+
+      # on 3rd cycle we overwrite oldest data and return data from 2nd cycle.
+      input3 = np.array([[5.0, 6.0]], dtype=np.float32)
+      module.predict(input3)
+      # output = np.array([[3.0, 4.0]], dtype=np.float32)
+
+    self.compare_backends(stateful_ringbuffer)
 
 
 if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/scatter_update_test.py b/integrations/tensorflow/e2e/scatter_update_test.py
index ab5ab91..8b43e3a 100644
--- a/integrations/tensorflow/e2e/scatter_update_test.py
+++ b/integrations/tensorflow/e2e/scatter_update_test.py
@@ -49,28 +49,37 @@
 
 
 @tf_test_utils.compile_module(ScatterUpdateModule)
-class ScatterUpdateTest(tf_test_utils.CompiledModuleTestCase):
+class ScatterUpdateTest(tf_test_utils.TracedModuleTestCase):
 
   def test_scatter_update_1D(self):
-    tensor = tf.ones([8], dtype=tf.int32)
-    indices = tf.constant([[4], [5], [6]])
-    updates = tf.constant([9, 10, 11])
-    result = self.get_module().scatter_update_1D(tensor, indices, updates)
-    result.assert_all_close()
+
+    def scatter_update_1D(module):
+      tensor = np.ones([8], dtype=np.int32)
+      indices = np.array([[4], [5], [6]], dtype=np.int32)
+      updates = np.array([9, 10, 11], dtype=np.int32)
+      module.scatter_update_1D(tensor, indices, updates)
+
+    self.compare_backends(scatter_update_1D)
 
   def test_scatter_update_2D(self):
-    tensor = tf.ones([4, 3], dtype=tf.int32)
-    indices = tf.constant([[1, 0], [2, 1], [3, 2]])
-    updates = tf.constant([2, 5, 8])
-    result = self.get_module().scatter_update_2D(tensor, indices, updates)
-    result.assert_all_close()
+
+    def scatter_update_2D(module):
+      tensor = np.ones([4, 3], dtype=np.int32)
+      indices = np.array([[1, 0], [2, 1], [3, 2]], dtype=np.int32)
+      updates = np.array([2, 5, 8], dtype=np.int32)
+      module.scatter_update_2D(tensor, indices, updates)
+
+    self.compare_backends(scatter_update_2D)
 
   def test_scatter_update_2D_slice(self):
-    tensor = tf.ones([4, 3], dtype=tf.int32)
-    indices = tf.constant([[1]])
-    updates = tf.constant([[2, 3, 4]])
-    result = self.get_module().scatter_update_2D_slice(tensor, indices, updates)
-    result.assert_all_close()
+
+    def scatter_update_2D_slice(module):
+      tensor = np.ones([4, 3], dtype=np.int32)
+      indices = np.array([[1]], dtype=np.int32)
+      updates = np.array([[2, 3, 4]], dtype=np.int32)
+      module.scatter_update_2D_slice(tensor, indices, updates)
+
+    self.compare_backends(scatter_update_2D_slice)
 
 
 if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/simple_arithmetic_test.py b/integrations/tensorflow/e2e/simple_arithmetic_test.py
index d3ea327..aaec578 100644
--- a/integrations/tensorflow/e2e/simple_arithmetic_test.py
+++ b/integrations/tensorflow/e2e/simple_arithmetic_test.py
@@ -16,6 +16,7 @@
 
 import numpy as np
 from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
 
 
@@ -37,21 +38,27 @@
 
 
 @tf_test_utils.compile_module(SimpleArithmeticModule)
-class SimpleArithmeticTest(tf_test_utils.CompiledModuleTestCase):
+class SimpleArithmeticTest(tf_test_utils.TracedModuleTestCase):
 
   def test_simple_mul(self):
-    a = np.array([1., 2., 3., 4.], dtype=np.float32)
-    b = np.array([400., 5., 6., 7.], dtype=np.float32)
-    r = self.get_module().simple_mul(a, b)
-    r.print().assert_all_close()
+
+    def simple_mul(module):
+      a = np.array([1., 2., 3., 4.], dtype=np.float32)
+      b = np.array([400., 5., 6., 7.], dtype=np.float32)
+      c = module.simple_mul(a, b)
+      module.simple_mul(a, c)
+
+    self.compare_backends(simple_mul)
 
   def test_simple_matmul(self):
-    np.random.seed(12345)
-    # Note: scaling by a small value to increase numerical stability.
-    a = np.random.random((128, 3072)).astype(np.float32) * 1e-3
-    b = np.random.random((3072, 256)).astype(np.float32) * 1e-3
-    r = self.get_module().simple_matmul(a, b)
-    r.print().assert_all_close()
+
+    def simple_matmul(module):
+      # Note: scaling by a small value to increase numerical stability.
+      a = tf_utils.uniform((128, 3072)) * 1e-3
+      b = tf_utils.uniform((3072, 256)) * 1e-3
+      module.simple_matmul(a, b)
+
+    self.compare_backends(simple_matmul)
 
 
 if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/simple_stateful_test.py b/integrations/tensorflow/e2e/simple_stateful_test.py
index 24dd23e..eff49a8 100644
--- a/integrations/tensorflow/e2e/simple_stateful_test.py
+++ b/integrations/tensorflow/e2e/simple_stateful_test.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import numpy as np
 from pyiree.tf.support import tf_test_utils
 import tensorflow.compat.v2 as tf
 
@@ -33,12 +34,15 @@
 
 
 @tf_test_utils.compile_module(Stateful)
-class StatefulTest(tf_test_utils.CompiledModuleTestCase):
+class StatefulTest(tf_test_utils.TracedModuleTestCase):
 
   def test_stateful(self):
-    m = self.get_module()
-    m.inc_by(tf.constant(1.))
-    m.get_state().print().assert_all_close()
+
+    def get_state(module):
+      module.inc_by(np.array(1., dtype=np.float32))
+      module.get_state()
+
+    self.compare_backends(get_state)
 
 
 if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/sliding_window_test.py b/integrations/tensorflow/e2e/sliding_window_test.py
index f206d86..513aa97 100644
--- a/integrations/tensorflow/e2e/sliding_window_test.py
+++ b/integrations/tensorflow/e2e/sliding_window_test.py
@@ -76,18 +76,20 @@
 
 
 @tf_test_utils.compile_module(SlidingWindowModule, exported_names=["predict"])
-class SlidingWindowTest(tf_test_utils.CompiledModuleTestCase):
+class SlidingWindowTest(tf_test_utils.TracedModuleTestCase):
 
-  def test_slidingwindow(self):
-    input1 = np.array([[1.0, 2.0]], dtype=np.float32)
-    result1 = self.get_module().predict(input1)
-    output1 = np.array([[0.0, 0.0], [0.0, 0.0], [1.0, 2.0]], dtype=np.float32)
-    assert np.allclose(result1, output1)
+  def test_sliding_window(self):
 
-    input2 = np.array([[3.0, 4.0]], dtype=np.float32)
-    result2 = self.get_module().predict(input2)
-    output2 = np.array([[0.0, 0.0], [1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
-    assert np.allclose(result2, output2)
+    def sliding_window(module):
+      input1 = np.array([[1.0, 2.0]], dtype=np.float32)
+      result1 = module.predict(input1)
+      # output1 = np.array([[0.0, 0.0], [0.0, 0.0], [1.0, 2.0]], dtype=np.float32)
+
+      input2 = np.array([[3.0, 4.0]], dtype=np.float32)
+      result2 = module.predict(input2)
+      # output2 = np.array([[0.0, 0.0], [1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
+
+    self.compare_backends(sliding_window)
 
 
 if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/strings_test.py b/integrations/tensorflow/e2e/strings_test.py
index ce0787e..206b33c 100644
--- a/integrations/tensorflow/e2e/strings_test.py
+++ b/integrations/tensorflow/e2e/strings_test.py
@@ -41,20 +41,27 @@
 
 
 @tf_test_utils.compile_module(StringsModule)
-class StringsTest(tf_test_utils.CompiledModuleTestCase):
+class StringsTest(tf_test_utils.TracedModuleTestCase):
 
   def test_print_ids(self):
-    input_ids = np.asarray(
-        [[12, 10, 29, 28, 94, 15, 24, 27, 94, 25, 21, 10, 34],
-         [13, 24, 16, 28, 94, 15, 24, 27, 94, 28, 29, 10, 34]])
-    self.get_module().print_ids(input_ids)
+
+    def print_ids(module):
+      input_ids = np.asarray(
+          [[12, 10, 29, 28, 94, 15, 24, 27, 94, 25, 21, 10, 34],
+           [13, 24, 16, 28, 94, 15, 24, 27, 94, 28, 29, 10, 34]])
+      module.print_ids(input_ids)
+
+    self.compare_backends(print_ids)
 
   def test_strings_to_ids(self):
-    input_ids = np.asarray(
-        [[12, 10, 29, 28, 94, 15, 24, 27, 94, 25, 21, 10, 34],
-         [13, 24, 16, 28, 94, 15, 24, 27, 94, 28, 29, 10, 34]])
-    result = self.get_module().strings_to_ids(input_ids)
-    result.assert_all_equal()
+
+    def strings_to_ids(module):
+      input_ids = np.asarray(
+          [[12, 10, 29, 28, 94, 15, 24, 27, 94, 25, 21, 10, 34],
+           [13, 24, 16, 28, 94, 15, 24, 27, 94, 28, 29, 10, 34]])
+      module.strings_to_ids(input_ids)
+
+    self.compare_backends(strings_to_ids)
 
 
 if __name__ == "__main__":
diff --git a/integrations/tensorflow/e2e/tensorlist_test.py b/integrations/tensorflow/e2e/tensorlist_test.py
index 9b1330c..440bd43 100644
--- a/integrations/tensorflow/e2e/tensorlist_test.py
+++ b/integrations/tensorflow/e2e/tensorlist_test.py
@@ -12,7 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import numpy as np
 from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
 import tensorflow.compat.v2 as tf
 
 STATIC_SIZE = 20
@@ -65,36 +67,46 @@
 
 
 @tf_test_utils.compile_module(TensorListModule)
-class TensorListTest(tf_test_utils.CompiledModuleTestCase):
+class TensorListTest(tf_test_utils.TracedModuleTestCase):
 
   def test_identity_through_tensorlist(self):
-    m = self.get_module()
-    result = m.identity_through_tensorlist(tf.constant(42.))
-    result.print().assert_all_close()
+
+    def identity_through_tensorlist(module):
+      module.identity_through_tensorlist(np.array(42., dtype=np.float32))
+
+    self.compare_backends(identity_through_tensorlist)
 
   def test_add_through_tensorlist(self):
-    m = self.get_module()
-    result = m.add_through_tensorlist(tf.constant(42.), tf.constant(43.))
-    result.print().assert_all_close()
+
+    def add_through_tensorlist(module):
+      module.add_through_tensorlist(
+          np.array(42., dtype=np.float32), np.array(43., dtype=np.float32))
+
+    self.compare_backends(add_through_tensorlist)
 
   def test_slice_first_element_with_from_tensor(self):
-    m = self.get_module()
-    result = m.slice_first_element_with_from_tensor(
-        tf.range(STATIC_SIZE, dtype=tf.float32))
-    result.print().assert_all_close()
+
+    def slice_first_element_with_from_tensor(module):
+      module.slice_first_element_with_from_tensor(
+          np.arange(STATIC_SIZE, dtype=np.float32))
+
+    self.compare_backends(slice_first_element_with_from_tensor)
 
   def test_slice_first_element_with_from_tensor_high_rank(self):
-    m = self.get_module()
-    result = m.slice_first_element_with_from_tensor_high_rank(
-        tf.broadcast_to(
-            tf.range(STATIC_SIZE, dtype=tf.float32),
-            [STATIC_SIZE, STATIC_SIZE]))
-    result.print().assert_all_close()
+
+    def slice_first_element_with_from_tensor_high_rank(module):
+      module.slice_first_element_with_from_tensor_high_rank(
+          tf_utils.ndarange([STATIC_SIZE, STATIC_SIZE]))
+
+    self.compare_backends(slice_first_element_with_from_tensor_high_rank)
 
   def test_concat_with_tensorlist_stack(self):
-    m = self.get_module()
-    result = m.concat_with_tensorlist_stack(tf.constant(42.), tf.constant(43.))
-    result.print().assert_all_close()
+
+    def concat_with_tensorlist_stack(module):
+      module.concat_with_tensorlist_stack(
+          np.array(42., dtype=np.float32), np.array(43., dtype=np.float32))
+
+    self.compare_backends(concat_with_tensorlist_stack)
 
 
 if __name__ == "__main__":
diff --git a/iree/compiler/Conversion/BUILD b/iree/compiler/Conversion/BUILD
index 35f424e..8a297f8 100644
--- a/iree/compiler/Conversion/BUILD
+++ b/iree/compiler/Conversion/BUILD
@@ -24,6 +24,7 @@
         "init_conversions.h",
     ],
     deps = [
+        "//iree/compiler/Conversion/HLOToHLO",
         "//iree/compiler/Conversion/HLOToLinalg",
         "//iree/compiler/Conversion/LinalgToLLVM",
         "//iree/compiler/Conversion/LinalgToSPIRV",
diff --git a/iree/compiler/Conversion/CMakeLists.txt b/iree/compiler/Conversion/CMakeLists.txt
index 6a022f2..5fa1478 100644
--- a/iree/compiler/Conversion/CMakeLists.txt
+++ b/iree/compiler/Conversion/CMakeLists.txt
@@ -20,6 +20,7 @@
   HDRS
     "init_conversions.h"
   DEPS
+    iree::compiler::Conversion::HLOToHLO
     iree::compiler::Conversion::HLOToLinalg
     iree::compiler::Conversion::LinalgToLLVM
     iree::compiler::Conversion::LinalgToSPIRV
diff --git a/iree/compiler/Conversion/HLOToHLO/BUILD b/iree/compiler/Conversion/HLOToHLO/BUILD
new file mode 100644
index 0000000..685e835
--- /dev/null
+++ b/iree/compiler/Conversion/HLOToHLO/BUILD
@@ -0,0 +1,38 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+    default_visibility = ["//visibility:public"],
+    features = ["layering_check"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+cc_library(
+    name = "HLOToHLO",
+    srcs = [
+        "DecomposeHLOClamp.cpp",
+        "Passes.cpp",
+    ],
+    hdrs = [
+        "Passes.h",
+    ],
+    deps = [
+        "@llvm-project//mlir:CFGTransforms",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Pass",
+        "@org_tensorflow//tensorflow/compiler/mlir/hlo",
+        "@org_tensorflow//tensorflow/compiler/mlir/hlo:legalize_gather_to_torch_index_select",
+        "@org_tensorflow//tensorflow/compiler/mlir/hlo:mhlo_to_mhlo_lowering_patterns",
+    ],
+)
diff --git a/iree/compiler/Conversion/HLOToHLO/CMakeLists.txt b/iree/compiler/Conversion/HLOToHLO/CMakeLists.txt
new file mode 100644
index 0000000..f94bbb5
--- /dev/null
+++ b/iree/compiler/Conversion/HLOToHLO/CMakeLists.txt
@@ -0,0 +1,31 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+  NAME
+    HLOToHLO
+  HDRS
+    "Passes.h"
+  SRCS
+    "DecomposeHLOClamp.cpp"
+    "Passes.cpp"
+  DEPS
+    MLIRIR
+    MLIRPass
+    MLIRSCFToStandard
+    tensorflow::mlir_hlo
+  PUBLIC
+)
diff --git a/iree/compiler/Conversion/HLOToLinalg/DecomposeHLOClamp.cpp b/iree/compiler/Conversion/HLOToHLO/DecomposeHLOClamp.cpp
similarity index 100%
rename from iree/compiler/Conversion/HLOToLinalg/DecomposeHLOClamp.cpp
rename to iree/compiler/Conversion/HLOToHLO/DecomposeHLOClamp.cpp
diff --git a/iree/compiler/Conversion/HLOToHLO/Passes.cpp b/iree/compiler/Conversion/HLOToHLO/Passes.cpp
new file mode 100644
index 0000000..4bb2a8f
--- /dev/null
+++ b/iree/compiler/Conversion/HLOToHLO/Passes.cpp
@@ -0,0 +1,45 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Conversion/HLOToHLO/Passes.h"
+
+#include "mlir/IR/Matchers.h"
+#include "mlir/Pass/PassManager.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+struct ConvertHLOToCompatibleHLOPass
+    : public PassWrapper<ConvertHLOToCompatibleHLOPass, FunctionPass> {
+  void runOnFunction() override {
+    MLIRContext *context = &getContext();
+    OwningRewritePatternList greedyPatterns;
+    mhlo::PopulateComplexLoweringPatterns(context, &greedyPatterns);
+    mhlo::PopulateGatherToTorchIndexSelectPatterns(context, &greedyPatterns);
+
+    if (failed(applyPatternsAndFoldGreedily(getOperation(), greedyPatterns))) {
+      return signalPassFailure();
+    }
+  }
+};
+}  // namespace
+
+std::unique_ptr<OperationPass<FuncOp>> createHLOToCompatibleHLOPass() {
+  return std::make_unique<ConvertHLOToCompatibleHLOPass>();
+}
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Conversion/HLOToHLO/Passes.h b/iree/compiler/Conversion/HLOToHLO/Passes.h
new file mode 100644
index 0000000..4dc36de
--- /dev/null
+++ b/iree/compiler/Conversion/HLOToHLO/Passes.h
@@ -0,0 +1,47 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- Passes.h - Pass to convert from XLA-HLO to IREE supported XLA-HLO --===//
+//
+// IREE specific passes used to sanitize XLA-HLO to IREE compatible XLA-HLO.
+// Some examples may be raising gather operations to torch index select or
+// decomposing complex operations input real valued equivalents.
+//
+//===----------------------------------------------------------------------===//
+#ifndef IREE_COMPILER_CONVERSION_HLOTOHLO_PASSES_H_
+#define IREE_COMPILER_CONVERSION_HLOTOHLO_PASSES_H_
+
+#include <memory>
+
+#include "mlir/IR/Function.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Creates a pass to decompose XLA-HLO clamp ops into primitive ops.
+std::unique_ptr<OperationPass<FuncOp>> createDecomposeHLOClampPass();
+
+/// Creates XLA-HLO to XLA-HLO transformation pass.
+std::unique_ptr<OperationPass<FuncOp>> createHLOToCompatibleHLOPass();
+
+/// Populates the patterns that convert from XLA-HLO to IREE compatible XLA-HLO.
+/// Imports patterns from the Tensorflow XLA passes.
+void populateHLOToCompatibleHLOPatterns(MLIRContext *context,
+                                        OwningRewritePatternList &patterns);
+
+}  // namespace iree_compiler
+}  // namespace mlir
+
+#endif  // IREE_COMPILER_CONVERSION_HLOTOHLO_PASSES_H_
diff --git a/iree/compiler/Conversion/HLOToLinalg/BUILD b/iree/compiler/Conversion/HLOToLinalg/BUILD
index 2e7447b..89e790d 100644
--- a/iree/compiler/Conversion/HLOToLinalg/BUILD
+++ b/iree/compiler/Conversion/HLOToLinalg/BUILD
@@ -21,7 +21,6 @@
 cc_library(
     name = "HLOToLinalg",
     srcs = [
-        "DecomposeHLOClamp.cpp",
         "HLOToLinalgOnBuffers.cpp",
         "HLOToLinalgOnTensors.cpp",
         "Passes.cpp",
@@ -32,6 +31,7 @@
     ],
     deps = [
         "//iree/compiler/Conversion/CodegenUtils",
+        "//iree/compiler/Conversion/HLOToHLO",
         "//iree/compiler/Dialect/HAL/IR",
         "//iree/compiler/Dialect/HAL/IR:HALDialect",
         "//iree/compiler/Dialect/IREE/IR",
diff --git a/iree/compiler/Conversion/HLOToLinalg/CMakeLists.txt b/iree/compiler/Conversion/HLOToLinalg/CMakeLists.txt
index bdfd4c9..26da6bd 100644
--- a/iree/compiler/Conversion/HLOToLinalg/CMakeLists.txt
+++ b/iree/compiler/Conversion/HLOToLinalg/CMakeLists.txt
@@ -20,7 +20,6 @@
   HDRS
     "Passes.h"
   SRCS
-    "DecomposeHLOClamp.cpp"
     "HLOToLinalgOnBuffers.cpp"
     "HLOToLinalgOnTensors.cpp"
     "Passes.cpp"
@@ -36,6 +35,7 @@
     MLIRSupport
     MLIRTransforms
     iree::compiler::Conversion::CodegenUtils
+    iree::compiler::Conversion::HLOToHLO
     iree::compiler::Dialect::HAL::IR
     iree::compiler::Dialect::HAL::IR::HALDialect
     iree::compiler::Dialect::IREE::IR
diff --git a/iree/compiler/Conversion/HLOToLinalg/Passes.cpp b/iree/compiler/Conversion/HLOToLinalg/Passes.cpp
index 91432c3..e5638b0 100644
--- a/iree/compiler/Conversion/HLOToLinalg/Passes.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/Passes.cpp
@@ -14,6 +14,7 @@
 
 #include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
 
+#include "iree/compiler/Conversion/HLOToHLO/Passes.h"
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/Passes.h"
@@ -22,6 +23,7 @@
 namespace iree_compiler {
 
 void addHLOToLinalgOnBuffersPasses(OpPassManager &pm) {
+  pm.addPass(createHLOToCompatibleHLOPass());
   pm.addPass(createHLOToLinalgOnTensorsPass());
   pm.addPass(createLinalgFusionOfTensorOpsPass());
   pm.addPass(createLinalgFoldUnitExtentDimsPass());
diff --git a/iree/compiler/Conversion/HLOToLinalg/Passes.h b/iree/compiler/Conversion/HLOToLinalg/Passes.h
index 437b0a6..445bd51 100644
--- a/iree/compiler/Conversion/HLOToLinalg/Passes.h
+++ b/iree/compiler/Conversion/HLOToLinalg/Passes.h
@@ -27,9 +27,6 @@
 namespace mlir {
 namespace iree_compiler {
 
-/// Crates a pass to decompose XLA-HLO clamp ops into primitive ops.
-std::unique_ptr<OperationPass<FuncOp>> createDecomposeHLOClampPass();
-
 /// Creates XLA-HLO to Linalg on buffers transformation pass.
 std::unique_ptr<OperationPass<FuncOp>> createHLOToLinalgOnBuffersPass();
 
diff --git a/iree/compiler/Conversion/LinalgToLLVM/BUILD b/iree/compiler/Conversion/LinalgToLLVM/BUILD
index c6cafc1..48d21e3 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/BUILD
+++ b/iree/compiler/Conversion/LinalgToLLVM/BUILD
@@ -30,6 +30,7 @@
     ],
     deps = [
         "//iree/compiler/Conversion/CodegenUtils",
+        "//iree/compiler/Conversion/HLOToHLO",
         "//iree/compiler/Conversion/HLOToLinalg",
         "//iree/compiler/Dialect/HAL/IR",
         "//iree/compiler/Dialect/HAL/IR:HALDialect",
diff --git a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
index fddc144..09e06fa 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
@@ -39,6 +39,7 @@
     MLIRVectorToLLVM
     MLIRVectorToSCF
     iree::compiler::Conversion::CodegenUtils
+    iree::compiler::Conversion::HLOToHLO
     iree::compiler::Conversion::HLOToLinalg
     iree::compiler::Dialect::HAL::IR
     iree::compiler::Dialect::HAL::IR::HALDialect
diff --git a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
index 777fb02..fc476d0 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
@@ -144,8 +144,9 @@
  public:
   explicit ConvertFuncWithHALInterface(MLIRContext *context,
                                        LLVMTypeConverter &typeconverter)
-      : ConvertToLLVMPattern(FuncOp::getOperationName(), context,
-                             typeconverter) {}
+      : ConvertToLLVMPattern(
+            mlir::FuncOp::getOperationName(), context, typeconverter,
+            LowerToLLVMOptions::getDefaultOptions(), 65535 - 1) {}
 
   LogicalResult matchAndRewrite(
       Operation *op, ArrayRef<Value> operands,
@@ -345,8 +346,15 @@
                   RemoveInterfaceOpPattern>(&getContext(), converter);
   LLVMConversionTarget target(getContext());
   target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
-  if (failed(applyPartialConversion(module, target, patterns)))
+  target.addIllegalOp<IREE::PlaceholderOp>();
+  target.addDynamicallyLegalOp<FuncOp>([](FuncOp funcOp) {
+    bool any = false;
+    funcOp.walk([&](IREE::PlaceholderOp placeholderOp) { any = true; });
+    return any ? false : true;
+  });
+  if (failed(applyPartialConversion(module, target, patterns))) {
     signalPassFailure();
+  }
 }
 
 std::unique_ptr<OperationPass<ModuleOp>> createConvertToLLVMPass() {
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
index 8631cdf..c83a5d8 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
@@ -14,6 +14,7 @@
 
 #include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
 
+#include "iree/compiler/Conversion/HLOToHLO/Passes.h"
 #include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
 #include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
index fc447e6..7a80343 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/BUILD
+++ b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
@@ -41,6 +41,7 @@
     ],
     deps = [
         "//iree/compiler/Conversion/CodegenUtils",
+        "//iree/compiler/Conversion/HLOToHLO",
         "//iree/compiler/Conversion/HLOToLinalg",
         "//iree/compiler/Dialect/HAL/IR",
         "//iree/compiler/Dialect/IREE/IR",
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
index 69c04d9..7bca4c8 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
@@ -56,6 +56,7 @@
     MLIRTransforms
     MLIRVector
     iree::compiler::Conversion::CodegenUtils
+    iree::compiler::Conversion::HLOToHLO
     iree::compiler::Conversion::HLOToLinalg
     iree::compiler::Dialect::HAL::IR
     iree::compiler::Dialect::IREE::IR
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
index ed2ab3a..0991416 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
@@ -20,6 +20,7 @@
 
 #include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
 
+#include "iree/compiler/Conversion/HLOToHLO/Passes.h"
 #include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
 #include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
 #include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
index b1bba5e..198a019 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
@@ -171,8 +171,7 @@
                                             separableOp.index()));
     StringRef newFnName = splitKernels.back();
     builder.setInsertionPointToStart(moduleOp.getBody());
-    auto newFn = builder.create<FuncOp>(loc, newFnName, oldFn.getType(),
-                                        /*attrs=*/ArrayRef<NamedAttribute>());
+    auto newFn = builder.create<FuncOp>(loc, newFnName, oldFn.getType());
 
     // Copy over all attributes except type and name.
     for (const auto &namedAttr : oldFn.getAttrs()) {
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
index c16fd06..b8d7bfa 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
@@ -76,8 +76,8 @@
     if (VectorToGPUPattern<StdOp>::cooperativeMatrixAnalysis
             .usesCooperativeMatrixType(operation))
       return failure();
-    Value newOp = rewriter.create<StdOp>(
-        operation.getLoc(), ValueRange(operands), ArrayRef<NamedAttribute>{});
+    Value newOp =
+        rewriter.create<StdOp>(operation.getLoc(), ValueRange(operands));
     rewriter.replaceOp(operation, ValueRange(newOp));
     return success();
   }
diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h
index 259e3d5..7450f85 100644
--- a/iree/compiler/Conversion/init_conversions.h
+++ b/iree/compiler/Conversion/init_conversions.h
@@ -15,6 +15,7 @@
 #ifndef IREE_COMPILER_CONVERSION_INIT_CONVERSIONS_H_
 #define IREE_COMPILER_CONVERSION_INIT_CONVERSIONS_H_
 
+#include "iree/compiler/Conversion/HLOToHLO/Passes.h"
 #include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
 #include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
 #include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
@@ -26,6 +27,8 @@
 // expects all the possible conversions to be made available to the context
 // automatically.
 
+inline void registerHLOToHLOPasses() { createHLOToCompatibleHLOPass(); }
+
 inline void registerHLOToLinalgPasses() {
   createDecomposeHLOClampPass();
   createHLOToLinalgOnBuffersPass();
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp
index 4c3f467..40053dd 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp
@@ -42,8 +42,7 @@
   // provided symbols and it seems like it shouldn't be.
   auto uniqueName = (Twine("__") + variableOp.getName() + "_initializer").str();
   auto initializerFuncOp =
-      rewriter.create<FuncOp>(variableOp.getLoc(), uniqueName, initializerType,
-                              ArrayRef<NamedAttribute>{});
+      rewriter.create<FuncOp>(variableOp.getLoc(), uniqueName, initializerType);
   auto *entryBlock = initializerFuncOp.addEntryBlock();
   rewriter.setInsertionPointToEnd(entryBlock);
 
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRTarget.cpp
index 96bb5ac..128abd4 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRTarget.cpp
@@ -52,6 +52,9 @@
     // At this moment we are leaving MLIR LLVM dialect land translating module
     // into target independent LLVMIR.
     auto llvmModule = mlir::translateModuleToLLVMIR(targetOp.getInnerModule());
+    if (!llvmModule) {
+      return targetOp.emitError("Failed to translate executable to LLVM IR");
+    }
 
     // Create invocation function an populate entry_points.
     iree::LLVMIRExecutableDefT llvmIrExecutableDef;
diff --git a/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp b/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
index e605418..1c1c9e4 100644
--- a/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
@@ -123,8 +123,7 @@
     descriptorSetLayoutCache_.try_emplace(bindingsAttr, variableOp);
 
     auto initializerOp = moduleBuilder.create<FuncOp>(
-        loc, initializerName, moduleBuilder.getFunctionType({}, {layoutType}),
-        ArrayRef<NamedAttribute>{});
+        loc, initializerName, moduleBuilder.getFunctionType({}, {layoutType}));
     SymbolTable::setSymbolVisibility(initializerOp,
                                      SymbolTable::Visibility::Private);
     auto *block = initializerOp.addEntryBlock();
@@ -168,8 +167,7 @@
     executableLayoutCache_.try_emplace(setLayoutsArrayAttr, variableOp);
 
     auto initializerOp = moduleBuilder.create<FuncOp>(
-        loc, initializerName, moduleBuilder.getFunctionType({}, {layoutType}),
-        ArrayRef<NamedAttribute>{});
+        loc, initializerName, moduleBuilder.getFunctionType({}, {layoutType}));
     SymbolTable::setSymbolVisibility(initializerOp,
                                      SymbolTable::Visibility::Private);
     auto *block = initializerOp.addEntryBlock();
@@ -212,8 +210,7 @@
 
     auto initializerOp = moduleBuilder.create<FuncOp>(
         loc, initializerName,
-        moduleBuilder.getFunctionType({}, {executableCacheType}),
-        ArrayRef<NamedAttribute>{});
+        moduleBuilder.getFunctionType({}, {executableCacheType}));
     SymbolTable::setSymbolVisibility(initializerOp,
                                      SymbolTable::Visibility::Private);
     auto *block = initializerOp.addEntryBlock();
diff --git a/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp b/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
index 83b1884..3e343eb 100644
--- a/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
@@ -72,8 +72,7 @@
           "_device_match_id_" + std::to_string(matchKey.index());
       auto initializerOp = moduleBuilder.create<FuncOp>(
           fusedLoc, variableName + "_initializer",
-          moduleBuilder.getFunctionType({}, {moduleBuilder.getI1Type()}),
-          ArrayRef<NamedAttribute>{});
+          moduleBuilder.getFunctionType({}, {moduleBuilder.getI1Type()}));
       SymbolTable::setSymbolVisibility(initializerOp,
                                        SymbolTable::Visibility::Private);
       auto variableOp = moduleBuilder.create<IREE::HAL::VariableOp>(
diff --git a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
index b3ceb17..9c7f0c2 100644
--- a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
+++ b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
@@ -1270,11 +1270,11 @@
     if (op.getOperation()->getOperand(0).getType().template isa<RefType>()) {
       condValue = rewriter.template createOrFold<CmpRefOp>(
           op.getLoc(), ArrayRef<Type>{condType},
-          op.getOperation()->getOperands(), ArrayRef<NamedAttribute>{});
+          op.getOperation()->getOperands());
     } else {
       condValue = rewriter.template createOrFold<CmpI32Op>(
           op.getLoc(), ArrayRef<Type>{condType},
-          op.getOperation()->getOperands(), ArrayRef<NamedAttribute>{});
+          op.getOperation()->getOperands());
     }
     condValue = rewriter.createOrFold<XorI32Op>(
         op.getLoc(), condType, condValue,
diff --git a/iree/compiler/Dialect/VMLA/Transforms/BUILD b/iree/compiler/Dialect/VMLA/Transforms/BUILD
index e70cb08..d53f037 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/BUILD
+++ b/iree/compiler/Dialect/VMLA/Transforms/BUILD
@@ -48,5 +48,6 @@
         "@llvm-project//mlir:Transforms",
         "@org_tensorflow//tensorflow/compiler/mlir/hlo",
         "@org_tensorflow//tensorflow/compiler/mlir/hlo:lhlo_fuse_linalg",
+        "@org_tensorflow//tensorflow/compiler/mlir/hlo:mhlo_to_mhlo_lowering_patterns",
     ],
 )
diff --git a/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp b/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
index 15d91c6..f22c87f 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
+++ b/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
@@ -34,6 +34,7 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
 
 namespace mlir {
 namespace iree_compiler {
@@ -281,6 +282,15 @@
  public:
   void runOnOperation() {
     MLIRContext *context = &getContext();
+
+    // These patterns should be run greedily as they are not dialect
+    // conversions.
+    OwningRewritePatternList greedyPatterns;
+    mhlo::PopulateComplexLoweringPatterns(context, &greedyPatterns);
+    if (failed(applyPatternsAndFoldGreedily(getOperation(), greedyPatterns))) {
+      return signalPassFailure();
+    }
+
     OwningRewritePatternList patterns;
     ConversionTarget target(*context);
     target.addLegalDialect<StandardOpsDialect>();
diff --git a/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir b/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir
index d21bfda..3473e44 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir
+++ b/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir
@@ -31,3 +31,20 @@
   %0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[5, 6]> : tensor<2xi64>} : (tensor<3xf32>) -> tensor<5x6x3xf32>
   return %0 : tensor<5x6x3xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
+  // CHECK-NOT: "mhlo.complex"
+  %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex<f32>>
+
+  // CHECK-DAG: [[V1:%.+]] = mhlo.multiply %arg0, %arg0
+  // CHECK-DAG: [[V2:%.+]] = mhlo.multiply %arg1, %arg1
+  // CHECK-DAG: [[V3:%.+]] = mhlo.subtract [[V1]], [[V2]]
+  %1 = "mhlo.multiply"(%0, %0) : (tensor<3xcomplex<f32>>, tensor<3xcomplex<f32>>) -> tensor<3xcomplex<f32>>
+  %2 = "mhlo.real"(%1) : (tensor<3xcomplex<f32>>) -> tensor<3xf32>
+
+  // CHECK: return [[V3]]
+  return %2 : tensor<3xf32>
+}
diff --git a/iree/schemas/BUILD b/iree/schemas/BUILD
index 4d9d6d1..b1301e2 100644
--- a/iree/schemas/BUILD
+++ b/iree/schemas/BUILD
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 load("//iree:build_defs.oss.bzl", "FLATBUFFER_SUPPORTS_REFLECTIONS", "iree_build_test", "iree_flatbuffer_cc_library")
+load("//build_tools/bazel:iree_flatcc.bzl", "iree_flatbuffer_c_library")
 load("//build_tools/embed_data:build_defs.bzl", "cc_embed_data")
 
 package(
@@ -33,7 +34,16 @@
     "--gen-object-api",
 ]
 
-# TODO(benvanik): also expose as C using flatcc.
+iree_flatbuffer_c_library(
+    name = "bytecode_module_def_c_fbs",
+    srcs = ["bytecode_module_def.fbs"],
+    flatcc_args = [
+        "--reader",
+        "--builder",
+        "--verifier",
+    ],
+)
+
 iree_flatbuffer_cc_library(
     name = "bytecode_module_def_cc_fbs",
     srcs = ["bytecode_module_def.fbs"],
diff --git a/iree/schemas/CMakeLists.txt b/iree/schemas/CMakeLists.txt
index 751ed40..313ffa1 100644
--- a/iree/schemas/CMakeLists.txt
+++ b/iree/schemas/CMakeLists.txt
@@ -14,6 +14,18 @@
 
 iree_add_all_subdirs()
 
+flatbuffer_c_library(
+  NAME
+    bytecode_module_def_c_fbs
+  SRCS
+    "bytecode_module_def.fbs"
+  FLATCC_ARGS
+    "--reader"
+    "--builder"
+    "--verifier"
+  PUBLIC
+)
+
 flatbuffer_cc_library(
   NAME
     bytecode_module_def_cc_fbs
diff --git a/iree/schemas/bytecode_module_def.fbs b/iree/schemas/bytecode_module_def.fbs
index 2394afb..f971093 100644
--- a/iree/schemas/bytecode_module_def.fbs
+++ b/iree/schemas/bytecode_module_def.fbs
@@ -93,7 +93,7 @@
   compression_type:CompressionTypeDef;
 
   // Contents in a format defined by CompressionTypeDef.
-  data:[uint8] (force_align: 16);
+  data:[uint8];
 }
 
 // Read-write data segment.
@@ -117,6 +117,7 @@
   // Offset and length within the larger bytecode data block.
   bytecode_offset:int32;
   bytecode_length:int32;
+  // TODO(benvanik): remove counts and embed directly in bytecode.
   // Total number of i32 registers used by the function.
   i32_register_count:int16;
   // Total number of ref_ptr registers used by the function.
@@ -175,7 +176,7 @@
   function_descriptors:[FunctionDescriptor];
 
   // Bytecode contents. One large buffer containing all of the function op data.
-  bytecode_data:[uint8] (force_align: 4);
+  bytecode_data:[uint8];
 }
 
 root_type BytecodeModuleDef;
diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD
index 744ff22..9ccb2d9 100644
--- a/iree/test/e2e/xla_ops/BUILD
+++ b/iree/test/e2e/xla_ops/BUILD
@@ -57,11 +57,8 @@
         # TODO(#1696): Enable after standard dialect can support floor
         # operation. Lowering from XLA -> linalg should be easy fix.
         # "floor.mlir",
-
-        # TODO(#1694): Enable after mhlo.gather can be lowered to linalg.
-        # "gather.mlir",
-        # "gather_concat.mlir",
-        #
+        "gather.mlir",
+        "gather_concat.mlir",
         "iota.mlir",
         "log.mlir",
         "maximum.mlir",
@@ -89,6 +86,8 @@
     target_backend = "vulkan-spirv",
 )
 
+# TODO(ataei): Enable dylib-llvm-aot tests.
+# See: https://github.com/google/iree/issues/2645
 iree_check_single_backend_test_suite(
     name = "check_llvm-ir_llvm",
     srcs = [
@@ -107,6 +106,8 @@
         "divide.mlir",
         "dot.mlir",
         "exponential.mlir",
+        "gather.mlir",
+        "gather_concat.mlir",
         "iota.mlir",
         "log.mlir",
         "maximum.mlir",
diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt
index e4ae959..51a62cf 100644
--- a/iree/test/e2e/xla_ops/CMakeLists.txt
+++ b/iree/test/e2e/xla_ops/CMakeLists.txt
@@ -46,6 +46,8 @@
     "divide.mlir"
     "dot.mlir"
     "exponential.mlir"
+    "gather.mlir"
+    "gather_concat.mlir"
     "iota.mlir"
     "log.mlir"
     "maximum.mlir"
@@ -93,6 +95,8 @@
     "divide.mlir"
     "dot.mlir"
     "exponential.mlir"
+    "gather.mlir"
+    "gather_concat.mlir"
     "iota.mlir"
     "log.mlir"
     "maximum.mlir"
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index 46ea8a9..a3b5a9f 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -243,6 +243,7 @@
         "@com_google_absl//absl/types:span",
         "//iree/base:api",
         "//iree/base:api_util",
+        "//iree/base:localfile",
         "//iree/base:source_location",
         "//iree/base:tracing",
         "//iree/compiler/Dialect/Flow/Transforms",
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index a277877..b3a4fa9 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -330,6 +330,7 @@
       iree::base::api
       iree::base::api_util
       iree::base::init
+      iree::base::localfile
       iree::base::source_location
       iree::base::status
       iree::base::tracing
diff --git a/iree/vm/BUILD b/iree/vm/BUILD
index 93829d0..ceec083 100644
--- a/iree/vm/BUILD
+++ b/iree/vm/BUILD
@@ -59,11 +59,10 @@
         ":value",
         "//iree/base:alignment",
         "//iree/base:api",
-        "//iree/base:flatbuffer_util",
+        "//iree/base:logging",
         "//iree/base:target_platform",
-        "//iree/schemas:bytecode_module_def_cc_fbs",
-        "@com_github_google_flatbuffers//:flatbuffers",
-        "@com_google_absl//absl/strings",
+        "//iree/schemas:bytecode_module_def_c_fbs",
+        "@com_github_dvidelabs_flatcc//:runtime",
     ],
 )
 
diff --git a/iree/vm/CMakeLists.txt b/iree/vm/CMakeLists.txt
index ae91187..84b8bf5 100644
--- a/iree/vm/CMakeLists.txt
+++ b/iree/vm/CMakeLists.txt
@@ -64,13 +64,12 @@
     ::stack
     ::type_def
     ::value
-    absl::strings
-    flatbuffers
+    flatcc::runtime
     iree::base::alignment
     iree::base::api
-    iree::base::flatbuffer_util
+    iree::base::logging
     iree::base::target_platform
-    iree::schemas::bytecode_module_def_cc_fbs
+    iree::schemas::bytecode_module_def_c_fbs
   PUBLIC
 )
 
diff --git a/iree/vm/bytecode_dispatch.c b/iree/vm/bytecode_dispatch.c
index eab5136..2565c5e 100644
--- a/iree/vm/bytecode_dispatch.c
+++ b/iree/vm/bytecode_dispatch.c
@@ -741,7 +741,7 @@
         target_function.module = &module->interface;
         target_function.linkage = IREE_VM_FUNCTION_LINKAGE_INTERNAL;
         target_function.ordinal = function_ordinal;
-        const iree_vm_function_descriptor_t* target_descriptor =
+        const iree_vm_FunctionDescriptor_t* target_descriptor =
             &module->function_descriptor_table[function_ordinal];
         target_function.i32_register_count =
             target_descriptor->i32_register_count;
diff --git a/iree/vm/bytecode_module.cc b/iree/vm/bytecode_module.cc
index 9f56a35..cec96fd 100644
--- a/iree/vm/bytecode_module.cc
+++ b/iree/vm/bytecode_module.cc
@@ -14,23 +14,20 @@
 
 #include "iree/vm/bytecode_module.h"
 
-#include <string.h>
-
-#include "absl/strings/match.h"
 #include "iree/base/alignment.h"
 #include "iree/base/api.h"
-#include "iree/base/flatbuffer_util.h"
+#include "iree/base/logging.h"
 #include "iree/vm/bytecode_module_impl.h"
 #include "iree/vm/ref.h"
 #include "iree/vm/stack.h"
 
-// TODO(benvanik): replace with flatcc version so this file can be pure C.
-#include "flatbuffers/flatbuffers.h"
-#include "iree/schemas/bytecode_module_def_generated.h"
-
-#define IREE_VM_GET_MODULE_DEF(module)                 \
-  ::flatbuffers::GetRoot<iree::vm::BytecodeModuleDef>( \
-      module->flatbuffer_data.data)
+// Perform an strcmp between a flatbuffers string and an IREE string view.
+static bool iree_vm_flatbuffer_strcmp(flatbuffers_string_t lhs,
+                                      iree_string_view_t rhs) {
+  size_t lhs_size = flatbuffers_string_len(lhs);
+  int x = strncmp(lhs, rhs.data, lhs_size < rhs.size ? lhs_size : rhs.size);
+  return x != 0 ? x : lhs_size < rhs.size ? -1 : lhs_size > rhs.size;
+}
 
 // Returns true if the given |type_def| is valid, meaning that the type it was
 // resolved from is registered or known to the system as a builtin.
@@ -41,28 +38,36 @@
 
 // Resolves a type through either builtin rules or the ref registered types.
 static iree_vm_type_def_t iree_vm_bytecode_module_resolve_type(
-    const iree::vm::TypeDef* type_def) {
-  auto full_name = iree::WrapString(type_def->full_name());
+    iree_vm_TypeDef_table_t type_def) {
   iree_vm_type_def_t result;
   memset(&result, 0, sizeof(result));
-  if (full_name == "i8") {
+  flatbuffers_string_t full_name = iree_vm_TypeDef_full_name(type_def);
+  if (!flatbuffers_string_len(full_name)) {
+    return result;
+  } else if (iree_vm_flatbuffer_strcmp(full_name,
+                                       iree_make_cstring_view("i8")) == 0) {
     result.value_type = IREE_VM_VALUE_TYPE_I8;
-  } else if (full_name == "i16") {
+  } else if (iree_vm_flatbuffer_strcmp(full_name,
+                                       iree_make_cstring_view("i16")) == 0) {
     result.value_type = IREE_VM_VALUE_TYPE_I16;
-  } else if (full_name == "i32") {
+  } else if (iree_vm_flatbuffer_strcmp(full_name,
+                                       iree_make_cstring_view("i32")) == 0) {
     result.value_type = IREE_VM_VALUE_TYPE_I32;
-  } else if (full_name == "i64") {
+  } else if (iree_vm_flatbuffer_strcmp(full_name,
+                                       iree_make_cstring_view("i64")) == 0) {
     result.value_type = IREE_VM_VALUE_TYPE_I64;
-  } else if (!full_name.empty() && full_name[0] == '!') {
-    full_name.remove_prefix(1);
-    if (absl::StartsWith(full_name, "vm.list<")) {
+  } else if (full_name[0] == '!') {
+    // Note that we drop the ! prefix:
+    iree_string_view_t type_name = iree_string_view_t{
+        full_name + 1, flatbuffers_string_len(full_name) - 1};
+    if (strncmp(type_name.data, "vm.list<", strlen("vm.list<")) == 0) {
       // This is a !vm.list<...> type. We don't actually care about the type as
-      // we allow list types to be widened.
-      full_name.remove_suffix(full_name.size() - 7);
+      // we allow list types to be widened. Rewrite to just vm.list as that's
+      // all we have registered.
+      type_name = iree_make_cstring_view("vm.list");
     }
     const iree_vm_ref_type_descriptor_t* type_descriptor =
-        iree_vm_ref_lookup_registered_type(
-            iree_string_view_t{full_name.data(), full_name.size()});
+        iree_vm_ref_lookup_registered_type(type_name);
     if (type_descriptor) {
       result.ref_type = type_descriptor->type;
     }
@@ -74,12 +79,13 @@
 // |type_table| can be omitted to just perform verification that all types are
 // registered.
 static iree_status_t iree_vm_bytecode_module_resolve_types(
-    const iree::vm::BytecodeModuleDef* module_def,
-    iree_vm_type_def_t* type_table) {
-  for (int i = 0; i < module_def->types()->size(); ++i) {
-    type_table[i] =
-        iree_vm_bytecode_module_resolve_type(module_def->types()->Get(i));
+    iree_vm_TypeDef_vec_t type_defs, iree_vm_type_def_t* type_table) {
+  for (size_t i = 0; i < iree_vm_TypeDef_vec_len(type_defs); ++i) {
+    iree_vm_TypeDef_table_t type_def = iree_vm_TypeDef_vec_at(type_defs, i);
+    type_table[i] = iree_vm_bytecode_module_resolve_type(type_def);
     if (!iree_vm_type_def_is_valid(type_table[i])) {
+      LOG(ERROR) << "no type registered with name '"
+                 << iree_vm_TypeDef_full_name(type_def) << "'";
       return IREE_STATUS_NOT_FOUND;
     }
   }
@@ -91,121 +97,136 @@
 // names on functions with internal linkage), however we shouldn't need to
 // bounds check anything within the flatbuffer after this succeeds.
 static iree_status_t iree_vm_bytecode_module_flatbuffer_verify(
-    const iree::vm::BytecodeModuleDef* module_def) {
-  if (!module_def->name() || module_def->name()->size() == 0) {
-    LOG(ERROR) << "All modules must have a name.";
+    iree_const_byte_span_t flatbuffer_data) {
+  if (!flatbuffer_data.data || flatbuffer_data.data_length < 16) {
+    LOG(ERROR) << "Flatbuffer data is not present or less than 16 bytes";
     return IREE_STATUS_INVALID_ARGUMENT;
   }
 
-  if (!module_def->types()) {
-    LOG(ERROR) << "Type table is mandatory, though it could be empty (in empty "
-                  "modules).";
+  // Run flatcc generated verification. This ensures all pointers are in-bounds
+  // and that we can safely walk the file, but not that the actual contents of
+  // the flatbuffer meet our expectations.
+  int verify_ret = iree_vm_BytecodeModuleDef_verify_as_root(
+      flatbuffer_data.data, flatbuffer_data.data_length);
+  if (verify_ret != flatcc_verify_ok) {
+    LOG(ERROR) << flatcc_verify_error_string(verify_ret);
     return IREE_STATUS_INVALID_ARGUMENT;
   }
 
-  if (!module_def->exported_functions() ||
-      module_def->exported_functions()->size() == 0) {
-    LOG(ERROR) << "At least one exported function is required.";
+  iree_vm_BytecodeModuleDef_table_t module_def =
+      iree_vm_BytecodeModuleDef_as_root(flatbuffer_data.data);
+
+  flatbuffers_string_t name = iree_vm_BytecodeModuleDef_name(module_def);
+  if (!flatbuffers_string_len(name)) {
+    LOG(ERROR) << "module name missing";
     return IREE_STATUS_INVALID_ARGUMENT;
   }
 
-  if (!module_def->internal_functions() ||
-      module_def->internal_functions()->size() == 0) {
-    LOG(ERROR) << "At least one internal function is required.";
-    return IREE_STATUS_INVALID_ARGUMENT;
-  }
-
-  if (!module_def->function_descriptors() ||
-      module_def->function_descriptors()->size() !=
-          module_def->internal_functions()->size()) {
-    LOG(ERROR)
-        << "All internal functions need a mapping into the bytecode data.";
-    return IREE_STATUS_INVALID_ARGUMENT;
-  }
-
-  if (!module_def->bytecode_data()) {
-    LOG(ERROR) << "Bytecode data is required if we have any functions.";
-    return IREE_STATUS_INVALID_ARGUMENT;
-  }
-
-  for (int i = 0; i < module_def->types()->size(); ++i) {
-    const auto* type_def = module_def->types()->Get(i);
+  iree_vm_TypeDef_vec_t types = iree_vm_BytecodeModuleDef_types(module_def);
+  for (size_t i = 0; i < iree_vm_TypeDef_vec_len(types); ++i) {
+    iree_vm_TypeDef_table_t type_def = iree_vm_TypeDef_vec_at(types, i);
     if (!type_def) {
-      LOG(ERROR) << "All types must be valid.";
-      return IREE_STATUS_INVALID_ARGUMENT;
-    } else if (!type_def->full_name() || type_def->full_name()->size() == 0) {
-      LOG(ERROR) << "All types require a name.";
+      LOG(ERROR) << "type def missing body";
       return IREE_STATUS_INVALID_ARGUMENT;
     }
-    if (!iree_vm_type_def_is_valid(
-            iree_vm_bytecode_module_resolve_type(type_def))) {
-      LOG(ERROR) << "No type registered with name '"
-                 << type_def->full_name()->c_str() << "'.";
+    flatbuffers_string_t full_name = iree_vm_TypeDef_full_name(type_def);
+    if (flatbuffers_string_len(full_name) <= 0) {
+      LOG(ERROR) << "type def missing full_name";
       return IREE_STATUS_INVALID_ARGUMENT;
     }
   }
 
-  if (module_def->imported_functions()) {
-    for (int i = 0; i < module_def->imported_functions()->size(); ++i) {
-      auto* import_def = module_def->imported_functions()->Get(i);
-      if (!import_def) {
-        LOG(ERROR) << "All imports must be valid.";
-        return IREE_STATUS_INVALID_ARGUMENT;
-      } else if (!import_def->full_name() ||
-                 import_def->full_name()->size() == 0) {
-        LOG(ERROR) << "All imports require a name.";
-        return IREE_STATUS_INVALID_ARGUMENT;
-      } else if (!import_def->signature()) {
-        LOG(ERROR) << "All imports require a signature.";
-        return IREE_STATUS_INVALID_ARGUMENT;
-      }
+  iree_vm_ImportFunctionDef_vec_t imported_functions =
+      iree_vm_BytecodeModuleDef_imported_functions(module_def);
+  iree_vm_ExportFunctionDef_vec_t exported_functions =
+      iree_vm_BytecodeModuleDef_exported_functions(module_def);
+  iree_vm_InternalFunctionDef_vec_t internal_functions =
+      iree_vm_BytecodeModuleDef_internal_functions(module_def);
+  iree_vm_FunctionDescriptor_vec_t function_descriptors =
+      iree_vm_BytecodeModuleDef_function_descriptors(module_def);
+
+  if (flatbuffers_vec_len(internal_functions) !=
+      flatbuffers_vec_len(function_descriptors)) {
+    LOG(ERROR)
+        << "mismatched internal_functions and function_descriptors vectors";
+    return IREE_STATUS_INVALID_ARGUMENT;
+  }
+
+  for (size_t i = 0; i < iree_vm_ImportFunctionDef_vec_len(imported_functions);
+       ++i) {
+    iree_vm_ImportFunctionDef_table_t import_def =
+        iree_vm_ImportFunctionDef_vec_at(imported_functions, i);
+    if (!import_def) {
+      LOG(ERROR) << "import def missing body";
+      return IREE_STATUS_INVALID_ARGUMENT;
+    }
+    flatbuffers_string_t full_name =
+        iree_vm_ImportFunctionDef_full_name(import_def);
+    if (!flatbuffers_string_len(full_name)) {
+      LOG(ERROR) << "import def missing full_name";
+      return IREE_STATUS_INVALID_ARGUMENT;
+    }
+    if (!iree_vm_ImportFunctionDef_signature(import_def)) {
+      LOG(ERROR) << "import def missing a function signature";
+      return IREE_STATUS_INVALID_ARGUMENT;
     }
   }
 
-  for (int i = 0; i < module_def->exported_functions()->size(); ++i) {
-    auto* export_def = module_def->exported_functions()->Get(i);
+  for (size_t i = 0; i < iree_vm_ExportFunctionDef_vec_len(exported_functions);
+       ++i) {
+    iree_vm_ExportFunctionDef_table_t export_def =
+        iree_vm_ExportFunctionDef_vec_at(exported_functions, i);
     if (!export_def) {
-      LOG(ERROR) << "All exports must be valid.";
+      LOG(ERROR) << "export def missing body";
       return IREE_STATUS_INVALID_ARGUMENT;
-    } else if (!export_def->local_name() ||
-               export_def->local_name()->size() == 0) {
-      LOG(ERROR) << "All exports require a name.";
+    }
+    flatbuffers_string_t local_name =
+        iree_vm_ExportFunctionDef_local_name(export_def);
+    if (!flatbuffers_string_len(local_name)) {
+      LOG(ERROR) << "export def missing local_name";
       return IREE_STATUS_INVALID_ARGUMENT;
-    } else if (!export_def->signature()) {
-      LOG(ERROR) << "All exports require a signature.";
+    }
+    if (!iree_vm_ExportFunctionDef_signature(export_def)) {
+      LOG(ERROR) << "export def missing a function signature";
       return IREE_STATUS_INVALID_ARGUMENT;
-    } else if (export_def->internal_ordinal() < 0 ||
-               export_def->internal_ordinal() >=
-                   module_def->internal_functions()->size()) {
-      LOG(ERROR)
-          << "Out-of-bounds reference to a function in the internal table.";
+    }
+    int32_t internal_ordinal =
+        iree_vm_ExportFunctionDef_internal_ordinal(export_def);
+    if (internal_ordinal < 0 ||
+        internal_ordinal >=
+            iree_vm_InternalFunctionDef_vec_len(internal_functions)) {
+      LOG(ERROR) << "export def internal_ordinal out of bounds";
       return IREE_STATUS_INVALID_ARGUMENT;
     }
   }
 
-  for (int i = 0; i < module_def->internal_functions()->size(); ++i) {
-    auto* function_def = module_def->internal_functions()->Get(i);
+  flatbuffers_uint8_vec_t bytecode_data =
+      iree_vm_BytecodeModuleDef_bytecode_data(module_def);
+  for (size_t i = 0;
+       i < iree_vm_InternalFunctionDef_vec_len(internal_functions); ++i) {
+    iree_vm_InternalFunctionDef_table_t function_def =
+        iree_vm_InternalFunctionDef_vec_at(internal_functions, i);
     if (!function_def) {
-      LOG(ERROR) << "All functions must be valid.";
+      LOG(ERROR) << "function def missing body";
       return IREE_STATUS_INVALID_ARGUMENT;
-    } else if (!function_def->signature()) {
-      LOG(ERROR) << "All functions require a signature.";
+    }
+    if (!iree_vm_InternalFunctionDef_signature(function_def)) {
+      LOG(ERROR) << "function def missing signature";
       return IREE_STATUS_INVALID_ARGUMENT;
     }
 
-    const auto* function_descriptor =
-        module_def->function_descriptors()->Get(i);
-    if (!function_descriptor || function_descriptor->bytecode_offset() < 0 ||
-        function_descriptor->bytecode_length() < 0 ||
-        function_descriptor->bytecode_offset() +
-                function_descriptor->bytecode_length() >
-            module_def->bytecode_data()->size()) {
-      LOG(ERROR) << "Bytecode span must be a valid range.";
+    iree_vm_FunctionDescriptor_struct_t function_descriptor =
+        iree_vm_FunctionDescriptor_vec_at(function_descriptors, i);
+    if (function_descriptor->bytecode_offset < 0 ||
+        function_descriptor->bytecode_offset +
+                function_descriptor->bytecode_length >
+            flatbuffers_uint8_vec_len(bytecode_data)) {
+      LOG(ERROR) << "function descriptor bytecode span out of range";
       return IREE_STATUS_INVALID_ARGUMENT;
     }
-    if (function_descriptor->i32_register_count() > IREE_I32_REGISTER_COUNT ||
-        function_descriptor->ref_register_count() > IREE_REF_REGISTER_COUNT) {
-      LOG(ERROR) << "Register counts out of range.";
+    if (function_descriptor->i32_register_count > IREE_I32_REGISTER_COUNT ||
+        function_descriptor->ref_register_count > IREE_REF_REGISTER_COUNT) {
+      LOG(ERROR) << "function descriptor register out of range";
       return IREE_STATUS_INVALID_ARGUMENT;
     }
 
@@ -228,22 +249,21 @@
 
 static iree_string_view_t iree_vm_bytecode_module_name(void* self) {
   iree_vm_bytecode_module_t* module = (iree_vm_bytecode_module_t*)self;
-  auto* module_def = IREE_VM_GET_MODULE_DEF(module);
-  return iree_string_view_t{module_def->name()->data(),
-                            module_def->name()->size()};
+  flatbuffers_string_t name = iree_vm_BytecodeModuleDef_name(module->def);
+  return iree_string_view_t{name, flatbuffers_string_len(name)};
 }
 
 static iree_vm_module_signature_t iree_vm_bytecode_module_signature(
     void* self) {
   iree_vm_bytecode_module_t* module = (iree_vm_bytecode_module_t*)self;
-  auto* module_def = IREE_VM_GET_MODULE_DEF(module);
   iree_vm_module_signature_t signature;
-  signature.import_function_count =
-      module_def->imported_functions()
-          ? module_def->imported_functions()->size()
-          : 0;
-  signature.export_function_count = module_def->exported_functions()->size();
-  signature.internal_function_count = module_def->internal_functions()->size();
+  memset(&signature, 0, sizeof(signature));
+  signature.import_function_count = iree_vm_ImportFunctionDef_vec_len(
+      iree_vm_BytecodeModuleDef_imported_functions(module->def));
+  signature.export_function_count = iree_vm_ExportFunctionDef_vec_len(
+      iree_vm_BytecodeModuleDef_exported_functions(module->def));
+  signature.internal_function_count = iree_vm_InternalFunctionDef_vec_len(
+      iree_vm_BytecodeModuleDef_internal_functions(module->def));
   return signature;
 }
 
@@ -252,66 +272,75 @@
     iree_vm_function_t* out_function, iree_string_view_t* out_name,
     iree_vm_function_signature_t* out_signature) {
   if (out_function) {
-    memset(out_function, 0, sizeof(iree_vm_function_t));
+    memset(out_function, 0, sizeof(*out_function));
   }
   if (out_name) {
-    out_name->data = NULL;
-    out_name->size = 0;
+    memset(out_name, 0, sizeof(*out_name));
   }
   if (out_signature) {
-    memset(out_signature, 0, sizeof(iree_vm_function_signature_t));
+    memset(out_signature, 0, sizeof(*out_signature));
   }
 
   iree_vm_bytecode_module_t* module = (iree_vm_bytecode_module_t*)self;
-  auto* module_def = IREE_VM_GET_MODULE_DEF(module);
-
-  const ::flatbuffers::String* name = nullptr;
-  const iree::vm::FunctionSignatureDef* signature = nullptr;
+  flatbuffers_string_t name = NULL;
+  iree_vm_FunctionSignatureDef_table_t signature = NULL;
   if (linkage == IREE_VM_FUNCTION_LINKAGE_IMPORT) {
-    if (!module_def->imported_functions() || ordinal < 0 ||
-        ordinal >= module_def->imported_functions()->size()) {
+    iree_vm_ImportFunctionDef_vec_t imported_functions =
+        iree_vm_BytecodeModuleDef_imported_functions(module->def);
+    if (ordinal < 0 ||
+        ordinal >= iree_vm_ImportFunctionDef_vec_len(imported_functions)) {
       return IREE_STATUS_INVALID_ARGUMENT;
     }
-    auto* import_def = module_def->imported_functions()->Get(ordinal);
-    name = import_def->full_name();
-    signature = import_def->signature();
+    iree_vm_ImportFunctionDef_table_t import_def =
+        iree_vm_ImportFunctionDef_vec_at(imported_functions, ordinal);
+    name = iree_vm_ImportFunctionDef_full_name(import_def);
+    signature = iree_vm_ImportFunctionDef_signature(import_def);
     if (out_function) {
       out_function->module = &module->interface;
       out_function->linkage = linkage;
       out_function->ordinal = ordinal;
     }
   } else if (linkage == IREE_VM_FUNCTION_LINKAGE_EXPORT) {
-    if (ordinal < 0 || ordinal >= module_def->exported_functions()->size()) {
+    iree_vm_ExportFunctionDef_vec_t exported_functions =
+        iree_vm_BytecodeModuleDef_exported_functions(module->def);
+    if (ordinal < 0 ||
+        ordinal >= iree_vm_ExportFunctionDef_vec_len(exported_functions)) {
       return IREE_STATUS_INVALID_ARGUMENT;
     }
-    auto* export_def = module_def->exported_functions()->Get(ordinal);
-    name = export_def->local_name();
-    signature = export_def->signature();
+    iree_vm_ExportFunctionDef_table_t export_def =
+        iree_vm_ExportFunctionDef_vec_at(exported_functions, ordinal);
+    name = iree_vm_ExportFunctionDef_local_name(export_def);
+    signature = iree_vm_ExportFunctionDef_signature(export_def);
     if (out_function) {
       out_function->module = &module->interface;
       out_function->linkage = IREE_VM_FUNCTION_LINKAGE_INTERNAL;
-      out_function->ordinal = export_def->internal_ordinal();
+      out_function->ordinal =
+          iree_vm_ExportFunctionDef_internal_ordinal(export_def);
 
-      const iree_vm_function_descriptor_t* function_descriptor =
-          &module->function_descriptor_table[export_def->internal_ordinal()];
+      const iree_vm_FunctionDescriptor_t* function_descriptor =
+          &module->function_descriptor_table[out_function->ordinal];
       out_function->i32_register_count =
           function_descriptor->i32_register_count;
       out_function->ref_register_count =
           function_descriptor->ref_register_count;
     }
   } else {
-    if (ordinal < 0 || ordinal >= module_def->internal_functions()->size()) {
+    iree_vm_InternalFunctionDef_vec_t internal_functions =
+        iree_vm_BytecodeModuleDef_internal_functions(module->def);
+    if (ordinal < 0 ||
+        ordinal >= iree_vm_InternalFunctionDef_vec_len(internal_functions)) {
       return IREE_STATUS_INVALID_ARGUMENT;
     }
-    auto* function_def = module_def->internal_functions()->Get(ordinal);
-    name = function_def->local_name();
-    signature = function_def->signature();
+    iree_vm_InternalFunctionDef_table_t function_def =
+        iree_vm_InternalFunctionDef_vec_at(internal_functions, ordinal);
+    name = iree_vm_InternalFunctionDef_local_name(function_def);
+    signature = iree_vm_InternalFunctionDef_signature(function_def);
     if (out_function) {
       out_function->module = &module->interface;
       out_function->linkage = IREE_VM_FUNCTION_LINKAGE_INTERNAL;
       out_function->ordinal = ordinal;
 
-      const iree_vm_function_descriptor_t* function_descriptor =
+      const iree_vm_FunctionDescriptor_t* function_descriptor =
           &module->function_descriptor_table[ordinal];
       out_function->i32_register_count =
           function_descriptor->i32_register_count;
@@ -321,14 +350,14 @@
   }
 
   if (out_name && name) {
-    out_name->data = name->c_str();
-    out_name->size = name->size();
+    out_name->data = name;
+    out_name->size = flatbuffers_string_len(name);
   }
   if (out_signature && signature) {
-    out_signature->argument_count =
-        signature->argument_types() ? signature->argument_types()->size() : 0;
-    out_signature->result_count =
-        signature->result_types() ? signature->result_types()->size() : 0;
+    out_signature->argument_count = flatbuffers_int32_vec_len(
+        iree_vm_FunctionSignatureDef_argument_types(signature));
+    out_signature->result_count = flatbuffers_int32_vec_len(
+        iree_vm_FunctionSignatureDef_result_types(signature));
   }
 
   return IREE_STATUS_OK;
@@ -337,9 +366,6 @@
 static iree_status_t iree_vm_bytecode_module_get_function_reflection_attr(
     void* self, iree_vm_function_linkage_t linkage, int32_t ordinal,
     int32_t index, iree_string_view_t* key, iree_string_view_t* value) {
-  iree_vm_bytecode_module_t* module = (iree_vm_bytecode_module_t*)self;
-  auto* module_def = IREE_VM_GET_MODULE_DEF(module);
-
   if (linkage != IREE_VM_FUNCTION_LINKAGE_INTERNAL) {
     iree_vm_function_t internal_function;
     iree_vm_bytecode_module_get_function(self, linkage, ordinal,
@@ -348,83 +374,98 @@
     ordinal = internal_function.ordinal;
   }
 
-  if (ordinal < 0 || ordinal >= module_def->internal_functions()->size()) {
+  iree_vm_bytecode_module_t* module = (iree_vm_bytecode_module_t*)self;
+  iree_vm_InternalFunctionDef_vec_t internal_functions =
+      iree_vm_BytecodeModuleDef_internal_functions(module->def);
+
+  if (ordinal < 0 ||
+      ordinal >= iree_vm_InternalFunctionDef_vec_len(internal_functions)) {
     return IREE_STATUS_INVALID_ARGUMENT;
   }
 
-  auto* export_def = module_def->internal_functions()->Get(ordinal);
-  const iree::vm::FunctionSignatureDef* signature = export_def->signature();
-  if (!signature->reflection_attrs() || index < 0 ||
-      index >= signature->reflection_attrs()->size()) {
+  iree_vm_InternalFunctionDef_table_t function_def =
+      iree_vm_InternalFunctionDef_vec_at(internal_functions, ordinal);
+  iree_vm_FunctionSignatureDef_table_t signature =
+      iree_vm_InternalFunctionDef_signature(function_def);
+  iree_vm_ReflectionAttrDef_vec_t reflection_attrs =
+      iree_vm_FunctionSignatureDef_reflection_attrs(signature);
+  if (index < 0 ||
+      index >= iree_vm_ReflectionAttrDef_vec_len(reflection_attrs)) {
     return IREE_STATUS_NOT_FOUND;
   }
-  const ::iree::vm::ReflectionAttrDef* attr =
-      signature->reflection_attrs()->Get(index);
-  const ::flatbuffers::String* attr_key = attr->key();
-  const ::flatbuffers::String* attr_value = attr->value();
-  if (!attr_key || !attr_value) {
+  iree_vm_ReflectionAttrDef_table_t attr =
+      iree_vm_ReflectionAttrDef_vec_at(reflection_attrs, index);
+  flatbuffers_string_t attr_key = iree_vm_ReflectionAttrDef_key(attr);
+  flatbuffers_string_t attr_value = iree_vm_ReflectionAttrDef_value(attr);
+  if (!flatbuffers_string_len(attr_key) ||
+      !flatbuffers_string_len(attr_value)) {
     // Because reflection metadata should not impose any overhead for the
     // non reflection case, we do not eagerly validate in on load -- instead
     // verify it structurally as needed.
     return IREE_STATUS_FAILED_PRECONDITION;
   }
 
-  key->data = attr_key->c_str();
-  key->size = attr_key->size();
-  value->data = attr_value->c_str();
-  value->size = attr_value->size();
+  key->data = attr_key;
+  key->size = flatbuffers_string_len(attr_key);
+  value->data = attr_value;
+  value->size = flatbuffers_string_len(attr_value);
 
   return IREE_STATUS_OK;
 }
 
-static bool iree_vm_bytecode_module_compare_str(const flatbuffers::String* lhs,
-                                                iree_string_view_t rhs) {
-  if (!lhs || lhs->size() != rhs.size) return false;
-  return strncmp(lhs->c_str(), rhs.data, rhs.size) == 0;
-}
-
 static iree_status_t iree_vm_bytecode_module_lookup_function(
     void* self, iree_vm_function_linkage_t linkage, iree_string_view_t name,
     iree_vm_function_t* out_function) {
   if (!out_function) return IREE_STATUS_INVALID_ARGUMENT;
   memset(out_function, 0, sizeof(iree_vm_function_t));
 
-  if (!name.data || !name.size) return IREE_STATUS_INVALID_ARGUMENT;
-
-  iree_vm_bytecode_module_t* module = (iree_vm_bytecode_module_t*)self;
-  auto* module_def = IREE_VM_GET_MODULE_DEF(module);
+  if (iree_string_view_is_empty(name)) return IREE_STATUS_INVALID_ARGUMENT;
 
   // NOTE: we could organize imports/exports alphabetically so we could bsearch.
+  iree_vm_bytecode_module_t* module = (iree_vm_bytecode_module_t*)self;
   if (linkage == IREE_VM_FUNCTION_LINKAGE_IMPORT) {
-    if (!module_def->imported_functions()) {
-      return IREE_STATUS_NOT_FOUND;
-    }
-    for (int ordinal = 0; ordinal < module_def->imported_functions()->size();
+    iree_vm_ImportFunctionDef_vec_t imported_functions =
+        iree_vm_BytecodeModuleDef_imported_functions(module->def);
+    for (size_t ordinal = 0;
+         ordinal < iree_vm_ImportFunctionDef_vec_len(imported_functions);
          ++ordinal) {
-      auto* import_def = module_def->imported_functions()->Get(ordinal);
-      if (iree_vm_bytecode_module_compare_str(import_def->full_name(), name)) {
+      iree_vm_ImportFunctionDef_table_t import_def =
+          iree_vm_ImportFunctionDef_vec_at(imported_functions, ordinal);
+      if (iree_vm_flatbuffer_strcmp(
+              iree_vm_ImportFunctionDef_full_name(import_def), name) == 0) {
         return iree_vm_bytecode_module_get_function(self, linkage, ordinal,
                                                     out_function, NULL, NULL);
       }
     }
     return IREE_STATUS_NOT_FOUND;
   } else if (linkage == IREE_VM_FUNCTION_LINKAGE_EXPORT) {
-    for (int ordinal = 0; ordinal < module_def->exported_functions()->size();
+    iree_vm_ExportFunctionDef_vec_t exported_functions =
+        iree_vm_BytecodeModuleDef_exported_functions(module->def);
+    for (size_t ordinal = 0;
+         ordinal < iree_vm_InternalFunctionDef_vec_len(exported_functions);
          ++ordinal) {
-      auto* export_def = module_def->exported_functions()->Get(ordinal);
-      if (iree_vm_bytecode_module_compare_str(export_def->local_name(), name)) {
+      iree_vm_ExportFunctionDef_table_t export_def =
+          iree_vm_ExportFunctionDef_vec_at(exported_functions, ordinal);
+      if (iree_vm_flatbuffer_strcmp(
+              iree_vm_ExportFunctionDef_local_name(export_def), name) == 0) {
         return iree_vm_bytecode_module_get_function(
             self, IREE_VM_FUNCTION_LINKAGE_INTERNAL,
-            export_def->internal_ordinal(), out_function, NULL, NULL);
+            iree_vm_ExportFunctionDef_internal_ordinal(export_def),
+            out_function, NULL, NULL);
       }
     }
     return IREE_STATUS_NOT_FOUND;
   } else {
-    for (int ordinal = 0; ordinal < module_def->internal_functions()->size();
+    iree_vm_InternalFunctionDef_vec_t internal_functions =
+        iree_vm_BytecodeModuleDef_internal_functions(module->def);
+    for (size_t ordinal = 0;
+         ordinal < iree_vm_InternalFunctionDef_vec_len(internal_functions);
          ++ordinal) {
-      auto* function_def = module_def->internal_functions()->Get(ordinal);
-      if (iree_vm_bytecode_module_compare_str(function_def->local_name(),
-                                              name)) {
+      iree_vm_InternalFunctionDef_table_t function_def =
+          iree_vm_InternalFunctionDef_vec_at(internal_functions, ordinal);
+      if (iree_vm_flatbuffer_strcmp(
+              iree_vm_InternalFunctionDef_local_name(function_def), name) ==
+          0) {
         return iree_vm_bytecode_module_get_function(
             self, IREE_VM_FUNCTION_LINKAGE_INTERNAL, ordinal, out_function,
             NULL, NULL);
@@ -438,20 +479,21 @@
 // Returns the total size of the structure and all tables with padding applied.
 // |state| may be null if only the structure size is required for allocation.
 static iree_host_size_t iree_vm_bytecode_module_layout_state(
-    const iree::vm::BytecodeModuleDef* module_def,
+    iree_vm_BytecodeModuleDef_table_t module_def,
     iree_vm_bytecode_module_state_t* state) {
-  int rwdata_storage_capacity =
-      module_def->module_state()
-          ? module_def->module_state()->global_bytes_capacity()
-          : 0;
-  int global_ref_count = module_def->module_state()
-                             ? module_def->module_state()->global_ref_count()
-                             : 0;
-  int rodata_ref_count =
-      module_def->rodata_segments() ? module_def->rodata_segments()->size() : 0;
-  int import_function_count = module_def->imported_functions()
-                                  ? module_def->imported_functions()->size()
-                                  : 0;
+  iree_vm_ModuleStateDef_table_t module_state =
+      iree_vm_BytecodeModuleDef_module_state(module_def);
+  int rwdata_storage_capacity = 0;
+  int global_ref_count = 0;
+  if (module_state) {
+    rwdata_storage_capacity =
+        iree_vm_ModuleStateDef_global_bytes_capacity(module_state);
+    global_ref_count = iree_vm_ModuleStateDef_global_ref_count(module_state);
+  }
+  int rodata_ref_count = iree_vm_RodataSegmentDef_vec_len(
+      iree_vm_BytecodeModuleDef_rodata_segments(module_def));
+  int import_function_count = iree_vm_ImportFunctionDef_vec_len(
+      iree_vm_BytecodeModuleDef_imported_functions(module_def));
 
   uint8_t* base_ptr = (uint8_t*)state;
   iree_host_size_t offset =
@@ -491,7 +533,7 @@
   *out_module_state = NULL;
 
   iree_vm_bytecode_module_t* module = (iree_vm_bytecode_module_t*)self;
-  auto* module_def = IREE_VM_GET_MODULE_DEF(module);
+  iree_vm_BytecodeModuleDef_table_t module_def = module->def;
 
   // Compute the total size required (with padding) for the state structure.
   iree_host_size_t total_state_struct_size =
@@ -507,13 +549,16 @@
   iree_vm_bytecode_module_layout_state(module_def, state);
 
   // Setup rodata segments to point directly at the flatbuffer memory.
+  iree_vm_RodataSegmentDef_vec_t rodata_segments =
+      iree_vm_BytecodeModuleDef_rodata_segments(module_def);
   for (int i = 0; i < state->rodata_ref_count; ++i) {
-    const iree::vm::RodataSegmentDef* segment =
-        module_def->rodata_segments()->Get(i);
+    iree_vm_RodataSegmentDef_table_t segment =
+        iree_vm_RodataSegmentDef_vec_at(rodata_segments, i);
     iree_vm_ro_byte_buffer_t* ref = &state->rodata_ref_table[i];
     iree_atomic_store(&ref->ref_object.counter, 1);
-    ref->data.data = segment->data()->Data();
-    ref->data.data_length = segment->data()->size();
+    ref->data.data = iree_vm_RodataSegmentDef_data(segment);
+    ref->data.data_length =
+        flatbuffers_uint8_vec_len(iree_vm_RodataSegmentDef_data(segment));
   }
 
   *out_module_state = (iree_vm_module_state_t*)state;
@@ -594,31 +639,24 @@
     iree_const_byte_span_t flatbuffer_data,
     iree_allocator_t flatbuffer_allocator, iree_allocator_t allocator,
     iree_vm_module_t** out_module) {
-  if (!out_module) {
-    LOG(ERROR) << "Output module argument not set";
-    return IREE_STATUS_INVALID_ARGUMENT;
-  }
+  if (!out_module) return IREE_STATUS_INVALID_ARGUMENT;
   *out_module = NULL;
 
-  if (!flatbuffer_data.data || flatbuffer_data.data_length < 16) {
-    LOG(ERROR) << "Flatbuffer data is not present or less than 16 bytes";
-    return IREE_STATUS_INVALID_ARGUMENT;
-  } else if (!iree::vm::BytecodeModuleDefBufferHasIdentifier(
-                 flatbuffer_data.data)) {
-    LOG(ERROR) << "Flatbuffer data does not have bytecode module identifier";
-    return IREE_STATUS_INVALID_ARGUMENT;
-  }
+  IREE_RETURN_IF_ERROR(
+      iree_vm_bytecode_module_flatbuffer_verify(flatbuffer_data));
 
-  const iree::vm::BytecodeModuleDef* module_def =
-      ::flatbuffers::GetRoot<iree::vm::BytecodeModuleDef>(flatbuffer_data.data);
+  iree_vm_BytecodeModuleDef_table_t module_def =
+      iree_vm_BytecodeModuleDef_as_root(flatbuffer_data.data);
   if (!module_def) {
-    LOG(ERROR) << "Failed getting root from flatbuffer data";
+    LOG(ERROR) << "failed getting root from flatbuffer data; expected "
+                  "identifier " iree_vm_BytecodeModuleDef_file_identifier
+                  " not found";
     return IREE_STATUS_INVALID_ARGUMENT;
   }
-  IREE_RETURN_IF_ERROR(iree_vm_bytecode_module_flatbuffer_verify(module_def));
 
+  iree_vm_TypeDef_vec_t type_defs = iree_vm_BytecodeModuleDef_types(module_def);
   size_t type_table_size =
-      module_def->types()->size() * sizeof(iree_vm_type_def_t);
+      iree_vm_TypeDef_vec_len(type_defs) * sizeof(iree_vm_type_def_t);
 
   iree_vm_bytecode_module_t* module = NULL;
   IREE_RETURN_IF_ERROR(iree_allocator_malloc(
@@ -626,21 +664,30 @@
       (void**)&module));
   module->allocator = allocator;
 
+  iree_vm_FunctionDescriptor_vec_t function_descriptors =
+      iree_vm_BytecodeModuleDef_function_descriptors(module_def);
   module->function_descriptor_count =
-      module_def->function_descriptors()->size();
-  module->function_descriptor_table =
-      (const iree_vm_function_descriptor_t*)module_def->function_descriptors()
-          ->data();
+      iree_vm_FunctionDescriptor_vec_len(function_descriptors);
+  module->function_descriptor_table = function_descriptors;
+
+  flatbuffers_uint8_vec_t bytecode_data =
+      iree_vm_BytecodeModuleDef_bytecode_data(module_def);
   module->bytecode_data = iree_const_byte_span_t{
-      module_def->bytecode_data()->Data(), module_def->bytecode_data()->size()};
+      bytecode_data, flatbuffers_uint8_vec_len(bytecode_data)};
 
   module->flatbuffer_data = flatbuffer_data;
   module->flatbuffer_allocator = flatbuffer_allocator;
+  module->def = module_def;
 
-  module->type_count = module_def->types()->size();
+  module->type_count = iree_vm_TypeDef_vec_len(type_defs);
   module->type_table = (iree_vm_type_def_t*)((uint8_t*)module +
                                              sizeof(iree_vm_bytecode_module_t));
-  iree_vm_bytecode_module_resolve_types(module_def, module->type_table);
+  iree_status_t resolve_status =
+      iree_vm_bytecode_module_resolve_types(type_defs, module->type_table);
+  if (!iree_status_is_ok(resolve_status)) {
+    iree_allocator_free(allocator, module);
+    return resolve_status;
+  }
 
   iree_vm_module_initialize(&module->interface, module);
   module->interface.destroy = iree_vm_bytecode_module_destroy;
diff --git a/iree/vm/bytecode_module_impl.h b/iree/vm/bytecode_module_impl.h
index da53486..0f4e818 100644
--- a/iree/vm/bytecode_module_impl.h
+++ b/iree/vm/bytecode_module_impl.h
@@ -16,6 +16,16 @@
 #define IREE_VM_BYTECODE_MODULE_IMPL_H_
 
 #include <stdint.h>
+#include <string.h>
+
+#ifdef IREE_PLATFORM_ANDROID
+#include <stdalign.h>
+#else
+// TODO(benvanik): figure out how to make MSVC happy with C11 stdalign.h.
+#ifdef __cplusplus
+#include <cstdalign>
+#endif  // __cplusplus
+#endif
 
 #include "iree/base/api.h"
 #include "iree/vm/builtin_types.h"
@@ -25,18 +35,15 @@
 #include "iree/vm/type_def.h"
 #include "iree/vm/value.h"
 
+// NOTE: include order matters:
+#include "flatcc/reflection/flatbuffers_common_reader.h"
+#include "iree/schemas/bytecode_module_def_reader.h"
+#include "iree/schemas/bytecode_module_def_verifier.h"
+
 #ifdef __cplusplus
 extern "C" {
 #endif  // __cplusplus
 
-// Matches the FunctionDescriptor struct in the flatbuffer.
-typedef struct {
-  int32_t bytecode_offset;
-  int32_t bytecode_length;
-  uint16_t i32_register_count;
-  uint16_t ref_register_count;
-} iree_vm_function_descriptor_t;
-
 // A loaded bytecode module.
 typedef struct {
   // Interface routing to the bytecode module functions.
@@ -48,7 +55,8 @@
   // Mapped 1:1 with internal functions. Each defined bytecode span represents a
   // range of bytes in |bytecode_data|.
   int32_t function_descriptor_count;
-  const iree_vm_function_descriptor_t* function_descriptor_table;
+  const iree_vm_FunctionDescriptor_t* function_descriptor_table;
+
   // A pointer to the bytecode data embedded within the module.
   iree_const_byte_span_t bytecode_data;
 
@@ -58,6 +66,7 @@
   // Underlying FlatBuffer data and allocator (which may be null).
   iree_const_byte_span_t flatbuffer_data;
   iree_allocator_t flatbuffer_allocator;
+  iree_vm_BytecodeModuleDef_table_t def;
 
   // Type table mapping module type IDs to registered VM types.
   int32_t type_count;
diff --git a/packaging/python/README.md b/packaging/python/README.md
index 6fcae98..3d387a2 100644
--- a/packaging/python/README.md
+++ b/packaging/python/README.md
@@ -55,6 +55,11 @@
 functional) version without TensorFlow kernels. This should not be done for
 released binaries but can help while developing.
 
+Note that bazel does not always build properly named artifacts. See the tool
+`hack_python_package_from_runfiles.py` to extract and fixup artifacts from a
+bazel-bin directory. If using this mechanism, then the environment variable
+`PYIREE_PYTHON_ROOT` should be set to a suitable temp directory.
+
 ```shell
 cd $IREE_SRC
 bazel build -c opt \
diff --git a/packaging/python/common_setup.py b/packaging/python/common_setup.py
index 149dbdf..89d2639 100644
--- a/packaging/python/common_setup.py
+++ b/packaging/python/common_setup.py
@@ -28,6 +28,11 @@
 
 
 def get_package_dir(prefix=("bindings", "python")):
+  explicit_root = os.environ.get("PYIREE_PYTHON_ROOT")
+  if explicit_root:
+    return explicit_root
+
+  # Use env variables based on build system type.
   cmake_build_root = os.environ.get("PYIREE_CMAKE_BUILD_ROOT")
   bazel_build_root = os.environ.get("PYIREE_BAZEL_BUILD_ROOT")
 
diff --git a/packaging/python/hack_python_package_from_runfiles.py b/packaging/python/hack_python_package_from_runfiles.py
new file mode 100644
index 0000000..61c2d6e
--- /dev/null
+++ b/packaging/python/hack_python_package_from_runfiles.py
@@ -0,0 +1,102 @@
+#!/usr/bin/python
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Given a runfiles directory from a bazel build, does surgery to extract
+# a usable python package directory. In addition to the bazel directory
+# structure being unnecessarily obtuse, it is also really hard to actually
+# name files correctly. This affects python extension modules which must be
+# named with a specific extension suffix. Bazel is extremely unflexible and
+# we patch around it with this script. For the record, there are various ways
+# to write custom rules to do this more natively, but it is all complicated
+# and needless complexity. We opt for a script that is at least readable by
+# mere mortals and in one place.
+# Usage:
+#   ./this_script <dest_dir> <path to bazel-bin>
+
+import os
+import platform
+import shutil
+import sys
+import sysconfig
+
+FILE_NAME_MAP = {
+    "binding.so": "binding{}".format(sysconfig.get_config_var("EXT_SUFFIX")),
+    "binding.pyd": False,
+    "binding.dylib": False,
+}
+
+
+def get_exe_suffix():
+  if platform.system() == "Windows":
+    return ".exe"
+  else:
+    return ""
+
+
+def copy_prefix(dest_dir, runfiles_dir, prefix):
+  # And finally seek into the corresponding path in the runfiles dir.
+  # Aren't bazel paths fun???
+  # Note that the "iree_core" path segment corresponds to the workspace name.
+  pkg_dir = os.path.join(runfiles_dir, "iree_core", *prefix)
+  if not os.path.exists(pkg_dir):
+    return
+  dest_dir = os.path.join(dest_dir)
+  for root, dirs, files in os.walk(pkg_dir):
+    assert root.startswith(pkg_dir)
+    dest_prefix = root[len(pkg_dir):]
+    if dest_prefix.startswith(os.path.sep):
+      dest_prefix = dest_prefix[1:]
+    local_dest_dir = os.path.join(dest_dir, dest_prefix)
+    os.makedirs(local_dest_dir, exist_ok=True)
+    for file in files:
+      copy_file(os.path.join(root, file), local_dest_dir)
+
+
+def copy_file(src_file, dst_dir):
+  basename = os.path.basename(src_file)
+  dst_file = os.path.join(dst_dir, basename)
+  mapped_name = FILE_NAME_MAP.get(basename)
+  if mapped_name is False:
+    # Skip.
+    return
+  elif mapped_name is not None:
+    dst_file = os.path.join(dst_dir, mapped_name)
+  shutil.copyfile(src_file, dst_file, follow_symlinks=True)
+
+
+def main():
+  # Parse args.
+  dest_dir = sys.argv[1]
+  bazel_bin = sys.argv[2] if len(sys.argv) > 2 else os.path.join(
+      os.path.dirname(__file__), "..", "..", "bazel-bin")
+
+  # Find the path to the runfiles of the built target:
+  #   //bindings/python/packaging:all_pyiree_packages
+  runfiles_dir = os.path.join(
+      bazel_bin, "packaging", "python",
+      "all_pyiree_packages%s.runfiles" % (get_exe_suffix(),))
+  if not os.path.isdir(runfiles_dir):
+    print("ERROR: Could not find build target 'all_pyiree_packages':",
+          runfiles_dir)
+    print("Make sure to build target", "//packaging/python:all_pyiree_packages")
+    sys.exit(1)
+
+  copy_prefix(dest_dir, runfiles_dir, ("bindings", "python"))
+  copy_prefix(dest_dir, runfiles_dir,
+              ("integrations", "tensorflow", "bindings", "python"))
+
+
+if __name__ == "__main__":
+  main()