Merge main -> google (#4048)

* abf54564 Merge google -> main (#4045)
* 1e4349ec Merge branch 'main' into google-to-main
* 1f86f0ef Add new tf.keras.applications models (#3958)
* 41004a91 [spirv] Add a better tiled and vectorized convolution path (#3990)
* c4129446 Bump Tracy to get Ben's dynamic zone colors and to get past (#4042)
* 2a25f5f1 Remove input shape hack from kws tests (#3959)
* 359ffa01 Allow fusing batch_matmul and fill (#4038)
* 35e837ed Merge pull request #4008 from MaheshRavishankar/matmul_fusion
* ffdd9d19 Enable e2e execution of fusion of matmul with producer.
* de116a06 Use tile+fuse on CPU side.
* caf91acb Use fusion of sequence of LinalgOps in IREE SPIR-V codegen path.
* 551f254a Add convert f32 to f16 pass (#4010)
* f1b1e581 Use vector transfer forwarding and dead store elimination in IREE (#4032)
* fac8fd1f Add file to track docker :prod digests (#3955)
* 37f0ae80 Tensorlist refactor to remove shape input and support empty reserve behavior (..
* 7585e532 Add relative_artifacts_dir (#3960)
* da538897 Use subprocess.run and types in manage_images.py (#3947)
* fbc63a4f Upgrade PR action to v3 (#3946)
* 83ea4a3c [llvm] Add a PlanConvLoopOrder pass (#3920)
* 1f1c70a0 Merge pull request #4012
* bab1468c Submodule fix
* bab0308d Merge branch 'main' into google-to-main
* 8df7e5cd Sort add_subdirectory and remove todo given it's fixed now (#4029)
* 1bb131cb Add missing mkdir to fix building and publishing documentation (#4030)
* 46949bad Fix VectorizeMemref out-of-bound access (#4031)
* b32de092 Update Linux docs to include an LLVM AOT example (#3165)
* 40da4547 Update docs to recommend enabling/disabling assertions globally. (#4014)
* fc8da035 Update Tracy to get clean shutdown on Android (#4015)
* 12e20efc support running as root on android. (#4013)
* 1a4f469e Merge branch 'main' into google-to-main
* aff3e9cb Add vectorization support for batch_matmul op (#3997)
* 33d4a6be Changing iree/base/math.h to C and adding some PRNG routines. (#4005)
* 73bdbe7b Fixup EmitC include directories (#3983)
* af7b4c39 Refactor finding python3 to use standard cmake find_package (#3778)
* f92602b7 Always emitting framepointers in generated ELFs (#3987)
* 89ef6328 bump Tracy to get Android fixes (#3988)
* 7f989eb2 Disable MLIR crash reproducer on CI in python tests. (#3943)
* 4db3d08c Adding a demonstration of using the flow dialect post-partitioning. (#3701)
* 589dfa7b Remove no-longer-functional flag (#3961)
* 64543172 Fix MacOS builds after hack-and-slash PR (#3962)
* 8fb887e4 Update links to coverage tables (#3956)
* c93facb4 Adding iree_atomic_slist_t type. (#3917)
* bd082ca6 Merge pull request #3874 from google/benvanik-hack-and-slash
* f4f4ea26 Use UnitTestSpec in tf.keras.layers tests (#3935)
* 7b8e9f75 Reverting flatcc to use our own cmake file for cross-compilation.
* 323e1fde Simplify dylib driver switch.
* 55f3de0d Only register VMLA driver in bindings/java/.
* f429916f Fix warning flag on Windows and HAL CTS driver registration. (#3911)
* 7951e228 Drop IREE_DRIVER_MODULES and iree/base:initializer from ModelBuilder.
* 4e111e2f Disable layering_check in iree/hal/drivers/BUILD.
* 4773736d Add package to iree/base/testing/BUILD.
* 7ca321c1 Skipping dylib driver in simple_embedding_test as a hack.
* 692deb59 Overriding the default MLIR -> LLVM module name.
* 513f40e8 Speculative removing nowindows tags (#3615). If there's something that still d..
* cc47813a Removing the broken forward declarations entirely from some codegen code. http..
* fbcad44d Removing _GNU_SOURCE copt for wait_handle.
* c9b10a01 Fixing bad type in hal/api.h (been there for ages!).
* e0c532ec Changing iree::InitializeEnvironment to iree_flags_parse. Preparation for #3814.
* 8886ac07 Removing iree_api_init from the API.
* 132d747c Removing ALWAYSLINK support from cmake.
* 0ed81f6b Removing iree/base/initializer.h.
* 0135343c Changing to an explicit driver registration mechanism. This is required for ha..
* bf091d3c Removing ALWAYSLINK support from external_cc_library.
* d4bb871d Changing iree-tblgen to not require alwayslink.
* 6bc6f90c Removing IREE_COMMON_INCLUDE_DIRS and uses by LLVM/MLIR.
* 3c082aab Removing IREE_COMMON_INCLUDE_DIRS mhlo pollution.
* 2395cb99 Removing emitc usage of IREE_COMMON_INCLUDE_DIRS for now.
* 036bd966 TODOs on future library layout changes.
* c3a13e62 Rearranging iree/vm/ to reduce a public + cc target.
* 493b0e2b Rearranging iree/base build rules. By moving the dynamic_library_test out of t..
* 8fd38bf9 Replacing uses of some absl utilities now in C++14.
* 67863190 Removing unused absl/types/variant.h include.
* 99bd1af5 Replace absl::exchange with iree::exchange to reduce absl/utility dep.
* c1d0ee10 Removing unused PLATFORM_VULKAN_DEPS. It may be needed internally but it shoul..
* 15437f4b Simplifying iree/hal/dylib/ build config.
* e6984a5a Simplifying iree/hal/ build config.
* 10062814 Simplifying iree/hal/vulkan/ build config.
* 827e51b0 Simplifying iree/hal/llvmjit/ build config.
* 9a72f5d1 Simplifying iree/hal/metal/ build config.
* c7a7d726 Simplifying iree/hal/vmla/ build config.
* 90faf21f Adding IREE_TARGET_GUI_LINKOPTS to remove custom linkopts use.
* e5774c30 Remove unused args from flatbuffers_c_library macro.
* 22d16b4d Adding iree/base/flatcc.h to make flatcc easier to include.
* e44dee56 Switching from -DVK_NO_PROTOTYPES to iree/hal/vulkan/vulkan_headers.h.
* 48ca2fe6 Removing build-config setting of _GNU_SOURCE.
* eeb7dde0 Goodbye flatbuffers (well, the C++ ones anyway).
* 9c676a86 Removing all build config/utils related to flatbuffers.
* 49c61213 byte->ubyte in flatbuffer defs.
* c99000a8 Replacing compiler use of VM bytecode module def flatbuffers->flatcc.
* 1bf1e8d7 Replacing runtime use of metal executable flatbuffers->flatcc. Maybe it works?..
* 48aafb89 Replacing runtime use of spirv executable flatbuffers->flatcc.
* 011e9a2d Replacing runtime use of llvmjit executable flatbuffers->flatcc.
* a021062f Replacing runtime use of dylib executable flatbuffers->flatcc.
* 53a05d73 Replacing runtime use of VMLA executable flatbuffers->flatc.
* 99d30a99 Replacing compiler use of HAL executable flatbuffers->flatc.
* 6ebd1b0c Removing unused tag field in metal/spirv.
* bc685ed7 Adding flatcc json support and making iree-dump-module use it.
* 94b11c35 Adding include for flatcc to flat_c_libraries.
* 1172cf1f Removing unused iree::schemas::reflection_data.
* c86281af Removing unneeded flatbuffers copts.
* 7f3a7e3a Fixing various type warnings. We don't today have these warnings enabled in ou..
* c17659fc Refining MSVC warning set to the minimum required and documenting.
* b7c92bf4 Cleaning up MSVC warnings and syncing with bazel warnings.
* 94356d3b Removing legacy repo_utils.bzl.
* 0f0d9c82 Prevent bazel-to-cmake from running on iree/base/CMakeLists.txt for now.
* 36225a4d Centralizing -ldl/-lpthread linkopts (as they were in bazel already).
* 31c4dbb9 Documenting iree_copts with a nice big warning.
* e4740a57 Pass android libraries as actual linkopts.
* 85cdd868 Fixing cmake style issues - prefer `if(` not `if (` please.
* bf4069e3 Sorting copts/linkopts so we can override things.
* 28040cd8 Simplifying VMA build integration.
* 479ef30f Replacing use of PROJECT_SOURCE_DIR/PROJECT_BINARY_DIR. Those use the previous..

PiperOrigin-RevId: 345252359
diff --git a/.github/workflows/update_llvm_dependent_submodules.yml b/.github/workflows/update_llvm_dependent_submodules.yml
index bd0319a..2d66b0a 100644
--- a/.github/workflows/update_llvm_dependent_submodules.yml
+++ b/.github/workflows/update_llvm_dependent_submodules.yml
@@ -42,7 +42,7 @@
           echo "TF_SHA=$(git submodule status third_party/tensorflow | awk '{print $1}' | cut -c -12)" >> $GITHUB_ENV
           echo "LLVM_BAZEL_SHA=$(git submodule status third_party/llvm-bazel | awk '{print $1}' | cut -c -12)" >> $GITHUB_ENV
       - name: Creating Pull Request
-        uses: peter-evans/create-pull-request@v2
+        uses: peter-evans/create-pull-request@v3
         with:
           # Personal token is required to trigger additional automation (e.g. presubmits).
           token: ${{ secrets.GITHUB_WRITE_ACCESS_TOKEN }}
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 3903d75..b37f0e0 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -374,18 +374,17 @@
 # Non-LLVM Dependencies
 #-------------------------------------------------------------------------------
 
-# Use the (deprecated) FindPythonInterp/FindPythonLibs functions before
-# any of our dependencies do. See
+# Use the FindPython functions before any of our dependencies do. See
 # https://pybind11.readthedocs.io/en/stable/faq.html#inconsistent-detection-of-python-version-in-cmake-and-pybind11
 # If one dependency finds Python 2 (the default),
 # any others that try to find Python 3 will fail.
 # (Also come on, it's $CURRENT_YEAR - please just use Python 3 already.)
 if(${IREE_BUILD_COMPILER} OR ${IREE_BUILD_PYTHON_BINDINGS})
-  find_package(PythonInterp 3 REQUIRED)
+  find_package(Python3 COMPONENTS Interpreter REQUIRED)
 endif()
 if(${IREE_BUILD_PYTHON_BINDINGS})
   # Note: Optional because python libs can be manually specified.
-  find_package(PythonLibs 3)
+  find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
 endif()
 
 include(external_cc_library)
diff --git a/SUBMODULE_VERSIONS b/SUBMODULE_VERSIONS
index 331c7b0..517435a 100644
--- a/SUBMODULE_VERSIONS
+++ b/SUBMODULE_VERSIONS
@@ -3,7 +3,7 @@
 63b254577ed77a8004a9be6ac707f3dccc4e1fd9 third_party/cpuinfo
 4c13807b7d43ff0946b7ffea0ae3aee9e611d778 third_party/dear_imgui
 4fb0ff7069bd88ee85902f4d0bb62794e5f6d021 third_party/flatcc
-f2fb48c3b3d79a75a88a99fba6576b25d42ec528 third_party/googletest
+b1fbd33c06cdb0024c67733c6fdec2009d17b384 third_party/googletest
 ff0c2096ee35d3a14eaf425e6bbd3a36ab3a5a4a third_party/llvm-bazel
 c266c56d545dfecf767b312771f716b394c5d5eb third_party/llvm-project
 55801f03f9cc69abfcf8b508a873f702c11b3b5f third_party/mlir-emitc
@@ -14,6 +14,6 @@
 685f86471e9d26b3eb7676695a2e2cefb4551ae9 third_party/spirv_cross
 f8bf11a0253a32375c32cad92c841237b96696c0 third_party/spirv_headers
 431fc9d043c52eb0990d7efdb19546fa31c1bf1d third_party/tensorflow
-d7059eca6351546d1f51e248fc75e49dfeee709e third_party/tracy
+9c3dac3ed2bd647b8d63f197fed058fee97a7e1e third_party/tracy
 9bd3f561bcee3f01d22912de10bb07ce4e23d378 third_party/vulkan_headers
 3528e2aed3e8808f33e1e7d63eeb1560456a605a third_party/vulkan_memory_allocator
diff --git a/build_tools/cmake/build_docs.sh b/build_tools/cmake/build_docs.sh
index 57b4aaa..24d3f22 100755
--- a/build_tools/cmake/build_docs.sh
+++ b/build_tools/cmake/build_docs.sh
@@ -49,6 +49,7 @@
 ninja iree-doc iree_tools_iree-opt
 
 cd ${ROOT_DIR}
+mkdir -p ${BUILD_DIR}/doc/
 # Copy docs in source tree over
 cp README.md ${BUILD_DIR}/doc/index.md
 cp -rf docs/* ${BUILD_DIR}/doc/
diff --git a/build_tools/cmake/iree_multipy.cmake b/build_tools/cmake/iree_multipy.cmake
index f773925..1ac1028 100644
--- a/build_tools/cmake/iree_multipy.cmake
+++ b/build_tools/cmake/iree_multipy.cmake
@@ -22,7 +22,7 @@
   # Configure the defaults.
   # Note that this is using the pybind11 configuration vars, which creates
   # a fragile dependency. It would be better to derive these locally.
-  if(PYTHONLIBS_FOUND)
+  if(Python3_FOUND)
     set(IREE_MULTIPY_DEFAULT_EXECUTABLE "${PYTHON_EXECUTABLE}" CACHE INTERNAL "Python executable" )
     set(IREE_MULTIPY_DEFAULT_INCLUDE_DIRS "${PYTHON_INCLUDE_DIRS}" CACHE INTERNAL "Python include dirs" )
     set(IREE_MULTIPY_DEFAULT_LIBRARIES "${PYTHON_LIBRARIES}" CACHE INTERNAL "Python libraries")
diff --git a/build_tools/docker/README.md b/build_tools/docker/README.md
index 3172c86..3fd42d6 100644
--- a/build_tools/docker/README.md
+++ b/build_tools/docker/README.md
@@ -82,7 +82,7 @@
 2. Build the image, push the image to GCR and update all references to the image
    with the new GCR digest:
 
-   ```shell
+    ```shell
     python3 build_tools/docker/manage_images.py \
       --image "${IMAGE?}" --build \
       --tag latest \
@@ -99,7 +99,7 @@
    digest references to test the new images.
 
 5. Merge your PR after is approved and all CI tests pass. **Please remember to
-   complete the rest of the steps below**.
+   complete the step below**.
 
 ### Part 3. Updating the `:prod` tag
 
@@ -108,38 +108,13 @@
 in GCR. This also makes development significantly easier for others who need to
 modify the `docker` images.
 
-6. On the `main` branch, build (but don't push) the images and locally tag them
-   with the `:prod` tag:
-
-   ```shell
-    python3 build_tools/docker/manage_images.py \
-      --image "${IMAGE?}" --build \
-      --tag prod \
-      --update_references
-    ```
-
-    This build should be entirely cache hits.
-7. We include `--update_references` in the command above so that we can check
-   that none of the images or references to them have been changed. Check that
-   the following command produces no output before continuing:
-
-   ```shell
-   git status --porcelain
-   ```
-
-   If the output is not empty then you'll need to find the source of the
-   discrepancy (e.g. a locally modified `Dockerfile`) and remove it, and repeat
-   steps 5 and 6 before continuing. (This relies on you keeping your local copy
-   of the Docker images. If you didn't, you'll have to manually pull the missing
-   images by their digest).
-8. Now that we've confirmed that none of the images were changed, we can push
-   them to GCR with the `:prod` tag.
+6. We use `build_tools/docker/prod_digests.txt` as a source of truth for which
+   versions of the images on GCR should have the `:prod` tag. The following
+   command will ensure that you are at upstream HEAD on the `main` branch before
+   it updates the tags.
 
     ```shell
-    python3 build_tools/docker/manage_images.py \
-      --image "${IMAGE?}" \
-      --tag prod \
-      --push
+    python3 build_tools/docker/manage_prod.py
     ```
 
 ## Debugging
@@ -150,13 +125,13 @@
 GCR).
 
 ```shell
-# Pull all :prod images
-python3 build_tools/docker/manage_images.py --images all --pull --tag prod
+# Pull all images that should have :prod tags. (They won't if someone ignores
+# step 6 above, but the images that this command pulls are correct regardless).
+python3 build_tools/docker/manage_prod.py --pull_only
+
 # Update the :latest images to match the :prod images.
-# If you have a clean workspace this _shouldn't_ require building anything as
-# everything should be cache hits from the :prod images downloaded above, but if
-# the :prod images are behind then that will not be the case and this may take
-# several hours (depending on your machine).
+# If you have a clean workspace this shouldn't require building anything as
+# everything should be cache hits from the :prod images downloaded above.
 python3 build_tools/docker/manage_images.py \
   --images all --build \
   --tag latest \
diff --git a/build_tools/docker/__init__.py b/build_tools/docker/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/build_tools/docker/__init__.py
diff --git a/build_tools/docker/manage_images.py b/build_tools/docker/manage_images.py
index ad5c3a5..4093c68 100755
--- a/build_tools/docker/manage_images.py
+++ b/build_tools/docker/manage_images.py
@@ -29,9 +29,6 @@
 transitively on depend on it, but don't take side-effecting actions:
   python3 build_tools/docker/manage_images.py --build --image cmake --dry-run
 
-Push all `prod` images to GCR:
-  python3 build_tools/docker/manage_images.py --push --tag prod --images all
-
 Rebuild and push all images and update references to them in the repository:
   python3 build_tools/docker/manage_images.py --push --images all
   --update-references
@@ -44,9 +41,13 @@
 import re
 import subprocess
 import sys
+from typing import List, Sequence, Union
+
+import utils
 
 IREE_GCR_URL = 'gcr.io/iree-oss/'
-DOCKER_DIR = 'build_tools/docker/'
+DIGEST_REGEX = r'sha256:[a-zA-Z0-9]+'
+DOCKER_DIR = 'build_tools/docker/'.replace('/', os.sep)
 
 # Map from image names to images that they depend on.
 IMAGES_TO_DEPENDENCIES = {
@@ -88,13 +89,6 @@
                       required=True,
                       action='append',
                       help=f'Name of the image to build: {IMAGES_HELP}.')
-  parser.add_argument(
-      '--tag',
-      type=str,
-      default='latest',
-      help='Tag for the images to build. Defaults to `latest` (which is good '
-      'for testing changes in a PR). Use `prod` to update the images that the '
-      'CI caches.')
   parser.add_argument('--pull',
                       action='store_true',
                       help='Pull the specified image before building.')
@@ -130,144 +124,148 @@
   return args
 
 
-def get_ordered_images_to_process(images):
-  unmarked_images = list(images)
-  # Python doesn't have a builtin OrderedSet
-  marked_images = set()
-  order = []
+def get_ordered_images_to_process(images: Sequence[str]) -> List[str]:
+  # Python doesn't have a builtin OrderedSet, so we mimic one to the extent
+  # that we need by using 'in' before adding any elements.
+  processing_order = []
 
-  def visit(image):
-    if image in marked_images:
-      return
-    for dependent_images in IMAGES_TO_DEPENDENT_IMAGES[image]:
-      visit(dependent_images)
-    marked_images.add(image)
-    order.append(image)
+  def add_dependent_images(image: str):
+    if image not in processing_order:
+      for dependent_image in IMAGES_TO_DEPENDENT_IMAGES[image]:
+        add_dependent_images(dependent_image)
+      processing_order.append(image)
 
-  while unmarked_images:
-    visit(unmarked_images.pop())
+  for image in images:
+    add_dependent_images(image)
 
-  order.reverse()
-  return order
+  processing_order.reverse()
+  return processing_order
 
 
-def stream_command(command, dry_run=False):
+def run_command(command: Sequence[str],
+                dry_run: bool = False,
+                check: bool = True,
+                capture_output: bool = False,
+                universal_newlines: bool = True,
+                **run_kwargs) -> subprocess.CompletedProcess:
+  """Thin wrapper around subprocess.run"""
   print(f'Running: `{" ".join(command)}`')
   if dry_run:
-    return 0
-  process = subprocess.Popen(command,
-                             bufsize=1,
-                             stderr=subprocess.STDOUT,
-                             stdout=subprocess.PIPE,
-                             universal_newlines=True)
-  for line in process.stdout:
-    print(line, end='')
+    # Dummy CompletedProess with successful returncode.
+    return subprocess.CompletedProcess(command, returncode=0)
 
-  if process.poll() is None:
-    raise RuntimeError('Unexpected end of output while process is not finished')
-  return process.poll()
+  if capture_output:
+    # Hardcode support for python <= 3.6.
+    run_kwargs['stdout'] = subprocess.PIPE
+    run_kwargs['stderr'] = subprocess.PIPE
+  return subprocess.run(command,
+                        universal_newlines=universal_newlines,
+                        check=check,
+                        **run_kwargs)
 
 
-def check_stream_command(command, dry_run=False):
-  exit_code = stream_command(command, dry_run=dry_run)
-  if exit_code != 0:
-    print(f'Command failed with exit code {exit_code}: `{" ".join(command)}`')
-    sys.exit(exit_code)
-
-
-def get_repo_digest(image):
+def get_repo_digest(tagged_image_url: str) -> str:
   inspect_command = [
       'docker',
       'image',
       'inspect',
-      f'{image}',
+      tagged_image_url,
       '-f',
       '{{index .RepoDigests 0}}',
   ]
-  inspect_process = subprocess.run(inspect_command,
-                                   universal_newlines=True,
-                                   stdout=subprocess.PIPE,
-                                   stderr=subprocess.PIPE,
-                                   timeout=10)
-  if inspect_process.returncode != 0:
-    print(f'Computing the repository digest for {image} failed.'
-          ' Has it been pushed to GCR?')
-    print(f'Output from `{" ".join(inspect_command)}`:')
-    print(inspect_process.stdout, end='')
-    print(inspect_process.stderr, end='')
-    sys.exit(inspect_process.returncode)
-  _, repo_digest = inspect_process.stdout.strip().split('@')
+  try:
+    completed_process = utils.run_command(
+        inspect_command,
+        dry_run=False,  # Run even if --dry_run is True.
+        capture_output=True,
+        timeout=10)
+  except subprocess.CalledProcessError as error:
+    raise RuntimeError(f'Computing the repository digest for {tagged_image_url}'
+                       ' failed. Has it been pushed to GCR?') from error
+  _, repo_digest = completed_process.stdout.strip().split('@')
   return repo_digest
 
 
-def update_rbe_reference(digest, dry_run=False):
+def update_rbe_reference(digest: str, dry_run: bool = False):
   print('Updating WORKSPACE file for rbe-toolchain')
+  digest_updates = 0
   for line in fileinput.input(files=['WORKSPACE'], inplace=(not dry_run)):
     if line.strip().startswith('digest ='):
-      print(re.sub('sha256:[a-zA-Z0-9]+', digest, line), end='')
+      digest_updates += 1
+      print(re.sub(DIGEST_REGEX, digest, line), end='')
     else:
       print(line, end='')
 
+  if digest_updates > 1:
+    raise RuntimeError(
+        "There is more than one instance of 'digest =' in the WORKSPACE file. "
+        "This means that more than just the 'rbe_toolchain' digest was "
+        "overwritten, and the file should be restored.")
 
-def update_references(image_name, digest, dry_run=False):
-  print(f'Updating references to {image_name}')
 
-  grep_command = ['git', 'grep', '-l', f'{image_name}@sha256']
-  grep_process = subprocess.run(grep_command,
-                                stdout=subprocess.PIPE,
-                                stderr=subprocess.PIPE,
-                                timeout=5,
-                                universal_newlines=True)
-  if grep_process.returncode > 1:
-    print(f'{" ".join(grep_command)} '
-          f'failed with exit code {grep_process.returncode}')
-    sys.exit(grep_process.returncode)
-  if grep_process.returncode == 1:
-    print(f'Found no references to {image_name}')
-    return
+def update_references(image_url: str, digest: str, dry_run: bool = False):
+  """Updates all references to 'image_url' with a sha256 digest."""
+  print(f'Updating references to {image_url}')
 
-  files = grep_process.stdout.split()
+  grep_command = ['git', 'grep', '-l', f'{image_url}@sha256']
+  try:
+    completed_process = run_command(grep_command,
+                                    capture_output=True,
+                                    timeout=5)
+  except subprocess.CalledProcessError as error:
+    if error.returncode == 1:
+      print(f'Found no references to {image_url}')
+      return
+    raise error
+
+  # Update references in all grepped files.
+  files = completed_process.stdout.split()
   print(f'Updating references in {len(files)} files: {files}')
   for line in fileinput.input(files=files, inplace=(not dry_run)):
-    print(re.sub(f'{image_name}@sha256:[a-zA-Z0-9]+', f'{image_name}@{digest}',
-                 line),
+    print(re.sub(f'{image_url}@{DIGEST_REGEX}', f'{image_url}@{digest}', line),
           end='')
 
 
 if __name__ == '__main__':
   args = parse_arguments()
 
-  # Ensure the user has the correct authorization if they try to push to GCR.
   if args.push:
-    if stream_command(['which', 'gcloud']) != 0:
-      print('gcloud not found.'
-            ' See https://cloud.google.com/sdk/install for installation.')
-      sys.exit(1)
-    check_stream_command(['gcloud', 'auth', 'configure-docker'],
-                         dry_run=args.dry_run)
+    # Ensure the user has the correct authorization if they try to push to GCR.
+    utils.check_gcloud_auth(dry_run=args.dry_run)
 
   images_to_process = get_ordered_images_to_process(args.images)
   print(f'Also processing dependent images. Will process: {images_to_process}')
 
   for image in images_to_process:
     print(f'Processing image {image}')
-    image_name = posixpath.join(IREE_GCR_URL, image)
-    image_tag = f'{image_name}:{args.tag}'
+    image_url = posixpath.join(IREE_GCR_URL, image)
+    tagged_image_url = f'{image_url}:latest'
     image_path = os.path.join(DOCKER_DIR, image)
 
     if args.pull:
-      check_stream_command(['docker', 'pull', image_tag], dry_run=args.dry_run)
+      utils.run_command(['docker', 'pull', tagged_image_url], args.dry_run)
 
     if args.build:
-      check_stream_command(['docker', 'build', '--tag', image_tag, image_path],
-                           dry_run=args.dry_run)
+      utils.run_command(
+          ['docker', 'build', '--tag', tagged_image_url, image_path],
+          args.dry_run)
 
     if args.push:
-      check_stream_command(['docker', 'push', image_tag], dry_run=args.dry_run)
+      utils.run_command(['docker', 'push', tagged_image_url], args.dry_run)
 
     if args.update_references:
-      digest = get_repo_digest(image_tag)
+      digest = get_repo_digest(tagged_image_url)
+
+      # Check that the image is in 'prod_digests.txt' and append it to the list
+      # in the file if it isn't. We know that the GCR digest exists at this
+      # point because 'get_repo_digest' confirms that the image has been pushed.
+      with open(utils.PROD_DIGESTS_PATH, 'r') as f:
+        in_prod_digests = f'{image_url}@' in f.read()
+      if not in_prod_digests:
+        with open(utils.PROD_DIGESTS_PATH, 'a') as f:
+          f.write(f'{image_url}@{digest}\n')
+
       # Just hardcode this oddity
       if image == 'rbe-toolchain':
         update_rbe_reference(digest, dry_run=args.dry_run)
-      update_references(image_name, digest, dry_run=args.dry_run)
+      update_references(image_url, digest, dry_run=args.dry_run)
diff --git a/build_tools/docker/manage_prod.py b/build_tools/docker/manage_prod.py
new file mode 100644
index 0000000..3f26b00
--- /dev/null
+++ b/build_tools/docker/manage_prod.py
@@ -0,0 +1,65 @@
+#!/usr/bin/env python3
+
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Uses prod_digests.txt to update GCR's :prod tags.
+
+Usage:
+  Pull all images that should have :prod tags, tag the with :prod and push
+  them to GCR. This will make sure that you are at upstream head on the main
+  branch before pushing:
+    python3 build_tools/docker/manage_prod.py
+
+  Pull all images that should have :prod tags:
+    python3  build_tools/docker/manage_prod.py --pull_only
+"""
+
+import argparse
+import os
+import utils
+
+
+def parse_arguments():
+  """Parses command-line options."""
+  parser = argparse.ArgumentParser(
+      description="Pull and push the images in prod_digests.txt to GCR.")
+  parser.add_argument("--pull_only",
+                      "--pull-only",
+                      action="store_true",
+                      help="Pull but do not tag or push the images.")
+  return parser.parse_args()
+
+
+if __name__ == "__main__":
+  args = parse_arguments()
+
+  if not args.pull_only:
+    # Ensure the user has the correct authorization if they try to push to GCR.
+    utils.check_gcloud_auth()
+
+    # Only allow the :prod tag to be pushed from the version of
+    # `prod_digests.txt` at upstream HEAD on the main branch.
+    utils.run_command([os.path.normpath("scripts/git/git_update.sh"), "main"])
+
+  with open(utils.PROD_DIGESTS_PATH, "r") as f:
+    images_with_digests = [line.strip() for line in f.readlines()]
+
+  for image_with_digest in images_with_digests:
+    image_url, _ = image_with_digest.split("@")
+    tagged_image_url = f"{image_url}:prod"
+
+    utils.run_command(["docker", "pull", image_with_digest])
+    if not args.pull_only:
+      utils.run_command(["docker", "tag", image_with_digest, tagged_image_url])
+      utils.run_command(["docker", "push", tagged_image_url])
diff --git a/build_tools/docker/prod_digests.txt b/build_tools/docker/prod_digests.txt
new file mode 100644
index 0000000..4ee200a
--- /dev/null
+++ b/build_tools/docker/prod_digests.txt
@@ -0,0 +1,17 @@
+gcr.io/iree-oss/base@sha256:392b2f865f000c6fb558d01a372446f3ab81120db34249f03efa999669647230
+gcr.io/iree-oss/util@sha256:ec9198493cea4f5d9ac7097e8a64b94b7a43628cb995b91e6e89a95cff4a1982
+gcr.io/iree-oss/cmake@sha256:ceaff365ca0cd3d770daf5fad370e29783e30b654f56780761a6d0a040da45e5
+gcr.io/iree-oss/swiftshader@sha256:3ed32e7c74da71b6db1904b583827e760ea845d5d2876b38c036cf72ca6e5623
+gcr.io/iree-oss/cmake-python@sha256:5fa42743c458a7df680175542269067fd89d2072b776b43e48169a7d0b43ebc3
+gcr.io/iree-oss/cmake-android@sha256:7accda0b84e2ae337740f2ee71801ee30f2155900abf1cf7b73ea47c15dc694f
+gcr.io/iree-oss/bazel@sha256:59da17e5cc8176890a6e1bda369b1f3d398e27af3d47e02e1ffd5b76729c215b
+gcr.io/iree-oss/bazel-python@sha256:473b7e294136bc38abc1941042f0c0404199de5827f141520f0b6757305b7a95
+gcr.io/iree-oss/bazel-tensorflow@sha256:6ec501edcbaaf817941c5be5060cafc47616ca4e2a875bbb62944ffbc396ceb0
+gcr.io/iree-oss/vulkan@sha256:c2e21657a231f3e39c50c01c3cbae3355f5b03ff52033b41ad322a0c792099dd
+gcr.io/iree-oss/rbe-toolchain@sha256:d6d895294076b5289e81489f664656211c41656cffe7c448ecb5c6f54f045974
+gcr.io/iree-oss/cmake-python-vulkan@sha256:63db8f65485e73af8a16603729bf39b4e616b6cb90216a1589ba2919489a6483
+gcr.io/iree-oss/cmake-python-swiftshader@sha256:3e3d3427f3a58b32fa3ed578b610e411e0b81fd0e1984ac9b0fceae8bf8343dc
+gcr.io/iree-oss/cmake-python-nvidia@sha256:310e3b399717905bb2b485f3ebed32222915c7dc4dc075aa4e1b8551101fe607
+gcr.io/iree-oss/bazel-tensorflow-vulkan@sha256:61522fcfcd11cd9c067e991b72419d6decf70dae35b8ee3efa71e55ca31b8866
+gcr.io/iree-oss/bazel-tensorflow-swiftshader@sha256:39c0e43c503bddfacd69758a50f02450ad2322d35324e2f56997aebb33a1b20a
+gcr.io/iree-oss/bazel-tensorflow-nvidia@sha256:e5e96ec1709e83355ee2264c97c26fa5c3d40f749a62734f4787b17a83f2c3b8
diff --git a/build_tools/docker/utils.py b/build_tools/docker/utils.py
new file mode 100644
index 0000000..f9a041b
--- /dev/null
+++ b/build_tools/docker/utils.py
@@ -0,0 +1,55 @@
+#!/usr/bin/env python3
+
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import subprocess
+from typing import Sequence
+
+PROD_DIGESTS_PATH = "build_tools/docker/prod_digests.txt".replace("/", os.sep)
+
+
+def run_command(command: Sequence[str],
+                dry_run: bool = False,
+                check: bool = True,
+                capture_output: bool = False,
+                universal_newlines: bool = True,
+                **run_kwargs) -> subprocess.CompletedProcess:
+  """Thin wrapper around subprocess.run"""
+  print(f"Running: `{' '.join(command)}`")
+  if not dry_run:
+    if capture_output:
+      # Hardcode support for python <= 3.6.
+      run_kwargs["stdout"] = subprocess.PIPE
+      run_kwargs["stderr"] = subprocess.PIPE
+
+    completed_process = subprocess.run(command,
+                                       universal_newlines=universal_newlines,
+                                       check=check,
+                                       **run_kwargs)
+    return completed_process
+  # Dummy CompletedProess with successful returncode.
+  return subprocess.CompletedProcess(command, returncode=0)
+
+
+def check_gcloud_auth(dry_run: bool = False):
+  # Ensure the user has the correct authorization if they try to push to GCR.
+  try:
+    run_command(['which', 'gcloud'])
+  except subprocess.CalledProcessError as error:
+    raise RuntimeError(
+        'gcloud not found. See https://cloud.google.com/sdk/install for '
+        'installation.') from error
+  run_command(["gcloud", "auth", "configure-docker"], dry_run)
diff --git a/docs/get_started/getting_started_linux_bazel.md b/docs/get_started/getting_started_linux_bazel.md
index a05ebb6..77f7c6c 100644
--- a/docs/get_started/getting_started_linux_bazel.md
+++ b/docs/get_started/getting_started_linux_bazel.md
@@ -90,8 +90,9 @@
 # and with assertions enabled.
 build:debug --config=asserts --compilation_mode=opt '--per_file_copt=iree|llvm@-O0' --strip=never
 
-# Use --config=asserts to enable assertions in IREE and LLVM.
-build:asserts --compilation_mode=opt '--per_file_copt=iree|llvm@-UNDEBUG'
+# Use --config=asserts to enable assertions. This has to be done globally:
+# Code compiled with and without assertions can't be linked together (ODR violation).
+build:asserts --compilation_mode=opt '--copt=-UNDEBUG'
 ```
 
 ## What's next?
diff --git a/docs/get_started/getting_started_linux_cmake.md b/docs/get_started/getting_started_linux_cmake.md
index 04f499f..83d0fb2 100644
--- a/docs/get_started/getting_started_linux_cmake.md
+++ b/docs/get_started/getting_started_linux_cmake.md
@@ -107,6 +107,39 @@
   -function-input="i32=-2" -iree-hal-target-backends=vmla -print-mlir
 ```
 
+### LLVM Ahead-of-Time (AOT) backend
+
+To compile IREE LLVM AOT (vs JIT) module we need to set the AOT linker path environment variable:
+
+```shell
+$ export IREE_LLVMAOT_LINKER_PATH=ld.lld-10
+```
+
+Translate a source MLIR into an IREE module:
+
+```shell
+# Assuming in IREE source root
+$ ./build/iree/tools/iree-translate \
+    -iree-mlir-to-vm-bytecode-module \
+    -iree-llvm-target-triple=x86_64-linux-gnu \
+    -iree-hal-target-backends=dylib-llvm-aot \
+    iree/tools/test/simple.mlir \
+    -o /tmp/simple-llvm_aot.vmfb
+```
+
+Then run the compiled module using the `dylib` HAL driver:
+
+```shell
+$ ./build/iree/tools/iree-run-module -driver=dylib \
+          -input_file=/tmp/simple-llvm_aot.vmfb \
+          -entry_function=abs \
+          -inputs="i32=-5"
+
+EXEC @abs
+i32=5
+```
+
+
 ### Further Reading
 
 *   For an introduction to IREE's project structure and developer tools, see
diff --git a/docs/get_started/getting_started_macos_bazel.md b/docs/get_started/getting_started_macos_bazel.md
index 71b6818..0087166 100644
--- a/docs/get_started/getting_started_macos_bazel.md
+++ b/docs/get_started/getting_started_macos_bazel.md
@@ -93,8 +93,9 @@
 # and with assertions enabled.
 build:debug --config=asserts --compilation_mode=opt '--per_file_copt=iree|llvm@-O0' --strip=never
 
-# Use --config=asserts to enable assertions in IREE and LLVM.
-build:asserts --compilation_mode=opt '--per_file_copt=iree|llvm@-UNDEBUG'
+# Use --config=asserts to enable assertions. This has to be done globally:
+# Code compiled with and without assertions can't be linked together (ODR violation).
+build:asserts --compilation_mode=opt '--copt=-UNDEBUG'
 ```
 
 ## What's next?
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 313be88..1961a9a 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -55,7 +55,7 @@
 DEFAULT_INPUT_GENERATOR = tf_utils.uniform
 
 
-def _setup_artifacts_dir(module_name: str) -> str:
+def _setup_artifacts_dir(relative_artifacts_dir: str) -> str:
   parent_dirs = [
       FLAGS.artifacts_dir,
       os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR'),
@@ -65,7 +65,7 @@
   # Use the most preferred path in parent_dirs that isn't None.
   parent_dir = next(parent for parent in parent_dirs if parent is not None)
 
-  artifacts_dir = os.path.join(parent_dir, module_name)
+  artifacts_dir = os.path.join(parent_dir, relative_artifacts_dir)
   logging.info("Saving compilation artifacts and traces to '%s'", artifacts_dir)
   os.makedirs(artifacts_dir, exist_ok=True)
   return artifacts_dir
@@ -125,9 +125,9 @@
 _global_modules = None
 
 
-def compile_tf_module(
-    module_class: Type[tf.Module],
-    exported_names: Sequence[str] = ()) -> Modules:
+def compile_tf_module(module_class: Type[tf.Module],
+                      exported_names: Sequence[str] = (),
+                      relative_artifacts_dir: str = None) -> Modules:
   """Compiles module_class to each backend that we test.
 
   Args:
@@ -135,6 +135,9 @@
     exported_names: optional iterable of strings representing which of
       module_class's functions to compile. If exported_names is empty all
       functions will be compiled.
+    relative_artifacts_dir: optional string specifying where to save compilation
+      artifacts within the artifacts_dir. If it is not specified then
+      module_class.__name__ will be used.
 
   Returns:
     A 'Modules' namedtuple containing the reference module, target modules and
@@ -145,7 +148,9 @@
     return _global_modules
 
   # Setup the directory for saving compilation artifacts and traces.
-  artifacts_dir = _setup_artifacts_dir(module_class.__name__)
+  if relative_artifacts_dir is None:
+    relative_artifacts_dir = module_class.__name__
+  artifacts_dir = _setup_artifacts_dir(relative_artifacts_dir)
 
   # Get the backend information for this test.
   ref_backend_info = module_utils.BackendInfo(FLAGS.reference_backend,
diff --git a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/BUILD b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/BUILD
index 8b74864..1a957b7 100644
--- a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/BUILD
+++ b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/BUILD
@@ -36,6 +36,7 @@
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Transforms",
         "@org_tensorflow//tensorflow/compiler/mlir/tensorflow",
+        "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
     ],
     alwayslink = 1,
 )
@@ -70,6 +71,8 @@
     deps = [
         "//integrations/tensorflow/compiler/dialect/tf_tensorlist/ir",
         "//iree/compiler/Dialect/HAL/Conversion",
+        "//iree/compiler/Dialect/HAL/IR",
+        "//iree/compiler/Dialect/HAL/Utils",
         "//iree/compiler/Dialect/Modules/TensorList/IR",
         "//iree/compiler/Dialect/Modules/TensorList/IR:TensorListDialect",
         "@llvm-project//mlir:IR",
diff --git a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_flow_to_hal.cc b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_flow_to_hal.cc
index 501fbe2..fa38d04 100644
--- a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_flow_to_hal.cc
+++ b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_flow_to_hal.cc
@@ -16,11 +16,123 @@
 
 #include "integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/tf_tensorlist_ops.h"
 #include "iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "iree/compiler/Dialect/HAL/Utils/TypeUtils.h"
 #include "iree/compiler/Dialect/Modules/TensorList/IR/TensorListOps.h"
+#include "iree/compiler/Dialect/Modules/TensorList/IR/TensorListTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
 
 namespace mlir {
 namespace iree_compiler {
 
+namespace {
+
+Value getBufferView(Operation *srcOp, Value srcOperand, Value dstOperand,
+                    ConversionPatternRewriter &rewriter) {
+  auto operand = IREE::HAL::TensorRewriteAdaptor::getChecked(
+      srcOp->getLoc(), srcOperand, dstOperand, rewriter);
+  if (!operand.hasValue()) {
+    srcOp->emitOpError() << "unable to create adaptor for operand";
+    return nullptr;
+  }
+  auto bufferView = operand->getBufferView();
+  if (!bufferView) {
+    srcOp->emitOpError() << "unable to get buffer view for operand";
+    return nullptr;
+  }
+
+  return bufferView;
+}
+
+class ReserveOpConversion : public OpConversionPattern<tf_tensorlist::Reserve> {
+ public:
+  ReserveOpConversion(MLIRContext *ctx, TypeConverter &converter)
+      : OpConversionPattern(ctx) {}
+
+  LogicalResult matchAndRewrite(
+      tf_tensorlist::Reserve reserveOp, llvm::ArrayRef<Value> newOperands,
+      ConversionPatternRewriter &rewriter) const override {
+    auto elementTy = reserveOp.element_type();
+    auto element_value = IREE::HAL::getElementTypeValue(elementTy).getValue();
+
+    auto operand0 = getBufferView(reserveOp, reserveOp.getOperand(0),
+                                  newOperands[0], rewriter);
+    auto operand1 = getBufferView(reserveOp, reserveOp.getOperand(1),
+                                  newOperands[1], rewriter);
+
+    if (!operand0 || !operand1) {
+      return failure();
+    }
+
+    rewriter.replaceOpWithNewOp<IREE::TensorList::Reserve>(
+        reserveOp,
+        IREE::TensorList::TensorListType::get(reserveOp.getContext()), operand0,
+        operand1, rewriter.getI32IntegerAttr(element_value));
+    return success();
+  }
+};
+
+class ConcatOpConversion : public OpConversionPattern<tf_tensorlist::Concat> {
+ public:
+  ConcatOpConversion(MLIRContext *ctx, TypeConverter &converter)
+      : OpConversionPattern(ctx) {}
+
+  LogicalResult matchAndRewrite(
+      tf_tensorlist::Concat concatOp, llvm::ArrayRef<Value> newOperands,
+      ConversionPatternRewriter &rewriter) const override {
+    auto device =
+        rewriter.createOrFold<IREE::HAL::ExSharedDeviceOp>(concatOp.getLoc());
+    auto allocator =
+        rewriter.create<IREE::HAL::DeviceAllocatorOp>(concatOp.getLoc(), device)
+            .getResult();
+
+    auto newConcatOp = rewriter.createOrFold<IREE::TensorList::Concat>(
+        concatOp.getLoc(),
+        IREE::HAL::BufferViewType::get(rewriter.getContext()), allocator,
+        newOperands[0]);
+
+    auto bufferOp = rewriter.createOrFold<IREE::HAL::BufferViewBufferOp>(
+        newConcatOp.getLoc(), newConcatOp);
+
+    rewriter.replaceOp(concatOp, bufferOp);
+    return success();
+  }
+};
+
+class StackOpConversion : public OpConversionPattern<tf_tensorlist::Stack> {
+ public:
+  StackOpConversion(MLIRContext *ctx, TypeConverter &converter)
+      : OpConversionPattern(ctx) {}
+
+  LogicalResult matchAndRewrite(
+      tf_tensorlist::Stack stackOp, llvm::ArrayRef<Value> newOperands,
+      ConversionPatternRewriter &rewriter) const override {
+    auto device =
+        rewriter.createOrFold<IREE::HAL::ExSharedDeviceOp>(stackOp.getLoc());
+    auto allocator =
+        rewriter.create<IREE::HAL::DeviceAllocatorOp>(stackOp.getLoc(), device)
+            .getResult();
+
+    auto operand1 =
+        getBufferView(stackOp, stackOp.getOperand(1), newOperands[1], rewriter);
+    if (!operand1) return failure();
+
+    auto newStackOp = rewriter.createOrFold<IREE::TensorList::Stack>(
+        stackOp.getLoc(), IREE::HAL::BufferViewType::get(rewriter.getContext()),
+        allocator, newOperands[0], operand1);
+
+    auto bufferOp = rewriter.createOrFold<IREE::HAL::BufferViewBufferOp>(
+        stackOp.getLoc(), newStackOp);
+
+    rewriter.replaceOp(stackOp, bufferOp);
+    return success();
+  }
+};
+
+}  // namespace
+
 void populateTensorListToHALPatterns(MLIRContext *context,
                                      OwningRewritePatternList &patterns,
                                      TypeConverter &typeConverter) {
@@ -29,9 +141,6 @@
   // verification or have a specific use case (such as a place where only the
   // buffer is required and the shape is not) we could add our own.
   patterns.insert<
-      HALOpConversion<tf_tensorlist::Reserve, IREE::TensorList::Reserve>>(
-      context, typeConverter);
-  patterns.insert<
       HALOpConversion<tf_tensorlist::GetItem, IREE::TensorList::GetItem>>(
       context, typeConverter);
   patterns.insert<
@@ -40,12 +149,10 @@
   patterns.insert<
       HALOpConversion<tf_tensorlist::FromTensor, IREE::TensorList::FromTensor>>(
       context, typeConverter);
-  patterns
-      .insert<HALOpConversion<tf_tensorlist::Concat, IREE::TensorList::Concat>>(
-          context, typeConverter);
-  patterns
-      .insert<HALOpConversion<tf_tensorlist::Stack, IREE::TensorList::Stack>>(
-          context, typeConverter);
+
+  patterns.insert<ConcatOpConversion>(context, typeConverter);
+  patterns.insert<ReserveOpConversion>(context, typeConverter);
+  patterns.insert<StackOpConversion>(context, typeConverter);
 }
 
 }  // namespace iree_compiler
diff --git a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc
index 71e42f6..2570e2c 100644
--- a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc
+++ b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc
@@ -19,10 +19,26 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
 
 namespace mlir {
 namespace tf_tensorlist {
 
+namespace {
+TypeAttr GetVariantElementTypeAttr(Type type) {
+  if (auto variantTy = type.dyn_cast<TF::VariantType>()) {
+    return GetVariantElementTypeAttr(variantTy.getSubtypes().front());
+  }
+
+  if (auto shapedTy = type.dyn_cast<ShapedType>()) {
+    return GetVariantElementTypeAttr(shapedTy.getElementType());
+  }
+
+  return TypeAttr::get(type);
+}
+
+}  // namespace
+
 #include "integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.inc"
 
 class ConvertTfToTfTensorList
diff --git a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.td b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.td
index 0887605..3968c00 100644
--- a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.td
+++ b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.td
@@ -17,26 +17,27 @@
 include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
 include "integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/tf_tensorlist_ops.td"
 
-def : Pat<(TF_TensorListReserveOp $element_shape, $num_elements),
-          (TfTensorList_Reserve $element_shape, $num_elements)>;
+// GetElementTypeAttr
+def GetVariantElementTypeAttr : NativeCodeCall<
+  "GetVariantElementTypeAttr($0.getType())">;
+
+def : Pat<(TF_TensorListReserveOp:$result $element_shape, $num_elements),
+          (TfTensorList_Reserve $element_shape, $num_elements,
+           (GetVariantElementTypeAttr $result))>;
 
 def : Pat<(TF_TensorListGetItemOp $input_handle, $index, $element_shape),
-          (TfTensorList_GetItem
-            $input_handle,
-            $index,
-            $element_shape)>;
+          (TfTensorList_GetItem $input_handle, $index)>;
 
 def : Pat<(TF_TensorListSetItemOp $input_handle, $index, $item),
           (TfTensorList_SetItem $input_handle, $index, $item)>;
 
 def : Pat<(TF_TensorListFromTensorOp $tensor, $element_shape),
-          (TfTensorList_FromTensor $tensor, $element_shape)>;
+          (TfTensorList_FromTensor $tensor)>;
 
 def WrapScalarI64InTensor : NativeCodeCall<
     "DenseElementsAttr::get(RankedTensorType::get({}, $_builder.getIntegerType(64)), {$_self.getValue()})">;
 def : Pat<(TF_TensorListStackOp $input_handle, $element_shape, $num_elements),
           (TfTensorList_Stack
             $input_handle,
-            $element_shape,
             (ConstantOp WrapScalarI64InTensor:$num_elements))>;
 
diff --git a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/test/convert_flow_to_hal.mlir b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/test/convert_flow_to_hal.mlir
index a1fc6ad..4cbf904 100644
--- a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/test/convert_flow_to_hal.mlir
+++ b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/test/convert_flow_to_hal.mlir
@@ -5,7 +5,7 @@
 // CHECK:         [[VIEW0:%.+]] = hal.buffer_view.create %arg0{{.*}}
 // CHECK:         [[VIEW1:%.+]] = hal.buffer_view.create %arg1{{.*}}
 // CHECK:         "tensorlist.Reserve"([[VIEW0]], [[VIEW1]])
-  %0 = "tf_tensorlist.Reserve"(%arg0, %arg1) : (tensor<0xi32>, tensor<i32>) -> !tf_tensorlist.list
+  %0 = "tf_tensorlist.Reserve"(%arg0, %arg1) {element_type = f32} : (tensor<0xi32>, tensor<i32>) -> !tf_tensorlist.list
   return %0 : !tf_tensorlist.list
 }
 
@@ -18,20 +18,26 @@
   return %0 : !tf_tensorlist.list
 }
 
-// CHECK-LABEL: func @GetItem(%arg0: !tensorlist.list, %arg1: !hal.buffer, %arg2: !hal.buffer) -> !hal.buffer {
-func @GetItem(%arg0: !tf_tensorlist.list, %arg1: tensor<i32>, %arg2: tensor<0xi32>) -> tensor<f32> {
+// CHECK-LABEL: func @GetItem(%arg0: !tensorlist.list, %arg1: !hal.buffer) -> !hal.buffer {
+func @GetItem(%arg0: !tf_tensorlist.list, %arg1: tensor<i32>) -> tensor<f32> {
 // CHECK:         [[VIEW1:%.+]] = hal.buffer_view.create %arg1{{.*}}
-// CHECK:         [[VIEW2:%.+]] = hal.buffer_view.create %arg2{{.*}}
-// CHECK:         "tensorlist.GetItem"(%arg0, [[VIEW1]], [[VIEW2]])
-  %0 = "tf_tensorlist.GetItem"(%arg0, %arg1, %arg2) : (!tf_tensorlist.list, tensor<i32>, tensor<0xi32>) -> tensor<f32>
+// CHECK:         "tensorlist.GetItem"(%arg0, [[VIEW1]])
+  %0 = "tf_tensorlist.GetItem"(%arg0, %arg1) : (!tf_tensorlist.list, tensor<i32>) -> tensor<f32>
   return %0 : tensor<f32>
 }
 
-// CHECK-LABEL: func @Stack(%arg0: !tensorlist.list, %arg1: !hal.buffer, %arg2: !hal.buffer) -> !hal.buffer {
-func @Stack(%arg0: !tf_tensorlist.list, %arg1: tensor<1xi32>, %arg2: tensor<i32>) -> tensor<1xf32> {
+// CHECK-LABEL: func @Stack(%arg0: !tensorlist.list, %arg1: !hal.buffer) -> !hal.buffer {
+func @Stack(%arg0: !tf_tensorlist.list, %arg1: tensor<i32>) -> tensor<1xf32> {
 // CHECK:         [[VIEW1:%.+]] = hal.buffer_view.create %arg1{{.*}}
-// CHECK:         [[VIEW2:%.+]] = hal.buffer_view.create %arg2{{.*}}
-// CHECK:         "tensorlist.Stack"(%arg0, [[VIEW1]], [[VIEW2]])
-  %0 = "tf_tensorlist.Stack"(%arg0, %arg1, %arg2) : (!tf_tensorlist.list, tensor<1xi32>, tensor<i32>) -> tensor<1xf32>
+// CHECK:         "tensorlist.Stack"(%allocator, %arg0, [[VIEW1]])
+  %0 = "tf_tensorlist.Stack"(%arg0, %arg1) : (!tf_tensorlist.list, tensor<i32>) -> tensor<1xf32>
   return %0 : tensor<1xf32>
 }
+
+// CHECK-LABEL: func @Concat(%arg0: !tensorlist.list) -> !hal.buffer {
+func @Concat(%arg0: !tf_tensorlist.list) -> (tensor<1xf32>) {
+// CHECK:         "tensorlist.Concat"(%allocator, %arg0)
+  %0 = "tf_tensorlist.Concat"(%arg0) : (!tf_tensorlist.list) -> tensor<1xf32>
+  return %0 : tensor<1xf32>
+}
+
diff --git a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/test/convert_tf_to_tf_tensorlist.mlir b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/test/convert_tf_to_tf_tensorlist.mlir
index 1f08ff4..2d96974 100644
--- a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/test/convert_tf_to_tf_tensorlist.mlir
+++ b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/test/convert_tf_to_tf_tensorlist.mlir
@@ -4,9 +4,9 @@
 
 // CHECK-LABEL: func @basic
 func @basic(%arg0: tensor<f32>, %num_elements: tensor<i32>, %element_shape: tensor<0xi32>, %index: tensor<i32>, %item: tensor<f32>) -> tensor<f32> {
-  // CHECK-NEXT: [[LIST0:%.+]] = "tf_tensorlist.Reserve"(%arg2, %arg1) : (tensor<0xi32>, tensor<i32>) -> !tf_tensorlist.list
+  // CHECK-NEXT: [[LIST0:%.+]] = "tf_tensorlist.Reserve"(%arg2, %arg1) {element_type = f32} : (tensor<0xi32>, tensor<i32>) -> !tf_tensorlist.list
   // CHECK-NEXT: [[LIST1:%.+]] = "tf_tensorlist.SetItem"([[LIST0]], %arg3, %arg4) : (!tf_tensorlist.list, tensor<i32>, tensor<f32>) -> !tf_tensorlist.list
-  // CHECK-NEXT: [[T:%.+]] = "tf_tensorlist.GetItem"([[LIST1]], %arg3, %arg2) : (!tf_tensorlist.list, tensor<i32>, tensor<0xi32>) -> tensor<f32>
+  // CHECK-NEXT: [[T:%.+]] = "tf_tensorlist.GetItem"([[LIST1]], %arg3) : (!tf_tensorlist.list, tensor<i32>) -> tensor<f32>
   // CHECK-NEXT: return [[T]] : tensor<f32>
   %list0 = "tf.TensorListReserve"(%element_shape, %num_elements) : (tensor<0xi32>, tensor<i32>) -> tensor<!tf.variant<tensor<f32>>>
   %list1 = "tf.TensorListSetItem"(%list0, %index, %item) : (tensor<!tf.variant<tensor<f32>>>, tensor<i32>, tensor<f32>) -> tensor<!tf.variant<tensor<f32>>>
@@ -16,9 +16,9 @@
 
 // CHECK-LABEL: func @stack
 func @stack(%arg0: tensor<?xf32>, %element_shape: tensor<0xi32>) -> tensor<?xf32> {
-  // CHECK-NEXT: [[LIST0:%.+]] = "tf_tensorlist.FromTensor"(%arg0, %arg1) : (tensor<?xf32>, tensor<0xi32>) -> !tf_tensorlist.list
+  // CHECK-NEXT: [[LIST0:%.+]] = "tf_tensorlist.FromTensor"(%arg0) : (tensor<?xf32>) -> !tf_tensorlist.list
   // CHECK-NEXT: [[CONST:%.+]] = constant dense<-1> : tensor<i64>
-  // CHECK-NEXT: [[T:%.+]] = "tf_tensorlist.Stack"([[LIST0]], %arg1, [[CONST]]) : (!tf_tensorlist.list, tensor<0xi32>, tensor<i64>) -> tensor<?xf32>
+  // CHECK-NEXT: [[T:%.+]] = "tf_tensorlist.Stack"([[LIST0]], [[CONST]]) : (!tf_tensorlist.list, tensor<i64>) -> tensor<?xf32>
   // CHECK-NEXT: return [[T]]
 
   %list0 = "tf.TensorListFromTensor"(%arg0, %element_shape) : (tensor<?xf32>, tensor<0xi32>) -> tensor<!tf.variant<tensor<f32>>>
@@ -28,7 +28,7 @@
 
 // CHECK-LABEL: func @concat
 func @concat(%arg0: tensor<?xf32>, %element_shape: tensor<0xi32>, %lead: tensor<i64>) -> (tensor<?xf32>, tensor<0xi64>) {
-  // CHECK-DAG: [[LIST0:%.+]] = "tf_tensorlist.FromTensor"(%arg0, %arg1)
+  // CHECK-DAG: [[LIST0:%.+]] = "tf_tensorlist.FromTensor"(%arg0)
   // CHECK-DAG: [[T:%.+]] = "tf_tensorlist.Concat"([[LIST0]])
   // CHECK-DAG: [[L:%.+]] = "tf_tensorlist.GetDim0"([[LIST0]])
   // CHECK: return [[T]], [[L]]
@@ -38,6 +38,7 @@
   return %t#0, %t#1 : tensor<?xf32>, tensor<0xi64>
 }
 
+
 // CHECK-LABEL: func @control_flow_simple
 func @control_flow_simple(%arg0: tensor<f32>, %num_elements: tensor<i32>, %element_shape: tensor<0xi32>, %index: tensor<i32>, %item: tensor<f32>) {
   // CHECK-NEXT: tf_tensorlist.Reserve
diff --git a/integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/test/ops.mlir b/integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/test/ops.mlir
index 3bf2630..9e2c13d 100644
--- a/integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/test/ops.mlir
+++ b/integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/test/ops.mlir
@@ -9,13 +9,13 @@
   %item: tensor<?xf32>
 ) {
   // CHECK: tf_tensorlist.Reserve
-  %2 = "tf_tensorlist.Reserve"(%element_shape, %num_elements) : (tensor<1xi32>, tensor<i32>) -> !tf_tensorlist.list
+  %2 = "tf_tensorlist.Reserve"(%element_shape, %num_elements) { element_type = f32 } : (tensor<1xi32>, tensor<i32>) -> !tf_tensorlist.list
   // CHECK: tf_tensorlist.GetItem
-  %3 = "tf_tensorlist.GetItem"(%list, %index, %element_shape) : (!tf_tensorlist.list, tensor<i32>, tensor<1xi32>) -> tensor<?xf32>
+  %3 = "tf_tensorlist.GetItem"(%list, %index) : (!tf_tensorlist.list, tensor<i32>) -> tensor<?xf32>
   // CHECK: tf_tensorlist.SetItem
   %4 = "tf_tensorlist.SetItem"(%list, %index, %item) : (!tf_tensorlist.list, tensor<i32>, tensor<?xf32>) -> !tf_tensorlist.list
   // CHECK: tf_tensorlist.Stack
-  %5 = "tf_tensorlist.Stack"(%list, %element_shape, %index) : (!tf_tensorlist.list, tensor<1xi32>, tensor<i32>) -> tensor<1x2xf32>
+  %5 = "tf_tensorlist.Stack"(%list, %index) : (!tf_tensorlist.list, tensor<i32>) -> tensor<1x2xf32>
   // CHECK: tf_tensorlist.Concat
   %6 = "tf_tensorlist.Concat"(%list) : (!tf_tensorlist.list) -> tensor<1x2xf32>
   // CHECK: tf_tensorlist.GetDim0
diff --git a/integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/tf_tensorlist_ops.td b/integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/tf_tensorlist_ops.td
index f125e7a..3f35bff 100644
--- a/integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/tf_tensorlist_ops.td
+++ b/integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/tf_tensorlist_ops.td
@@ -34,7 +34,8 @@
 
   let arguments = (ins
     TF_I32OrI64Tensor:$element_shape,
-    I32Tensor:$num_elements
+    I32Tensor:$num_elements,
+    TypeAttr:$element_type
   );
 
   let results = (outs
@@ -62,8 +63,7 @@
 
   let arguments = (ins
     TfTensorList_TensorList:$list,
-    I32Tensor:$index,
-    I32Tensor:$element_shape
+    I32Tensor:$index
   );
 
   let results = (outs
@@ -103,8 +103,7 @@
   }];
 
   let arguments = (ins
-    TF_Tensor:$tensor,
-    I32Tensor:$element_shape
+    TF_Tensor:$tensor
   );
 
   let results = (outs
@@ -125,7 +124,6 @@
 
   let arguments = (ins
     TfTensorList_TensorList:$list,
-    I32Tensor:$element_shape,
     // TODO(silvasean): Properly handle IREE's blind truncation to 32-bit.
     // This is logically `index` type, but coming from TensorFlow it
     // comes in as i64. IREE then proceeds to blindly truncate it to I32
diff --git a/integrations/tensorflow/e2e/keras/BUILD b/integrations/tensorflow/e2e/keras/BUILD
index 6ecd6b5..89f92ad 100644
--- a/integrations/tensorflow/e2e/keras/BUILD
+++ b/integrations/tensorflow/e2e/keras/BUILD
@@ -28,10 +28,6 @@
     "//integrations/tensorflow/e2e:iree_e2e_cartesian_product_test_suite.bzl",
     "iree_e2e_cartesian_product_test_suite",
 )
-load(
-    "//integrations/tensorflow/e2e:iree_e2e_test_suite.bzl",
-    "iree_e2e_test_suite",
-)
 
 package(
     default_visibility = ["//visibility:public"],
@@ -39,33 +35,6 @@
     licenses = ["notice"],  # Apache 2.0
 )
 
-# @unused
-DOC = """
-vision_model_test_manual is for manual testing of all keras vision models.
-Test will run only manually with all parameters specified manually, for example:
-bazel run -c opt integrations/tensorflow/e2e/keras:vision_model_test_manual -- \
---target_backends=tf,iree_vmla \
---data=imagenet \
---url=https://storage.googleapis.com/iree_models/ \
---model=ResNet50
-
-Command arguments description:
---target_backends: can be combination of these: tf,iree_vmla
---data: can be 'imagenet' or 'cifar10'.
-    imagenet - input image size (1, 224, 224, 3)
-    cifar10 - input image size (1, 32, 32, 3) - it is used for quick tests
-            and needs pretrained weights, we pretrained models: ResNet50, MobileNet, MobileNetV2
---include_top: Whether or not to include the final (top) layers of the model.
---url: we need it only for cifar10 models to load weights from https://storage.googleapis.com/iree_models/
-       imagenet pretrained weights url is specified by keras
---model: supports ResNet50, MobileNet, MobileNetV2, ResNet101, ResNet152,
-    ResNet50V2, ResNet101V2, ResNet152V2, VGG16, VGG19, Xception,
-    InceptionV3, InceptionResNetV2, DenseNet121, DenseNet169,
-    DenseNet201, NASNetMobile, NASNetLarge
-    All above models works with 'imagenet' data sets.
-    ResNet50, MobileNet, MobileNetV2 work with both 'imagenet' and 'cifar10' data sets.
-"""
-
 [
     iree_py_binary(
         name = src.replace(".py", "_manual"),
@@ -82,308 +51,6 @@
     )
 ]
 
-SPECIAL_CASES = [
-    "keyword_spotting_streaming_test.py",
-    "vision_model_test.py",
-]
-
-TFLITE_FAILING = []
-
-VMLA_FAILING = []
-
-LLVM_FAILING = []
-
-VULKAN_FAILING = []
-
-TF_PASSING = glob(
-    ["*_test.py"],
-    exclude = SPECIAL_CASES,
-)
-
-TFLITE_PASSING = glob(
-    ["*_test.py"],
-    exclude = TFLITE_FAILING + SPECIAL_CASES,
-)
-
-VMLA_PASSING = glob(
-    ["*_test.py"],
-    exclude = VMLA_FAILING + SPECIAL_CASES,
-)
-
-LLVM_PASSING = glob(
-    ["*_test.py"],
-    exclude = LLVM_FAILING + SPECIAL_CASES,
-)
-
-VULKAN_PASSING = glob(
-    ["*_test.py"],
-    exclude = VULKAN_FAILING + SPECIAL_CASES,
-)
-
-iree_e2e_test_suite(
-    name = "keras_tests",
-    backends_to_srcs = {
-        "tf": TF_PASSING,
-        "tflite": TFLITE_PASSING,
-        "iree_vmla": VMLA_PASSING,
-        "iree_llvmjit": LLVM_PASSING,
-        "iree_vulkan": VULKAN_PASSING,
-    },
-    reference_backend = "tf",
-    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
-        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
-    ],
-)
-
-iree_e2e_test_suite(
-    name = "keras_tests_failing",
-    backends_to_srcs = {
-        "tflite": TFLITE_FAILING,
-        "iree_vmla": VMLA_FAILING,
-        "iree_llvmjit": LLVM_FAILING,
-        "iree_vulkan": VULKAN_FAILING,
-    },
-    reference_backend = "tf",
-    tags = [
-        "failing",
-        "manual",
-        "nokokoro",
-        "notap",
-    ],
-    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
-        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
-    ],
-)
-
-iree_e2e_cartesian_product_test_suite(
-    name = "large_cifar10_tests",
-    size = "large",
-    srcs = ["vision_model_test.py"],
-    flags_to_values = {
-        "reference_backend": "tf",
-        "data": "cifar10",
-        "model": [
-            # All models with runtime shorter than ResNet50.
-            "MobileNet",  # Max: Vulkan 61.0s
-            "MobileNetV2",  # Max: LLVM 96.3s
-            "ResNet50",  # Max: LLVM 145.6s
-            "VGG16",  # Max: LLVM 89.5s
-            "VGG19",  # Max: LLVM 94.7s
-        ],
-        "target_backends": [
-            "tf",
-            "tflite",
-            "iree_vmla",
-            "iree_llvmjit",
-            "iree_vulkan",
-        ],
-    },
-    main = "vision_model_test.py",
-    tags = ["manual"],
-    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
-        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
-    ],
-)
-
-iree_e2e_cartesian_product_test_suite(
-    name = "enormous_cifar10_tests",
-    size = "enormous",
-    srcs = ["vision_model_test.py"],
-    failing_configurations = [
-        {
-            # Failing on vmla with negative inputs.
-            "model": [
-                "NASNetLarge",
-                "NASNetMobile",
-            ],
-            "target_backends": "iree_vmla",
-        },
-        {
-            # Failing on llvm and vulkan:
-            "model": [
-                "NASNetLarge",
-                "NASNetMobile",
-                "ResNet50V2",
-                "ResNet101V2",
-                "ResNet152V2",
-            ],
-            "target_backends": [
-                "iree_llvmjit",
-                "iree_vulkan",
-            ],
-        },
-    ],
-    flags_to_values = {
-        "reference_backend": "tf",
-        "data": "cifar10",
-        "model": [
-            "DenseNet121",
-            "DenseNet169",
-            "DenseNet201",
-            "NASNetLarge",
-            "NASNetMobile",
-            "ResNet50V2",
-            "ResNet101",
-            "ResNet101V2",
-            "ResNet152",
-            "ResNet152V2",
-        ],
-        "target_backends": [
-            "tf",
-            "tflite",
-            "iree_vmla",
-            "iree_llvmjit",
-            "iree_vulkan",
-        ],
-    },
-    main = "vision_model_test.py",
-    tags = [
-        "guitar",
-        "manual",
-        "nokokoro",
-        "notap",
-    ],
-    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
-        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
-    ],
-)
-
-# 'non_hermetic' tests use real model weights to test numerical correctness.
-iree_e2e_cartesian_product_test_suite(
-    name = "cifar10_non_hermetic_tests",
-    size = "large",
-    srcs = ["vision_model_test.py"],
-    flags_to_values = {
-        "reference_backend": "tf",
-        "data": "cifar10",
-        "url": "https://storage.googleapis.com/iree_models/",
-        "use_external_weights": True,
-        "model": [
-            "MobileNet",
-            "MobileNetV2",
-            "ResNet50",
-        ],
-        "target_backends": [
-            "tf",
-            "tflite",
-            "iree_vmla",
-            "iree_llvmjit",
-            "iree_vulkan",
-        ],
-    },
-    main = "vision_model_test.py",
-    tags = [
-        "external",
-        "guitar",
-        "manual",
-        "no-remote",
-        "nokokoro",
-        "notap",
-    ],
-    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
-        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
-    ],
-)
-
-# 'non_hermetic' tests use real model weights to test numerical correctness.
-iree_e2e_cartesian_product_test_suite(
-    name = "imagenet_non_hermetic_tests",
-    size = "enormous",
-    srcs = ["vision_model_test.py"],
-    failing_configurations = [
-        {
-            # Failing on vmla with negative inputs.
-            "model": [
-                "NASNetLarge",
-                "NASNetMobile",
-            ],
-            "target_backends": "iree_vmla",
-        },
-        {
-            # Failing vulkan:
-            "model": [
-                "InceptionResNetV2",
-                "InceptionV3",
-            ],
-            "target_backends": [
-                "iree_vulkan",
-            ],
-        },
-        {
-            # Failing llvm and vulkan:
-            "model": [
-                "NASNetLarge",
-                "NASNetMobile",
-                "ResNet50V2",
-                "ResNet101V2",
-                "ResNet152V2",
-                "Xception",
-            ],
-            "target_backends": [
-                "iree_llvmjit",
-                "iree_vulkan",
-            ],
-        },
-    ],
-    flags_to_values = {
-        "reference_backend": "tf",
-        "data": "imagenet",
-        "use_external_weights": True,
-        "model": [
-            "DenseNet121",
-            "DenseNet169",
-            "DenseNet201",
-            "InceptionResNetV2",
-            "InceptionV3",
-            "MobileNet",
-            "MobileNetV2",
-            "NASNetLarge",
-            "NASNetMobile",
-            "ResNet50",
-            "ResNet50V2",
-            "ResNet101",
-            "ResNet101V2",
-            "ResNet152",
-            "ResNet152V2",
-            "VGG16",
-            "VGG19",
-            "Xception",
-        ],
-        "target_backends": [
-            "tf",
-            "tflite",
-            "iree_vmla",
-            "iree_llvmjit",
-            "iree_vulkan",
-        ],
-    },
-    main = "vision_model_test.py",
-    tags = [
-        "external",
-        "guitar",
-        "manual",
-        "nokokoro",
-        "notap",
-    ],
-    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
-        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
-    ],
-)
-
-# It is used to produce weights for keras vision models with input image size
-# 32x32. These models are not optimized for accuracy or latency (they are for
-# debugging only). They have the same neural net topology with keras vision
-# models trained on imagenet data sets
-iree_py_binary(
-    name = "train_vision_models_on_cifar",
-    srcs = ["train_vision_models_on_cifar.py"],
-    python_version = "PY3",
-    srcs_version = "PY2AND3",
-    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
-        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
-    ],
-)
-
 # Keyword Spotting Tests:
 KEYWORD_SPOTTING_MODELS = [
     "svdf",
diff --git a/integrations/tensorflow/e2e/keras/applications/BUILD b/integrations/tensorflow/e2e/keras/applications/BUILD
new file mode 100644
index 0000000..a9df330
--- /dev/null
+++ b/integrations/tensorflow/e2e/keras/applications/BUILD
@@ -0,0 +1,318 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Test coverage across backends for e2e tests is defined directly in the BUILD
+# files. Coverage tables generated from this file can be viewed here:
+#   https://google.github.io/iree/tensorflow-coverage/vision-coverage
+# Updates made to test suite names should also be reflected here:
+#   https://github.com/google/iree/blob/main/scripts/update_e2e_coverage.py
+
+load(
+    "//bindings/python:build_defs.oss.bzl",
+    "INTREE_TENSORFLOW_PY_DEPS",
+    "NUMPY_DEPS",
+    "iree_py_binary",
+)
+load(
+    "//integrations/tensorflow/e2e:iree_e2e_cartesian_product_test_suite.bzl",
+    "iree_e2e_cartesian_product_test_suite",
+)
+
+package(
+    default_visibility = ["//visibility:public"],
+    features = ["layering_check"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+# @unused
+DOC = """
+applications_test_manual is for manual testing of all keras vision models.
+Test will run only manually with all parameters specified manually, for example:
+bazel run -c opt integrations/tensorflow/e2e/keras/applications:applications_test_manual -- \
+--target_backends=tf,iree_vmla \
+--data=imagenet \
+--url=https://storage.googleapis.com/iree_models/ \
+--model=ResNet50
+
+Command arguments description:
+--target_backends: can be combination of these: tf,iree_vmla
+--data: can be 'imagenet' or 'cifar10'.
+    imagenet - input image size (1, 224, 224, 3)
+    cifar10 - input image size (1, 32, 32, 3) - it is used for quick tests
+            and needs pretrained weights, we pretrained models: ResNet50, MobileNet, MobileNetV2
+--include_top: Whether or not to include the final (top) layers of the model.
+--url: we need it only for cifar10 models to load weights from https://storage.googleapis.com/iree_models/
+       imagenet pretrained weights url is specified by keras
+--model: supports ResNet50, MobileNet, MobileNetV2, ResNet101, ResNet152,
+    ResNet50V2, ResNet101V2, ResNet152V2, VGG16, VGG19, Xception,
+    InceptionV3, InceptionResNetV2, DenseNet121, DenseNet169,
+    DenseNet201, NASNetMobile, NASNetLarge
+    All above models works with 'imagenet' data sets.
+    ResNet50, MobileNet, MobileNetV2 work with both 'imagenet' and 'cifar10' data sets.
+"""
+
+[
+    iree_py_binary(
+        name = src.replace(".py", "_manual"),
+        srcs = [src],
+        main = src,
+        python_version = "PY3",
+        deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+            "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+        ],
+    )
+    for src in glob(
+        ["*_test.py"],
+    )
+]
+
+KERAS_APPLICATIONS_MODELS = [
+    "DenseNet121",
+    "DenseNet169",
+    "DenseNet201",
+    "EfficientNetB0",
+    "EfficientNetB1",
+    "EfficientNetB2",
+    "EfficientNetB3",
+    "EfficientNetB4",
+    "EfficientNetB5",
+    "EfficientNetB6",
+    "EfficientNetB7",
+    "InceptionResNetV2",
+    "InceptionV3",
+    "MobileNet",
+    "MobileNetV2",
+    "MobileNetV3Large",
+    "MobileNetV3Small",
+    "NASNetLarge",
+    "NASNetMobile",
+    "ResNet101",
+    "ResNet101V2",
+    "ResNet152",
+    "ResNet152V2",
+    "ResNet50",
+    "ResNet50V2",
+    "VGG16",
+    "VGG19",
+]
+
+iree_e2e_cartesian_product_test_suite(
+    name = "large_cifar10_tests",
+    size = "large",
+    srcs = ["applications_test.py"],
+    flags_to_values = {
+        "reference_backend": "tf",
+        "data": "cifar10",
+        "model": [
+            # All models with runtime shorter than ResNet50.
+            "MobileNet",  # Max: Vulkan 61.0s
+            "MobileNetV2",  # Max: LLVM 96.3s
+            "ResNet50",  # Max: LLVM 145.6s
+            "VGG16",  # Max: LLVM 89.5s
+            "VGG19",  # Max: LLVM 94.7s
+        ],
+        "target_backends": [
+            "tf",
+            "tflite",
+            "iree_vmla",
+            "iree_llvmjit",
+            "iree_vulkan",
+        ],
+    },
+    main = "applications_test.py",
+    tags = ["manual"],
+    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+    ],
+)
+
+iree_e2e_cartesian_product_test_suite(
+    name = "enormous_cifar10_tests",
+    size = "enormous",
+    srcs = ["applications_test.py"],
+    failing_configurations = [
+        {
+            # Failing on vmla with negative inputs.
+            "model": [
+                "NASNetLarge",
+                "NASNetMobile",
+            ],
+            "target_backends": "iree_vmla",
+        },
+        {
+            # Failing on llvm and vulkan:
+            "model": [
+                "NASNetLarge",
+                "NASNetMobile",
+                "ResNet50V2",
+                "ResNet101V2",
+                "ResNet152V2",
+            ],
+            "target_backends": [
+                "iree_llvmjit",
+                "iree_vulkan",
+            ],
+        },
+    ],
+    flags_to_values = {
+        "reference_backend": "tf",
+        "data": "cifar10",
+        "model": [
+            "DenseNet121",
+            "DenseNet169",
+            "DenseNet201",
+            "NASNetLarge",
+            "NASNetMobile",
+            "ResNet50V2",
+            "ResNet101",
+            "ResNet101V2",
+            "ResNet152",
+            "ResNet152V2",
+        ],
+        "target_backends": [
+            "tf",
+            "tflite",
+            "iree_vmla",
+            "iree_llvmjit",
+            "iree_vulkan",
+        ],
+    },
+    main = "applications_test.py",
+    tags = [
+        "guitar",
+        "manual",
+        "nokokoro",
+        "notap",
+    ],
+    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+    ],
+)
+
+# 'non_hermetic' tests use real model weights to test numerical correctness.
+iree_e2e_cartesian_product_test_suite(
+    name = "cifar10_non_hermetic_tests",
+    size = "large",
+    srcs = ["applications_test.py"],
+    flags_to_values = {
+        "reference_backend": "tf",
+        "data": "cifar10",
+        "url": "https://storage.googleapis.com/iree_models/",
+        "use_external_weights": True,
+        "model": [
+            "MobileNet",
+            "MobileNetV2",
+            "ResNet50",
+        ],
+        "target_backends": [
+            "tf",
+            "tflite",
+            "iree_vmla",
+            "iree_llvmjit",
+            "iree_vulkan",
+        ],
+    },
+    main = "applications_test.py",
+    tags = [
+        "external",
+        "guitar",
+        "manual",
+        "no-remote",
+        "nokokoro",
+        "notap",
+    ],
+    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+    ],
+)
+
+# 'non_hermetic' tests use real model weights to test numerical correctness.
+iree_e2e_cartesian_product_test_suite(
+    name = "imagenet_non_hermetic_tests",
+    size = "enormous",
+    srcs = ["applications_test.py"],
+    failing_configurations = [
+        {
+            # Failing on vmla with negative inputs.
+            "model": [
+                "NASNetLarge",
+                "NASNetMobile",
+            ],
+            "target_backends": "iree_vmla",
+        },
+        {
+            # Failing vulkan:
+            "model": [
+                "InceptionResNetV2",
+                "InceptionV3",
+            ],
+            "target_backends": [
+                "iree_vulkan",
+            ],
+        },
+        {
+            # Failing llvm and vulkan:
+            "model": [
+                "NASNetLarge",
+                "NASNetMobile",
+                "ResNet50V2",
+                "ResNet101V2",
+                "ResNet152V2",
+                "Xception",
+            ],
+            "target_backends": [
+                "iree_llvmjit",
+                "iree_vulkan",
+            ],
+        },
+    ],
+    flags_to_values = {
+        "reference_backend": "tf",
+        "data": "imagenet",
+        "use_external_weights": True,
+        "model": KERAS_APPLICATIONS_MODELS,
+        "target_backends": [
+            "tf",
+            "tflite",
+            "iree_vmla",
+            "iree_llvmjit",
+            "iree_vulkan",
+        ],
+    },
+    main = "applications_test.py",
+    tags = [
+        "external",
+        "guitar",
+        "manual",
+        "nokokoro",
+        "notap",
+    ],
+    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+    ],
+)
+
+# It is used to produce weights for keras vision models with input image size
+# 32x32. These models are not optimized for accuracy or latency (they are for
+# debugging only). They have the same neural net topology with keras vision
+# models trained on imagenet data sets
+iree_py_binary(
+    name = "train_vision_models_on_cifar",
+    srcs = ["train_vision_models_on_cifar.py"],
+    python_version = "PY3",
+    srcs_version = "PY2AND3",
+    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+    ],
+)
diff --git a/integrations/tensorflow/e2e/keras/applications/applications_test.py b/integrations/tensorflow/e2e/keras/applications/applications_test.py
new file mode 100644
index 0000000..8c0e77c
--- /dev/null
+++ b/integrations/tensorflow/e2e/keras/applications/applications_test.py
@@ -0,0 +1,121 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Test all models in tf.keras.applications."""
+
+import os
+
+from absl import app
+from absl import flags
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
+import tensorflow.compat.v2 as tf
+
+FLAGS = flags.FLAGS
+
+# Testing all applications models automatically can take time
+# so we test it one by one, with argument --model=MobileNet
+flags.DEFINE_string("model", "ResNet50", "model name")
+flags.DEFINE_string(
+    "url", "", "url with model weights "
+    "for example https://storage.googleapis.com/iree_models/")
+flags.DEFINE_bool("use_external_weights", False,
+                  "Whether or not to load external weights from the web")
+flags.DEFINE_enum("data", "cifar10", ["cifar10", "imagenet"],
+                  "data sets on which model was trained: imagenet, cifar10")
+flags.DEFINE_bool(
+    "include_top", True,
+    "Whether or not to include the final (top) layers of the model.")
+
+BATCH_SIZE = 1
+IMAGE_DIM = 224
+
+
+def load_cifar10_weights(model):
+  file_name = "cifar10" + FLAGS.model
+  # get_file will download the model weights from a publicly available folder,
+  # save them to cache_dir=~/.keras/models/ and return a path to them.
+  url = os.path.join(
+      FLAGS.url, f"cifar10_include_top_{FLAGS.include_top:d}_{FLAGS.model}.h5")
+  weights_path = tf.keras.utils.get_file(file_name, url)
+  model.load_weights(weights_path)
+  return model
+
+
+def initialize_model():
+  # If weights == "imagenet", the model will load the appropriate weights from
+  # an external tf.keras URL.
+  weights = None
+  if FLAGS.use_external_weights and FLAGS.data == "imagenet":
+    weights = "imagenet"
+
+  model_class = getattr(tf.keras.applications, FLAGS.model)
+  model = model_class(weights=weights, include_top=FLAGS.include_top)
+
+  if FLAGS.use_external_weights and FLAGS.data == "cifar10":
+    if not FLAGS.url:
+      raise ValueError(
+          "cifar10 weights cannot be loaded without the `--url` flag.")
+    model = load_cifar10_weights(model)
+  return model
+
+
+class ApplicationsModule(tf_test_utils.TestModule):
+
+  def __init__(self):
+    super().__init__()
+    self.m = initialize_model()
+
+    input_shape = list([BATCH_SIZE] + self.m.inputs[0].shape[1:])
+
+    # Some models accept dynamic image dimensions by default, so we use
+    # IMAGE_DIM as a stand-in.
+    for i, dim in enumerate(input_shape):
+      if dim is None:
+        input_shape[i] = IMAGE_DIM
+
+    # Specify input shape with a static batch size.
+    # TODO(b/142948097): Add support for dynamic shapes in SPIR-V lowering.
+    self.call = tf_test_utils.tf_function_unit_test(
+        input_signature=[tf.TensorSpec(input_shape)],
+        name="call",
+        rtol=1e-5,
+        atol=1e-5)(lambda x: self.m(x, training=False))
+
+
+class ApplicationsTest(tf_test_utils.TracedModuleTestCase):
+
+  def __init__(self, *args, **kwargs):
+    super().__init__(*args, **kwargs)
+    self._modules = tf_test_utils.compile_tf_module(
+        ApplicationsModule,
+        exported_names=ApplicationsModule.get_tf_function_unit_tests(),
+        relative_artifacts_dir=os.path.join(FLAGS.model, FLAGS.data))
+
+
+def main(argv):
+  del argv  # Unused.
+  if hasattr(tf, "enable_v2_behavior"):
+    tf.enable_v2_behavior()
+
+  if not hasattr(tf.keras.applications, FLAGS.model):
+    raise ValueError(f"Unsupported model: {FLAGS.model}")
+
+  ApplicationsTest.generate_unit_tests(ApplicationsModule)
+  tf.test.main()
+
+
+if __name__ == "__main__":
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/keras/train_vision_models_on_cifar.py b/integrations/tensorflow/e2e/keras/applications/train_vision_models_on_cifar.py
similarity index 63%
rename from integrations/tensorflow/e2e/keras/train_vision_models_on_cifar.py
rename to integrations/tensorflow/e2e/keras/applications/train_vision_models_on_cifar.py
index 6cfa854..28cd296 100644
--- a/integrations/tensorflow/e2e/keras/train_vision_models_on_cifar.py
+++ b/integrations/tensorflow/e2e/keras/applications/train_vision_models_on_cifar.py
@@ -28,50 +28,12 @@
     'include_top', True,
     'Whether or not to include the final (top) layers of the model.')
 
-APP_MODELS = {
-    'ResNet50':
-        tf.keras.applications.resnet.ResNet50,
-    'ResNet101':
-        tf.keras.applications.resnet.ResNet101,
-    'ResNet152':
-        tf.keras.applications.resnet.ResNet152,
-    'ResNet50V2':
-        tf.keras.applications.resnet_v2.ResNet50V2,
-    'ResNet101V2':
-        tf.keras.applications.resnet_v2.ResNet101V2,
-    'ResNet152V2':
-        tf.keras.applications.resnet_v2.ResNet152V2,
-    'VGG16':
-        tf.keras.applications.vgg16.VGG16,
-    'VGG19':
-        tf.keras.applications.vgg19.VGG19,
-    'Xception':
-        tf.keras.applications.xception.Xception,
-    'InceptionV3':
-        tf.keras.applications.inception_v3.InceptionV3,
-    'InceptionResNetV2':
-        tf.keras.applications.inception_resnet_v2.InceptionResNetV2,
-    'MobileNet':
-        tf.keras.applications.mobilenet.MobileNet,
-    'MobileNetV2':
-        tf.keras.applications.mobilenet_v2.MobileNetV2,
-    'DenseNet121':
-        tf.keras.applications.densenet.DenseNet121,
-    'DenseNet169':
-        tf.keras.applications.densenet.DenseNet169,
-    'DenseNet201':
-        tf.keras.applications.densenet.DenseNet201,
-    'NASNetMobile':
-        tf.keras.applications.nasnet.NASNetMobile,
-    'NASNetLarge':
-        tf.keras.applications.nasnet.NASNetLarge,
-}
-
 # minimum size for keras vision models
 INPUT_SHAPE = [1, 32, 32, 3]
 
 
-def main(_):
+def main(argv):
+  del argv  # Unused.
 
   # prepare training and testing data
   (train_images,
@@ -89,10 +51,10 @@
   train_labels = train_labels[:4000]
 
   # It is a toy model for debugging (not optimized for accuracy or speed).
-
-  model = APP_MODELS[FLAGS.model](weights=None,
-                                  include_top=FLAGS.include_top,
-                                  input_shape=INPUT_SHAPE[1:])
+  model_class = getattr(tf.keras.applications, FLAGS.model)
+  model = model_class(weights=None,
+                      include_top=FLAGS.include_top,
+                      input_shape=INPUT_SHAPE[1:])
   model.summary()
   model.compile(optimizer='adam',
                 loss='sparse_categorical_crossentropy',
diff --git a/integrations/tensorflow/e2e/keras/keyword_spotting_streaming_test.py b/integrations/tensorflow/e2e/keras/keyword_spotting_streaming_test.py
index a099f5b..e2da0c2 100644
--- a/integrations/tensorflow/e2e/keras/keyword_spotting_streaming_test.py
+++ b/integrations/tensorflow/e2e/keras/keyword_spotting_streaming_test.py
@@ -49,42 +49,28 @@
 }
 
 
-class KeywordSpottingModule(tf.Module):
+class KeywordSpottingModule(tf_test_utils.TestModule):
 
   def __init__(self):
     super().__init__()
     self.m = utils.get_model_with_default_params(FLAGS.model,
                                                  MODE_ENUM_TO_MODE[FLAGS.mode])
-    self.write_input_shapes_to_cls(self.m)
-    self.m.predict = lambda x: self.m.call(x, training=False)
-    input_signature = [tf.TensorSpec(shape) for shape in self.input_shapes]
-    self.predict = tf.function(input_signature=input_signature)(self.m.predict)
 
-  @classmethod
-  def write_input_shapes_to_cls(cls, model):
-    # We store the input shapes on the cls because we need access to them to
-    # generate the random inputs. Lists are not valid exported names, so we
-    # cannot access them from the module instance itself, and instead we store
-    # the input shapes on the test case below.
-    cls.input_shapes = [tensor.shape for tensor in model.inputs]
+    call = lambda *args: self.m(*args, training=False)
+    input_signature = [tf.TensorSpec(tensor.shape) for tensor in self.m.inputs]
+    self.call = tf_test_utils.tf_function_unit_test(
+        input_signature=input_signature, name="call", atol=1e-5)(call)
 
 
 class KeywordSpottingTest(tf_test_utils.TracedModuleTestCase):
 
   def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
-    self._modules = tf_test_utils.compile_tf_module(KeywordSpottingModule,
-                                                    exported_names=['predict'])
-    self._input_shapes = KeywordSpottingModule.input_shapes
-
-  def test_predict(self):
-
-    def predict(module):
-      inputs = [tf_utils.uniform(shape) for shape in self._input_shapes]
-      inputs = inputs[0] if len(inputs) == 1 else inputs
-      module.predict(inputs, atol=1e-5)
-
-    self.compare_backends(predict, self._modules)
+    self._modules = tf_test_utils.compile_tf_module(
+        KeywordSpottingModule,
+        exported_names=['call'],
+        relative_artifacts_dir=os.path.join('kws_streaming', FLAGS.model,
+                                            FLAGS.mode))
 
 
 def main(argv):
@@ -95,9 +81,8 @@
   if FLAGS.model not in ALL_MODELS:
     raise ValueError(f'Unsupported model: {FLAGS.model}.\n'
                      f'Expected one of {MODELS_HELP}.')
-  KeywordSpottingModule.__name__ = os.path.join('keyword_spotting', FLAGS.model,
-                                                FLAGS.mode)
 
+  KeywordSpottingTest.generate_unit_tests(KeywordSpottingModule)
   tf.test.main()
 
 
diff --git a/integrations/tensorflow/e2e/keras/layers/layers_test.py b/integrations/tensorflow/e2e/keras/layers/layers_test.py
index 45a7d95..1734d68 100644
--- a/integrations/tensorflow/e2e/keras/layers/layers_test.py
+++ b/integrations/tensorflow/e2e/keras/layers/layers_test.py
@@ -480,11 +480,6 @@
   return inputs[0] if len(inputs) == 1 else inputs
 
 
-def keras_arg_wrapper(*args):
-  """Wrapper to convert multiple positional args into a list of values."""
-  return list(args) if isinstance(args, tuple) else args
-
-
 def create_wrapped_keras_layer(
     layer: str, unit_test_spec: tf_test_utils.UnitTestSpec) -> tf.keras.Model:
   """Wraps a keras layer in a model for compilation."""
@@ -522,7 +517,7 @@
     static_signature = [static_signature]
     dynamic_signature = [dynamic_signature]
 
-  call = lambda *args: model(keras_arg_wrapper(*args), training=FLAGS.training)
+  call = lambda *args: model(*args, training=FLAGS.training)
   return tf_test_utils.tf_function_unit_test(
       input_signature=dynamic_signature,
       static_signature=static_signature,
@@ -557,13 +552,22 @@
       setattr(self, unit_test_spec.unit_test_name, layer_unit_test)
 
 
+def get_relative_artifacts_dir() -> str:
+  dynamic_str = "dynamic" if FLAGS.dynamic_dims else "static"
+  training_str = "training" if FLAGS.training else "non_training"
+  full_api_str = "default_api" if FLAGS.test_default_kwargs_only else "full_api"
+  settings_str = f"{full_api_str}_{dynamic_str}_{training_str}"
+  return os.path.join("tf", "keras", "layers", FLAGS.layer, settings_str)
+
+
 class KerasLayersTest(tf_test_utils.TracedModuleTestCase):
 
   def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
     self._modules = tf_test_utils.compile_tf_module(
         KerasLayersModule,
-        exported_names=KerasLayersModule.get_tf_function_unit_tests())
+        exported_names=KerasLayersModule.get_tf_function_unit_tests(),
+        relative_artifacts_dir=get_relative_artifacts_dir())
 
 
 def main(argv):
@@ -580,17 +584,6 @@
   if FLAGS.layer not in LAYERS_TO_UNIT_TEST_SPECS:
     raise ValueError(f"Unrecognized layer: '{FLAGS.layer}'")
 
-  # Set up name for saving artifacts.
-  dynamic_str = "dynamic" if FLAGS.dynamic_dims else "static"
-  training_str = "training" if FLAGS.training else "non_training"
-  full_api_str = "default_api" if FLAGS.test_default_kwargs_only else "full_api"
-  settings_str = f"{full_api_str}_{dynamic_str}_{training_str}"
-  relative_artifacts_dir = os.path.join("tf", "keras", "layers", FLAGS.layer,
-                                        settings_str)
-  # The relative artifacts directory path is calculated from the module name
-  # TODO(meadowlark): provide a better way of overridding this default.
-  KerasLayersModule.__name__ = relative_artifacts_dir
-
   unit_tests = KerasLayersModule.get_tf_function_unit_tests()
   logging.info("Testing the following %s functions: %s", len(unit_tests),
                unit_tests)
diff --git a/integrations/tensorflow/e2e/keras/train/classification_training_test.py b/integrations/tensorflow/e2e/keras/train/classification_training_test.py
index 0b3027a..a73205b 100644
--- a/integrations/tensorflow/e2e/keras/train/classification_training_test.py
+++ b/integrations/tensorflow/e2e/keras/train/classification_training_test.py
@@ -94,7 +94,10 @@
   def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
     self._modules = tf_test_utils.compile_tf_module(
-        ClassificationTrainingModule, exported_names=["train_on_batch"])
+        ClassificationTrainingModule,
+        exported_names=["train_on_batch"],
+        relative_artifacts_dir=os.path.join(
+            ClassificationTrainingModule.__name__, FLAGS.optimizer))
 
   def test_train_on_batch(self):
 
@@ -109,8 +112,6 @@
   del argv  # Unused
   if hasattr(tf, "enable_v2_behavior"):
     tf.enable_v2_behavior()
-  ClassificationTrainingModule.__name__ = os.path.join(
-      ClassificationTrainingModule.__name__, FLAGS.optimizer)
   tf.test.main()
 
 
diff --git a/integrations/tensorflow/e2e/keras/train/regression_training_test.py b/integrations/tensorflow/e2e/keras/train/regression_training_test.py
index cd064d3..0f9aa72 100644
--- a/integrations/tensorflow/e2e/keras/train/regression_training_test.py
+++ b/integrations/tensorflow/e2e/keras/train/regression_training_test.py
@@ -73,7 +73,10 @@
   def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
     self._modules = tf_test_utils.compile_tf_module(
-        RegressionTrainingModule, exported_names=["train_on_batch"])
+        RegressionTrainingModule,
+        exported_names=["train_on_batch"],
+        relative_artifacts_dir=os.path.join(RegressionTrainingModule.__name__,
+                                            FLAGS.optimizer))
 
   def test_train_on_batch(self):
 
@@ -88,8 +91,6 @@
   del argv  # Unused
   if hasattr(tf, "enable_v2_behavior"):
     tf.enable_v2_behavior()
-  RegressionTrainingModule.__name__ = os.path.join(
-      RegressionTrainingModule.__name__, FLAGS.optimizer)
   tf.test.main()
 
 
diff --git a/integrations/tensorflow/e2e/keras/vision_model_test.py b/integrations/tensorflow/e2e/keras/vision_model_test.py
deleted file mode 100644
index 114d407..0000000
--- a/integrations/tensorflow/e2e/keras/vision_model_test.py
+++ /dev/null
@@ -1,174 +0,0 @@
-# Lint as: python3
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#      https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Test all applications models in Keras."""
-
-import os
-
-from absl import app
-from absl import flags
-import numpy as np
-from pyiree.tf.support import tf_test_utils
-from pyiree.tf.support import tf_utils
-import tensorflow.compat.v2 as tf
-
-FLAGS = flags.FLAGS
-
-# Testing all applications models automatically can take time
-# so we test it one by one, with argument --model=MobileNet
-flags.DEFINE_string('model', 'ResNet50', 'model name')
-flags.DEFINE_string(
-    'url', '', 'url with model weights '
-    'for example https://storage.googleapis.com/iree_models/')
-flags.DEFINE_bool('use_external_weights', False,
-                  'Whether or not to load external weights from the web')
-flags.DEFINE_enum('data', 'cifar10', ['cifar10', 'imagenet'],
-                  'data sets on which model was trained: imagenet, cifar10')
-flags.DEFINE_bool(
-    'include_top', True,
-    'Whether or not to include the final (top) layers of the model.')
-
-APP_MODELS = {
-    'ResNet50':
-        tf.keras.applications.resnet.ResNet50,
-    'ResNet101':
-        tf.keras.applications.resnet.ResNet101,
-    'ResNet152':
-        tf.keras.applications.resnet.ResNet152,
-    'ResNet50V2':
-        tf.keras.applications.resnet_v2.ResNet50V2,
-    'ResNet101V2':
-        tf.keras.applications.resnet_v2.ResNet101V2,
-    'ResNet152V2':
-        tf.keras.applications.resnet_v2.ResNet152V2,
-    'VGG16':
-        tf.keras.applications.vgg16.VGG16,
-    'VGG19':
-        tf.keras.applications.vgg19.VGG19,
-    'Xception':
-        tf.keras.applications.xception.Xception,
-    'InceptionV3':
-        tf.keras.applications.inception_v3.InceptionV3,
-    'InceptionResNetV2':
-        tf.keras.applications.inception_resnet_v2.InceptionResNetV2,
-    'MobileNet':
-        tf.keras.applications.mobilenet.MobileNet,
-    'MobileNetV2':
-        tf.keras.applications.mobilenet_v2.MobileNetV2,
-    'DenseNet121':
-        tf.keras.applications.densenet.DenseNet121,
-    'DenseNet169':
-        tf.keras.applications.densenet.DenseNet169,
-    'DenseNet201':
-        tf.keras.applications.densenet.DenseNet201,
-    'NASNetMobile':
-        tf.keras.applications.nasnet.NASNetMobile,
-    'NASNetLarge':
-        tf.keras.applications.nasnet.NASNetLarge,
-}
-
-
-def get_input_shape():
-  if FLAGS.data == 'imagenet':
-    if FLAGS.model in ['InceptionV3', 'Xception', 'InceptionResNetV2']:
-      return (1, 299, 299, 3)
-    elif FLAGS.model == 'NASNetLarge':
-      return (1, 331, 331, 3)
-    else:
-      return (1, 224, 224, 3)
-  elif FLAGS.data == 'cifar10':
-    return (1, 32, 32, 3)
-  else:
-    raise ValueError(f'Data not supported: {FLAGS.data}')
-
-
-def load_cifar10_weights(model):
-  file_name = 'cifar10' + FLAGS.model
-  # get_file will download the model weights from a publicly available folder,
-  # save them to cache_dir=~/.keras/models/ and return a path to them.
-  url = os.path.join(
-      FLAGS.url, f'cifar10_include_top_{FLAGS.include_top:d}_{FLAGS.model}.h5')
-  weights_path = tf.keras.utils.get_file(file_name, url)
-  model.load_weights(weights_path)
-  return model
-
-
-def initialize_model():
-  tf_utils.set_random_seed()
-
-  # Keras applications models receive input shapes without a batch dimension, as
-  # the batch size is dynamic by default. This selects just the image size.
-  input_shape = get_input_shape()[1:]
-
-  # If weights == 'imagenet', the model will load the appropriate weights from
-  # an external tf.keras URL.
-  weights = None
-  if FLAGS.use_external_weights and FLAGS.data == 'imagenet':
-    weights = 'imagenet'
-
-  model = APP_MODELS[FLAGS.model](weights=weights,
-                                  include_top=FLAGS.include_top,
-                                  input_shape=input_shape)
-
-  if FLAGS.use_external_weights and FLAGS.data == 'cifar10':
-    if not FLAGS.url:
-      raise ValueError(
-          'cifar10 weights cannot be loaded without the `--url` flag.')
-    model = load_cifar10_weights(model)
-  return model
-
-
-class VisionModule(tf.Module):
-
-  def __init__(self):
-    super().__init__()
-    self.m = initialize_model()
-    self.m.predict = lambda x: self.m.call(x, training=False)
-    # Specify input shape with a static batch size.
-    # TODO(b/142948097): Add support for dynamic shapes in SPIR-V lowering.
-    # Replace input_shape with m.input_shape to make the batch size dynamic.
-    self.predict = tf.function(
-        input_signature=[tf.TensorSpec(get_input_shape())])(self.m.predict)
-
-
-class AppTest(tf_test_utils.TracedModuleTestCase):
-
-  def __init__(self, *args, **kwargs):
-    super().__init__(*args, **kwargs)
-    self._modules = tf_test_utils.compile_tf_module(VisionModule,
-                                                    exported_names=['predict'])
-
-  def test_predict(self):
-
-    def predict(module):
-      module.predict(tf_utils.uniform(get_input_shape()), atol=1e-5, rtol=1e-5)
-
-    self.compare_backends(predict, self._modules)
-
-
-def main(argv):
-  del argv  # Unused
-  if hasattr(tf, 'enable_v2_behavior'):
-    tf.enable_v2_behavior()
-
-  if FLAGS.model not in APP_MODELS:
-    raise ValueError(f'Unsupported model: {FLAGS.model}')
-  # Override VisionModule's __name__ to be more specific.
-  VisionModule.__name__ = os.path.join(FLAGS.model, FLAGS.data)
-
-  tf.test.main()
-
-
-if __name__ == '__main__':
-  app.run(main)
diff --git a/integrations/tensorflow/e2e/math/math_test.py b/integrations/tensorflow/e2e/math/math_test.py
index 00020d0..6ec57b4 100644
--- a/integrations/tensorflow/e2e/math/math_test.py
+++ b/integrations/tensorflow/e2e/math/math_test.py
@@ -696,12 +696,34 @@
         setattr(self, unit_test_spec.unit_test_name, function_unit_test)
 
 
+def get_relative_artifacts_dir() -> str:
+  if len(FLAGS.functions) > 1:
+    # We only allow testing multiple functions with a single target backend
+    # so that we can store the artifacts under:
+    #   'artifacts_dir/multiple_functions__backend/...'
+    # We specialize the 'multiple_functions' dir by backend to avoid overwriting
+    # tf_input.mlir and iree_input.mlir. These are typically identical across
+    # backends, but are not when the functions to compile change per-backend.
+    if len(FLAGS.target_backends) != 1:
+      raise flags.IllegalFlagValueError(
+          "Expected len(target_backends) == 1 when len(functions) > 1, but got "
+          f"the following values for target_backends: {FLAGS.target_backends}.")
+    function_str = f"multiple_functions__{FLAGS.target_backends[0]}"
+  else:
+    function_str = FLAGS.functions[0]
+  dim_str = "dynamic_dims" if FLAGS.dynamic_dims else "static_dims"
+  complex_str = "complex" if FLAGS.test_complex else "non_complex"
+  return os.path.join("tf", "math", function_str, f"{dim_str}_{complex_str}")
+
+
 class TfMathTest(tf_test_utils.TracedModuleTestCase):
 
   def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
     self._modules = tf_test_utils.compile_tf_module(
-        TfMathModule, exported_names=TfMathModule.get_tf_function_unit_tests())
+        TfMathModule,
+        exported_names=TfMathModule.get_tf_function_unit_tests(),
+        relative_artifacts_dir=get_relative_artifacts_dir())
 
 
 def main(argv):
@@ -721,26 +743,6 @@
         "'--functions' must be specified if "
         "'--list_functions_with_complex_tests' isn't")
 
-  if len(FLAGS.functions) > 1:
-    # We only allow testing multiple functions with a single target backend
-    # so that we can store the artifacts under:
-    #   'artifacts_dir/multiple_functions__backend/...'
-    # We specialize the 'multiple_functions' dir by backend to avoid overwriting
-    # tf_input.mlir and iree_input.mlir. These are typically identical across
-    # backends, but are not when the functions to compile change per-backend.
-    if len(FLAGS.target_backends) != 1:
-      raise flags.IllegalFlagValueError(
-          "Expected len(target_backends) == 1 when len(functions) > 1, but got "
-          f"the following values for target_backends: {FLAGS.target_backends}.")
-    function_str = f"multiple_functions__{FLAGS.target_backends[0]}"
-  else:
-    function_str = FLAGS.functions[0]
-  dim_str = "dynamic_dims" if FLAGS.dynamic_dims else "static_dims"
-  settings_str = os.path.join(function_str, dim_str)
-  # The relative artifacts directory path is calculated from the module name
-  # TODO(meadowlark): provide a better way of overridding this default.
-  TfMathModule.__name__ = os.path.join("tf", "math", settings_str)
-
   TfMathTest.generate_unit_tests(TfMathModule)
   tf.test.main()
 
diff --git a/integrations/tensorflow/e2e/slim_vision_models/slim_vision_model_test.py b/integrations/tensorflow/e2e/slim_vision_models/slim_vision_model_test.py
index 0580eaf..1d2fb5f 100644
--- a/integrations/tensorflow/e2e/slim_vision_models/slim_vision_model_test.py
+++ b/integrations/tensorflow/e2e/slim_vision_models/slim_vision_model_test.py
@@ -77,8 +77,10 @@
 
   def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
-    self._modules = tf_test_utils.compile_tf_module(SlimVisionModule,
-                                                    exported_names=['predict'])
+    self._modules = tf_test_utils.compile_tf_module(
+        SlimVisionModule,
+        exported_names=['predict'],
+        relative_artifacts_dir=FLAGS.model)
 
   def test_predict(self):
 
@@ -94,8 +96,6 @@
   del argv  # Unused.
   if hasattr(tf, 'enable_v2_behavior'):
     tf.enable_v2_behavior()
-
-  SlimVisionModule.__name__ = FLAGS.model
   tf.test.main()
 
 
diff --git a/integrations/tensorflow/e2e/tensorlist_test.py b/integrations/tensorflow/e2e/tensorlist_test.py
index e76790e..104ede9 100644
--- a/integrations/tensorflow/e2e/tensorlist_test.py
+++ b/integrations/tensorflow/e2e/tensorlist_test.py
@@ -67,6 +67,12 @@
     ta = ta.write(1, b)
     return ta.stack()
 
+  @tf.function(input_signature=[tf.TensorSpec([], tf.float32)])
+  def partially_empty_stack(self, x):
+    ta = tf.TensorArray(dtype=tf.float32, size=2, element_shape=[])
+    ta = ta.write(0, x)
+    return ta.stack()
+
 
 class TensorListTest(tf_test_utils.TracedModuleTestCase):
 
@@ -104,6 +110,11 @@
       module.concat_with_tensorlist_stack(np.array(42., dtype=np.float32),
                                           np.array(43., dtype=np.float32))
     self.compare_backends(concat_with_tensorlist_stack, self._modules)
+
+  def test_partially_empty_stack(self):
+    def partially_empty_stack(module):
+      module.partially_empty_stack(np.array(42., dtype=np.float32))
+    self.compare_backends(partially_empty_stack, self._modules)
   # yapf: enable
 
 
diff --git a/iree/base/BUILD b/iree/base/BUILD
index e6e634d..e90ea64 100644
--- a/iree/base/BUILD
+++ b/iree/base/BUILD
@@ -67,6 +67,16 @@
     ],
 )
 
+cc_test(
+    name = "math_test",
+    srcs = ["math_test.cc"],
+    deps = [
+        ":core_headers",
+        "//iree/testing:gtest",
+        "//iree/testing:gtest_main",
+    ],
+)
+
 #===------------------------------------------------------------------------===#
 # Internal IREE C++ wrappers and utilities
 #===------------------------------------------------------------------------===#
diff --git a/iree/base/CMakeLists.txt b/iree/base/CMakeLists.txt
index c45233e..e68b8e0 100644
--- a/iree/base/CMakeLists.txt
+++ b/iree/base/CMakeLists.txt
@@ -309,6 +309,17 @@
   PUBLIC
 )
 
+iree_cc_test(
+  NAME
+    math_test
+  SRCS
+    "math_test.cc"
+  DEPS
+    ::core_headers
+    iree::testing::gtest
+    iree::testing::gtest_main
+)
+
 iree_cc_library(
   NAME
     ref_ptr
diff --git a/iree/base/internal/file_io_posix.cc b/iree/base/internal/file_io_posix.cc
index c67403e..fd1a718 100644
--- a/iree/base/internal/file_io_posix.cc
+++ b/iree/base/internal/file_io_posix.cc
@@ -126,7 +126,13 @@
     return tmpdir;
   }
 
+#ifdef __ANDROID__
+  // Support running Android command-line programs both as regular shell user
+  // and as root. For the latter, TMPDIR is not defined by default.
+  return "/data/local/tmp";
+#else
   return "/tmp";
+#endif
 }
 
 StatusOr<std::string> GetTempFile(absl::string_view base_name) {
diff --git a/iree/base/math.h b/iree/base/math.h
index 3e27781..7e64db1 100644
--- a/iree/base/math.h
+++ b/iree/base/math.h
@@ -15,9 +15,11 @@
 #ifndef IREE_BASE_MATH_H_
 #define IREE_BASE_MATH_H_
 
-#include <cstdint>
+#include <stdbool.h>
+#include <stdint.h>
+#include <stdlib.h>
 
-#include "absl/base/attributes.h"
+#include "iree/base/target_platform.h"
 
 // Haswell or later, gcc compile time option: -mlzcnt
 #if defined(__LZCNT__)
@@ -26,88 +28,98 @@
 
 // Clang on Windows has __builtin_clzll; otherwise we need to use the
 // windows intrinsic functions.
-#if defined(_MSC_VER)
+#if defined(IREE_COMPILER_MSVC)
 #include <intrin.h>
-#if defined(_M_X64)
+#if defined(IREE_ARCH_ARM_64) || defined(IREE_ARCH_X86_64)
 #pragma intrinsic(_BitScanReverse64)
 #pragma intrinsic(_BitScanForward64)
 #endif
 #pragma intrinsic(_BitScanReverse)
 #pragma intrinsic(_BitScanForward)
-#endif
+#endif  // IREE_COMPILER_MSVC
 
-#if defined(_MSC_VER)
-// We can achieve something similar to attribute((always_inline)) with MSVC by
-// using the __forceinline keyword, however this is not perfect. MSVC is
-// much less aggressive about inlining, and even with the __forceinline keyword.
-#define IREE_FORCEINLINE __forceinline
-#else
-// Use default attribute inline.
-#define IREE_FORCEINLINE inline ABSL_ATTRIBUTE_ALWAYS_INLINE
-#endif
+//==============================================================================
+// Bitwise rotation (aka circular shifts)
+//==============================================================================
 
-namespace iree {
+// Unsigned rotate-left a 64-bit integer.
+// https://en.cppreference.com/w/cpp/numeric/rotl
+//
+//
+// NOTE: this exact form is confirmed to be recognized by the compilers we care
+// about; do not modify: https://godbolt.org/z/xzof9d
+static inline uint64_t iree_math_rotl_u64(const uint64_t n, uint32_t c) {
+  if (!c) return n;
+  const uint32_t mask = 8 * sizeof(n) - 1;
+  c &= mask;
+  return (n << c) | (n >> (64 - c));
+}
 
-IREE_FORCEINLINE int CountLeadingZeros32(uint32_t n) {
-#if defined(_MSC_VER)
+// Unsigned rotate-right a 64-bit integer.
+// https://en.cppreference.com/w/cpp/numeric/rotr
+//
+// NOTE: this exact form is confirmed to be recognized by the compilers we care
+// about **except MSVC**; do not modify: https://godbolt.org/z/xzof9d
+static inline uint64_t iree_math_rotr_u64(const uint64_t n, uint32_t c) {
+  if (!c) return n;
+  const uint32_t mask = 8 * sizeof(n) - 1;
+  c &= mask;
+  return (n >> c) | (n << ((-c) & mask));
+}
+
+//==============================================================================
+// Bit scanning/counting
+//==============================================================================
+
+static inline int iree_math_count_leading_zeros_u32(const uint32_t n) {
+#if defined(IREE_COMPILER_MSVC)
   unsigned long result = 0;  // NOLINT(runtime/int)
   if (_BitScanReverse(&result, n)) {
-    return 31 - result;
+    return (int)(31 - result);
   }
   return 32;
-#elif defined(__GNUC__)
-  // Use __builtin_clz, which uses the following instructions:
-  //  x86: bsr
-  //  ARM64: clz
-  //  PPC: cntlzd
-  static_assert(sizeof(int) == sizeof(n),
-                "__builtin_clz does not take 32-bit arg");
-
+#elif defined(IREE_COMPILER_GCC_COMPAT)
 #if defined(__LCZNT__)
   // NOTE: LZCNT is a risky instruction; it is not supported on architectures
   // before Haswell, yet it is encoded as 'rep bsr', which typically ignores
   // invalid rep prefixes, and interprets it as the 'bsr' instruction, which
   // returns the index of the value rather than the count, resulting in
   // incorrect code.
-  return __lzcnt32(n);
+  return (int)__lzcnt32(n);
 #endif  // defined(__LCZNT__)
 
   // Handle 0 as a special case because __builtin_clz(0) is undefined.
-  if (n == 0) {
-    return 32;
-  }
-  return __builtin_clz(n);
-#else
-#error No clz for this arch.
-#endif
-}
-
-IREE_FORCEINLINE int CountLeadingZeros64(uint64_t n) {
-#if defined(_MSC_VER) && defined(_M_X64)
-  // MSVC does not have __buitin_clzll. Use _BitScanReverse64.
-  unsigned long result = 0;  // NOLINT(runtime/int)
-  if (_BitScanReverse64(&result, n)) {
-    return 63 - result;
-  }
-  return 64;
-#elif defined(_MSC_VER)
-  // MSVC does not have __buitin_clzll. Compose two calls to _BitScanReverse
-  unsigned long result = 0;  // NOLINT(runtime/int)
-  if ((n >> 32) && _BitScanReverse(&result, n >> 32)) {
-    return 31 - result;
-  }
-  if (_BitScanReverse(&result, n)) {
-    return 63 - result;
-  }
-  return 64;
-#elif defined(__GNUC__)
-  // Use __builtin_clzll, which uses the following instructions:
+  if (n == 0) return 32;
+  // Use __builtin_clz, which uses the following instructions:
   //  x86: bsr
   //  ARM64: clz
   //  PPC: cntlzd
-  static_assert(sizeof(unsigned long long) == sizeof(n),  // NOLINT(runtime/int)
-                "__builtin_clzll does not take 64-bit arg");
+  return (int)__builtin_clz(n);
+#else
+#error No clz for this arch.
+#endif  // IREE_COMPILER_MSVC / IREE_COMPILER_GCC_COMPAT
+}
 
+static inline int iree_math_count_leading_zeros_u64(uint64_t n) {
+#if defined(IREE_COMPILER_MSVC) && \
+    (defined(IREE_ARCH_ARM_64) || defined(IREE_ARCH_X86_64))
+  // MSVC does not have __buitin_clzll. Use _BitScanReverse64.
+  unsigned long result = 0;  // NOLINT(runtime/int)
+  if (_BitScanReverse64(&result, n)) {
+    return (int)(63 - result);
+  }
+  return 64;
+#elif defined(IREE_COMPILER_MSVC)
+  // MSVC does not have __buitin_clzll. Compose two calls to _BitScanReverse
+  unsigned long result = 0;  // NOLINT(runtime/int)
+  if ((n >> 32) && _BitScanReverse(&result, n >> 32)) {
+    return (int)(31 - result);
+  }
+  if (_BitScanReverse(&result, n)) {
+    return (int)(63 - result);
+  }
+  return 64;
+#elif defined(IREE_COMPILER_GCC_COMPAT)
 #if defined(__LCZNT__)
   // NOTE: LZCNT is a risky instruction; it is not supported on architectures
   // before Haswell, yet it is encoded as 'rep bsr', which typically ignores
@@ -117,99 +129,267 @@
   return __lzcnt64(n);
 #elif defined(__aarch64__) || defined(__powerpc64__)
   // Empirically verified that __builtin_clzll(0) works as expected.
-  return __builtin_clzll(n);
+  return (int)__builtin_clzll(n);
 #endif
-
   // Handle 0 as a special case because __builtin_clzll(0) is undefined.
-  if (n == 0) {
-    return 64;
-  }
-  return __builtin_clzll(n);
+  if (!n) return 64;
+  // Use __builtin_clzll, which uses the following instructions:
+  //    x86: bsr
+  //    PPC: cntlzd
+  //   WASM: i32.clz
+  // RISC-V: __clzsi2 in GCC, splat out in clang
+  return (int)__builtin_clzll(n);
 #else
 #error No clz for this arch.
-#endif
+#endif  // IREE_COMPILER_MSVC / IREE_COMPILER_GCC_COMPAT
 }
 
-IREE_FORCEINLINE int CountTrailingZerosNonZero32(uint32_t n) {
-#if defined(_MSC_VER)
+static inline int iree_math_count_trailing_zeros_u32(uint32_t n) {
+#if defined(IREE_COMPILER_MSVC)
   unsigned long result = 0;  // NOLINT(runtime/int)
   _BitScanForward(&result, n);
-  return result;
-#elif defined(__GNUC__)
-  static_assert(sizeof(int) == sizeof(n),
-                "__builtin_ctz does not take 32-bit arg");
-  return __builtin_ctz(n);
+  return (int)result;
+#elif defined(IREE_COMPILER_GCC_COMPAT)
+  return (int)__builtin_ctz(n);
 #else
   int c = 31;
   n &= ~n + 1;
-  if (n & 0x0000FFFF) c -= 16;
-  if (n & 0x00FF00FF) c -= 8;
-  if (n & 0x0F0F0F0F) c -= 4;
-  if (n & 0x33333333) c -= 2;
-  if (n & 0x55555555) c -= 1;
+  if (n & 0x0000FFFFu) c -= 16;
+  if (n & 0x00FF00FFu) c -= 8;
+  if (n & 0x0F0F0F0Fu) c -= 4;
+  if (n & 0x33333333u) c -= 2;
+  if (n & 0x55555555u) c -= 1;
   return c;
-#endif
+#endif  // IREE_COMPILER_MSVC / IREE_COMPILER_GCC_COMPAT
 }
 
-IREE_FORCEINLINE int CountTrailingZerosNonZero64(uint64_t n) {
-#if defined(_MSC_VER) && defined(_M_X64)
+static inline int iree_math_count_trailing_zeros_u64(uint64_t n) {
+#if defined(IREE_COMPILER_MSVC) && defined(IREE_PTR_SIZE_64)
   unsigned long result = 0;  // NOLINT(runtime/int)
   _BitScanForward64(&result, n);
-  return result;
-#elif defined(_MSC_VER)
+  return (int)result;
+#elif defined(IREE_COMPILER_MSVC) && defined(IREE_PTR_SIZE_32)
   unsigned long result = 0;  // NOLINT(runtime/int)
-  if (static_cast<uint32_t>(n) == 0) {
+  if ((uint32_t)(n) == 0) {
     _BitScanForward(&result, n >> 32);
     return result + 32;
   }
   _BitScanForward(&result, n);
-  return result;
-#elif defined(__GNUC__)
-  static_assert(sizeof(unsigned long long) == sizeof(n),  // NOLINT(runtime/int)
-                "__builtin_ctzll does not take 64-bit arg");
+  return (int)result;
+#elif defined(IREE_COMPILER_GCC_COMPAT)
+  // Use __builtin_clzll, which uses the following instructions:
+  //    x86: bsr
+  //    PPC: cntlzd
+  //   WASM: i64.clz
+  // RISC-V: __clzdi2 in GCC, splat out in clang
   return __builtin_ctzll(n);
 #else
   int c = 63;
   n &= ~n + 1;
-  if (n & 0x00000000FFFFFFFF) c -= 32;
-  if (n & 0x0000FFFF0000FFFF) c -= 16;
-  if (n & 0x00FF00FF00FF00FF) c -= 8;
-  if (n & 0x0F0F0F0F0F0F0F0F) c -= 4;
-  if (n & 0x3333333333333333) c -= 2;
-  if (n & 0x5555555555555555) c -= 1;
+  if (n & 0x00000000FFFFFFFFull) c -= 32;
+  if (n & 0x0000FFFF0000FFFFull) c -= 16;
+  if (n & 0x00FF00FF00FF00FFull) c -= 8;
+  if (n & 0x0F0F0F0F0F0F0F0Full) c -= 4;
+  if (n & 0x3333333333333333ull) c -= 2;
+  if (n & 0x5555555555555555ull) c -= 1;
   return c;
-#endif
+#endif  // IREE_COMPILER_MSVC / IREE_COMPILER_GCC_COMPAT
 }
 
-template <typename T>
-IREE_FORCEINLINE int TrailingZeros(T x) {
-  return sizeof(T) == 8 ? CountTrailingZerosNonZero64(static_cast<uint64_t>(x))
-                        : CountTrailingZerosNonZero32(static_cast<uint32_t>(x));
-}
-
-template <typename T>
-IREE_FORCEINLINE int LeadingZeros(T x) {
-  return sizeof(T) == 8 ? CountLeadingZeros64(static_cast<uint64_t>(x))
-                        : CountLeadingZeros32(static_cast<uint32_t>(x));
-}
+//==============================================================================
+// Population count
+//==============================================================================
 
 // Returns the number of 1 bits in a 32 bit value.
-IREE_FORCEINLINE int CountOnes32(uint32_t n) {
-  n -= ((n >> 1) & 0x55555555);
-  n = ((n >> 2) & 0x33333333) + (n & 0x33333333);
-  return static_cast<int>((((n + (n >> 4)) & 0xF0F0F0F) * 0x1010101) >> 24);
+static inline int iree_math_count_ones_u32(uint32_t n) {
+  n -= ((n >> 1) & 0x55555555u);
+  n = ((n >> 2) & 0x33333333u) + (n & 0x33333333u);
+  return (int)((((n + (n >> 4)) & 0x0F0F0F0Fu) * 0x01010101u) >> 24);
 }
 
 // Returns the number of 1 bits in a 64 bit value.
-IREE_FORCEINLINE int CountOnes64(uint64_t n) {
-  return CountOnes32(n >> 32) + CountOnes32(n & 0xffffffff);
+static inline int iree_math_count_ones_u64(uint64_t n) {
+  return iree_math_count_ones_u32(n >> 32) +
+         iree_math_count_ones_u32(n & 0xFFFFFFFFu);
+}
+
+//==============================================================================
+// Rounding and alignment
+//==============================================================================
+// There are certain platforms - mostly those with poorer quality compilers or
+// more restricted instruction sets - where we want to avoid the clz path as
+// it is emulated and instead we use some bit-twiddling hacks. On other
+// platforms it's the opposite - they may emulate clz but doing so saves
+// dozens of bytes that otherwise would have been the shift/or tree.
+//
+// Which to choose is entirely determined by fiddling on godbolt for the
+// target platform: https://godbolt.org/z/h4vPzo
+
+// Rounds up the value to the nearest power of 2 (if not already a power of 2).
+// For 32-bit numbers this only supports values <= 2^31; higher will wrap.
+static inline uint32_t iree_math_round_up_to_pow2_u32(uint32_t n) {
+#if 0    // golf required; can be bloated
+  const uint32_t i = (n != 1);
+  return (1 + i) << ((iree_math_count_leading_zeros_u32(n - i) ^ 31));
+#elif 0  // golf required; can be bloated
+  return n == 1 ? 1u : 2u << ((iree_math_count_leading_zeros_u32(n - 1) ^ 31));
+#else
+  // https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
+  n--;
+  n |= n >> 1;
+  n |= n >> 2;
+  n |= n >> 4;
+  n |= n >> 8;
+  n |= n >> 16;
+  return n + 1;
+#endif  // 1
 }
 
 // Rounds up the value to the nearest power of 2 (if not already a power of 2).
-IREE_FORCEINLINE int RoundUpToNearestPow2(int n) {
-  return n ? ~0u >> LeadingZeros(n) : 1;
+// For 64-bit numbers this only supports values <= 2^63; higher will wrap.
+static inline uint64_t iree_math_round_up_to_pow2_u64(uint64_t n) {
+#if 0    // golf required; can be bloated
+  const uint64_t i = (n != 1);
+  return (1 + i) << ((iree_math_count_leading_zeros_u64(n - i) ^ 63));
+#elif 0  // golf required; can be bloated
+  return n == 1 ? 1ull
+                : 2ull << ((iree_math_count_leading_zeros_u64(n - 1) ^ 63));
+#else
+  // https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
+  n--;
+  n |= n >> 1;
+  n |= n >> 2;
+  n |= n >> 4;
+  n |= n >> 8;
+  n |= n >> 16;
+  n |= n >> 32;
+  return n + 1;
+#endif  // 1
 }
 
-}  // namespace iree
+//==============================================================================
+// Pseudo-random number generators (PRNGs): **NOT CRYPTOGRAPHICALLY SECURE*
+//==============================================================================
+
+// A fixed-increment version of Java 8's SplittableRandom generator
+// See http://dx.doi.org/10.1145/2714064.2660195 and
+// http://docs.oracle.com/javase/8/docs/api/java/util/SplittableRandom.html
+//
+// SplitMix64 as recommended for use with xoroshiro by the authors:
+// http://prng.di.unimi.it/splitmix64.c
+// http://rosettacode.org/wiki/Pseudo-random_numbers/Splitmix64
+typedef uint64_t iree_prng_splitmix64_state_t;
+
+// Initializes a SplitMix64 PRNG state vector; |out_state| is overwritten.
+// |seed| may be any 64-bit value.
+static inline void iree_prng_splitmix64_initialize(
+    uint64_t seed, iree_prng_splitmix64_state_t* out_state) {
+  *out_state = seed;
+}
+
+// Steps a SplitMix64 PRNG state vector and yields a value for use.
+static inline uint64_t iree_prng_splitmix64_next(
+    iree_prng_splitmix64_state_t* state) {
+  uint64_t z = (*state += 0x9E3779B97F4A7C15ull);
+  z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ull;
+  z = (z ^ (z >> 27)) * 0x94D049BB133111EBull;
+  return z ^ (z >> 31);
+}
+
+// A small **pseudorandom** number generator (named after the operations used).
+// http://prng.di.unimi.it/
+typedef struct {
+  uint64_t value[2];
+} iree_prng_xoroshiro128_state_t;
+
+// Initializes a xoroshiro128+ PRNG state vector; |out_state| is overwritten.
+// |seed| may be any 64-bit value.
+static inline void iree_prng_xoroshiro128_initialize(
+    uint64_t seed, iree_prng_xoroshiro128_state_t* out_state) {
+  // The authors recommend using SplitMix64 to go from a single int seed
+  // into the two state values we need. It's critical that we don't use a
+  // xoroshiro128 for this as seeding a PRNG with the results of itself is...
+  // unsound.
+  iree_prng_splitmix64_state_t init_state;
+  iree_prng_splitmix64_initialize(seed, &init_state);
+  out_state->value[0] = iree_prng_splitmix64_next(&seed);
+  out_state->value[1] = iree_prng_splitmix64_next(&seed);
+
+  // A state of 0 will never produce anything but zeros so ensure that doesn't
+  // happen; of course, after running splitmix that should be closer to the
+  // side of never than not.
+  if (!out_state->value[0] && !out_state->value[1]) {
+    out_state->value[0] = 1;
+  }
+}
+
+// Steps a xoroshiro128 state vector and yields a value for use.
+// xoroshiro128+ variant: produces a single value with 32-bit bits of entropy.
+// This is the fastest variant but the lower 4 bits of the returned value may
+// not be sufficiently well-distributed. This is fine if the usage requires
+// fewer than 60 bits such as when sampling bools or array indices.
+// Note also that this works great for floating-point numbers where only 23 or
+// 53 bits are required to populate a mantissa and an additional step can be
+// used to generate the sign/exponent when required.
+//
+//   footprint: 128-bits
+//      period: 2^128 - 1
+//  ns/64-bits: 0.72
+// cycles/byte: 0.29
+//
+// http://prng.di.unimi.it/xoroshiro128plus.c
+static inline uint64_t iree_prng_xoroshiro128plus_next_uint60(
+    iree_prng_xoroshiro128_state_t* state) {
+  uint64_t s0 = state->value[0];
+  uint64_t s1 = state->value[1];
+  const uint64_t result = s0 + s1;
+  s1 ^= s0;
+  state->value[0] = iree_math_rotl_u64(s0, 24) ^ s1 ^ (s1 << 16);  // a, b
+  state->value[1] = iree_math_rotl_u64(s1, 37);                    // c
+  return result;
+}
+
+// Steps a xoroshiro128 state vector and yields a single boolean value for use.
+// See iree_prng_xoroshiro128plus_next_uint60 for details.
+static inline bool iree_prng_xoroshiro128plus_next_bool(
+    iree_prng_xoroshiro128_state_t* state) {
+  return (bool)(iree_prng_xoroshiro128plus_next_uint60(state) >> (64 - 1));
+}
+
+// Steps a xoroshiro128 state vector and yields a single uint8_t value for use.
+// See iree_prng_xoroshiro128plus_next_uint60 for details.
+static inline uint8_t iree_prng_xoroshiro128plus_next_uint8(
+    iree_prng_xoroshiro128_state_t* state) {
+  return (uint8_t)(iree_prng_xoroshiro128plus_next_uint60(state) >> (64 - 8));
+}
+
+// Steps a xoroshiro128 state vector and yields a single uint32_t value for use.
+// See iree_prng_xoroshiro128plus_next_uint60 for details.
+static inline uint32_t iree_prng_xoroshiro128plus_next_uint32(
+    iree_prng_xoroshiro128_state_t* state) {
+  return (uint32_t)(iree_prng_xoroshiro128plus_next_uint60(state) >> (64 - 32));
+}
+
+// Steps a xoroshiro128 state vector and yields a value for use.
+// xoroshiro128** variant: produces a single value with 32-bit bits of entropy.
+// Prefer this to xoroshiro128+ when good distribution over the integer range
+// is required; see xoroshiro128+ for details of its issues.
+//
+//   footprint: 128-bits
+//      period: 2^128 - 1
+//  ns/64-bits: 0.93
+// cycles/byte: 0.42
+//
+// http://prng.di.unimi.it/xoroshiro128starstar.c
+static inline uint64_t iree_prng_xoroshiro128starstar_next_uint64(
+    iree_prng_xoroshiro128_state_t* state) {
+  uint64_t s0 = state->value[0];
+  uint64_t s1 = state->value[1];
+  const uint64_t result = iree_math_rotl_u64(s0 * 5, 7) * 9;
+  s1 ^= s0;
+  state->value[0] = iree_math_rotl_u64(s0, 24) ^ s1 ^ (s1 << 16);  // a, b
+  state->value[1] = iree_math_rotl_u64(s1, 37);                    // c
+  return result;
+}
 
 #endif  // IREE_BASE_MATH_H_
diff --git a/iree/base/math_test.cc b/iree/base/math_test.cc
new file mode 100644
index 0000000..7dad343
--- /dev/null
+++ b/iree/base/math_test.cc
@@ -0,0 +1,226 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/base/math.h"
+
+#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
+
+namespace {
+
+//==============================================================================
+// Bitwise rotation (aka circular shifts)
+//==============================================================================
+
+TEST(BitwiseRotationTest, ROTL64) {
+  EXPECT_EQ(0ull, iree_math_rotl_u64(0ull, 0u));
+  EXPECT_EQ(0ull, iree_math_rotl_u64(0ull, 0u));
+  EXPECT_EQ(1ull, iree_math_rotl_u64(1ull, 0u));
+  EXPECT_EQ(1ull, iree_math_rotl_u64(1ull, 0u));
+
+  EXPECT_EQ(2ull, iree_math_rotl_u64(1ull, 1u));
+  EXPECT_EQ(2ull, iree_math_rotl_u64(1ull, 1u));
+  EXPECT_EQ(UINT64_MAX, iree_math_rotl_u64(UINT64_MAX, 63u));
+  EXPECT_EQ(UINT64_MAX, iree_math_rotl_u64(UINT64_MAX, 64u));
+}
+
+TEST(BitwiseRotationTest, ROTR64) {
+  EXPECT_EQ(0ull, iree_math_rotr_u64(0ull, 0u));
+  EXPECT_EQ(0ull, iree_math_rotr_u64(0ull, 0u));
+  EXPECT_EQ(1ull, iree_math_rotr_u64(1ull, 0u));
+  EXPECT_EQ(1ull, iree_math_rotr_u64(1ull, 0u));
+
+  EXPECT_EQ(1ull, iree_math_rotr_u64(2ull, 1u));
+  EXPECT_EQ(0x8000000000000000ull, iree_math_rotr_u64(2ull, 2u));
+  EXPECT_EQ(0x8000000000000000ull, iree_math_rotr_u64(1ull, 1u));
+  EXPECT_EQ(0x4000000000000000ull, iree_math_rotr_u64(1ull, 2u));
+}
+
+//==============================================================================
+// Bit scanning/counting
+//==============================================================================
+
+TEST(BitwiseScansTest, CLZ32) {
+  EXPECT_EQ(32, iree_math_count_leading_zeros_u32(uint32_t{}));
+  EXPECT_EQ(0, iree_math_count_leading_zeros_u32(~uint32_t{}));
+  for (int index = 0; index < 32; index++) {
+    uint32_t x = 1u << index;
+    const int cnt = 31 - index;
+    ASSERT_EQ(cnt, iree_math_count_leading_zeros_u32(x)) << index;
+    ASSERT_EQ(cnt, iree_math_count_leading_zeros_u32(x + x - 1)) << index;
+  }
+}
+
+TEST(BitwiseScansTest, CLZ64) {
+  EXPECT_EQ(64, iree_math_count_leading_zeros_u64(uint64_t{}));
+  EXPECT_EQ(0, iree_math_count_leading_zeros_u64(~uint64_t{}));
+  for (int index = 0; index < 64; index++) {
+    uint64_t x = 1ull << index;
+    const int cnt = 63 - index;
+    ASSERT_EQ(cnt, iree_math_count_leading_zeros_u64(x)) << index;
+    ASSERT_EQ(cnt, iree_math_count_leading_zeros_u64(x + x - 1)) << index;
+  }
+}
+
+TEST(BitwiseScansTest, CTZ32) {
+  EXPECT_EQ(0, iree_math_count_trailing_zeros_u32(~uint32_t{}));
+  for (int index = 0; index < 32; index++) {
+    uint32_t x = static_cast<uint32_t>(1) << index;
+    const int cnt = index;
+    ASSERT_EQ(cnt, iree_math_count_trailing_zeros_u32(x)) << index;
+    ASSERT_EQ(cnt, iree_math_count_trailing_zeros_u32(~(x - 1))) << index;
+  }
+}
+
+TEST(BitwiseScansTest, CTZ64) {
+  // iree_math_count_trailing_zeros_u32
+  EXPECT_EQ(0, iree_math_count_trailing_zeros_u64(~uint64_t{}));
+  for (int index = 0; index < 64; index++) {
+    uint64_t x = static_cast<uint64_t>(1) << index;
+    const int cnt = index;
+    ASSERT_EQ(cnt, iree_math_count_trailing_zeros_u64(x)) << index;
+    ASSERT_EQ(cnt, iree_math_count_trailing_zeros_u64(~(x - 1))) << index;
+  }
+}
+
+//==============================================================================
+// Population count
+//==============================================================================
+
+TEST(PopulationCountTest, Ones32) {
+  EXPECT_EQ(0, iree_math_count_ones_u32(0u));
+  EXPECT_EQ(1, iree_math_count_ones_u32(1u));
+  EXPECT_EQ(29, iree_math_count_ones_u32(-15u));
+  EXPECT_EQ(5, iree_math_count_ones_u32(341u));
+  EXPECT_EQ(32, iree_math_count_ones_u32(UINT32_MAX));
+  EXPECT_EQ(31, iree_math_count_ones_u32(UINT32_MAX - 1));
+}
+
+TEST(PopulationCountTest, Ones64) {
+  EXPECT_EQ(0, iree_math_count_ones_u64(0ull));
+  EXPECT_EQ(1, iree_math_count_ones_u64(1ull));
+  EXPECT_EQ(61, iree_math_count_ones_u64(-15ull));
+  EXPECT_EQ(5, iree_math_count_ones_u64(341ull));
+  EXPECT_EQ(64, iree_math_count_ones_u64(UINT64_MAX));
+  EXPECT_EQ(63, iree_math_count_ones_u64(UINT64_MAX - 1ull));
+}
+
+//==============================================================================
+// Rounding and alignment
+//==============================================================================
+
+TEST(RoundingTest, UpToNextPow232) {
+  constexpr uint32_t kUint16Max = UINT16_MAX;
+  constexpr uint32_t kUint32Max = UINT32_MAX;
+  EXPECT_EQ(0u, iree_math_round_up_to_pow2_u32(0u));
+  EXPECT_EQ(1u, iree_math_round_up_to_pow2_u32(1u));
+  EXPECT_EQ(2u, iree_math_round_up_to_pow2_u32(2u));
+  EXPECT_EQ(4u, iree_math_round_up_to_pow2_u32(3u));
+  EXPECT_EQ(8u, iree_math_round_up_to_pow2_u32(8u));
+  EXPECT_EQ(16u, iree_math_round_up_to_pow2_u32(9u));
+  EXPECT_EQ(kUint16Max + 1u, iree_math_round_up_to_pow2_u32(kUint16Max - 1u));
+  EXPECT_EQ(kUint16Max + 1u, iree_math_round_up_to_pow2_u32(kUint16Max));
+  EXPECT_EQ(kUint16Max + 1u, iree_math_round_up_to_pow2_u32(kUint16Max + 1u));
+  EXPECT_EQ(131072u, iree_math_round_up_to_pow2_u32(kUint16Max + 2u));
+  EXPECT_EQ(262144u, iree_math_round_up_to_pow2_u32(262144u - 1u));
+  EXPECT_EQ(0x80000000u, iree_math_round_up_to_pow2_u32(0x7FFFFFFFu));
+  EXPECT_EQ(0x80000000u, iree_math_round_up_to_pow2_u32(0x80000000u));
+
+  // NOTE: wrap to 0.
+  EXPECT_EQ(0u, iree_math_round_up_to_pow2_u32(0x80000001u));
+  EXPECT_EQ(0u, iree_math_round_up_to_pow2_u32(kUint32Max - 1u));
+  EXPECT_EQ(0u, iree_math_round_up_to_pow2_u32(kUint32Max));
+}
+
+TEST(RoundingTest, UpToNextPow264) {
+  constexpr uint64_t kUint16Max = UINT16_MAX;
+  constexpr uint64_t kUint64Max = UINT64_MAX;
+  EXPECT_EQ(0ull, iree_math_round_up_to_pow2_u64(0ull));
+  EXPECT_EQ(1ull, iree_math_round_up_to_pow2_u64(1ull));
+  EXPECT_EQ(2ull, iree_math_round_up_to_pow2_u64(2ull));
+  EXPECT_EQ(4ull, iree_math_round_up_to_pow2_u64(3ull));
+  EXPECT_EQ(8ull, iree_math_round_up_to_pow2_u64(8ull));
+  EXPECT_EQ(16ull, iree_math_round_up_to_pow2_u64(9ull));
+  EXPECT_EQ(kUint16Max + 1ull,
+            iree_math_round_up_to_pow2_u64(kUint16Max - 1ull));
+  EXPECT_EQ(kUint16Max + 1ull, iree_math_round_up_to_pow2_u64(kUint16Max));
+  EXPECT_EQ(kUint16Max + 1ull,
+            iree_math_round_up_to_pow2_u64(kUint16Max + 1ull));
+  EXPECT_EQ(131072ull, iree_math_round_up_to_pow2_u64(kUint16Max + 2ull));
+  EXPECT_EQ(0x100000000ull, iree_math_round_up_to_pow2_u64(0xFFFFFFFEull));
+  EXPECT_EQ(0x100000000ull, iree_math_round_up_to_pow2_u64(0xFFFFFFFFull));
+  EXPECT_EQ(0x80000000ull, iree_math_round_up_to_pow2_u64(0x7FFFFFFFull));
+  EXPECT_EQ(0x80000000ull, iree_math_round_up_to_pow2_u64(0x80000000ull));
+  EXPECT_EQ(0x100000000ull, iree_math_round_up_to_pow2_u64(0x80000001ull));
+
+  // NOTE: wrap to 0.
+  EXPECT_EQ(0ull, iree_math_round_up_to_pow2_u64(0x8000000000000001ull));
+  EXPECT_EQ(0ull, iree_math_round_up_to_pow2_u64(kUint64Max - 1ull));
+  EXPECT_EQ(0ull, iree_math_round_up_to_pow2_u64(kUint64Max));
+}
+
+//==============================================================================
+// Pseudo-random number generators (PRNGs): **NOT CRYPTOGRAPHICALLY SECURE*
+//==============================================================================
+// NOTE: we leave the real testing to the authors; this just ensures we aren't
+// `return 4;`ing it or ignoring the seed.
+
+TEST(PRNG, SplitMix64) {
+  iree_prng_splitmix64_state_t state;
+
+  iree_prng_splitmix64_initialize(/*seed=*/0ull, &state);
+  EXPECT_EQ(16294208416658607535ull, iree_prng_splitmix64_next(&state));
+  EXPECT_EQ(7960286522194355700ull, iree_prng_splitmix64_next(&state));
+
+  iree_prng_splitmix64_initialize(/*seed=*/1ull, &state);
+  EXPECT_EQ(10451216379200822465ull, iree_prng_splitmix64_next(&state));
+  EXPECT_EQ(13757245211066428519ull, iree_prng_splitmix64_next(&state));
+
+  iree_prng_splitmix64_initialize(/*seed=*/UINT64_MAX, &state);
+  EXPECT_EQ(16490336266968443936ull, iree_prng_splitmix64_next(&state));
+  EXPECT_EQ(16834447057089888969ull, iree_prng_splitmix64_next(&state));
+}
+
+TEST(PRNG, Xoroshiro128) {
+  iree_prng_xoroshiro128_state_t state;
+
+  iree_prng_xoroshiro128_initialize(/*seed=*/0ull, &state);
+  EXPECT_EQ(5807750865143411619ull,
+            iree_prng_xoroshiro128plus_next_uint60(&state));
+  EXPECT_TRUE(iree_prng_xoroshiro128plus_next_bool(&state));
+  EXPECT_EQ(218u, iree_prng_xoroshiro128plus_next_uint8(&state));
+  EXPECT_EQ(1647201753u, iree_prng_xoroshiro128plus_next_uint32(&state));
+  EXPECT_EQ(7260361800523965311ull,
+            iree_prng_xoroshiro128starstar_next_uint64(&state));
+
+  iree_prng_xoroshiro128_initialize(/*seed=*/1ull, &state);
+  EXPECT_EQ(5761717516557699368ull,
+            iree_prng_xoroshiro128plus_next_uint60(&state));
+  EXPECT_TRUE(iree_prng_xoroshiro128plus_next_bool(&state));
+  EXPECT_EQ(103u, iree_prng_xoroshiro128plus_next_uint8(&state));
+  EXPECT_EQ(2242241045u, iree_prng_xoroshiro128plus_next_uint32(&state));
+  EXPECT_EQ(661144386810419178ull,
+            iree_prng_xoroshiro128starstar_next_uint64(&state));
+
+  iree_prng_xoroshiro128_initialize(/*seed=*/UINT64_MAX, &state);
+  EXPECT_EQ(14878039250348781289ull,
+            iree_prng_xoroshiro128plus_next_uint60(&state));
+  EXPECT_FALSE(iree_prng_xoroshiro128plus_next_bool(&state));
+  EXPECT_EQ(137u, iree_prng_xoroshiro128plus_next_uint8(&state));
+  EXPECT_EQ(2111322015u, iree_prng_xoroshiro128plus_next_uint32(&state));
+  EXPECT_EQ(138107609852220106ull,
+            iree_prng_xoroshiro128starstar_next_uint64(&state));
+}
+
+}  // namespace
diff --git a/iree/compiler/Conversion/CodegenUtils/FunctionUtils.cpp b/iree/compiler/Conversion/CodegenUtils/FunctionUtils.cpp
index 7025221..5930e1f 100644
--- a/iree/compiler/Conversion/CodegenUtils/FunctionUtils.cpp
+++ b/iree/compiler/Conversion/CodegenUtils/FunctionUtils.cpp
@@ -14,6 +14,7 @@
 
 #include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
 
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/IR/SymbolTable.h"
 
 namespace mlir {
@@ -24,5 +25,14 @@
          SymbolTable::Visibility::Public;
 }
 
+unsigned getNumOuterParallelLoops(linalg::LinalgOp op) {
+  return op.iterator_types()
+      .getValue()
+      .take_while([](Attribute attr) -> bool {
+        return linalg::isParallelIteratorType(attr);
+      })
+      .size();
+}
+
 }  // namespace iree_compiler
 }  // namespace mlir
diff --git a/iree/compiler/Conversion/CodegenUtils/FunctionUtils.h b/iree/compiler/Conversion/CodegenUtils/FunctionUtils.h
index 66c85bd..b305cc7 100644
--- a/iree/compiler/Conversion/CodegenUtils/FunctionUtils.h
+++ b/iree/compiler/Conversion/CodegenUtils/FunctionUtils.h
@@ -15,6 +15,7 @@
 #ifndef IREE_COMPILER_CONVERSION_CODEGENUTILS_FUNCTIONUTILS_H_
 #define IREE_COMPILER_CONVERSION_CODEGENUTILS_FUNCTIONUTILS_H_
 
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/IR/Function.h"
 
 namespace mlir {
@@ -33,6 +34,9 @@
   return "operand_result_index";
 }
 
+/// Returns the number of outer parallel loops of a linalgOp.
+unsigned getNumOuterParallelLoops(linalg::LinalgOp op);
+
 }  // namespace iree_compiler
 }  // namespace mlir
 
diff --git a/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.cpp b/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.cpp
index 452b662..be9760b 100644
--- a/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.cpp
+++ b/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.cpp
@@ -14,18 +14,24 @@
 
 #include "iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h"
 
+#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/Utils/TypeUtils.h"
 #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/Module.h"
 
 #define DEBUG_TYPE "workgroup-calculation"
 
 namespace mlir {
 namespace iree_compiler {
+
 FuncOp getNumWorkgroupsFn(FuncOp entryPointFn,
                           llvm::StringRef numWorkgroupsFnAttr) {
   SymbolRefAttr attr =
@@ -43,9 +49,23 @@
   return numWorkgroupsFn;
 }
 
-/// Computes the bounds of the parallel loops partitioned across workgroups.
-static Optional<SmallVector<Value, 2>> getParallelLoopRange(
-    PatternRewriter &rewriter, FuncOp numWorkgroupsFn, Location loc,
+// TODO: This method is templated on the builder type since the `OpBuilder`
+// doesnt have an erase method. Just erasing the op leads to segfaults when the
+// builder is `PatternRewriter` since the rewriter doesn't know the op was
+// deleted. This can be simplified a lot when this issue is fixed.
+template <typename BuilderTy>
+static void eraseOp(BuilderTy &builder, Operation *op) {
+  builder.eraseOp(op);
+}
+template <>
+void eraseOp(OpBuilder &builder, Operation *op) {
+  op->erase();
+}
+
+/// Computes the bounds of the loops of the `linalgOp`.
+template <typename BuilderTy>
+static Optional<SmallVector<Value, 4>> getLoopUpperBounds(
+    BuilderTy &builder, Location loc, FuncOp numWorkgroupsFn,
     linalg::LinalgOp linalgOp) {
   if (!numWorkgroupsFn.empty()) {
     numWorkgroupsFn.emitError("num workgroups fn expected to be empty");
@@ -56,79 +76,105 @@
                  << numWorkgroupsFn.getName();
   });
 
-  rewriter.setInsertionPointToEnd(numWorkgroupsFn.addEntryBlock());
+  builder.createBlock(&numWorkgroupsFn.getBody(), /*insertPt=*/{},
+                      numWorkgroupsFn.getType().getInputs());
   llvm::SetVector<Operation *> slice;
   getBackwardSlice(linalgOp, &slice);
   BlockAndValueMapping mapper;
   for (Operation *op : slice) {
-    rewriter.clone(*op, mapper);
+    builder.clone(*op, mapper);
   }
   // Clone the linalg operation just to compute the loop bounds.
   linalg::LinalgOp clonedLinalgOp =
-      rewriter.clone(*linalgOp.getOperation(), mapper);
-  SmallVector<Range, 4> ranges = clonedLinalgOp.createLoopRanges(rewriter, loc);
-  SmallVector<Value, 4> bounds;
-  bounds.reserve(ranges.size());
-  for (Range r : ranges) bounds.push_back(r.size);
-  unsigned numParallelLoops = linalgOp.iterator_types()
-                                  .getValue()
-                                  .take_while([](Attribute attr) -> bool {
-                                    return attr.cast<StringAttr>().getValue() ==
-                                           getParallelIteratorTypeName();
-                                  })
-                                  .size();
-  SmallVector<Value, 2> returnVals(bounds.begin(),
-                                   std::next(bounds.begin(), numParallelLoops));
-  rewriter.eraseOp(clonedLinalgOp);
-  return returnVals;
+      builder.clone(*linalgOp.getOperation(), mapper);
+  auto loopRange = clonedLinalgOp.createLoopRanges(builder, loc);
+  if (llvm::any_of(loopRange, [](Range range) {
+        return !matchPattern(range.stride, m_One()) ||
+               !matchPattern(range.offset, m_Zero());
+      })) {
+    linalgOp.emitError("unhandled non-unit stride loop range");
+    return llvm::None;
+  }
+  SmallVector<Value, 4> bounds = llvm::to_vector<4>(
+      llvm::map_range(loopRange, [](Range range) { return range.size; }));
+  eraseOp<BuilderTy>(builder, clonedLinalgOp);
+  return bounds;
 }
 
 /// Utility method to build IR that computes ceil(`numerator` / `denominator`)
-static Value buildCeilDiv(PatternRewriter &rewriter, Location loc,
-                          Value numerator, Value denominator) {
-  Value one = rewriter.create<ConstantIndexOp>(loc, 1);
-  Value t = rewriter.create<AddIOp>(
-      loc, numerator, rewriter.create<SubIOp>(loc, denominator, one));
-  return rewriter.create<SignedDivIOp>(loc, t, denominator);
+static Value buildCeilDiv(OpBuilder &builder, Location loc, Value numerator,
+                          Value denominator) {
+  Value one = builder.create<ConstantIndexOp>(loc, 1);
+  Value t = builder.create<AddIOp>(
+      loc, numerator, builder.create<SubIOp>(loc, denominator, one));
+  return builder.create<SignedDivIOp>(loc, t, denominator);
 }
 
 /// Utility method to build IR that computes ceil(`numerator` / `denominator`)
 /// when denominator is a constant.
-static Value buildCeilDivConstDenominator(PatternRewriter &rewriter,
-                                          Location loc, Value numerator,
-                                          int64_t denominator) {
-  return buildCeilDiv(rewriter, loc, numerator,
-                      rewriter.create<ConstantIndexOp>(loc, denominator));
+static Value buildCeilDiv(OpBuilder &builder, Location loc, Value numerator,
+                          int64_t denominator) {
+  return buildCeilDiv(
+      builder, loc, numerator,
+      builder.create<ConstantIndexOp>(loc, denominator).getResult());
 }
 
-LogicalResult createNumWorkgroupsFromResultShape(
-    PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
-    llvm::StringRef numWorkgroupsFnAttr, ArrayRef<int64_t> tileSizes) {
+template <class BuilderTy>
+static LogicalResult createNumWorkgroupsFromResultShapeImpl(
+    BuilderTy &builder, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
+    llvm::StringRef numWorkgroupsFnAttr, ArrayRef<int64_t> tileSizes,
+    ArrayRef<unsigned> distributedLoops) {
   FuncOp numWorkgroupsFn = getNumWorkgroupsFn(
       linalgOp.getParentOfType<FuncOp>(), numWorkgroupsFnAttr);
   if (!numWorkgroupsFn) return failure();
 
   Location loc = linalgOp.getLoc();
-  OpBuilder::InsertionGuard guard(rewriter);
-  Optional<SmallVector<Value, 2>> parallelLoopRange =
-      getParallelLoopRange(rewriter, numWorkgroupsFn, loc, linalgOp);
-  if (!parallelLoopRange) return failure();
-  Value one = rewriter.create<ConstantIndexOp>(loc, 1);
-  SmallVector<Value, 3> returnValues(3, one);
-  for (size_t i = 0, e = std::min<size_t>(parallelLoopRange->size(), 3); i != e;
-       ++i) {
-    if (tileSizes[e - i - 1] != 0) {
-      returnValues[i] = buildCeilDivConstDenominator(
-          rewriter, loc, (*parallelLoopRange)[e - i - 1], tileSizes[e - i - 1]);
+  OpBuilder::InsertionGuard guard(builder);
+  auto loopRange = getLoopUpperBounds(builder, loc, numWorkgroupsFn, linalgOp);
+  if (!loopRange) return failure();
+
+  SmallVector<Value, 4> numWorkgroups;
+  DenseSet<unsigned> distributedLoopsSet(distributedLoops.begin(),
+                                         distributedLoops.end());
+  for (auto size : enumerate(tileSizes)) {
+    if (size.value() && distributedLoopsSet.count(size.index())) {
+      Value num =
+          buildCeilDiv(builder, loc, (*loopRange)[size.index()], size.value());
+      numWorkgroups.push_back(num);
     }
   }
-  rewriter.create<mlir::ReturnOp>(loc, returnValues);
+  SmallVector<Value, 4> resultValues =
+      llvm::to_vector<4>(llvm::reverse(numWorkgroups));
+  Value one = builder.template create<ConstantIndexOp>(loc, 1);
+  resultValues.resize(3, one);
+  builder.template create<mlir::ReturnOp>(loc, resultValues);
   return success();
 }
 
-LogicalResult createNumWorkgroupsFromLinearizedResultShape(
+LogicalResult createNumWorkgroupsFromResultShape(
+    OpBuilder &builder, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
+    llvm::StringRef numWorkgroupsFnAttr, ArrayRef<int64_t> tileSizes,
+    ArrayRef<unsigned> distributedLoops) {
+  return createNumWorkgroupsFromResultShapeImpl<OpBuilder>(
+      builder, linalgOp, entryPointFn, numWorkgroupsFnAttr, tileSizes,
+      distributedLoops);
+}
+
+LogicalResult createNumWorkgroupsFromResultShape(
     PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
-    llvm::StringRef numWorkgroupsFnAttr, int64_t workgroupSizeX) {
+    llvm::StringRef numWorkgroupsFnAttr, ArrayRef<int64_t> tileSizes) {
+  SmallVector<unsigned, 4> distributedLoops =
+      llvm::to_vector<4>(llvm::seq<unsigned>(
+          0, std::min<unsigned>(3, getNumOuterParallelLoops(linalgOp))));
+  return createNumWorkgroupsFromResultShapeImpl<PatternRewriter>(
+      rewriter, linalgOp, entryPointFn, numWorkgroupsFnAttr, tileSizes,
+      distributedLoops);
+}
+
+LogicalResult createNumWorkgroupsFromLinearizedResultShape(
+    ConversionPatternRewriter &rewriter, linalg::LinalgOp linalgOp,
+    FuncOp entryPointFn, llvm::StringRef numWorkgroupsFnAttr,
+    int64_t workgroupSizeX) {
   FuncOp numWorkgroupsFn = getNumWorkgroupsFn(
       linalgOp.getParentOfType<FuncOp>(), numWorkgroupsFnAttr);
   if (!numWorkgroupsFn) return failure();
@@ -144,16 +190,17 @@
 
   Location loc = linalgOp.getLoc();
   OpBuilder::InsertionGuard guard(rewriter);
-  Optional<SmallVector<Value, 2>> parallelLoopRange =
-      getParallelLoopRange(rewriter, numWorkgroupsFn, loc, linalgOp);
-  if (!parallelLoopRange) return failure();
+  Optional<SmallVector<Value, 4>> loopRange =
+      getLoopUpperBounds(rewriter, loc, numWorkgroupsFn, linalgOp);
+  if (!loopRange) return failure();
+  unsigned numParallelLoops = getNumOuterParallelLoops(linalgOp);
   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
   SmallVector<Value, 3> returnValues(3, one);
-  for (auto range : *parallelLoopRange) {
+  for (auto range : ArrayRef<Value>(*loopRange).take_front(numParallelLoops)) {
     returnValues[0] = rewriter.create<MulIOp>(loc, range, returnValues[0]);
   }
-  returnValues[0] = buildCeilDivConstDenominator(rewriter, loc, returnValues[0],
-                                                 workgroupSizeX);
+  returnValues[0] =
+      buildCeilDiv(rewriter, loc, returnValues[0], workgroupSizeX);
   rewriter.create<mlir::ReturnOp>(loc, returnValues);
   return success();
 }
diff --git a/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h b/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h
index 4e28779..ec40ffe 100644
--- a/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h
+++ b/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h
@@ -32,33 +32,43 @@
 namespace iree_compiler {
 
 /// Generates a function that computes the number of workgroups as
-///  [ceil(`parallelLoopRange`[2] / `tileSizes`[2]),
-///   ceil(`parallelLoopRange`[1] / `tileSizes`[1]),
-///   ceil(`parallelLoopRange`[0] / `tileSizes`[0])]
-/// where `parallelLoopRange` is the ranges of the parallel loops of `linalgOp`
-/// distributed across workgroups.
+///  [ceil(`loopUpperBounds`[2] / `tileSizes`[2]),
+///   ceil(`loopUpperBounds`[1] / `tileSizes`[1]),
+///   ceil(`loopUpperBounds`[0] / `tileSizes`[0])]
+/// where `loopUpperBounds` is the ranges of the parallel loops of `linalgOp`
+///  distributed across workgroups. `distributedLoops` are the loop dimensions
+///  that are distributed.
+LogicalResult createNumWorkgroupsFromResultShape(
+    OpBuilder &builder, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
+    llvm::StringRef numWorkgroupsFnAttr, llvm::ArrayRef<int64_t> tileSizes,
+    llvm::ArrayRef<unsigned> distributedLoops);
+
+/// Generates a function that computes the number of workgroups as
+///  [ceil(`loopUpperBounds`[2] / `tileSizes`[2]),
+///   ceil(`loopUpperBounds`[1] / `tileSizes`[1]),
+///   ceil(`loopUpperBounds`[0] / `tileSizes`[0])]
+/// where `loopUpperBounds` is the ranges of the parallel loops of `linalgOp`
+/// distributed across workgroups. Assumes that upto 3 outer parallel loops of
+/// the `linalgOp` are distributed.
 LogicalResult createNumWorkgroupsFromResultShape(
     PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
     llvm::StringRef numWorkgroupsFnAttr, llvm::ArrayRef<int64_t> tileSizes);
 
 /// Generates a function that computes the number of workgroups as
-///  ceil(`parallelLoopRange`[0] * `parallelLoopRange`[1] * ... *
-///       `parallelLoopRange`[n-1]  /  `workgroupSizeX`)
-/// where `parallelLoopRange` is the ranges of the parallel loops of `linalgOp`
+///  ceil(`loopUpperBounds`[0] * `loopUpperBounds`[1] * ... *
+///       `loopUpperBounds`[n-1]  /  `workgroupSizeX`)
+/// where `loopUpperBounds` is the ranges of the parallel loops of `linalgOp`
 /// distributed across workgroups.
 LogicalResult createNumWorkgroupsFromLinearizedResultShape(
-    PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
-    llvm::StringRef numWorkgroupsFnAttr, int64_t workgroupSizeX);
+    ConversionPatternRewriter &rewriter, linalg::LinalgOp linalgOp,
+    FuncOp entryPointFn, llvm::StringRef numWorkgroupsFnAttr,
+    int64_t workgroupSizeX);
 
 /// For a given `entryPointFn` return the function that computes the number of
 /// workgroups to use at launch time.
 FuncOp getNumWorkgroupsFn(FuncOp entryPointFn,
                           llvm::StringRef numWorkgroupsFnAttr);
 
-LogicalResult createNumWorkgroupsFromLinearizedResultShape(
-    PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
-    llvm::StringRef numWorkgroupsFnAttr, int64_t workgroupSizeX);
-
 /// The codegeneration emits a function `numWorkgroupsFn` for each entry point
 /// function. This function has arguments the !shapex.ranked_shape for all the
 /// input and output shaped types. Using this the function returns the number of
diff --git a/iree/compiler/Conversion/CodegenUtils/MarkerUtils.cpp b/iree/compiler/Conversion/CodegenUtils/MarkerUtils.cpp
index ab2a9c1..a7af32d 100644
--- a/iree/compiler/Conversion/CodegenUtils/MarkerUtils.cpp
+++ b/iree/compiler/Conversion/CodegenUtils/MarkerUtils.cpp
@@ -47,6 +47,12 @@
   return "copy_to_workgroup_memory";
 }
 
+// This marker is needed because we tile a convolution op multiple times: 1)
+// workgroups, 2) invocations, and 3) tiling along filter's height/width and
+// input channel to generate loops for a single GPU invocation. This marker
+// is for the 3) step.
+StringRef getConvFilterTileMarker() { return "tile_conv_filter"; }
+
 StringRef getVectorizeMarker() { return "vectorize"; }
 
 StringRef getDeleteMarker() { return "delete"; }
diff --git a/iree/compiler/Conversion/CodegenUtils/MarkerUtils.h b/iree/compiler/Conversion/CodegenUtils/MarkerUtils.h
index 5072293..5371aba 100644
--- a/iree/compiler/Conversion/CodegenUtils/MarkerUtils.h
+++ b/iree/compiler/Conversion/CodegenUtils/MarkerUtils.h
@@ -45,6 +45,9 @@
 /// Workgroup memory.
 StringRef getCopyToWorkgroupMemoryMarker();
 
+/// Marker for tiling along convolution filter dimensions.
+StringRef getConvFilterTileMarker();
+
 /// Marker for operations that are going to be vectorized.
 StringRef getVectorizeMarker();
 
diff --git a/iree/compiler/Conversion/Common/BUILD b/iree/compiler/Conversion/Common/BUILD
index b8c5b5a..5827238 100644
--- a/iree/compiler/Conversion/Common/BUILD
+++ b/iree/compiler/Conversion/Common/BUILD
@@ -22,22 +22,33 @@
     name = "Common",
     srcs = [
         "DeclareNumWorkgroupsFnPass.cpp",
+        "LaunchConfig.cpp",
         "LegalizeNumWorkgroupsFnPass.cpp",
+        "Transforms.cpp",
+        "VectorTransferOptimization.cpp",
     ],
     hdrs = [
         "Attributes.h",
+        "LaunchConfig.h",
         "Passes.h",
+        "Transforms.h",
     ],
     deps = [
         "//iree/compiler/Conversion/CodegenUtils",
         "//iree/compiler/Dialect/HAL/IR",
         "//iree/compiler/Dialect/IREE/IR",
         "//iree/compiler/Dialect/Shape/IR",
+        "@llvm-project//llvm:Support",
         "@llvm-project//mlir:CFGTransforms",
+        "@llvm-project//mlir:GPUDialect",
         "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:LinalgOps",
+        "@llvm-project//mlir:LinalgTransforms",
         "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:SPIRVLowering",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Transforms",
+        "@llvm-project//mlir:VectorOps",
         "@org_tensorflow//tensorflow/compiler/mlir/hlo",
     ],
 )
diff --git a/iree/compiler/Conversion/Common/CMakeLists.txt b/iree/compiler/Conversion/Common/CMakeLists.txt
index b233805..955f010 100644
--- a/iree/compiler/Conversion/Common/CMakeLists.txt
+++ b/iree/compiler/Conversion/Common/CMakeLists.txt
@@ -19,16 +19,28 @@
     Common
   HDRS
     "Attributes.h"
+    "LaunchConfig.h"
     "Passes.h"
+    "Transforms.h"
   SRCS
     "DeclareNumWorkgroupsFnPass.cpp"
+    "LaunchConfig.cpp"
     "LegalizeNumWorkgroupsFnPass.cpp"
+    "Transforms.cpp"
+    "VectorTransferOptimization.cpp"
   DEPS
+    LLVMSupport
+    MLIRGPU
     MLIRIR
+    MLIRLinalg
+    MLIRLinalgTransforms
     MLIRPass
     MLIRSCFToStandard
+    MLIRSPIRV
+    MLIRSPIRVTransforms
     MLIRStandard
     MLIRTransforms
+    MLIRVector
     iree::compiler::Conversion::CodegenUtils
     iree::compiler::Dialect::HAL::IR
     iree::compiler::Dialect::IREE::IR
diff --git a/iree/compiler/Conversion/Common/LaunchConfig.cpp b/iree/compiler/Conversion/Common/LaunchConfig.cpp
new file mode 100644
index 0000000..20c5863
--- /dev/null
+++ b/iree/compiler/Conversion/Common/LaunchConfig.cpp
@@ -0,0 +1,169 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- LaunchConfig.cpp - Specifies configuration used to drive the nfo ---===//
+//
+// This file defines the data structure that is used by the codegeneration to
+// lower to target specific IR. The values of the parameters are archtecture
+// specific. Once set the same transformations can be used to generate the
+// desired code. This allows sharing codegen infra between different backends.
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree/compiler/Conversion/Common/LaunchConfig.h"
+
+#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
+#include "iree/compiler/Conversion/Common/Attributes.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/SPIRV/TargetAndABI.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/StandardTypes.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Name of the StrAttr that can be used to get the key to access the tile size
+/// information.
+static const char kLaunchInfoKey[] = "launch_info_key";
+
+static Optional<StringRef> getKey(Operation *op) {
+  StringAttr attr = op->getAttrOfType<StringAttr>(kLaunchInfoKey);
+  if (!attr) return {};
+  return attr.getValue();
+}
+
+static void setKey(Operation *op, StringRef key) {
+  MLIRContext *context = op->getContext();
+  op->setAttr(Identifier::get(kLaunchInfoKey, op->getContext()),
+              StringAttr::get(key, context));
+}
+
+static std::string getOrSetNewKey(Operation *op, int64_t suffix) {
+  Optional<StringRef> key = getKey(op);
+  if (key) return key->str();
+  std::string newKey = llvm::formatv("__op_num_{0}__", suffix).str();
+  setKey(op, StringRef(newKey));
+  return newKey;
+}
+
+void LaunchConfig::finalize(FuncOp funcOp) {
+  funcOp.walk([&](linalg::LinalgOp linalgOp) {
+    linalgOp.removeAttr(Identifier::get(kLaunchInfoKey, funcOp.getContext()));
+  });
+}
+
+TileSizesListTypeRef LaunchConfig::getTileSizes(Operation *op) const {
+  auto key = getKey(op);
+  if (!key) return {};
+  auto it = tileSizes.find(*key);
+  return it->second;
+}
+
+ArrayRef<int64_t> LaunchConfig::getTileSizes(Operation *op,
+                                             size_t level) const {
+  auto t = getTileSizes(op);
+  if (level >= t.size()) return {};
+  return t[level];
+}
+
+void LaunchConfig::setTileSizes(Operation *op, TileSizesListType vTileSizes) {
+  tileSizes[getOrSetNewKey(op, tileSizes.size())] = vTileSizes;
+}
+
+void LaunchConfig::setTileSizes(Operation *op, ArrayRef<int64_t> vTileSizes,
+                                size_t level) {
+  tileSizes[getOrSetNewKey(op, tileSizes.size())].emplace_back(
+      vTileSizes.begin(), vTileSizes.end());
+}
+
+static void setArrayVals(std::array<int64_t, 3> &array,
+                         ArrayRef<int64_t> vals) {
+  if (vals.size() > 3) vals = vals.take_front(3);
+  for (auto size : enumerate(vals)) array[size.index()] = size.value();
+  for (unsigned i : llvm::seq<unsigned>(vals.size(), 3)) array[i] = 1;
+}
+
+void LaunchConfig::setWorkgroupSize(ArrayRef<int64_t> vWorkgroupSize) {
+  setArrayVals(workgroupSize, vWorkgroupSize);
+}
+
+void LaunchConfig::setNumSubgroups(ArrayRef<int64_t> vNumSubgroups) {
+  setArrayVals(numSubgroups, vNumSubgroups);
+}
+
+void LaunchConfig::setSameConfig(Operation *source, Operation *target) {
+  assert(getKey(source) && "missing configuration of source operation");
+  setKey(target, *getKey(source));
+}
+
+void LaunchConfig::setVectorize(bool enableVectorize) {
+  vectorize = enableVectorize;
+}
+
+LogicalResult propogateRootOperationLaunchConfig(
+    LaunchConfig &config, linalg::LinalgOp rootOperation,
+    const linalg::LinalgDependenceGraph &dependenceGraph) {
+  // Check the dependencies going into and out of the root operation. For now
+  // only the following dependencies are supported
+  // - WAW dependencies going into the root operation.
+  // - RAW dependencies going out of the root operation.
+  // - WAW dependencies going out of the root operation.
+  // i.e. there are no RAW dependences going into the root operation.
+  auto inRAWDependencies = dependenceGraph.getDependencesInto(
+      rootOperation, linalg::LinalgDependenceGraph::RAW);
+  if (!inRAWDependencies.empty()) {
+    return rootOperation.getOperation()->emitError(
+        "unhandled fusion of root operation with producer");
+  }
+  auto dependences = dependenceGraph.getDependentOperations(rootOperation);
+  unsigned numOuterParallel = getNumOuterParallelLoops(rootOperation);
+
+  // Check that for all dependences into and out of the root operation,
+  // - The result expressions of the indexing maps of the fused view in the
+  //   producer and consumer must match for the parallel loops.
+  for (auto dependence :
+       dependenceGraph.getDependentOperations(rootOperation)) {
+    unsigned viewIndex = dependence.indexingOpView.operandIndex;
+    AffineMap indexingMap = rootOperation.getIndexingMap(viewIndex);
+    linalg::LinalgOp fusedOp =
+        cast<linalg::LinalgOp>(dependence.dependentOpView.op);
+    unsigned fusedViewIndex = dependence.dependentOpView.operandIndex;
+    AffineMap fusedIndexingMap = fusedOp.getIndexingMap(fusedViewIndex);
+    if (indexingMap.getNumResults() < numOuterParallel ||
+        fusedIndexingMap.getNumResults() < numOuterParallel ||
+        !llvm::all_of(
+            llvm::seq<unsigned>(0, numOuterParallel),
+            [&indexingMap, fusedIndexingMap](unsigned i) {
+              return indexingMap.getResult(i).isa<AffineDimExpr>() &&
+                     fusedIndexingMap.getResult(i).isa<AffineDimExpr>() &&
+                     indexingMap.getResult(i) == fusedIndexingMap.getResult(i);
+            })) {
+      return rootOperation.getOperation()->emitError(
+          "unhandled fusion of root operation with all operations in the "
+          "dispatch region");
+    }
+  }
+  // The dependent operations get the same tile size information as the root
+  // operation. To propogate that information, just use the same key as the root
+  // operation.
+  for (auto dependence : dependences) {
+    config.setSameConfig(rootOperation, dependence.dependentOpView.op);
+  }
+  return success();
+}
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Conversion/Common/LaunchConfig.h b/iree/compiler/Conversion/Common/LaunchConfig.h
new file mode 100644
index 0000000..dcb3b14
--- /dev/null
+++ b/iree/compiler/Conversion/Common/LaunchConfig.h
@@ -0,0 +1,140 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- LaunchConfig.h - Configuration used to drive arch specific codegen -===//
+//
+// This file declares the data structure that is used by the codegeneration to
+// lower to target specific IR. The values of the parameters are archtecture
+// specific. Once set the same transformations can be used to generate the
+// desired code. This allows sharing codegen infra between different backends.
+//
+//===----------------------------------------------------------------------===//
+#ifndef IREE_COMPILER_CONVERSION_COMMON_LAUNCHCONFIG_H_
+#define IREE_COMPILER_CONVERSION_COMMON_LAUNCHCONFIG_H_
+#include <array>
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringMap.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Stores the tile sizes to use at different levels of tiling as a vector of
+/// vectors.
+/// - First level tiling maps to workgroups.
+/// - Second level tiling maps to subgroups.
+/// - Third level tiling maps to invocations.
+using TileSizesListType = SmallVector<SmallVector<int64_t, 4>, 1>;
+using TileSizesListTypeRef = ArrayRef<SmallVector<int64_t, 4>>;
+
+/// Configurations for mapping Linalg ops to CPU/GPU parallel hiearchies.
+///
+/// Based on the linalg operations in a dispatch region, the number of levels of
+/// tiling, the tile sizes needed, the workgroup size, etc. need to be
+/// decided. These parameters are called `LaunchConfig`. This class implements
+/// one heuristic to compute these for the different linalg operations on
+/// buffers. This can be adapted later to support multiple configurations that
+/// can be picked based on device information/problem size information. It
+/// exposes the information needed by the codegenerators, and hides the
+/// implementation from the rest of the pipeline.
+class LaunchConfig {
+ public:
+  LaunchConfig() : workgroupSize({1, 1, 1}), numSubgroups({1, 1, 1}) {}
+
+  /// Removes attributes added to operations for retrieving tile size
+  /// information.
+  void finalize(FuncOp funcOp);
+
+  /// Gets the tile size computed for an operation at all levels.
+  TileSizesListTypeRef getTileSizes(Operation *op) const;
+
+  /// Gets the tile size computed for an operation for an level.
+  ArrayRef<int64_t> getTileSizes(Operation *op, size_t level) const;
+
+  /// Returns the workgroup size to use based on the tile sizes.
+  ArrayRef<int64_t> getWorkgroupSize() const { return workgroupSize; }
+
+  /// Returns the number of subgroups to use.
+  ArrayRef<int64_t> getNumSubgroups() const { return numSubgroups; }
+
+  /// Returns true if tile sizes have been computed for the operation. If tile
+  /// sizes arent set, it implies operation is not to be tiled.
+  bool hasTileSizes(Operation *op, size_t level = 0) const {
+    return !getTileSizes(op, level).empty();
+  }
+
+  /// Use vectorize transformations.
+  bool useVectorize() const { return vectorize; }
+
+  /// Sets the tile sizes to use for all levels of tiling of `op`.
+  void setTileSizes(Operation *op, TileSizesListType vTileSizes);
+
+  /// Sets the tile sizes to use for a given `level` of tiling of `op`.
+  void setTileSizes(Operation *op, ArrayRef<int64_t> vTileSizes, size_t level);
+
+  /// Sets the workgroup size to use for the function.
+  void setWorkgroupSize(ArrayRef<int64_t> vWorkgroupSize);
+
+  /// Sets number of subgroups to use.
+  void setNumSubgroups(ArrayRef<int64_t> vNumSubgroups);
+
+  /// Sets the configuration of the `targetOp` to be same as the configuration
+  /// of the `sourceOp`.
+  void setSameConfig(Operation *sourceOp, Operation *targetOp);
+
+  /// Sets flag to enable vectorization.
+  void setVectorize(bool enableVectorize);
+
+ protected:
+  /// Current tile size configuration per operation. They key used here to
+  /// retrieve the tile size information per operation is the value of a StrAttr
+  /// added to operations during `init`. When tiled this attribute is copied
+  /// over to the tiled operation, thereby the same key can be used to retrieve
+  /// the tile sizes for the next level of tiling. The `finalize` method removes
+  /// these attributes.
+  llvm::StringMap<TileSizesListType> tileSizes;
+
+  /// Workgroup size to use.
+  std::array<int64_t, 3> workgroupSize = {1, 1, 1};
+
+  /// Number of subgroups that are logically distributed along x, y & z.
+  std::array<int64_t, 3> numSubgroups = {1, 1, 1};
+
+  /// Use vectorization.
+  bool vectorize = false;
+};
+
+/// Propogates tile sizes from `rootOperation` to other linalg operations in the
+/// dispatch region. This assumes that each dispatch region has a single root
+/// operation (like matmul, conv, etc.) that determines the tile sizes to use
+/// for tile+fuse+distribute. These are then propogated to the other operations.
+/// Note: This is a temporary solution and might be defunct when the codegen
+/// becomes more sophisticated.
+LogicalResult propogateRootOperationLaunchConfig(
+    LaunchConfig &launchConfig, linalg::LinalgOp rootOperation,
+    const linalg::LinalgDependenceGraph &dependenceGraph);
+
+}  // namespace iree_compiler
+}  // namespace mlir
+#endif  // IREE_COMPILER_CONVERSION_COMMON_LAUNCHCONFIG_H_
diff --git a/iree/compiler/Conversion/Common/Passes.h b/iree/compiler/Conversion/Common/Passes.h
index 00a61e0..1c9ca32 100644
--- a/iree/compiler/Conversion/Common/Passes.h
+++ b/iree/compiler/Conversion/Common/Passes.h
@@ -23,5 +23,8 @@
 /// each entry point function. The function is defined, but is populated later.
 std::unique_ptr<OperationPass<ModuleOp>> createDeclareNumWorkgroupsFnPass();
 
+/// Pass to optimize vector transfer_read and transfer_write.
+std::unique_ptr<FunctionPass> createVectorTransferOptimizationPass();
+
 }  // namespace iree_compiler
 }  // namespace mlir
diff --git a/iree/compiler/Conversion/Common/Transforms.cpp b/iree/compiler/Conversion/Common/Transforms.cpp
new file mode 100644
index 0000000..cf7bc2a
--- /dev/null
+++ b/iree/compiler/Conversion/Common/Transforms.cpp
@@ -0,0 +1,273 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- Transforms.cpp - Transformations common to all backends ------------===//
+//
+// Implements transformations that are common to all backends.
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree/compiler/Conversion/Common/Transforms.h"
+
+#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
+#include "iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h"
+#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
+#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
+#include "iree/compiler/Conversion/Common/Attributes.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-linalg-tile-and-fuse"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Apply canonicalizations related to tiling to make promotion/vectorization
+/// easier.
+void applyCanonicalizationPatternsForTiling(MLIRContext *context,
+                                            Operation *op) {
+  OwningRewritePatternList canonicalizationPatterns;
+  canonicalizationPatterns.insert<AffineMinCanonicalizationPattern>(context);
+  AffineApplyOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
+  AffineMinOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
+  SubViewOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
+  applyPatternsAndFoldGreedily(op, std::move(canonicalizationPatterns));
+}
+
+//===----------------------------------------------------------------------===//
+// Helper functions for tile and fuse.
+//===----------------------------------------------------------------------===//
+
+/// Promotes views used to chain Linalg ops in `fusedOps`
+/// into buffers using the allocation callback in `options`.
+///
+/// Once fused the fused views that are due to a RAW dependence can be promoted
+/// to workgroup memory. This will make the intermediate storage dead.
+static LogicalResult promoteFusedViews(OpBuilder &builder,
+                                       ArrayRef<linalg::LinalgOp> fusedOps,
+                                       const TileAndFuseOptions &options) {
+  linalg::Aliases aliases;
+  linalg::LinalgDependenceGraph dependenceGraph(aliases, fusedOps);
+  auto fusableDependences =
+      linalg::findAllFusableDependences(fusedOps, dependenceGraph);
+
+  DenseSet<Value> promotedViews;
+  OpBuilder::InsertionGuard guard(builder);
+  builder.setInsertionPoint(*fusedOps.begin());
+
+  // Scan the list of ops in reverse order. The fusion is performed by creating
+  // a tiled version of the ops within the tiled loops of the last operation in
+  // the sequence, and then proceeding up the sequence.
+  for (linalg::LinalgOp op : llvm::reverse(fusedOps)) {
+    auto dependences = fusableDependences.lookup(op);
+    if (dependences.empty()) continue;
+    if (!llvm::hasSingleElement(dependences)) {
+      return op.emitError(
+          "unable to promote ops with multiple fusable dependences");
+    }
+    auto dependence = dependences.front();
+    unsigned producerIdx = dependence.dependentOpView.operandIndex;
+    linalg::LinalgOp consumer =
+        cast<linalg::LinalgOp>(dependence.indexingOpView.op);
+    unsigned consumerIdx = dependence.indexingOpView.operandIndex;
+    Value consumerView = consumer.getShapedOperand(consumerIdx);
+    Value promotedView = nullptr;
+
+    // If the view is already promoted, reuse that. The assumption is that the
+    // view matches already.
+    if (promotedViews.count(consumerView)) {
+      promotedView = consumerView;
+    } else if (dependence.dependenceType ==
+               linalg::LinalgDependenceGraph::RAW) {
+      SubViewOp promotedViewProducer =
+          op.getShapedOperand(producerIdx).getDefiningOp<SubViewOp>();
+      assert(promotedViewProducer &&
+             "expected producer to be a subview op as well");
+      Optional<linalg::PromotionInfo> promotionInfo =
+          linalg::promoteSubviewAsNewBuffer(
+              builder, op.getLoc(), promotedViewProducer, options.allocationFn);
+      if (!promotionInfo) {
+        return op.emitError("unable to promote RAW dependence");
+      }
+      promotedView = promotionInfo->partialLocalView;
+      consumer.getOperation()->setOperand(consumerIdx, promotedView);
+      promotedViews.insert(promotedView);
+    }
+    if (!promotedView) continue;
+    op.getOperation()->setOperand(producerIdx, promotedView);
+  }
+  return success();
+}
+
+/// Tile+Fuse only tiles the loops that can be fused. Tile any of the unfused
+/// loops in the operation based on the configuration.
+static linalg::LinalgOp tileUnfusedLoops(
+    OpBuilder &builder, linalg::LinalgOp linalgOp,
+    const std::set<unsigned> &fusedLoopDims, ArrayRef<int64_t> tileSizesRef) {
+  SmallVector<int64_t, 4> tileSizes = llvm::to_vector<4>(tileSizesRef);
+  tileSizes.resize(linalgOp.getNumLoops(), 0);
+  // Linalg uses tile size = 0 for a loop to indicate not tiling that loop. Set
+  // the fused loops to be untiled (since they are already tiled during fusion).
+  for (unsigned loopNum : fusedLoopDims) {
+    tileSizes[loopNum] = 0;
+  }
+  if (llvm::all_of(tileSizes, [](int64_t v) { return !v; })) return linalgOp;
+  Optional<linalg::TiledLinalgOp> tiledOp = tileLinalgOp(
+      builder, linalgOp,
+      linalg::LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(
+          linalg::LinalgTilingLoopType::ParallelLoops));
+  if (!tiledOp) return nullptr;
+  linalgOp.erase();
+  return tiledOp->op;
+}
+
+/// Tiles the last operation in `fusableOps` and fuses all other operations with
+/// it by creating tiled versions of each within the generated inter-tile loops.
+static Optional<linalg::TiledAndFusedLinalgOps> tileAndFuseLinalgOps(
+    OpBuilder &builder, FuncOp funcOp, ArrayRef<linalg::LinalgOp> fusableOps,
+    const linalg::LinalgDependenceGraph &dependenceGraph,
+    ArrayRef<int64_t> tileSizes, const TileAndFuseOptions &options) {
+  // Get the tile sizes to use from the last fusable op and the tile+fuse all
+  // ops.
+  linalg::LinalgTilingOptions tilingOptions;
+  tilingOptions.setDistributionOptions(options.distributionOptions)
+      .setTileSizes(tileSizes)
+      .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops);
+
+  Optional<linalg::TiledAndFusedLinalgOps> tiledAndFusedOps = llvm::None;
+  if (fusableOps.size() == 1) {
+    linalg::LinalgOp linalgOp = fusableOps.front();
+    Optional<linalg::TiledLinalgOp> tiledOp =
+        tileLinalgOp(builder, linalgOp, tilingOptions);
+    if (!tiledOp) {
+      linalgOp.emitError("unable to tile operation");
+      return llvm::None;
+    }
+    tiledAndFusedOps = linalg::TiledAndFusedLinalgOps{tiledOp->op, {}, {}, {}};
+    auto seq = llvm::seq<unsigned>(0, tileSizes.size());
+    tiledAndFusedOps->fusedLoopDims.insert(seq.begin(), seq.end());
+    tiledAndFusedOps->fusedLoops.assign(tiledOp->loops.begin(),
+                                        tiledOp->loops.end());
+  } else {
+    tiledAndFusedOps = tileAndFuseLinalgOps(builder, fusableOps,
+                                            dependenceGraph, tilingOptions);
+  }
+  if (!tiledAndFusedOps) {
+    funcOp.emitError("tile and fuse of linalg operations failed");
+    return llvm::None;
+  }
+
+  // Update the launch configuration.
+  SmallVector<unsigned, 2> distributedLoops =
+      llvm::to_vector<2>(tiledAndFusedOps->fusedLoopDims);
+  if (funcOp.getAttr(getNumWorkgroupsFnAttrName()) &&
+      failed(createNumWorkgroupsFromResultShape(
+          builder, fusableOps.back(), funcOp, getNumWorkgroupsFnAttrName(),
+          tileSizes, distributedLoops))) {
+    funcOp.emitError("failed to update launch configuration");
+    return llvm::None;
+  }
+
+  // Delete all the original operations.
+  for (auto linalgOp : fusableOps) linalgOp.erase();
+
+  // Add workgroup markers to all the tiled and fused operations.
+  for (auto fusedProducer : tiledAndFusedOps->fusedProducers) {
+    setMarker(fusedProducer, getWorkgroupMarker());
+  }
+  setMarker(tiledAndFusedOps->op, getWorkgroupMarker());
+
+  return tiledAndFusedOps;
+}
+
+LogicalResult tileAndFuseLinalgBufferOps(
+    FuncOp funcOp, ArrayRef<linalg::LinalgOp> linalgOps,
+    const linalg::LinalgDependenceGraph &dependenceGraph,
+    const LaunchConfig &launchConfig, const TileAndFuseOptions &options) {
+  // Collect all operations that are to be tiled-and-fused.
+  MLIRContext *context = funcOp.getContext();
+  SmallVector<linalg::LinalgOp, 4> fusableOps;
+  for (Operation *operation : linalgOps) {
+    if (!launchConfig.hasTileSizes(operation)) continue;
+    fusableOps.push_back(cast<linalg::LinalgOp>(operation));
+  }
+  if (fusableOps.empty()) return success();
+
+  OpBuilder builder(context);
+  ArrayRef<int64_t> tileSizes = launchConfig.getTileSizes(fusableOps.back(), 0);
+  Optional<linalg::TiledAndFusedLinalgOps> tiledAndFusedOps =
+      tileAndFuseLinalgOps(builder, funcOp, fusableOps, dependenceGraph,
+                           tileSizes, options);
+  if (!tiledAndFusedOps) {
+    return funcOp.emitError("failed to tile and fuse operations");
+  }
+
+  LLVM_DEBUG({
+    llvm::dbgs() << "--- After Fusion on buffers ---\n";
+    funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+    llvm::dbgs() << "\n\n";
+  });
+
+  applyCanonicalizationPatternsForTiling(context, funcOp);
+
+  LLVM_DEBUG({
+    llvm::dbgs() << "--- After Canonicalization ---\n";
+    funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+    llvm::dbgs() << "\n\n";
+  });
+
+  if (options.allocationFn) {
+    SmallVector<linalg::LinalgOp, 4> promoteFusedViewOps =
+        llvm::to_vector<4>(tiledAndFusedOps->fusedProducers);
+    promoteFusedViewOps.push_back(tiledAndFusedOps->op);
+
+    if (failed(promoteFusedViews(builder, promoteFusedViewOps, options))) {
+      return failure();
+    }
+  }
+
+  LLVM_DEBUG({
+    llvm::dbgs() << "--- After Promotion ---\n";
+    funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+    llvm::dbgs() << "\n\n";
+  });
+
+  // Tile the unfused loops. Set the tile sizes for the fused loops to be zero
+  // to avoid tiling them again.
+  for (linalg::LinalgOp &fusedOp : tiledAndFusedOps->fusedProducers) {
+    ArrayRef<int64_t> fusedOpTileSizes = launchConfig.getTileSizes(fusedOp, 0);
+    linalg::LinalgOp tiledOp = tileUnfusedLoops(
+        builder, fusedOp, tiledAndFusedOps->fusedLoopDims, fusedOpTileSizes);
+    if (!tiledOp) {
+      return fusedOp.emitError("unable to tile unfused loops");
+    }
+  }
+  linalg::LinalgOp tiledOp =
+      tileUnfusedLoops(builder, tiledAndFusedOps->op,
+                       tiledAndFusedOps->fusedLoopDims, tileSizes);
+  if (!tiledOp) {
+    return tiledAndFusedOps->op.emitError("unable to tile unfused loops");
+  }
+
+  applyCanonicalizationPatternsForTiling(context, funcOp);
+  return success();
+}
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Conversion/Common/Transforms.h b/iree/compiler/Conversion/Common/Transforms.h
new file mode 100644
index 0000000..d10eec7
--- /dev/null
+++ b/iree/compiler/Conversion/Common/Transforms.h
@@ -0,0 +1,58 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- Transforms.h - Transformations common to all backends --------------===//
+//
+// Defines transformations that are common to backends
+//
+//===----------------------------------------------------------------------===//
+#ifndef IREE_COMPILER_CONVERSION_COMMON_TRANSFORMS_H_
+#define IREE_COMPILER_CONVERSION_COMMON_TRANSFORMS_H_
+
+#include "iree/compiler/Conversion/Common/LaunchConfig.h"
+#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/Operation.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Apply canonicalizations related to tiling to make promotion/vectorization
+/// easier.
+void applyCanonicalizationPatternsForTiling(MLIRContext *context,
+                                            Operation *op);
+
+struct TileAndFuseOptions {
+  linalg::LinalgLoopDistributionOptions distributionOptions;
+  linalg::AllocBufferCallbackFn allocationFn = nullptr;
+};
+/// Method to tile and fuse sequence of Linalg operations in `linalgOps`. Uses
+/// the tile sizes for the first level of tiling specified in
+/// `launchConfig`. Proceeds by
+/// 1) Find the common loops around `linalgOps` that can be fused.
+/// 2) Tile the fusable loops in the last operation in the sequence.
+/// 3) Creates tiled version of the other ops within the inter-tile loops
+///    generated in step 2.
+/// 4) For all the tiled+fused ops, tile the unfused loops as specified by
+///    launchconfig.
+LogicalResult tileAndFuseLinalgBufferOps(
+    FuncOp funcOp, ArrayRef<linalg::LinalgOp> linalgOps,
+    const linalg::LinalgDependenceGraph &dependenceGraph,
+    const LaunchConfig &launchConfig, const TileAndFuseOptions &options);
+
+}  // namespace iree_compiler
+}  // namespace mlir
+
+#endif  // IREE_COMPILER_CONVERSION_COMMON_TRANSFORMS_H_
diff --git a/iree/compiler/Conversion/Common/VectorTransferOptimization.cpp b/iree/compiler/Conversion/Common/VectorTransferOptimization.cpp
new file mode 100644
index 0000000..91ef5a0
--- /dev/null
+++ b/iree/compiler/Conversion/Common/VectorTransferOptimization.cpp
@@ -0,0 +1,40 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+struct VectorTransferOptimizationPass
+    : public PassWrapper<VectorTransferOptimizationPass, FunctionPass> {
+  void runOnFunction() override { vector::transferOpflowOpt(getFunction()); }
+};
+
+}  // namespace
+
+std::unique_ptr<FunctionPass> createVectorTransferOptimizationPass() {
+  return std::make_unique<VectorTransferOptimizationPass>();
+}
+
+static PassRegistration<VectorTransferOptimizationPass> pass(
+    "iree-codegen-optimize-vector-transfer",
+    "Run optimization transformations on vector transfer operations",
+    [] { return std::make_unique<VectorTransferOptimizationPass>(); });
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Conversion/HLOToHLO/BUILD b/iree/compiler/Conversion/HLOToHLO/BUILD
index fccd391..2dbe2b7 100644
--- a/iree/compiler/Conversion/HLOToHLO/BUILD
+++ b/iree/compiler/Conversion/HLOToHLO/BUILD
@@ -22,13 +22,18 @@
     name = "HLOToHLO",
     srcs = [
         "DecomposeHLOClamp.cpp",
+        "DemoteF32ToF16.cpp",
     ],
     hdrs = [
         "Passes.h",
     ],
     deps = [
+        "//iree/compiler/Dialect/Flow/IR",
+        "//iree/compiler/Dialect/IREE/IR",
+        "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:Support",
         "@llvm-project//mlir:TransformUtils",
         "@org_tensorflow//tensorflow/compiler/mlir/hlo",
     ],
diff --git a/iree/compiler/Conversion/HLOToHLO/CMakeLists.txt b/iree/compiler/Conversion/HLOToHLO/CMakeLists.txt
index c067d1b..2b6f0ae 100644
--- a/iree/compiler/Conversion/HLOToHLO/CMakeLists.txt
+++ b/iree/compiler/Conversion/HLOToHLO/CMakeLists.txt
@@ -21,10 +21,15 @@
     "Passes.h"
   SRCS
     "DecomposeHLOClamp.cpp"
+    "DemoteF32ToF16.cpp"
   DEPS
+    LLVMSupport
     MLIRIR
     MLIRPass
+    MLIRSupport
     MLIRTransformUtils
+    iree::compiler::Dialect::Flow::IR
+    iree::compiler::Dialect::IREE::IR
     tensorflow::mlir_hlo
   PUBLIC
 )
diff --git a/iree/compiler/Conversion/HLOToHLO/DemoteF32ToF16.cpp b/iree/compiler/Conversion/HLOToHLO/DemoteF32ToF16.cpp
new file mode 100644
index 0000000..7e0f286
--- /dev/null
+++ b/iree/compiler/Conversion/HLOToHLO/DemoteF32ToF16.cpp
@@ -0,0 +1,196 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <memory>
+#include <utility>
+
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace {
+struct ConvertF32ToF16Pass
+    : public PassWrapper<ConvertF32ToF16Pass, OperationPass<ModuleOp>> {
+  void runOnOperation() override;
+
+ private:
+};
+
+/// Any fp32 derived type is illegal.
+static bool isIllegalType(Type type) {
+  if (type.isF32()) return true;
+  if (auto ptrType = type.dyn_cast<IREE::PtrType>()) {
+    return isIllegalType(ptrType.getTargetType());
+  }
+  if (auto shapedType = type.dyn_cast<ShapedType>()) {
+    return isIllegalType(shapedType.getElementType());
+  }
+  return false;
+}
+
+class F32ToF16ConversionTarget : public ConversionTarget {
+ public:
+  using ConversionTarget::ConversionTarget;
+
+ protected:
+  // Operations are legal if they don't contain any illegal type.
+  bool isDynamicallyLegal(Operation *op) const override {
+    if (auto varOp = dyn_cast<IREE::Flow::VariableOp>(op)) {
+      return !isIllegalType(varOp.type());
+    }
+    if (auto funcOp = dyn_cast<FuncOp>(op)) {
+      for (Type type : funcOp.getType().getInputs()) {
+        if (isIllegalType(type)) return false;
+      }
+      for (Type type : funcOp.getType().getResults()) {
+        if (isIllegalType(type)) return false;
+      }
+    }
+    for (Type type : op->getResultTypes()) {
+      if (isIllegalType(type)) return false;
+    }
+    for (Type type : op->getOperandTypes()) {
+      if (isIllegalType(type)) return false;
+    }
+    return true;
+  }
+};
+
+class FloatTypeConverter : public TypeConverter {
+ public:
+  static Type convertTensor(RankedTensorType type) {
+    if (!type.getElementType().isF32()) return type;
+    auto newType = RankedTensorType::get(type.getShape(),
+                                         Float16Type::get(type.getContext()));
+    return newType;
+  }
+  explicit FloatTypeConverter() {
+    addConversion([](Type type) { return type; });
+    addConversion([&](FloatType type) {
+      if (type.isF32()) return FloatType::getF16(type.getContext());
+      return type;
+    });
+    addConversion(convertTensor);
+    addConversion([&](IREE::PtrType ptrType) {
+      if (auto tensorType =
+              ptrType.getTargetType().dyn_cast<RankedTensorType>()) {
+        return IREE::PtrType::get(convertTensor(tensorType));
+      }
+      return ptrType;
+    });
+  }
+};
+
+// Generic pattern to convert FP32 values and attributes to FP16.
+class GenericTypeConvert : public ConversionPattern {
+ public:
+  GenericTypeConvert(MLIRContext *context, TypeConverter &converter)
+      : ConversionPattern(0, converter, MatchAnyOpTypeTag()) {}
+  LogicalResult matchAndRewrite(
+      Operation *op, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    llvm::SmallVector<NamedAttribute, 4> newAttr;
+    convertAttributes(op->getAttrs(), rewriter, newAttr);
+    llvm::SmallVector<Type, 4> newResults;
+    getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
+    OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
+                         newResults, newAttr, op->getSuccessors());
+    for (Region &r : op->getRegions()) {
+      Region *newRegion = state.addRegion();
+      rewriter.inlineRegionBefore(r, *newRegion, newRegion->begin());
+      TypeConverter::SignatureConversion result(newRegion->getNumArguments());
+      getTypeConverter()->convertSignatureArgs(newRegion->getArgumentTypes(),
+                                               result);
+      rewriter.applySignatureConversion(newRegion, result);
+    }
+    Operation *newOp = rewriter.createOperation(state);
+    rewriter.replaceOp(op, newOp->getResults());
+    return success();
+  }
+
+ protected:
+  static void convertAttributes(ArrayRef<NamedAttribute> attrs,
+                                ConversionPatternRewriter &rewriter,
+                                SmallVectorImpl<NamedAttribute> &newAttrs) {
+    for (auto attr : attrs) {
+      if (auto fpAttr = attr.second.dyn_cast<DenseFPElementsAttr>()) {
+        std::vector<llvm::APFloat> args;
+        if (!fpAttr.getType().getElementType().isF32()) continue;
+        for (llvm::APFloat f : fpAttr.getFloatValues()) {
+          bool losesInfo;
+          f.convert(APFloat::IEEEhalf(), APFloat::rmTowardZero, &losesInfo);
+          args.push_back(f);
+        }
+        auto tensorType = RankedTensorType::get(fpAttr.getType().getShape(),
+                                                rewriter.getF16Type());
+        newAttrs.push_back(std::make_pair(
+            attr.first, DenseElementsAttr::get(tensorType, args)));
+      } else if (auto typeAttr = attr.second.dyn_cast<TypeAttr>()) {
+        if (isIllegalType(typeAttr.getValue())) {
+          if (auto tensorType =
+                  typeAttr.getValue().dyn_cast<RankedTensorType>()) {
+            Type newType = RankedTensorType::get(tensorType.getShape(),
+                                                 rewriter.getF16Type());
+            newAttrs.push_back(
+                std::make_pair(attr.first, TypeAttr::get(newType)));
+          }
+        }
+      } else {
+        newAttrs.push_back(attr);
+      }
+    }
+  }
+};
+
+void ConvertF32ToF16Pass::runOnOperation() {
+  MLIRContext *context = &getContext();
+  ModuleOp moduleOp = getOperation();
+
+  FloatTypeConverter converter;
+  OwningRewritePatternList patterns;
+  patterns.insert<GenericTypeConvert>(context, converter);
+  populateFuncOpTypeConversionPattern(patterns, context, converter);
+  F32ToF16ConversionTarget target(*context);
+  target.markUnknownOpDynamicallyLegal();
+  if (failed(applyFullConversion(moduleOp, target, std::move(patterns))))
+    return signalPassFailure();
+}
+}  // namespace
+
+//===----------------------------------------------------------------------===//
+// Pass entry point and registration
+//===----------------------------------------------------------------------===//
+std::unique_ptr<OperationPass<ModuleOp>> createDemoteF32ToF16Pass() {
+  return std::make_unique<ConvertF32ToF16Pass>();
+}
+
+static PassRegistration<ConvertF32ToF16Pass> pass(
+    "iree-convert-f32-to-f16",
+    "Convert f32 operations and values into equivalent f16 ones");
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Conversion/HLOToHLO/Passes.h b/iree/compiler/Conversion/HLOToHLO/Passes.h
index cf50702..df93c30 100644
--- a/iree/compiler/Conversion/HLOToHLO/Passes.h
+++ b/iree/compiler/Conversion/HLOToHLO/Passes.h
@@ -33,6 +33,10 @@
 /// Creates a pass to decompose XLA-HLO clamp ops into primitive ops.
 std::unique_ptr<OperationPass<FuncOp>> createDecomposeHLOClampPass();
 
+/// Create a pass to convert a model using f32 type to the equivalent one
+/// using 16.
+std::unique_ptr<OperationPass<ModuleOp>> createDemoteF32ToF16Pass();
+
 }  // namespace iree_compiler
 }  // namespace mlir
 
diff --git a/iree/compiler/Conversion/HLOToHLO/test/BUILD b/iree/compiler/Conversion/HLOToHLO/test/BUILD
new file mode 100644
index 0000000..1e3b7bb
--- /dev/null
+++ b/iree/compiler/Conversion/HLOToHLO/test/BUILD
@@ -0,0 +1,32 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Tests for common transforms.
+
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+
+package(
+    default_visibility = ["//visibility:public"],
+    features = ["layering_check"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+iree_lit_test_suite(
+    name = "lit",
+    srcs = glob(["*.mlir"]),
+    data = [
+        "//iree/tools:IreeFileCheck",
+        "//iree/tools:iree-opt",
+    ],
+)
diff --git a/iree/compiler/Conversion/HLOToHLO/test/CMakeLists.txt b/iree/compiler/Conversion/HLOToHLO/test/CMakeLists.txt
new file mode 100644
index 0000000..fcc538b
--- /dev/null
+++ b/iree/compiler/Conversion/HLOToHLO/test/CMakeLists.txt
@@ -0,0 +1,26 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir)
+iree_lit_test_suite(
+  NAME
+    lit
+  SRCS
+    "${_GLOB_X_MLIR}"
+  DATA
+    iree::tools::IreeFileCheck
+    iree::tools::iree-opt
+)
diff --git a/iree/compiler/Conversion/HLOToHLO/test/f32Tof16.mlir b/iree/compiler/Conversion/HLOToHLO/test/f32Tof16.mlir
new file mode 100644
index 0000000..6cd8325
--- /dev/null
+++ b/iree/compiler/Conversion/HLOToHLO/test/f32Tof16.mlir
@@ -0,0 +1,38 @@
+// RUN: iree-opt -split-input-file -iree-convert-f32-to-f16 %s | IreeFileCheck %s
+
+//       CHECK: flow.variable {{.*}} : tensor<4xf16>
+// CHECK-LABEL: func @simple_f32() -> tensor<4xf16>
+//  CHECK-NEXT: %{{.*}} = flow.variable.address @__global : !iree.ptr<tensor<4xf16>>
+//  CHECK-NEXT: %{{.*}} = flow.variable.load.indirect %{{.*}} : !iree.ptr<tensor<4xf16>> -> tensor<4xf16>
+//  CHECK-NEXT: return %{{.*}} : tensor<4xf16>
+module {
+  flow.variable @"__global" dense<"0x000020410000A040000020410000A040"> : tensor<4xf32> attributes {sym_visibility = "private"}
+  func @simple_f32() -> (tensor<4xf32>) {
+    %0 = flow.variable.address @"__global" : !iree.ptr<tensor<4xf32>>
+    %1 = flow.variable.load.indirect %0 : !iree.ptr<tensor<4xf32>> -> tensor<4xf32>
+    return %1 : tensor<4xf32>
+  }
+}
+
+// -----
+
+// CHECK: flow.variable
+// CHECK-NOT: f32
+// CHECK-LABEL: func @nested_region_f32()
+// CHECK-NOT: f32
+// CHECK: return %{{.*}} : tensor<4xf16>
+module {
+  flow.variable @"__global" dense<"0x000020410000A040000020410000A040"> : tensor<4xf32> attributes {sym_visibility = "private"}
+  func @nested_region_f32() -> (tensor<4xf32>) {
+    %0 = flow.variable.address @"__iree_flow_bert/embeddings/FakeLayerNorm/beta" : !iree.ptr<tensor<4xf32>>
+    %1 = flow.variable.load.indirect %0 : !iree.ptr<tensor<4xf32>> -> tensor<4xf32>
+    %2 = "mhlo.broadcast_in_dim"(%1) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4x4xf32>
+    %4 = mhlo.constant dense<0xFF800000> : tensor<f32>
+    %3 = "mhlo.reduce"(%2, %4) ( {
+    ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):  // no predecessors
+      %5393 = mhlo.maximum %arg3, %arg4 : tensor<f32>
+      "mhlo.return"(%5393) : (tensor<f32>) -> ()
+    }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<f32>) -> tensor<4xf32>
+    return %3 : tensor<4xf32>
+  }
+}
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
index d178038..90fe4dc 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
@@ -116,7 +116,7 @@
 
 /// Returns the MemRefType to use for a given `tensorType`.
 static MemRefType getMemrefTypeForTensor(
-    RankedTensorType tensorType, ArrayRef<AffineMap> affineMapComposition = {},
+    ShapedType tensorType, ArrayRef<AffineMap> affineMapComposition = {},
     unsigned memorySpace = 0) {
   return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
                          affineMapComposition, memorySpace);
@@ -197,6 +197,14 @@
     for (auto result : llvm::enumerate(op->getResults())) {
       Value resultBuffer = resultTensorToBufferMap.lookup(result.value());
       if (!resultBuffer) {
+        if (auto shapedType = result.value().getType().dyn_cast<ShapedType>()) {
+          if (shapedType.hasStaticShape()) {
+            resultBuffer = rewriter.create<AllocOp>(
+                op->getLoc(), getMemrefTypeForTensor(shapedType));
+          }
+        }
+      }
+      if (!resultBuffer) {
         return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
           diag << "failed to create buffer for result #" << result.index();
         });
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir b/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir
index 057e909..1abc9c2 100644
--- a/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir
+++ b/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir
@@ -472,3 +472,54 @@
 //       CHECK:   %[[RET0_RESHAPE2:.+]] = linalg.reshape %[[RET0]]
 //  CHECK-SAME:     memref<1x1x1x1000xf32> into memref<1x1000xf32>
 //       CHECK:   linalg.copy(%[[RET0_RESHAPE2]], %[[RET1]])
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+module {
+  func @matmul_add() {
+    %c0 = constant 0 : index
+    %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0
+      : tensor<32x48xf32>
+    %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0
+      : tensor<48x64xf32>
+    %2 = hal.interface.load.tensor @legacy_io::@arg2, offset = %c0
+      : tensor<32x64xf32>
+    %3 = "mhlo.dot"(%0, %1)
+      : (tensor<32x48xf32>, tensor<48x64xf32>) -> tensor<32x64xf32>
+    %4 = linalg.generic {
+      indexing_maps = [#map0, #map0, #map0],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%2, %3 : tensor<32x64xf32>, tensor<32x64xf32>) {
+        ^bb0(%arg0: f32, %arg1: f32):
+          %5 = addf %arg0, %arg1 : f32
+          linalg.yield %5 : f32
+      } -> tensor<32x64xf32>
+    hal.interface.store.tensor %4, @legacy_io::@ret0, offset = %c0
+      : tensor<32x64xf32>
+    return
+  }
+  hal.interface @legacy_io attributes {sym_visiblity = "private"} {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg2, set=0, binding=2, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=3, type="StorageBuffer", access="Write|Discard"
+  }
+}
+
+// CHECK-LABEL: func @matmul_add
+//   CHECK-DAG:   %[[RET0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0}
+//   CHECK-DAG:   %[[ARG0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0}
+//   CHECK-DAG:   %[[ARG1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1}
+//   CHECK-DAG:   %[[ARG2:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg2}
+//   CHECK-DAG:   %[[TEMP:.+]] = alloc()
+//       CHECK:   linalg.fill(%[[TEMP]], %{{.+}})
+//       CHECK:   linalg.matmul ins(%[[ARG0]], %[[ARG1]]
+//  CHECK-SAME:     ) outs(%[[TEMP]]
+//  CHECK-SAME:     )
+//       CHECK:   linalg.generic
+//  CHECK-SAME:     ins(%[[ARG2]], %[[TEMP]]
+//  CHECK-SAME:     ) outs(%[[RET0]]
+//  CHECK-SAME:     )
+//       CHECK:   return
+
diff --git a/iree/compiler/Conversion/LinalgToLLVM/BUILD b/iree/compiler/Conversion/LinalgToLLVM/BUILD
index 5dbbadd..bb93b48 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/BUILD
+++ b/iree/compiler/Conversion/LinalgToLLVM/BUILD
@@ -27,6 +27,7 @@
         "LinalgTileAndDistributePass.cpp",
         "LinalgTileAndVectorizePass.cpp",
         "Passes.cpp",
+        "PlanConvLoopOrder.cpp",
     ],
     hdrs = [
         "KernelDispatch.h",
diff --git a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
index 4018a29..df9c39a 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
@@ -27,6 +27,7 @@
     "LinalgTileAndDistributePass.cpp"
     "LinalgTileAndVectorizePass.cpp"
     "Passes.cpp"
+    "PlanConvLoopOrder.cpp"
   DEPS
     LLVMSupport
     MLIRAffineToStandard
diff --git a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp
index d37efed..0d81a0d 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp
@@ -23,6 +23,14 @@
 namespace mlir {
 namespace iree_compiler {
 
+// TODO(ravishankarm): This needs to be put in a common place for the CPU and
+// GPU backends to use.
+static llvm::cl::list<unsigned> clLLVMTileSizes(
+    "iree-llvm-tile-size",
+    llvm::cl::desc("Set tile sizes to use for tiling Linalg operations in "
+                   "LLVM code generation"),
+    llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
+
 static llvm::cl::opt<int> matmulWorkgroupTileSize(
     "iree-codegen-linalg-to-llvm-kernel-dispatch-matmul-workgroup-tile-size",
     llvm::cl::desc(
@@ -151,5 +159,48 @@
 
 #undef DEFINE_TILE_SIZE_FN
 
+Optional<LaunchConfig> initCPULaunchConfig(
+    MLIRContext *context, const linalg::LinalgDependenceGraph &dependenceGraph,
+    ArrayRef<linalg::LinalgOp> linalgOps) {
+  LaunchConfig config;
+  if (!clLLVMTileSizes.empty()) {
+    SmallVector<int64_t, 3> tileSizes(clLLVMTileSizes.begin(),
+                                      clLLVMTileSizes.end());
+    for (linalg::LinalgOp linalgOp : linalgOps) {
+      config.setTileSizes(linalgOp.getOperation(), tileSizes, 0);
+    }
+    return config;
+  }
+
+  Optional<linalg::LinalgOp> rootOperation = llvm::None;
+  for (auto linalgOp : linalgOps) {
+#define DISPATCH(opType)                                                     \
+  if (opType op = dyn_cast<opType>(linalgOp.getOperation())) {               \
+    if (rootOperation) {                                                     \
+      op.emitError("unhandled multiple root operations in dispatch region"); \
+      return llvm::None;                                                     \
+    }                                                                        \
+    rootOperation = linalgOp;                                                \
+    config.setTileSizes(                                                     \
+        op,                                                                  \
+        TileOpParameters::getSizes<opType, TilingLevel::WorkGroupTiles>(op), \
+        0);                                                                  \
+    continue;                                                                \
+  }
+
+    DISPATCH(linalg::MatmulOp)
+    DISPATCH(linalg::BatchMatmulOp)
+
+#undef DISPATCH
+  }
+  if (!rootOperation) {
+    return config;
+  }
+  if (failed(propogateRootOperationLaunchConfig(config, *rootOperation,
+                                                dependenceGraph)))
+    return llvm::None;
+  return config;
+}
+
 }  // namespace iree_compiler
 }  // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h
index f2d7711..20972af 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h
+++ b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h
@@ -14,6 +14,7 @@
 
 #include <cstdint>
 
+#include "iree/compiler/Conversion/Common/LaunchConfig.h"
 #include "llvm/ADT/SmallVector.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Operation.h"
@@ -28,7 +29,8 @@
   // Tile linalg operation on workgroup thread into L1 block tiles.
   Level1Tiles = 1,
   // Tile linalg operations on L1 block tiles into vector tiles.
-  Level2Tiles = 2
+  Level2Tiles = 2,
+  NumTileLevels = 3
 };
 
 class CPUKernelDispatch {
@@ -44,5 +46,9 @@
                                          Operation *operation);
 };
 
+Optional<LaunchConfig> initCPULaunchConfig(
+    MLIRContext *context, const linalg::LinalgDependenceGraph &dependenceGraph,
+    ArrayRef<linalg::LinalgOp> linalgOps);
+
 }  // namespace iree_compiler
 }  // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributePass.cpp b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributePass.cpp
index 6e8929c..cd0cdd2 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributePass.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/LinalgTileAndDistributePass.cpp
@@ -17,11 +17,13 @@
 #include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
 #include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
 #include "iree/compiler/Conversion/Common/Attributes.h"
+#include "iree/compiler/Conversion/Common/Transforms.h"
 #include "iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h"
 #include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
 #include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -41,11 +43,6 @@
   LinalgTileAndDistributePass() = default;
   LinalgTileAndDistributePass(const LinalgTileAndDistributePass &pass) {}
   void runOnOperation() override;
-
- private:
-  ListOption<int64_t> tileSizes{
-      *this, "tile-sizes", llvm::cl::desc("Set tile sizes to use"),
-      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
 };
 }  // namespace
 
@@ -124,6 +121,34 @@
 
 }  // namespace
 
+Optional<Value> allocateThreadLocalMemory(OpBuilder &b, SubViewOp subview,
+                                          ArrayRef<Value> boundingSubViewSize,
+                                          OperationFolder *folder) {
+  // Allocate the memory into the entry block of the parent FuncOp. This better
+  // aligns with the semantics of this memory which is available at the entry of
+  // the function.
+  OpBuilder::InsertionGuard guard(b);
+  FuncOp funcOp = subview.getParentOfType<FuncOp>();
+  if (!funcOp) {
+    subview.emitError("expected op to be within std.func");
+    return llvm::None;
+  }
+  b.setInsertionPointToStart(&(*funcOp.getBody().begin()));
+  // The bounding subview size is expected to be constant. This specified the
+  // shape of the allocation.
+  SmallVector<int64_t, 2> shape = llvm::to_vector<2>(
+      llvm::map_range(boundingSubViewSize, [](Value v) -> int64_t {
+        APInt value;
+        if (matchPattern(v, m_ConstantInt(&value))) return value.getSExtValue();
+        return -1;
+      }));
+  if (llvm::any_of(shape, [](int64_t v) { return v == -1; })) return {};
+  MemRefType allocType =
+      MemRefType::get(shape, subview.getType().getElementType());
+  Value buffer = b.create<AllocaOp>(subview.getLoc(), allocType);
+  return buffer;
+}
+
 void LinalgTileAndDistributePass::runOnOperation() {
   MLIRContext *context = &getContext();
   ModuleOp module = getOperation();
@@ -147,57 +172,57 @@
        linalg::DistributionMethod::CyclicNumProcsEqNumIters,
        linalg::DistributionMethod::CyclicNumProcsEqNumIters}};
 
-  CPUKernelDispatch cpuKernelDispatch;
-
   for (FuncOp funcOp : module.getOps<FuncOp>()) {
     if (!isEntryPoint(funcOp)) continue;
 
-    // Compute the Linalg Dependence Graph.
+    Region &body = funcOp.getBody();
+    if (!llvm::hasSingleElement(body.getBlocks())) {
+      funcOp.emitError("unhandled dispatch function with multiple blocks");
+      return signalPassFailure();
+    }
+    Block &block = body.front();
+    auto linalgOps = block.getOps<linalg::LinalgOp>();
+    if (linalgOps.empty()) continue;
+
+    SmallVector<linalg::LinalgOp, 4> linalgOpsVec =
+        llvm::to_vector<4>(llvm::map_range(linalgOps, [](Operation *op) {
+          return cast<linalg::LinalgOp>(op);
+        }));
     linalg::Aliases aliases;
-    linalg::LinalgDependenceGraph dependenceGraph =
-        linalg::LinalgDependenceGraph::buildDependenceGraph(aliases, funcOp);
+    linalg::LinalgDependenceGraph dependenceGraph(aliases, linalgOpsVec);
+    Optional<LaunchConfig> launchConfigOpt =
+        initCPULaunchConfig(context, dependenceGraph, linalgOpsVec);
+    if (!launchConfigOpt) {
+      funcOp.emitError("unable to find launch configuration");
+      return signalPassFailure();
+    }
+    LaunchConfig &launchConfig = *launchConfigOpt;
 
-    OwningRewritePatternList patterns;
-
-    auto linalgTilingOptions =
-        linalg::LinalgTilingOptions()
-            .setDistributionOptions(workgroupDistributionOptions)
-            .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops);
-    tileSizes.empty()
-        ? linalgTilingOptions.setTileSizeComputationFunction(
-              [&cpuKernelDispatch](
-                  OpBuilder &builder,
-                  Operation *operation) -> SmallVector<Value, 4> {
-                return TileSizeFn::get<TilingLevel::WorkGroupTiles>(
-                    cpuKernelDispatch, builder, operation);
-              })
-        : linalgTilingOptions.setTileSizes(ArrayRef<int64_t>(tileSizes));
-    patterns.insert<TileAndFuseToCPUThreads<linalg::MatmulOp>,
-                    TileAndFuseToCPUThreads<linalg::BatchMatmulOp>,
-                    TileToCPUThreads<linalg::MatmulOp>,
-                    TileToCPUThreads<linalg::BatchMatmulOp>>(
-        context, dependenceGraph, cpuKernelDispatch, linalgTilingOptions,
-        linalg::LinalgMarker(ArrayRef<Identifier>(),
-                             Identifier::get(getWorkgroupMarker(), context)));
-
-    // Tile and distribute to CPU threads.
-    applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
-
-    // Apply canonicalization patterns.
-    OwningRewritePatternList canonicalizationPatterns;
-    canonicalizationPatterns.insert<AffineMinCanonicalizationPattern>(context);
-    AffineApplyOp::getCanonicalizationPatterns(canonicalizationPatterns,
-                                               context);
-    AffineMinOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
-    SubViewOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
-
-    applyPatternsAndFoldGreedily(funcOp, std::move(canonicalizationPatterns));
-
-    // Delete the ops that are marked for deletion.
-    funcOp.walk([](linalg::LinalgOp linalgOp) {
-      if (hasMarker(linalgOp.getOperation(), getDeleteMarker()))
-        linalgOp.getOperation()->erase();
+    LLVM_DEBUG({
+      llvm::dbgs() << "@func " << funcOp.getName() << "\n";
+      for (auto op : linalgOps) {
+        llvm::dbgs() << "\t" << op.getOperation()->getName() << " : ";
+        TileSizesListTypeRef configTileSizes = launchConfig.getTileSizes(op);
+        llvm::dbgs() << "{";
+        std::string sep = "";
+        for (auto &level : enumerate(configTileSizes)) {
+          llvm::dbgs() << sep << level.index() << " : [";
+          sep = ", ";
+          interleaveComma(level.value(), llvm::dbgs());
+          llvm::dbgs() << "]";
+        }
+        llvm::dbgs() << "}\n";
+      }
     });
+
+    TileAndFuseOptions tileAndFuseOptions = {workgroupDistributionOptions,
+                                             allocateThreadLocalMemory};
+    if (failed(tileAndFuseLinalgBufferOps(funcOp, linalgOpsVec, dependenceGraph,
+                                          launchConfig, tileAndFuseOptions))) {
+      return signalPassFailure();
+    }
+
+    launchConfig.finalize(funcOp);
   }
 }
 
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
index 348b425..b3f2c9c 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
@@ -54,6 +54,7 @@
   if (convImg2ColConversion) {
     passManager.addNestedPass<FuncOp>(createConvImg2ColMatmulConversionPass());
   }
+  passManager.addNestedPass<FuncOp>(createPlanConvLoopOrderPass());
 
   passManager.addNestedPass<FuncOp>(
       createLinalgTileAndVectorizeWorkgroupsPass());
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.h b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
index 33c74d4..0a325dc 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.h
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
@@ -24,10 +24,14 @@
 /// linalg::MatmulOp.
 std::unique_ptr<FunctionPass> createConvImg2ColMatmulConversionPass();
 
-/// Distribute linalg ops among iree.workgroup logical threads.
+/// Converts linalg.conv into linalg.generic with a CPU-friendly iteration
+/// order.
+std::unique_ptr<FunctionPass> createPlanConvLoopOrderPass();
+
+/// Distributes linalg ops among iree.workgroup logical threads.
 std::unique_ptr<OperationPass<ModuleOp>> createLinalgTileAndDistributePass();
 
-/// Vectorize linalg ops executed in the same iree.workgroup.
+/// Vectorizes linalg ops executed in the same iree.workgroup.
 std::unique_ptr<FunctionPass> createLinalgTileAndVectorizeWorkgroupsPass();
 
 std::unique_ptr<OperationPass<ModuleOp>>
@@ -38,7 +42,7 @@
 void populateConvImg2ColMatmulConversionPatterns(
     MLIRContext *context, OwningRewritePatternList &patterns);
 
-/// Pass to perform final conversion to LLVM dialect.
+/// Performs the final conversion to LLVM dialect.
 std::unique_ptr<OperationPass<ModuleOp>> createConvertToLLVMPass();
 
 /// Populates passes needed to lower a XLA HLO op to LLVM dialect via the
diff --git a/iree/compiler/Conversion/LinalgToLLVM/PlanConvLoopOrder.cpp b/iree/compiler/Conversion/LinalgToLLVM/PlanConvLoopOrder.cpp
new file mode 100644
index 0000000..db898a4
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToLLVM/PlanConvLoopOrder.cpp
@@ -0,0 +1,77 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+struct PlanConvLoopOrderPass
+    : PassWrapper<PlanConvLoopOrderPass, FunctionPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<linalg::LinalgDialect>();
+  }
+  void runOnFunction() override;
+};
+
+}  // namespace
+
+void PlanConvLoopOrderPass::runOnFunction() {
+  auto funcOp = getOperation();
+  auto context = funcOp.getContext();
+
+  auto marker = Identifier::get("generalized_from_conv", context);
+  linalg::LinalgMarker firstStepMarker(
+      /*matchDisjunction=*/ArrayRef<Identifier>(),
+      /*replacement=*/marker);
+  linalg::LinalgMarker secondStepMarker(
+      /*matchDisjunction=*/marker,
+      /*replacement=*/llvm::None);
+
+  SmallVector<unsigned, 8> loopOrder = {
+      /*batch=*/0,
+      /*output_height=*/1,
+      /*output_width=*/2,
+      /*filter_height=*/5,
+      /*filter_width=*/6,
+      /*input_channel=*/4,
+      /*output_channel=*/3,
+  };
+
+  OwningRewritePatternList patterns;
+  linalg::populateLinalgConvGeneralizationPatterns(context, patterns,
+                                                   firstStepMarker);
+  patterns.insert<linalg::LinalgInterchangePattern<linalg::GenericOp>>(
+      context, loopOrder, secondStepMarker);
+
+  applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
+std::unique_ptr<FunctionPass> createPlanConvLoopOrderPass() {
+  return std::make_unique<PlanConvLoopOrderPass>();
+}
+
+static PassRegistration<PlanConvLoopOrderPass> pass(
+    "iree-codegen-linalg-to-llvm-plan-conv-loop-order",
+    "Convert linalg.conv to linalg.generic with a CPU-friendly iterator order",
+    [] { return std::make_unique<PlanConvLoopOrderPass>(); });
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/plan-conv-loop-order.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/plan-conv-loop-order.mlir
new file mode 100644
index 0000000..0cc8ace
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/plan-conv-loop-order.mlir
@@ -0,0 +1,15 @@
+// RUN: iree-opt -iree-codegen-linalg-to-llvm-plan-conv-loop-order %s | IreeFileCheck %s
+
+func @conv(%filter: memref<3x3x3x32xf32>, %input: memref<1x225x225x3xf32>, %output: memref<1x112x112x32xf32>) {
+  linalg.conv(%filter, %input, %output) {dilations = [1, 1], strides = [2, 2]} : memref<3x3x3x32xf32>, memref<1x225x225x3xf32>, memref<1x112x112x32xf32>
+  return
+}
+
+// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
+// CHECK:  #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 * 2 + d3, d2 * 2 + d4, d5)>
+// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d6)>
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[FILTER_MAP]], #[[INPUT_MAP]], #[[OUTPUT_MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "window", "window", "reduction", "parallel"]
+
+
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir
index baffba1..35f816b 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir
@@ -1,35 +1,35 @@
-// RUN: iree-opt --iree-codegen-llvm-linalg-tile-and-distribute=tile-sizes=2,4,1 -cse -split-input-file %s | IreeFileCheck %s
+// RUN: iree-opt --iree-codegen-llvm-linalg-tile-and-distribute -iree-llvm-tile-size=2,4,1 -cse -split-input-file %s | IreeFileCheck %s
 
 func @dynamic_matmul(%lhs: memref<?x?xf32>, %rhs: memref<?x?xf32>, %result: memref<?x?xf32>) {
   linalg.matmul ins(%lhs, %rhs : memref<?x?xf32>, memref<?x?xf32>) outs(%result : memref<?x?xf32>)
   return
 }
-// CHECK: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 2)>
-// CHECK: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (2, s1 - s0 * 2)>
-// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-// CHECK: #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 4)>
-// CHECK: #[[MAP4:.+]] = affine_map<()[s0, s1] -> (4, s1 - s0 * 4)>
-// CHECK: func @dynamic_matmul(%[[LHS:.+]]: memref<?x?xf32>, %[[RHS:.+]]: memref<?x?xf32>, %[[RESULT:.+]]: memref<?x?xf32>)
-// CHECK: %[[CONST_0:.+]] = constant 0 : index
-// CHECK: %[[CONST_1:.+]] = constant 1 : index
-// CHECK: %[[DIM_K:.+]] = dim %[[LHS]], %[[CONST_1]]
-// CHECK: %[[THREAD_X_ID:.+]] = iree.workgroup_id  {dimension = "x"} : index
-// CHECK: %[[THREAD_Y_ID:.+]] = iree.workgroup_id  {dimension = "y"} : index
-// CHECK:  scf.for %[[K:.+]] = %[[CONST_0]] to %[[DIM_K]]
-// CHECK:     %[[I:.+]] = affine.apply #[[MAP0]]()[%[[THREAD_Y_ID]]]
-// CHECK:     %[[DIM_I:.+]] = dim %[[LHS]], %[[CONST_0]]
-// CHECK:     %[[I_OFFSET:.+]] = affine.min #[[MAP1]]()[%[[THREAD_Y_ID]], %[[DIM_I]]]
-// CHECK:     %[[LHS_SUBVIEW:.+]] = subview %[[LHS]][%[[I]], %[[K]]] [%[[I_OFFSET]], 1] [1, 1] 
-// CHECK:     %[[J:.+]] = affine.apply #[[MAP3]]()[%[[THREAD_X_ID]]]
-// CHECK:     %[[DIM_J:.+]] = dim %[[RHS]], %[[CONST_1]] 
-// CHECK:     %[[J_OFFSET:.+]] = affine.min #[[MAP4]]()[%[[THREAD_X_ID]], %[[DIM_J]]]
-// CHECK:     %[[RHS_SUBVIEW:.+]] = subview %[[RHS]][%[[K]], %[[J]]] [1, %[[J_OFFSET]]] [1, 1]  
-// CHECK:     %[[DIM_I:.+]] = dim %[[RESULT]], %[[CONST_0]]
-// CHECK:     %[[DIM_I_OFFSET:.+]] = affine.min #[[MAP1]]()[%[[THREAD_Y_ID]], %[[DIM_I]]]
-// CHECK:     %[[DIM_J:.+]] = dim %[[RESULT]], %[[CONST_1]]
-// CHECK:     %[[DIM_J_OFFSET:.+]] = affine.min #[[MAP4]]()[%[[THREAD_X_ID]], %[[DIM_J]]]
-// CHECK:     %[[RESULT_SUBVIEW:.+]] = subview %[[RESULT]][%[[I]], %[[J]]] [%[[DIM_I_OFFSET]], %[[DIM_J_OFFSET]]] [1, 1]
-// CHECK:      linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%[[LHS_SUBVIEW]], %[[RHS_SUBVIEW]] : memref<?x1xf32, #[[MAP2]]>, memref<1x?xf32, #[[MAP2]]>) outs(%[[RESULT_SUBVIEW]] : memref<?x?xf32, #[[MAP2]]>)
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (2, s1 - s0 * 2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 4)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<()[s0, s1] -> (4, s1 - s0 * 4)>
+//     CHECK: func @dynamic_matmul(%[[LHS:.+]]: memref<?x?xf32>, %[[RHS:.+]]: memref<?x?xf32>, %[[RESULT:.+]]: memref<?x?xf32>)
+// CHECK-DAG: %[[CONST_0:.+]] = constant 0 : index
+// CHECK-DAG: %[[CONST_1:.+]] = constant 1 : index
+// CHECK-DAG: %[[DIM_K:.+]] = dim %[[LHS]], %[[CONST_1]]
+// CHECK-DAG: %[[THREAD_X_ID:.+]] = iree.workgroup_id  {dimension = "x"} : index
+// CHECK-DAG: %[[THREAD_Y_ID:.+]] = iree.workgroup_id  {dimension = "y"} : index
+//     CHECK:  scf.for %[[K:.+]] = %[[CONST_0]] to %[[DIM_K]]
+//     CHECK:     %[[I:.+]] = affine.apply #[[MAP0]]()[%[[THREAD_Y_ID]]]
+//     CHECK:     %[[DIM_I:.+]] = dim %[[LHS]], %[[CONST_0]]
+//     CHECK:     %[[I_OFFSET:.+]] = affine.min #[[MAP1]]()[%[[THREAD_Y_ID]], %[[DIM_I]]]
+//     CHECK:     %[[LHS_SUBVIEW:.+]] = subview %[[LHS]][%[[I]], %[[K]]] [%[[I_OFFSET]], 1] [1, 1] 
+//     CHECK:     %[[J:.+]] = affine.apply #[[MAP3]]()[%[[THREAD_X_ID]]]
+//     CHECK:     %[[DIM_J:.+]] = dim %[[RHS]], %[[CONST_1]] 
+//     CHECK:     %[[J_OFFSET:.+]] = affine.min #[[MAP4]]()[%[[THREAD_X_ID]], %[[DIM_J]]]
+//     CHECK:     %[[RHS_SUBVIEW:.+]] = subview %[[RHS]][%[[K]], %[[J]]] [1, %[[J_OFFSET]]] [1, 1]  
+//     CHECK:     %[[DIM_I:.+]] = dim %[[RESULT]], %[[CONST_0]]
+//     CHECK:     %[[DIM_I_OFFSET:.+]] = affine.min #[[MAP1]]()[%[[THREAD_Y_ID]], %[[DIM_I]]]
+//     CHECK:     %[[DIM_J:.+]] = dim %[[RESULT]], %[[CONST_1]]
+//     CHECK:     %[[DIM_J_OFFSET:.+]] = affine.min #[[MAP4]]()[%[[THREAD_X_ID]], %[[DIM_J]]]
+//     CHECK:     %[[RESULT_SUBVIEW:.+]] = subview %[[RESULT]][%[[I]], %[[J]]] [%[[DIM_I_OFFSET]], %[[DIM_J_OFFSET]]] [1, 1]
+//     CHECK:      linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%[[LHS_SUBVIEW]], %[[RHS_SUBVIEW]] : memref<?x1xf32, #[[MAP2]]>, memref<1x?xf32, #[[MAP2]]>) outs(%[[RESULT_SUBVIEW]] : memref<?x?xf32, #[[MAP2]]>)
 
 // -----
 
@@ -37,20 +37,20 @@
   linalg.matmul ins(%lhs, %rhs : memref<16x4xf32>, memref<4x8xf32>) outs(%result : memref<16x8xf32>)
   return
 }
-// CHECK: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 2)>
-// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 4 + s0 + d1)>
-// CHECK: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 4)>
-// CHECK: #[[MAP3:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1)>
-// CHECK: func @static_matmul(%[[LHS:.+]]: memref<16x4xf32>, %[[RHS:.+]]: memref<4x8xf32>, %[[RESULT:.+]]: memref<16x8xf32>)
-// CHECK: %[[CONST_0:.+]] = constant 0 : index
-// CHECK: %[[CONST_4:.+]] = constant 4 : index
-// CHECK: %[[CONST_1:.+]] = constant 1 : index
-// CHECK: %[[THREAD_X_ID:.+]] = iree.workgroup_id  {dimension = "x"} : index
-// CHECK: %[[THREAD_Y_ID:.+]] = iree.workgroup_id  {dimension = "y"} : index
-// CHECK:  scf.for %[[K:.+]] = %[[CONST_0]] to %[[CONST_4]] step %[[CONST_1]] 
-// CHECK:    %[[I:.+]] = affine.apply #[[MAP0]]()[%[[THREAD_Y_ID]]]
-// CHECK:    %[[LHS_SUBVIEW:.+]] = subview %[[LHS]][%[[I]], %[[K]]] [2, 1] [1, 1]  : memref<16x4xf32> to memref<2x1xf32, #[[MAP1]]>
-// CHECK:    %[[J:.+]] = affine.apply #[[MAP2]]()[%[[THREAD_X_ID]]]
-// CHECK:    %[[RHS_SUBVIEW:.+]] = subview %[[RHS]][%[[K]], %[[J]]] [1, 4] [1, 1]  : memref<4x8xf32> to memref<1x4xf32, #[[MAP3]]>
-// CHECK:    %[[RESULT_SUBVIEW:.+]] = subview %[[RESULT]][%[[I]], %[[J]]] [2, 4] [1, 1]  : memref<16x8xf32> to memref<2x4xf32, #[[MAP3]]>
-// CHECK:    linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%[[LHS_SUBVIEW]], %[[RHS_SUBVIEW]] : memref<2x1xf32, #[[MAP1]]>, memref<1x4xf32, #[[MAP3]]>) outs(%6 : memref<2x4xf32, #[[MAP3]]>)
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 4 + s0 + d1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 4)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1)>
+//     CHECK: func @static_matmul(%[[LHS:.+]]: memref<16x4xf32>, %[[RHS:.+]]: memref<4x8xf32>, %[[RESULT:.+]]: memref<16x8xf32>)
+// CHECK-DAG: %[[CONST_0:.+]] = constant 0 : index
+// CHECK-DAG: %[[CONST_4:.+]] = constant 4 : index
+// CHECK-DAG: %[[CONST_1:.+]] = constant 1 : index
+// CHECK-DAG: %[[THREAD_X_ID:.+]] = iree.workgroup_id  {dimension = "x"} : index
+// CHECK-DAG: %[[THREAD_Y_ID:.+]] = iree.workgroup_id  {dimension = "y"} : index
+//     CHECK:  scf.for %[[K:.+]] = %[[CONST_0]] to %[[CONST_4]] step %[[CONST_1]] 
+//     CHECK:    %[[I:.+]] = affine.apply #[[MAP0]]()[%[[THREAD_Y_ID]]]
+//     CHECK:    %[[LHS_SUBVIEW:.+]] = subview %[[LHS]][%[[I]], %[[K]]] [2, 1] [1, 1]  : memref<16x4xf32> to memref<2x1xf32, #[[MAP1]]>
+//     CHECK:    %[[J:.+]] = affine.apply #[[MAP2]]()[%[[THREAD_X_ID]]]
+//     CHECK:    %[[RHS_SUBVIEW:.+]] = subview %[[RHS]][%[[K]], %[[J]]] [1, 4] [1, 1]  : memref<4x8xf32> to memref<1x4xf32, #[[MAP3]]>
+//     CHECK:    %[[RESULT_SUBVIEW:.+]] = subview %[[RESULT]][%[[I]], %[[J]]] [2, 4] [1, 1]  : memref<16x8xf32> to memref<2x4xf32, #[[MAP3]]>
+//     CHECK:    linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%[[LHS_SUBVIEW]], %[[RHS_SUBVIEW]] : memref<2x1xf32, #[[MAP1]]>, memref<1x4xf32, #[[MAP3]]>) outs(%6 : memref<2x4xf32, #[[MAP3]]>)
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
index d281d79..8771e7f 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
@@ -754,6 +754,8 @@
                   MapLinalgOpToLocalInvocationId<linalg::ConvOp>,
                   MapLinalgOpToLocalInvocationId<linalg::CopyOp>,
                   MapLinalgOpToLocalInvocationId<linalg::FillOp>,
+                  MapLinalgOpToLocalInvocationId<linalg::GenericOp>,
+                  MapLinalgOpToLocalInvocationId<linalg::IndexedGenericOp>,
                   MapLinalgOpToLocalInvocationId<linalg::MatmulOp>,
                   MapLinalgOpToLocalInvocationId<linalg::BatchMatmulOp>,
                   MapLinalgOpToLocalInvocationId<linalg::PoolingMaxOp>,
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
index 7f0198a..4178fcc 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
@@ -25,6 +25,7 @@
 
 #include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
 #include "iree/compiler/Conversion/Common/Attributes.h"
+#include "iree/compiler/Conversion/Common/LaunchConfig.h"
 #include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
 #include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
 #include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
@@ -49,10 +50,6 @@
 namespace mlir {
 namespace iree_compiler {
 
-/// Name of the StrAttr that can be used to get the key to access the tile size
-/// information.
-static const char kLaunchInfoKey[] = "launch_info_key";
-
 /// Given `nprocs` try to distribute it evenly across 2 logical x and y.
 static std::tuple<int64_t, int64_t> distributeProcs2D(int64_t nprocs) {
   int64_t nprocs_x = std::max<int64_t>(
@@ -61,6 +58,13 @@
   return std::make_tuple(nprocs_x, nprocs / nprocs_x);
 }
 
+/// Returns the minimum of `shape` and `tileSize` if shape is static. If `shape`
+/// is dynamic returns `tileSize`.
+static int64_t getMinIfShapeStatic(int64_t shape, int64_t tileSize) {
+  if (shape == ShapedType::kDynamicSize) return tileSize;
+  return std::min(shape, tileSize);
+}
+
 namespace {
 struct LaunchConfigInfo {
   std::array<int64_t, 3> workgroupSize = {1, 1, 1};
@@ -83,6 +87,51 @@
   return op.emitError("undefined launch config for tiled operation");
 }
 
+static void getMaliBestMatMulTileSizes(Type elementType,
+                                       SmallVectorImpl<int64_t> &tileSizes) {
+  if (elementType.isF16()) {
+    tileSizes.append({16, 64, 8});
+  } else {
+    tileSizes.append({8, 64, 4});
+  }
+}
+
+/// Launch configuration for Mali GPU configuration.
+static LogicalResult getMaliSpecificConfig(
+    linalg::BatchMatmulOp op, const spirv::TargetEnv &targetEnv,
+    const SPIRVCodegenOptions &options, TileSizesListType &tileSizes,
+    std::array<int64_t, 3> &workgroupSize,
+    std::array<int64_t, 3> &numSubgroups) {
+  if (targetEnv.getVendorID() != spirv::Vendor::ARM) return failure();
+
+  auto lhsType = op.inputs()[0].getType().cast<MemRefType>();
+  auto rhsType = op.inputs()[1].getType().cast<MemRefType>();
+  assert(lhsType.getElementType() == rhsType.getElementType());
+  // Pick ideal tile size based on the type.
+  SmallVector<int64_t, 4> workgroupLevelTs(1, 1);
+  getMaliBestMatMulTileSizes(lhsType.getElementType(), workgroupLevelTs);
+  // Fall back to the none vectorize path for cases we don't handle.
+  if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape() ||
+      lhsType.getDimSize(1) % workgroupLevelTs[1] != 0 ||
+      rhsType.getDimSize(2) % workgroupLevelTs[2] != 0 ||
+      lhsType.getDimSize(2) % workgroupLevelTs[3] != 0) {
+    return failure();
+  }
+
+  workgroupSize[0] = targetEnv.getResourceLimits().subgroup_size().getInt();
+  workgroupSize[1] = 1;
+  workgroupSize[2] = 1;
+  tileSizes.emplace_back(workgroupLevelTs);
+  // No tiling at the subgroup level since this target doesn't use subgroup op
+  // or shared memory.
+  tileSizes.emplace_back();
+  SmallVector<int64_t, 4> invocationLevelTs = {
+      workgroupLevelTs[0], workgroupLevelTs[1],
+      workgroupLevelTs[2] / workgroupSize[0], workgroupLevelTs[3]};
+  tileSizes.emplace_back(invocationLevelTs);
+  return success();
+}
+
 /// Launch config for `linalg.batchmatmul`.
 template <>
 LogicalResult getOpLaunchConfig(linalg::BatchMatmulOp op,
@@ -90,6 +139,13 @@
                                 const SPIRVCodegenOptions &options,
                                 TileSizesListType &tileSizes,
                                 LaunchConfigInfo &config) {
+  if (options.enableVectorization &&
+      succeeded(getMaliSpecificConfig(op, targetEnv, options, tileSizes,
+                                      config.workgroupSize,
+                                      config.numSubgroups))) {
+    config.vectorize = true;
+    return success();
+  }
   unsigned maxWorkgroupSize = targetEnv.getResourceLimits()
                                   .max_compute_workgroup_invocations()
                                   .getInt();
@@ -220,16 +276,12 @@
   assert(lhsType.getElementType() == rhsType.getElementType());
   // Pick ideal tile size based on the type.
   SmallVector<int64_t, 4> workgroupLevelTs;
-  if (lhsType.getElementType().isF16()) {
-    workgroupLevelTs.append({16, 64, 8});
-  } else {
-    workgroupLevelTs.append({8, 64, 4});
-  }
+  getMaliBestMatMulTileSizes(lhsType.getElementType(), workgroupLevelTs);
 
   // Fall back to the none vectorize path for cases we don't handle.
   if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape() ||
       lhsType.getDimSize(0) % workgroupLevelTs[0] != 0 ||
-      rhsType.getDimSize(0) % workgroupLevelTs[1] != 0 ||
+      rhsType.getDimSize(1) % workgroupLevelTs[1] != 0 ||
       lhsType.getDimSize(1) % workgroupLevelTs[2] != 0) {
     return failure();
   }
@@ -283,19 +335,74 @@
     tileSizeK = 32;
   }
   assert(tileSizes.empty());
-  SmallVector<int64_t, 4> ts = {nRowsPerWorkitem * config.workgroupSize[1],
-                                nColsPerWorkitem * config.workgroupSize[0],
-                                tileSizeK};
+  int64_t M = op.inputs()[0].getType().cast<ShapedType>().getShape()[0];
+  int64_t N = op.inputs()[1].getType().cast<ShapedType>().getShape()[1];
+  int64_t K = op.inputs()[0].getType().cast<ShapedType>().getShape()[1];
+  SmallVector<int64_t, 4> ts = {
+      getMinIfShapeStatic(M, nRowsPerWorkitem * config.workgroupSize[1]),
+      getMinIfShapeStatic(N, nColsPerWorkitem * config.workgroupSize[0]),
+      getMinIfShapeStatic(K, tileSizeK)};
   tileSizes.emplace_back(std::move(ts));
   return success();
 }
 
+static LogicalResult getMaliSpecificConfig(linalg::ConvOp op,
+                                           TileSizesListType &tileSizes,
+                                           LaunchConfigInfo &config) {
+  auto inputType = op.getInput(1).getType().cast<MemRefType>();
+  auto outputType = op.getOutputBufferType(0).cast<MemRefType>();
+  if (!inputType.hasStaticShape() || !outputType.hasStaticShape())
+    return failure();
+
+  const int tileWidth = 8;
+  const int tileChannel = 32;
+
+  auto outputShape = outputType.getShape();
+  bool isInputTilable = inputType.getDimSize(3) % 4 == 0;
+  bool isOutputTilable = outputShape[0] == 1 &&
+                         outputShape[2] % tileWidth == 0 &&
+                         outputShape[3] % tileChannel == 0;
+  if (!isInputTilable || !isOutputTilable) return failure();
+
+  config.workgroupSize = {8, 2, 1};
+
+  SmallVector<int64_t, 4> workgroupLevel = {/*batch=*/0, /*output_height=*/1,
+                                            /*output_width=*/tileWidth,
+                                            /*output_channel=*/tileChannel};
+  tileSizes.emplace_back(std::move(workgroupLevel));
+
+  // No tiling at the subgroup level given that we don't use subgroup
+  // level syncrhonization  or shared memory.
+  tileSizes.emplace_back();
+
+  SmallVector<int64_t, 4> invocationLevel = {
+      /*batch=*/0, /*output_height=*/1,
+      /*output_width=*/tileWidth / config.workgroupSize[1],
+      /*output_channel=*/tileChannel / config.workgroupSize[0]};
+  tileSizes.emplace_back(invocationLevel);
+
+  // Finally, for each invocation, we use tiling to generate loops to loop over
+  // the filter's height (step 1), width (step 1), and input channel (step 4)
+  // dimensions.
+  SmallVector<int64_t, 4> fourthLevel = {0, 0, 0, 0, 4, 1, 1};
+  tileSizes.emplace_back(fourthLevel);
+
+  config.vectorize = true;
+
+  return success();
+}
+
 template <>
 LogicalResult getOpLaunchConfig(linalg::ConvOp op,
                                 const spirv::TargetEnv &targetEnv,
                                 const SPIRVCodegenOptions &options,
                                 TileSizesListType &tileSizes,
                                 LaunchConfigInfo &config) {
+  if (targetEnv.getVendorID() == spirv::Vendor::ARM &&
+      succeeded(getMaliSpecificConfig(op, tileSizes, config))) {
+    return success();
+  }
+
   unsigned maxWorkgroupSize = targetEnv.getResourceLimits()
                                   .max_compute_workgroup_invocations()
                                   .getInt();
@@ -346,55 +453,43 @@
 
 #undef DEFINE_POOLINGOP_CONFIG
 
-Optional<StringRef> LaunchConfig::getKey(Operation *op) const {
-  StringAttr attr = op->getAttrOfType<StringAttr>(kLaunchInfoKey);
-  if (!attr) return {};
-  return attr.getValue();
-}
-
-LogicalResult LaunchConfig::init(
+Optional<LaunchConfig> initGPULaunchConfig(
     MLIRContext *context, const linalg::LinalgDependenceGraph &dependenceGraph,
     const SPIRVCodegenOptions &options, ArrayRef<linalg::LinalgOp> linalgOps) {
-  unsigned numTiledOps = 0;
-  auto setKey = [&](Operation *op) -> std::string {
-    std::string key = llvm::formatv("__op_num_{0}__", numTiledOps++).str();
-    op->setAttr(Identifier::get(kLaunchInfoKey, context),
-                StringAttr::get(key, context));
-    return key;
-  };
-
+  LaunchConfig launchConfig;
   if (!options.workgroupSize.empty()) {
-    for (linalg::LinalgOp linalgOp : linalgOps)
-      tileSizes[setKey(linalgOp)].emplace_back(options.tileSizes.begin(),
-                                               options.tileSizes.end());
-    workgroupSize = {1, 1, 1};
-    for (unsigned i = 0,
-                  e = std::min<unsigned>(3, options.workgroupSize.size());
-         i != e; ++i)
-      workgroupSize[i] = options.workgroupSize[i];
-    return success();
+    SmallVector<int64_t, 3> tileSizes(options.tileSizes.begin(),
+                                      options.tileSizes.end());
+    for (linalg::LinalgOp linalgOp : linalgOps) {
+      launchConfig.setTileSizes(linalgOp.getOperation(), tileSizes, 0);
+    }
+    SmallVector<int64_t, 3> workgroupSize(options.workgroupSize.begin(),
+                                          options.workgroupSize.end());
+    launchConfig.setWorkgroupSize(workgroupSize);
+    return launchConfig;
   }
 
-  if (linalgOps.empty()) return success();
+  if (linalgOps.empty()) return launchConfig;
 
   spirv::TargetEnv targetEnv(spirv::lookupTargetEnv(*linalgOps.begin()));
 
   Optional<linalg::LinalgOp> rootOperation = {};
   LaunchConfigInfo config;
   for (linalg::LinalgOp linalgOp : linalgOps) {
-#define DISPATCH(opName)                                                  \
-  if (auto op = dyn_cast<opName>(linalgOp.getOperation())) {              \
-    if (rootOperation) {                                                  \
-      return op.emitError(                                                \
-          "unhandled multiple root operations in dispatch region");       \
-    }                                                                     \
-    rootOperation = linalgOp;                                             \
-    TileSizesListType &tileSizesInfo = tileSizes[setKey(*rootOperation)]; \
-    if (failed(getOpLaunchConfig(op, targetEnv, options, tileSizesInfo,   \
-                                 config))) {                              \
-      return failure();                                                   \
-    }                                                                     \
-    continue;                                                             \
+#define DISPATCH(opName)                                                     \
+  if (auto op = dyn_cast<opName>(linalgOp.getOperation())) {                 \
+    if (rootOperation) {                                                     \
+      op.emitError("unhandled multiple root operations in dispatch region"); \
+      return llvm::None;                                                     \
+    }                                                                        \
+    rootOperation = linalgOp;                                                \
+    TileSizesListType tileSizesInfo;                                         \
+    if (failed(getOpLaunchConfig(op, targetEnv, options, tileSizesInfo,      \
+                                 config))) {                                 \
+      return llvm::None;                                                     \
+    }                                                                        \
+    launchConfig.setTileSizes(op, tileSizesInfo);                            \
+    continue;                                                                \
   }
 
     DISPATCH(linalg::BatchMatmulOp)
@@ -406,73 +501,23 @@
 
 #undef DISPATCH
   }
-  workgroupSize = config.workgroupSize;
-  numSubgroups = config.numSubgroups;
-  vectorize = config.vectorize;
+
+  launchConfig.setWorkgroupSize(config.workgroupSize);
+  launchConfig.setNumSubgroups(config.numSubgroups);
+  launchConfig.setVectorize(config.vectorize);
+
   if (!rootOperation) {
     // No root operations found. Dont need to do anything.
-    return success();
+    return launchConfig;
   }
 
-  // Check the dependencies going into and out of the root operation. For now
-  // only the following dependencies are supported
-  // - WAW dependencies going into the root operation.
-  // - RAW dependencies going out of the root operation.
-  // - WAW dependencies going out of the root operation.
-  // i.e. there are no RAW dependences going into the root operation.
-  auto inRAWDependencies = dependenceGraph.getDependencesInto(
-      *rootOperation, linalg::LinalgDependenceGraph::RAW);
-  if (!inRAWDependencies.empty()) {
-    return rootOperation->getOperation()->emitError(
-        "unhandled fusion of root operation with producer");
-  }
-  auto dependences =
-      dependenceGraph.getDependentOperations(rootOperation.getValue());
-  unsigned numOuterParallel = getNumOuterParallelLoops(*rootOperation);
-
-  // Check that for all dependences into and out of the root operation,
-  // - The result expressions of the indexing maps of the fused view in the
-  //   producer and consumer must match for the parallel loops.
-  for (auto dependence :
-       dependenceGraph.getDependentOperations(rootOperation.getValue())) {
-    unsigned viewIndex = dependence.indexingOpView.operandIndex;
-    AffineMap indexingMap = rootOperation->getIndexingMap(viewIndex);
-    linalg::LinalgOp fusedOp =
-        cast<linalg::LinalgOp>(dependence.dependentOpView.op);
-    unsigned fusedViewIndex = dependence.dependentOpView.operandIndex;
-    AffineMap fusedIndexingMap = fusedOp.getIndexingMap(fusedViewIndex);
-    if (indexingMap.getNumResults() < numOuterParallel ||
-        fusedIndexingMap.getNumResults() < numOuterParallel ||
-        !llvm::all_of(
-            llvm::seq<unsigned>(0, numOuterParallel),
-            [&indexingMap, fusedIndexingMap](unsigned i) {
-              return indexingMap.getResult(i).isa<AffineDimExpr>() &&
-                     fusedIndexingMap.getResult(i).isa<AffineDimExpr>() &&
-                     indexingMap.getResult(i) == fusedIndexingMap.getResult(i);
-            })) {
-      return rootOperation->getOperation()->emitError(
-          "unhandled fusion of root operation with all operations in the "
-          "dispatch region");
-    }
-  }
-  // The dependent operations get the same tile size information as the root
-  // operation. To propogate that information, just use the same key as the root
-  // operation.
-  for (auto dependence : dependences) {
-    dependence.dependentOpView.op->setAttr(
-        Identifier::get(kLaunchInfoKey, context),
-        StringAttr::get(getKey(*rootOperation).getValue(), context));
-  }
+  if (failed(propogateRootOperationLaunchConfig(launchConfig, *rootOperation,
+                                                dependenceGraph)))
+    return llvm::None;
 
   // TODO(ravishankarm): Verify that the set configurations is within the device
   // limits.
-  return success();
-}
-
-void LaunchConfig::finalize(FuncOp funcOp) {
-  funcOp.walk([&](linalg::LinalgOp linalgOp) {
-    linalgOp.removeAttr(Identifier::get(kLaunchInfoKey, funcOp.getContext()));
-  });
+  return launchConfig;
 }
 
 template <typename OpTy>
@@ -493,8 +538,14 @@
         op.getAccType().cast<VectorType>().getElementType(),
         op.getResultType().cast<VectorType>().getElementType());
   } else {
+    unsigned lastParalleldim = 0;
+    for (auto it : llvm::enumerate(op.iterator_types())) {
+      if (isParallelIterator(it.value())) lastParalleldim = it.index();
+    }
+    SmallVector<int64_t, 4> nativeSize(op.iterator_types().size(), 1);
+    nativeSize[lastParalleldim] = 4;
     // Map to vec4 fma operations.
-    return SmallVector<int64_t, 4>({1, 4, 1});
+    return nativeSize;
   }
 }
 
@@ -508,10 +559,13 @@
     // the contract.
     return SmallVector<int64_t, 4>(op.getVectorType().getDimSize(0),
                                    op.getVectorType().getDimSize(1));
-  } else {
-    // Map to load4.
-    return SmallVector<int64_t, 4>({1, 4});
   }
+
+  // Map to load4.
+  auto rank = op.getVectorType().getRank();
+  SmallVector<int64_t, 4> nativeSize(rank, 1);
+  nativeSize.back() = 4;
+  return nativeSize;
 }
 
 template <>
@@ -524,10 +578,13 @@
     // the contract.
     return SmallVector<int64_t, 4>(op.getVectorType().getDimSize(0),
                                    op.getVectorType().getDimSize(1));
-  } else {
-    // Map to store4.
-    return SmallVector<int64_t, 4>({1, 4});
   }
+
+  // Map to store4.
+  auto rank = op.getVectorType().getRank();
+  SmallVector<int64_t, 4> nativeSize(rank, 1);
+  nativeSize.back() = 4;
+  return nativeSize;
 }
 
 Optional<SmallVector<int64_t, 4>> getNativeVectorSize(Operation *op) {
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
index 72ee6d2..8f01683 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
@@ -19,11 +19,13 @@
 // the number of workgroups to use for launch, etc.
 //
 //===----------------------------------------------------------------------===//
+
 #ifndef IREE_COMPILER_CONVERSION_LINALGTOSPIRV_KERNELDISPATCHUTILS_H_
 #define IREE_COMPILER_CONVERSION_LINALGTOSPIRV_KERNELDISPATCHUTILS_H_
 
 #include <array>
 
+#include "iree/compiler/Conversion/Common/LaunchConfig.h"
 #include "iree/compiler/Conversion/LinalgToSPIRV/CodeGenOptionUtils.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringMap.h"
@@ -41,94 +43,9 @@
 namespace mlir {
 namespace iree_compiler {
 
-/// Store the tile sizes to use at different levels of tiling as a vector of
-/// vectors.
-/// - First level tiling maps to workgroups.
-/// - Second level tiling maps to subgroups.
-using TileSizesListType = SmallVector<SmallVector<int64_t, 4>, 1>;
-
-/// Based on the linalg operations in a dispatch region, the number of levels of
-/// tiling, the tile sizes needed, the workgroup size, etc. need to be
-/// decided. These parameters are called `LaunchConfig`. This class implements
-/// one heuristic to compute these for the different linalg operations on
-/// buffers. This can be adapted later to support multiple configurations that
-/// can be picked based on device information/problem size information. It
-/// exposes the information needed by the codegenerators, and hides the
-/// implementation from the rest of the pipeline.
-class LaunchConfig {
- public:
-  LaunchConfig() : workgroupSize({1, 1, 1}), numSubgroups({1, 1, 1}) {}
-
-  /// Given the sequence of `linalgOps` (and `options`), decide the launch
-  /// configuration by deciding
-  /// - the number of levels of tiling,
-  /// - tile sizes for each level,
-  /// - the workgroup size, and
-  /// - number of subgroups to use.
-  LogicalResult init(MLIRContext *context,
-                     const linalg::LinalgDependenceGraph &dependenceGraph,
-                     const SPIRVCodegenOptions &options,
-                     ArrayRef<linalg::LinalgOp> linalgOps);
-
-  /// Remove attributed added to operations for retrieving tile size
-  /// information.
-  void finalize(FuncOp funcOp);
-
-  /// Gets the tile size computed for an operation at all levels.
-  TileSizesListType getTileSizes(Operation *op) const {
-    auto key = getKey(op);
-    if (!key) return {};
-    auto it = tileSizes.find(*key);
-    return it->second;
-  }
-
-  /// Gets the tile size computed for an operation for an level.
-  ArrayRef<int64_t> getTileSizes(Operation *op, size_t level) const {
-    auto key = getKey(op);
-    if (!key) return {};
-    auto it = tileSizes.find(*key);
-    if (it == tileSizes.end() || level >= it->second.size()) return {};
-    return it->second[level];
-  }
-
-  /// Returns the workgroup size to use based on the tile sizes.
-  ArrayRef<int64_t> getWorkgroupSize() const { return workgroupSize; }
-
-  /// Returns the number of subgroups to use.
-  ArrayRef<int64_t> getNumSubgroups() const { return numSubgroups; }
-
-  /// Returns true if tile sizes have been computed for the operation. If tile
-  /// sizes arent set, it implies operation is not to be tiled.
-  bool hasTileSizes(Operation *op, size_t level = 0) const {
-    return !getTileSizes(op, level).empty();
-  }
-
-  /// Use vectorize transformations.
-  bool useVectorize() const { return vectorize; }
-
- protected:
-  /// Current tile size configuration per operation. They key used here to
-  /// retrieve the tile size information per operation is the value of a StrAttr
-  /// added to operations during `init`. When tiled this attribute is copied
-  /// over to the tiled operation, thereby the same key can be used to retrieve
-  /// the tile sizes for the next level of tiling. The `finalize` method removes
-  /// these attributes.
-  llvm::StringMap<TileSizesListType> tileSizes;
-
-  /// Workgroup size to use.
-  std::array<int64_t, 3> workgroupSize;
-
-  /// Number of subgroups that are logically distributed along x, y & z.
-  std::array<int64_t, 3> numSubgroups;
-
-  /// Use vectorization.
-  bool vectorize = false;
-
- private:
-  /// Retrieves the key to use to get the `tileSizes` for a given
-  /// `operation`. Returns llvm::None on failure.
-  Optional<StringRef> getKey(Operation *op) const;
-};
+Optional<LaunchConfig> initGPULaunchConfig(
+    MLIRContext *context, const linalg::LinalgDependenceGraph &dependenceGraph,
+    const SPIRVCodegenOptions &options, ArrayRef<linalg::LinalgOp> linalgOps);
 
 /// Returns the size of instruction in `vector` dialect that maps directly to
 /// the hardware.
@@ -136,4 +53,5 @@
 
 }  // namespace iree_compiler
 }  // namespace mlir
+
 #endif  // IREE_COMPILER_CONVERSION_LINALGTOSPIRV_DISPATCHUTILS_H_
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index 34c6632..5be1613 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -21,13 +21,14 @@
 #include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
 #include "iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h"
 #include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
-#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
 #include "iree/compiler/Conversion/Common/Attributes.h"
+#include "iree/compiler/Conversion/Common/Transforms.h"
 #include "iree/compiler/Conversion/LinalgToSPIRV/CodeGenOptionUtils.h"
 #include "iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h"
 #include "iree/compiler/Conversion/LinalgToSPIRV/MemorySpace.h"
 #include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
 #include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
+#include "iree/compiler/Conversion/LinalgToVector/Passes.h"
 #include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
@@ -35,6 +36,8 @@
 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/Identifier.h"
 #include "mlir/IR/Matchers.h"
@@ -73,6 +76,32 @@
   return linalg::LinalgMarker(markers, Identifier::get(replaceMarker, context));
 }
 
+/// Returns the distribution options for operations when targeting workgroups.
+static linalg::LinalgLoopDistributionOptions getWorkgroupDistributionOptions() {
+  linalg::LinalgLoopDistributionOptions options;
+
+  options.procInfo = [](OpBuilder &builder, Location loc,
+                        ArrayRef<Range> parallelLoopRanges) {
+    return getGPUProcessorIdsAndCounts<gpu::BlockIdOp, gpu::GridDimOp>(
+        builder, loc, parallelLoopRanges.size());
+  };
+  options.distributionMethod.assign(
+      3, linalg::DistributionMethod::CyclicNumProcsEqNumIters);
+
+  return options;
+}
+
+/// Applies canonicalization over index calculation inside the given `funcOp`.
+static void applyIndexCalculationCanonicalization(FuncOp funcOp) {
+  MLIRContext *context = funcOp.getContext();
+  OwningRewritePatternList canonicalizationPatterns;
+  DimOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
+  AddIOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
+  SubIOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
+  SignedDivIOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
+  applyPatternsAndFoldGreedily(funcOp, std::move(canonicalizationPatterns));
+}
+
 //===----------------------------------------------------------------------===//
 // Main pass
 //===----------------------------------------------------------------------===//
@@ -100,154 +129,6 @@
 }  // namespace
 
 //===----------------------------------------------------------------------===//
-// Patterns to tile computation to map to workgroups
-//===----------------------------------------------------------------------===//
-
-/// Returns the distribution options for operations when targeting workgroups.
-static linalg::LinalgLoopDistributionOptions getWorkgroupDistributionOptions() {
-  linalg::LinalgLoopDistributionOptions options;
-
-  options.procInfo = [](OpBuilder &builder, Location loc,
-                        ArrayRef<Range> parallelLoopRanges) {
-    return getGPUProcessorIdsAndCounts<gpu::BlockIdOp, gpu::GridDimOp>(
-        builder, loc, parallelLoopRanges.size());
-  };
-  options.distributionMethod = {
-      linalg::DistributionMethod::CyclicNumProcsEqNumIters,
-      linalg::DistributionMethod::CyclicNumProcsEqNumIters,
-      linalg::DistributionMethod::CyclicNumProcsEqNumIters};
-
-  return options;
-}
-
-namespace {
-/// Pattern for tiling operations. Updates the workgroup size in the surrounding
-/// function operation if tiling succeeds, and generates the function that
-/// computes the number of workgroups for the launch.
-template <typename LinalgOpTy>
-class TileToWorkgroupsPattern : public linalg::LinalgBaseTilingPattern {
- public:
-  TileToWorkgroupsPattern(MLIRContext *context,
-                          const linalg::LinalgDependenceGraph &dependenceGraph,
-                          linalg::LinalgTilingOptions options,
-                          linalg::LinalgMarker marker,
-                          const LaunchConfig &launchConfig,
-                          PatternBenefit benefit = 1)
-      : Base(LinalgOpTy::getOperationName(), context, options, marker, benefit),
-        dependenceGraph(dependenceGraph),
-        launchConfig(launchConfig) {}
-
-  LogicalResult matchAndRewrite(Operation *op,
-                                PatternRewriter &rewriter) const override {
-    // Find the parent FuncOp before tiling. If tiling succeeds, the op will be
-    // erased.
-    FuncOp funcOp = op->getParentOfType<FuncOp>();
-    SmallVector<Value, 4> tensorResults;
-    linalg::LinalgOp linalgOp = cast<linalg::LinalgOp>(op);
-    if (!funcOp || dependenceGraph.hasDependentOperations(linalgOp) ||
-        failed(Base::matchAndRewriteBase(op, rewriter, tensorResults)) ||
-        !tensorResults.empty() ||
-        failed(updateWorkGroupSize(funcOp, launchConfig.getWorkgroupSize())) ||
-        (funcOp.getAttr(getNumWorkgroupsFnAttrName()) &&
-         failed(createNumWorkgroupsFromResultShape(
-             rewriter, linalgOp, funcOp, getNumWorkgroupsFnAttrName(),
-             launchConfig.getTileSizes(op, 0))))) {
-      return failure();
-    }
-    setMarker(op, getDeleteMarker());
-    return success();
-  }
-
- private:
-  using Base = linalg::LinalgBaseTilingPattern;
-
-  const linalg::LinalgDependenceGraph &dependenceGraph;
-  const LaunchConfig &launchConfig;
-};
-
-/// Pattern for tile + fuse of operations. Updates the workgroup size in the
-/// surrounding function operation if tiling succeeds, and generates the
-/// function that computes the number of workgroups for the launch..
-template <typename LinalgOpTy>
-class TileAndFuseToWorkgroupsPattern
-    : public linalg::LinalgTileAndFusePattern<LinalgOpTy> {
- public:
-  TileAndFuseToWorkgroupsPattern(
-      MLIRContext *context,
-      const linalg::LinalgDependenceGraph &dependenceGraph,
-      linalg::LinalgTilingOptions tilingOptions, linalg::LinalgMarker marker,
-      const LaunchConfig &launchConfig, PatternBenefit benefit = 1)
-      : Base(context, dependenceGraph, tilingOptions,
-             linalg::LinalgFusionOptions().setIndicesToFuse({2}), marker,
-             marker, getLinalgReplaceMarker(getDeleteMarker(), context),
-             benefit),
-        dependenceGraph(dependenceGraph),
-        launchConfig(launchConfig) {}
-
-  virtual LogicalResult matchAndRewrite(Operation *op,
-                                        PatternRewriter &rewriter) const {
-    FuncOp funcOp = op->getParentOfType<FuncOp>();
-    linalg::LinalgOp linalgOp = cast<linalg::LinalgOp>(op);
-    if (!funcOp || !dependenceGraph.hasDependentOperations(linalgOp) ||
-        failed(Base::matchAndRewrite(op, rewriter)) ||
-        failed(updateWorkGroupSize(funcOp, launchConfig.getWorkgroupSize())) ||
-        (funcOp.getAttr(getNumWorkgroupsFnAttrName()) &&
-         failed(createNumWorkgroupsFromResultShape(
-             rewriter, linalgOp, funcOp, getNumWorkgroupsFnAttrName(),
-             launchConfig.getTileSizes(op, 0))))) {
-      return failure();
-    }
-    return success();
-  }
-
- private:
-  using Base = linalg::LinalgTileAndFusePattern<LinalgOpTy>;
-
-  const linalg::LinalgDependenceGraph &dependenceGraph;
-  const LaunchConfig &launchConfig;
-};
-}  // namespace
-
-/// Populate patterns for first-level tiling.
-static void populateTilingToWorkgroupPatterns(
-    MLIRContext *context, const linalg::LinalgDependenceGraph &dependenceGraph,
-    const LaunchConfig &launchConfig, OwningRewritePatternList &patterns) {
-  // Function to compute first level tiling values.
-  auto getOuterTileSizeFn = [&launchConfig](
-                                OpBuilder &builder,
-                                Operation *operation) -> SmallVector<Value, 4> {
-    ArrayRef<int64_t> tileSizes = launchConfig.getTileSizes(operation, 0);
-    if (tileSizes.empty()) return {};
-    SmallVector<Value, 4> tileSizesVal;
-    tileSizesVal.reserve(tileSizes.size());
-    for (auto val : tileSizes) {
-      tileSizesVal.push_back(
-          builder.create<ConstantIndexOp>(operation->getLoc(), val));
-    }
-    return tileSizesVal;
-  };
-
-  patterns.insert<TileAndFuseToWorkgroupsPattern<linalg::BatchMatmulOp>,
-                  TileAndFuseToWorkgroupsPattern<linalg::ConvOp>,
-                  TileAndFuseToWorkgroupsPattern<linalg::MatmulOp>,
-                  TileAndFuseToWorkgroupsPattern<linalg::PoolingMaxOp>,
-                  TileAndFuseToWorkgroupsPattern<linalg::PoolingMinOp>,
-                  TileAndFuseToWorkgroupsPattern<linalg::PoolingSumOp>,
-                  TileToWorkgroupsPattern<linalg::BatchMatmulOp>,
-                  TileToWorkgroupsPattern<linalg::ConvOp>,
-                  TileToWorkgroupsPattern<linalg::MatmulOp>,
-                  TileToWorkgroupsPattern<linalg::PoolingMaxOp>,
-                  TileToWorkgroupsPattern<linalg::PoolingMinOp>,
-                  TileToWorkgroupsPattern<linalg::PoolingSumOp>>(
-      context, dependenceGraph,
-      linalg::LinalgTilingOptions()
-          .setDistributionOptions(getWorkgroupDistributionOptions())
-          .setTileSizeComputationFunction(getOuterTileSizeFn)
-          .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops),
-      getLinalgReplaceMarker(getWorkgroupMarker(), context), launchConfig);
-}
-
-//===----------------------------------------------------------------------===//
 // Patterns to promote subviews to workgroup memory
 //===----------------------------------------------------------------------===//
 
@@ -404,35 +285,36 @@
         return tileSizesVal;
       };
 
-  auto getThreadProcInfoFn = [&launchConfig](
-                                 OpBuilder &builder, Location loc,
-                                 ArrayRef<Range> parallelLoopRanges) {
-    Type indexType = builder.getIndexType();
-    SmallVector<linalg::ProcInfo, 2> procInfo(2);
-    procInfo[1] = {builder.create<gpu::ThreadIdOp>(loc, indexType,
-                                                   builder.getStringAttr("x")),
-                   builder.create<ConstantIndexOp>(
-                       loc, launchConfig.getWorkgroupSize()[0])};
-    procInfo[0] = {builder.create<gpu::ThreadIdOp>(loc, indexType,
-                                                   builder.getStringAttr("y")),
-                   builder.create<ConstantIndexOp>(
-                       loc, launchConfig.getWorkgroupSize()[1])};
-    return procInfo;
+  auto getThreadProcInfoFn = [](OpBuilder &builder, Location loc,
+                                ArrayRef<Range> parallelLoopRanges) {
+    return getGPUProcessorIdsAndCounts<gpu::ThreadIdOp, gpu::BlockDimOp>(
+        builder, loc, parallelLoopRanges.size());
   };
-  linalg::LinalgLoopDistributionOptions subgroupDistributionOptions = {
+  linalg::LinalgLoopDistributionOptions invocationDistributionOptions = {
       getThreadProcInfoFn,
       {linalg::DistributionMethod::CyclicNumProcsEqNumIters,
+       linalg::DistributionMethod::CyclicNumProcsEqNumIters,
        linalg::DistributionMethod::CyclicNumProcsEqNumIters}};
-  patterns.insert<linalg::LinalgTilingPattern<linalg::MatmulOp>,
-                  linalg::LinalgTilingPattern<linalg::FillOp>>(
-      context,
+
+  auto tilingOptions =
       linalg::LinalgTilingOptions()
           .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops)
           .setTileSizeComputationFunction(getInnerTileSizeFn)
-          .setDistributionOptions(subgroupDistributionOptions),
+          .setDistributionOptions(invocationDistributionOptions);
+
+  patterns.insert<linalg::LinalgTilingPattern<linalg::MatmulOp>,
+                  linalg::LinalgTilingPattern<linalg::FillOp>,
+                  linalg::LinalgTilingPattern<linalg::BatchMatmulOp>>(
+      context, tilingOptions,
       getLinalgMatchAndReplaceMarker(
           {getWorkgroupMemoryMarker(), getWorkgroupMarker()},
           getVectorizeMarker(), context));
+
+  patterns.insert<linalg::LinalgTilingPattern<linalg::ConvOp>>(
+      context, tilingOptions,
+      getLinalgMatchAndReplaceMarker(
+          {getWorkgroupMemoryMarker(), getWorkgroupMarker()},
+          getConvFilterTileMarker(), context));
 }
 
 //====---------------------------------------------------------------------===//
@@ -443,22 +325,12 @@
                                           const LaunchConfig &launchConfig,
                                           OwningRewritePatternList &patterns) {
   patterns.insert<linalg::LinalgVectorizationPattern<linalg::MatmulOp>,
+                  linalg::LinalgVectorizationPattern<linalg::BatchMatmulOp>,
                   linalg::LinalgVectorizationPattern<linalg::FillOp>>(
       context,
       linalg::LinalgMarker(Identifier::get(getVectorizeMarker(), context)));
 }
 
-/// Apply canonicalizations related to tiling to make promotion/vectorization
-/// easier.
-static void applyCanonicalizationPatterns(MLIRContext *context, Operation *op) {
-  OwningRewritePatternList canonicalizationPatterns;
-  canonicalizationPatterns.insert<AffineMinCanonicalizationPattern>(context);
-  AffineApplyOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
-  AffineMinOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
-  SubViewOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
-  applyPatternsAndFoldGreedily(op, std::move(canonicalizationPatterns));
-}
-
 //====---------------------------------------------------------------------===//
 // Patterns for unrolling vectors
 //====---------------------------------------------------------------------===//
@@ -507,6 +379,46 @@
 }
 
 //====---------------------------------------------------------------------===//
+// Patterns to tile convolution window dimensions
+//====---------------------------------------------------------------------===//
+
+static void populateTilingConvFilterPatterns(MLIRContext *context,
+                                             OwningRewritePatternList &patterns,
+                                             const LaunchConfig &launchConfig,
+                                             linalg::LinalgMarker marker) {
+  auto getTileSizeFn = [&launchConfig](OpBuilder &builder, Operation *op) {
+    SmallVector<Value, 4> tileSizes;
+    ArrayRef<int64_t> fourthLevel = launchConfig.getTileSizes(op, 3);
+    tileSizes.reserve(fourthLevel.size());
+
+    Location loc = op->getLoc();
+    for (int64_t size : fourthLevel) {
+      tileSizes.push_back(builder.create<ConstantIndexOp>(loc, size));
+    }
+    return tileSizes;
+  };
+
+  // TODO(antiagainst): move this to launch configuration.
+  SmallVector<unsigned, 8> loopOrder = {
+      /*batch=*/0,
+      /*output_height=*/1,
+      /*output_width=*/2,
+      /*output_channel=*/3,
+      /*filter_height=*/5,
+      /*filter_width=*/6,
+      /*input_channel=*/4,
+  };
+
+  auto tilingOptions = linalg::LinalgTilingOptions()
+                           .setLoopType(linalg::LinalgTilingLoopType::Loops)
+                           .setInterchange(loopOrder)
+                           .setTileSizeComputationFunction(getTileSizeFn);
+
+  patterns.insert<linalg::LinalgTilingPattern<linalg::ConvOp>>(
+      context, tilingOptions, marker);
+}
+
+//====---------------------------------------------------------------------===//
 // Main pass implementation
 //====---------------------------------------------------------------------===//
 
@@ -528,18 +440,19 @@
     auto linalgOps = block.getOps<linalg::LinalgOp>();
     if (linalgOps.empty()) continue;
 
-    LaunchConfig launchConfig;
     SmallVector<linalg::LinalgOp, 4> linalgOpsVec =
         llvm::to_vector<4>(llvm::map_range(linalgOps, [](Operation *op) {
           return cast<linalg::LinalgOp>(op);
         }));
     linalg::Aliases aliases;
     linalg::LinalgDependenceGraph dependenceGraph(aliases, linalgOpsVec);
-    if (failed(launchConfig.init(context, dependenceGraph, options,
-                                 linalgOpsVec))) {
+    Optional<LaunchConfig> launchConfigOpt =
+        initGPULaunchConfig(context, dependenceGraph, options, linalgOpsVec);
+    if (!launchConfigOpt) {
       funcOp.emitError("unable to find launch configuration");
       return signalPassFailure();
     }
+    LaunchConfig &launchConfig = *launchConfigOpt;
 
     LLVM_DEBUG({
       llvm::dbgs() << "@func " << funcOp.getName() << ": # workgroup sizes: [";
@@ -547,7 +460,7 @@
       llvm::dbgs() << "]\n";
       for (auto op : linalgOps) {
         llvm::dbgs() << "\t" << op.getOperation()->getName() << " : ";
-        TileSizesListType const &tileSizes = launchConfig.getTileSizes(op);
+        TileSizesListTypeRef tileSizes = launchConfig.getTileSizes(op);
         llvm::dbgs() << "{";
         std::string sep = "";
         for (auto &level : enumerate(tileSizes)) {
@@ -560,26 +473,38 @@
       }
     });
 
-    {
-      OwningRewritePatternList firstLevelTilingPatterns;
-      populateTilingToWorkgroupPatterns(context, dependenceGraph, launchConfig,
-                                        firstLevelTilingPatterns);
-      applyPatternsAndFoldGreedily(funcOp, std::move(firstLevelTilingPatterns));
-      applyCanonicalizationPatterns(context, funcOp);
-
-      // Delete the ops that are marked for deletion.
-      funcOp.walk([](linalg::LinalgOp linalgOp) {
-        if (hasMarker(linalgOp.getOperation(), getDeleteMarker()))
-          linalgOp.getOperation()->erase();
-      });
+    TileAndFuseOptions tileAndFuseOptions = {getWorkgroupDistributionOptions(),
+                                             allocateWorkgroupMemory};
+    if (failed(tileAndFuseLinalgBufferOps(funcOp, linalgOpsVec, dependenceGraph,
+                                          launchConfig, tileAndFuseOptions)) ||
+        failed(updateWorkGroupSize(funcOp, launchConfig.getWorkgroupSize()))) {
+      return signalPassFailure();
     }
 
     LLVM_DEBUG({
-      llvm::dbgs() << "--- After First level of tile+distribute ---\n";
+      llvm::dbgs() << "--- After first level of tiling and distribution ---\n";
       funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
       llvm::dbgs() << "\n\n";
     });
 
+    // In the above we distributed ops to workgroup dimensions and populated a
+    // function for calculating the number of workgroups. In the folling steps,
+    // we will need to query the workgroup count function to simplify GPU
+    // processor ID uses. It relies on constant upper bounds. So we need to
+    // canonicalize the workgroup count function first.
+    if (funcOp.getAttrOfType<SymbolRefAttr>(getNumWorkgroupsFnAttrName())) {
+      FuncOp numWorkGroupFunc =
+          getNumWorkgroupsFn(funcOp, getNumWorkgroupsFnAttrName());
+      applyIndexCalculationCanonicalization(numWorkGroupFunc);
+
+      LLVM_DEBUG({
+        llvm::dbgs()
+            << "--- After canonicalizing workgroup count function  ---\n";
+        numWorkGroupFunc.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+        llvm::dbgs() << "\n\n";
+      });
+    }
+
     if (options.useWorkgroupMemory) {
       // The promotion patterns are put separate from the tiling patterns to
       // make sure that the allocated scratchspace memory is constant sizes
@@ -587,7 +512,7 @@
       OwningRewritePatternList promotionPatterns;
       populatePromotionPatterns(context, promotionPatterns);
       applyPatternsAndFoldGreedily(funcOp, std::move(promotionPatterns));
-      applyCanonicalizationPatterns(context, funcOp);
+      applyCanonicalizationPatternsForTiling(context, funcOp);
 
       LLVM_DEBUG({
         llvm::dbgs() << "--- After Promotion  ---\n";
@@ -603,7 +528,7 @@
                                          secondLevelTilingPatterns);
         applyPatternsAndFoldGreedily(funcOp,
                                      std::move(secondLevelTilingPatterns));
-        applyCanonicalizationPatterns(context, funcOp);
+        applyCanonicalizationPatternsForTiling(context, funcOp);
         promoteSingleIterationLoops(funcOp);
 
         LLVM_DEBUG({
@@ -619,7 +544,7 @@
                                            thirdLevelTilingPatterns);
         applyPatternsAndFoldGreedily(funcOp,
                                      std::move(thirdLevelTilingPatterns));
-        applyCanonicalizationPatterns(context, funcOp);
+        applyCanonicalizationPatternsForTiling(context, funcOp);
         promoteSingleIterationLoops(funcOp);
 
         LLVM_DEBUG({
@@ -630,9 +555,29 @@
       }
 
       {
+        OwningRewritePatternList tilingPatterns;
+        auto marker = getLinalgMatchAndReplaceMarker(
+            getConvFilterTileMarker(), getVectorizeMarker(), context);
+        populateTilingConvFilterPatterns(context, tilingPatterns, launchConfig,
+                                         marker);
+        populateFoldGPUProcessorIDUsesPatterns(context, tilingPatterns);
+        tilingPatterns.insert<linalg::AffineMinSCFCanonicalizationPattern>(
+            context);
+        applyPatternsAndFoldGreedily(funcOp, std::move(tilingPatterns));
+        applyCanonicalizationPatternsForTiling(context, funcOp);
+
+        LLVM_DEBUG({
+          llvm::dbgs() << "--- After tiling linalg.conv  ---\n";
+          funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+          llvm::dbgs() << "\n\n";
+        });
+      }
+
+      {
         OwningRewritePatternList vectorizationPatterns;
         populateVectorizationPatterns(context, launchConfig,
                                       vectorizationPatterns);
+        populateVectorizeLinalgConvPatterns(context, vectorizationPatterns);
         applyPatternsAndFoldGreedily(funcOp, std::move(vectorizationPatterns));
         LLVM_DEBUG({
           llvm::dbgs() << "--- After Vectorization ---\n";
@@ -645,11 +590,6 @@
     }
 
     launchConfig.finalize(funcOp);
-    SmallVector<linalg::LinalgOp, 1> toDelete;
-    funcOp.walk([&](linalg::LinalgOp linalgOp) {
-      if (hasMarker(linalgOp.getOperation(), getDeleteMarker()))
-        linalgOp.erase();
-    });
   }
 }
 
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
index b014e19..25a20af 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
@@ -69,14 +69,15 @@
   //   - All Linalg ops have buffer semantics.
   //
   // Post-conditions:
+  //   - If there are multiple linalg operations in the dispatch region, they
+  //     are fused, using tile+fuse approach.
+  //     - The fused loops are distributed across workgroups.
   //   - The operations that cannot be fused at buffer levels are split into
   //     separate entry points.
-  //   - If the input Linalg ops are tilable:
-  //     - loop.parallel ops are generated for mapping to workgroups.
-  //     - Linalg ops are nested inside loop.parallel ops and ready for mapping
-  //       to workitems.
-  //     - If multiple linalg operations are present they get tiled and fused to
-  //       get outer loop.parallel ops which can be mapped to workitems.
+  //   - If there is a single linalg operation in the dispatch region, it is
+  //     tiled and the generated parallel loop distributed.
+  //     - The tiled linalg operation can be tiled again one or more times and
+  //       then vectorized.
   //   - Otherwise:
   //     - The Linalg op is kept untouched.
   //
@@ -147,6 +148,9 @@
   //   - Load/store on std.subview ops are converted into load/store on the
   //     original buffers.
   //===--------------------------------------------------------------------===//
+  if (options.enableVectorization) {
+    pm.addNestedPass<FuncOp>(createVectorTransferOptimizationPass());
+  }
   pm.addPass(createLegalizeStdOpsForSPIRVLoweringPass());
   pm.addPass(createCanonicalizerPass());
   pm.addPass(createCSEPass());
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
index 09ceb74..f195f10 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
@@ -38,6 +38,7 @@
 #include "llvm/Support/FormatVariadic.h"
 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Attributes.h"
@@ -84,12 +85,16 @@
                      linalg::LinalgDependenceGraph::DependenceType::WAW)
     ADD_FUSABLE_PAIR(linalg::FillOp, linalg::MatmulOp,
                      linalg::LinalgDependenceGraph::DependenceType::WAW)
+    ADD_FUSABLE_PAIR(linalg::FillOp, linalg::BatchMatmulOp,
+                     linalg::LinalgDependenceGraph::DependenceType::WAW)
     ADD_FUSABLE_PAIR(linalg::FillOp, linalg::PoolingMaxOp,
                      linalg::LinalgDependenceGraph::DependenceType::WAW)
     ADD_FUSABLE_PAIR(linalg::FillOp, linalg::PoolingMinOp,
                      linalg::LinalgDependenceGraph::DependenceType::WAW)
     ADD_FUSABLE_PAIR(linalg::FillOp, linalg::PoolingSumOp,
                      linalg::LinalgDependenceGraph::DependenceType::WAW)
+    ADD_FUSABLE_PAIR(linalg::MatmulOp, linalg::GenericOp,
+                     linalg::LinalgDependenceGraph::DependenceType::RAW)
 
 #undef ADD_FUSABLE_PAIR
   }
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
index 6e01249..26ed11d 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
@@ -163,15 +163,5 @@
 getGPUProcessorIdsAndCounts<GPUGlobalId, GPUGlobalCount>(OpBuilder &builder,
                                                          Location loc,
                                                          unsigned numDims);
-
-unsigned getNumOuterParallelLoops(linalg::LinalgOp op) {
-  return op.iterator_types()
-      .getValue()
-      .take_while([](Attribute attr) -> bool {
-        return attr.cast<StringAttr>().getValue() ==
-               getParallelIteratorTypeName();
-      })
-      .size();
-}
 }  // namespace iree_compiler
 }  // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Utils.h b/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
index ea1316c..0296226 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
@@ -56,10 +56,6 @@
 SmallVector<linalg::ProcInfo, 2> getGPUProcessorIdsAndCounts(OpBuilder &builder,
                                                              Location loc,
                                                              unsigned numDims);
-
-/// Function to get number of outer parallel loops of a linalgOp
-unsigned getNumOuterParallelLoops(linalg::LinalgOp op);
-
 /// Updates the workgroup size used for the dispatch region.
 LogicalResult updateWorkGroupSize(FuncOp funcOp,
                                   ArrayRef<int64_t> workGroupSize);
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
index bf032b7..82a94f3 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
@@ -248,7 +248,7 @@
 };
 
 // Lower vector contract to a single scalar or vector mulf+addf. Insert casts to
-// convert from 2D vector to 1D vector or scalar.
+// convert from N-D vector to 1D vector or scalar.
 class VectorContractLowering : public OpRewritePattern<vector::ContractionOp> {
  public:
   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
@@ -256,19 +256,20 @@
   LogicalResult matchAndRewrite(vector::ContractionOp op,
                                 PatternRewriter &rewriter) const override {
     auto iteratorTypes = op.iterator_types().getValue();
-    if (iteratorTypes.size() != 3 || !isParallelIterator(iteratorTypes[0]) ||
-        !isParallelIterator(iteratorTypes[1]) ||
-        !isReductionIterator(iteratorTypes[2]) ||
-        !isRowMajorMatmul(op.indexing_maps())) {
+    if (!isReductionIterator(iteratorTypes.back()) ||
+        op.getContractingDimMap().size() > 1)
       return failure();
-    }
     if (op.getLhsType().getNumElements() != 1) return failure();
-    unsigned vecSize = op.getAccType().cast<VectorType>().getNumElements();
-    if (!(vecSize >= 1 && vecSize <= 4)) return failure();
+    auto accType = op.getAccType().cast<VectorType>();
+    auto rhsType = op.getRhsType();
+    unsigned vecSize = accType.getNumElements();
+    if (accType != rhsType || !(vecSize >= 1 && vecSize <= 4) ||
+        accType.getShape().back() != vecSize)
+      return failure();
     auto loc = op.getLoc();
     VectorType vecType = VectorType::get(
         vecSize, op.getResultType().cast<VectorType>().getElementType());
-    std::array<int64_t, 2> zero = {0, 0};
+    llvm::SmallVector<int64_t, 4> zero(iteratorTypes.size() - 1, 0);
     Value lhs = rewriter.create<vector::ExtractOp>(loc, op.lhs(), zero);
     Value rhs, acc;
     if (vecSize == 1) {
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp b/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp
index 5a610d3..2f3616d 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp
@@ -84,6 +84,7 @@
   if (memrefType && !memrefType.getElementType().isa<VectorType>() &&
       (kMaxVectorizationSizeInBits % memrefType.getElementTypeBitWidth() ==
        0) &&
+      memrefType.getRank() > 0 &&
       !ShapedType::isDynamic(memrefType.getShape().back()) &&
       getUsesIfAllTransferOp(v, uses)) {
     return calculateMemrefVecSize(uses);
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir
new file mode 100644
index 0000000..f5aee9a
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir
@@ -0,0 +1,333 @@
+// RUN: iree-opt -split-input-file -pass-pipeline="iree-codegen-linalg-tile-and-fuse,canonicalize,cse" -iree-spirv-enable-vectorization %s | IreeFileCheck %s
+
+module attributes {
+  spv.target_env =
+    #spv.target_env<#spv.vce<v1.3,
+      [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess,
+       StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess,
+       UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform,
+       GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot,
+       GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers,
+       VariablePointersStorageBuffer],
+      [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage,
+       SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+      ARM:IntegratedGPU,
+      {max_compute_shared_memory_size = 32768 : i32,
+       max_compute_workgroup_invocations = 512 : i32,
+       max_compute_workgroup_size = dense<512> : vector<3xi32>,
+       subgroup_size = 16 : i32}>} {
+  func @batch_matmul_static_shape()
+    attributes {vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
+    %arg0 = iree.placeholder for "interface buffer"
+      {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4x1024x1024xf32>
+    %arg1 = iree.placeholder for "interface buffer"
+      {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4x1024x1024xf32>
+    %ret0 = iree.placeholder for "interface buffer"
+      {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4x1024x1024xf32>
+    linalg.batch_matmul ins(%arg0, %arg1 : memref<4x1024x1024xf32>, memref<4x1024x1024xf32>) outs(%ret0 : memref<4x1024x1024xf32>)
+    return
+  }
+  func @matmul_static_shape__num_workgroups__
+    (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>,
+     !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index)
+    attributes {sym_visibility = "private"}
+  hal.interface @legacy_io attributes {sym_visibility = "private"} {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+  }
+}
+
+
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 8)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 64)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 4)>
+//      CHECK: func @batch_matmul_static_shape
+//  CHECK-DAG:  %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
+//  CHECK-DAG:  %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
+//  CHECK-DAG:  %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
+//  CHECK-DAG:  %[[C0:.+]] = constant 0 : index
+//  CHECK-DAG:  %[[CST:.+]] = constant 0.0
+//  CHECK-DAG:  %[[C1:.+]] = constant 1 : index
+//  CHECK-DAG:  %[[C2:.+]] = constant 2 : index
+//  CHECK-DAG:  %[[C3:.+]] = constant 3 : index
+//  CHECK-DAG:  %[[C4:.+]] = constant 4 : index
+//  CHECK-DAG:  %[[C5:.+]] = constant 5 : index
+//  CHECK-DAG:  %[[C6:.+]] = constant 6 : index
+//  CHECK-DAG:  %[[C7:.+]] = constant 7 : index
+//      CHECK:  %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
+//      CHECK:  %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
+//      CHECK:  %[[BIDZ:.+]] = "gpu.block_id"() {dimension = "z"}
+//      CHECK:  %[[BOFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+//      CHECK:  %[[BOFFSET_X:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]]
+//      CHECK:  %[[SUBVIEW_RESULT:.+]] = subview %[[RET0]]
+// CHECK-SAME:      [%[[BIDZ]], %[[BOFFSET_Y]], %[[BOFFSET_X]]] [1, 8, 64]
+//      CHECK:  %[[IIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
+//      CHECK:  %[[IIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
+//      CHECK:  %[[IIDZ:.+]] = "gpu.thread_id"() {dimension = "z"}
+//      CHECK:  %[[IOFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[IIDY]]]
+//      CHECK:  %[[IOFFSET_X:.+]] = affine.apply #[[MAP2]]()[%[[IIDX]]]
+//      CHECK:  %[[SUBVIEW_RESULT_2:.+]] = subview %[[SUBVIEW_RESULT]]
+// CHECK-SAME:      [%[[IIDZ]], %[[IOFFSET_Y]], %[[IOFFSET_X]]] [1, 8, 4]
+//  CHECK-DAG:  %[[READ_INIT_0:.+]] = vector.transfer_read
+// CHECK-SAME:    %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C0]],  %[[C0]]]
+//  CHECK-DAG:  %[[READ_INIT_1:.+]] = vector.transfer_read
+// CHECK-SAME:    %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C1]],  %[[C0]]]
+//  CHECK-DAG:  %[[READ_INIT_2:.+]] = vector.transfer_read
+// CHECK-SAME:    %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C2]],  %[[C0]]]
+//  CHECK-DAG:  %[[READ_INIT_3:.+]] = vector.transfer_read
+// CHECK-SAME:    %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C3]],  %[[C0]]]
+//  CHECK-DAG:  %[[READ_INIT_4:.+]] = vector.transfer_read
+// CHECK-SAME:    %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C4]],  %[[C0]]]
+//  CHECK-DAG:  %[[READ_INIT_5:.+]] = vector.transfer_read
+// CHECK-SAME:    %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C5]],  %[[C0]]]
+//  CHECK-DAG:  %[[READ_INIT_6:.+]] = vector.transfer_read
+// CHECK-SAME:    %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C6]],  %[[C0]]]
+//  CHECK-DAG:  %[[READ_INIT_7:.+]] = vector.transfer_read
+// CHECK-SAME:    %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C7]],  %[[C0]]]
+
+//      CHECK:  %[[FOR_RES:.+]]:8 = scf.for %[[IV0:.+]] = {{.*}} to
+// CHECK-SAME:  iter_args(%[[ACC_0:.+]] = %[[READ_INIT_0]],
+// CHECK-SAME:  %[[ACC_1:.+]] = %[[READ_INIT_1]],
+// CHECK-SAME:  %[[ACC_2:.+]] = %[[READ_INIT_2]],
+// CHECK-SAME:  %[[ACC_3:.+]] = %[[READ_INIT_3]],
+// CHECK-SAME:  %[[ACC_4:.+]] = %[[READ_INIT_4]],
+// CHECK-SAME:  %[[ACC_5:.+]] = %[[READ_INIT_5]],
+// CHECK-SAME:  %[[ACC_6:.+]] = %[[READ_INIT_6]],
+// CHECK-SAME:  %[[ACC_7:.+]] = %[[READ_INIT_7]])
+//      CHECK:    %[[SUBVIEW_LHS:.+]] = subview %[[ARG0]]
+// CHECK-SAME:      [%[[BIDZ]], %[[BOFFSET_Y]], %[[IV0]]] [1, 8, 4]
+//      CHECK:    %[[SUBVIEW_RHS:.+]] = subview %[[ARG1]]
+// CHECK-SAME:      [%[[BIDZ]], %[[IV0]], %[[BOFFSET_X]]] [1, 4, 64]
+//      CHECK:    %[[SUBVIEW_LHS_2:.+]] = subview %[[SUBVIEW_LHS]]
+// CHECK-SAME:      [%[[IIDZ]], %[[IOFFSET_Y]], 0] [1, 8, 4] [1, 1, 1]
+//      CHECK:    %[[SUBVIEW_RHS_2:.+]] = subview %[[SUBVIEW_RHS]]
+// CHECK-SAME:      [%[[IIDZ]], 0, %[[IOFFSET_X]]] [1, 4, 4] [1, 1, 1]
+
+//  CHECK-DAG:    %[[READ_LHS_0:.+]] = vector.transfer_read
+// CHECK-SAME:      %[[SUBVIEW_LHS_2]][%[[C0]], %[[C0]], %[[C0]]]
+//  CHECK-DAG:    %[[READ_LHS_1:.+]] = vector.transfer_read
+// CHECK-SAME:      %[[SUBVIEW_LHS_2]][%[[C0]], %[[C1]], %[[C0]]]
+//  CHECK-DAG:    %[[READ_LHS_2:.+]] = vector.transfer_read
+// CHECK-SAME:      %[[SUBVIEW_LHS_2]][%[[C0]], %[[C2]], %[[C0]]]
+//  CHECK-DAG:    %[[READ_LHS_3:.+]] = vector.transfer_read
+// CHECK-SAME:      %[[SUBVIEW_LHS_2]][%[[C0]], %[[C3]], %[[C0]]]
+//  CHECK-DAG:    %[[READ_LHS_4:.+]] = vector.transfer_read
+// CHECK-SAME:      %[[SUBVIEW_LHS_2]][%[[C0]], %[[C4]], %[[C0]]]
+//  CHECK-DAG:    %[[READ_LHS_5:.+]] = vector.transfer_read
+// CHECK-SAME:      %[[SUBVIEW_LHS_2]][%[[C0]], %[[C5]], %[[C0]]]
+//  CHECK-DAG:    %[[READ_LHS_6:.+]] = vector.transfer_read
+// CHECK-SAME:      %[[SUBVIEW_LHS_2]][%[[C0]], %[[C6]], %[[C0]]]
+//  CHECK-DAG:    %[[READ_LHS_7:.+]] = vector.transfer_read
+// CHECK-SAME:      %[[SUBVIEW_LHS_2]][%[[C0]], %[[C7]], %[[C0]]]
+
+//  CHECK-DAG:    %[[READ_RHS_0:.+]] = vector.transfer_read
+// CHECK-SAME:      %[[SUBVIEW_RHS_2]][%[[C0]], %[[C0]], %[[C0]]]
+//  CHECK-DAG:    %[[READ_RHS_1:.+]] = vector.transfer_read
+// CHECK-SAME:      %[[SUBVIEW_RHS_2]][%[[C0]], %[[C1]], %[[C0]]]
+//  CHECK-DAG:    %[[READ_RHS_2:.+]] = vector.transfer_read
+// CHECK-SAME:      %[[SUBVIEW_RHS_2]][%[[C0]], %[[C2]], %[[C0]]]
+//  CHECK-DAG:    %[[READ_RHS_3:.+]] = vector.transfer_read
+// CHECK-SAME:      %[[SUBVIEW_RHS_2]][%[[C0]], %[[C3]], %[[C0]]]
+
+//  CHECK-DAG:    %[[READ_LHS_0_0:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_0]] {offsets = [0, 0, 0]
+//  CHECK-DAG:    %[[READ_LHS_0_1:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_0]] {offsets = [0, 0, 1]
+//  CHECK-DAG:    %[[READ_LHS_0_2:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_0]] {offsets = [0, 0, 2]
+//  CHECK-DAG:    %[[READ_LHS_0_3:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_0]] {offsets = [0, 0, 3]
+//  CHECK-DAG:    %[[READ_LHS_1_0:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_1]] {offsets = [0, 0, 0]
+//  CHECK-DAG:    %[[READ_LHS_1_1:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_1]] {offsets = [0, 0, 1]
+//  CHECK-DAG:    %[[READ_LHS_1_2:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_1]] {offsets = [0, 0, 2]
+//  CHECK-DAG:    %[[READ_LHS_1_3:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_1]] {offsets = [0, 0, 3]
+//  CHECK-DAG:    %[[READ_LHS_2_0:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_2]] {offsets = [0, 0, 0]
+//  CHECK-DAG:    %[[READ_LHS_2_1:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_2]] {offsets = [0, 0, 1]
+//  CHECK-DAG:    %[[READ_LHS_2_2:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_2]] {offsets = [0, 0, 2]
+//  CHECK-DAG:    %[[READ_LHS_2_3:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_2]] {offsets = [0, 0, 3]
+//  CHECK-DAG:    %[[READ_LHS_3_0:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_3]] {offsets = [0, 0, 0]
+//  CHECK-DAG:    %[[READ_LHS_3_1:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_3]] {offsets = [0, 0, 1]
+//  CHECK-DAG:    %[[READ_LHS_3_2:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_3]] {offsets = [0, 0, 2]
+//  CHECK-DAG:    %[[READ_LHS_3_3:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_3]] {offsets = [0, 0, 3]
+//  CHECK-DAG:    %[[READ_LHS_4_0:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_4]] {offsets = [0, 0, 0]
+//  CHECK-DAG:    %[[READ_LHS_4_1:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_4]] {offsets = [0, 0, 1]
+//  CHECK-DAG:    %[[READ_LHS_4_2:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_4]] {offsets = [0, 0, 2]
+//  CHECK-DAG:    %[[READ_LHS_4_3:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_4]] {offsets = [0, 0, 3]
+//  CHECK-DAG:    %[[READ_LHS_5_0:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_5]] {offsets = [0, 0, 0]
+//  CHECK-DAG:    %[[READ_LHS_5_1:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_5]] {offsets = [0, 0, 1]
+//  CHECK-DAG:    %[[READ_LHS_5_2:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_5]] {offsets = [0, 0, 2]
+//  CHECK-DAG:    %[[READ_LHS_5_3:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_5]] {offsets = [0, 0, 3]
+//  CHECK-DAG:    %[[READ_LHS_6_0:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_6]] {offsets = [0, 0, 0]
+//  CHECK-DAG:    %[[READ_LHS_6_1:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_6]] {offsets = [0, 0, 1]
+//  CHECK-DAG:    %[[READ_LHS_6_2:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_6]] {offsets = [0, 0, 2]
+//  CHECK-DAG:    %[[READ_LHS_6_3:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_6]] {offsets = [0, 0, 3]
+//  CHECK-DAG:    %[[READ_LHS_7_0:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_7]] {offsets = [0, 0, 0]
+//  CHECK-DAG:    %[[READ_LHS_7_1:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_7]] {offsets = [0, 0, 1]
+//  CHECK-DAG:    %[[READ_LHS_7_2:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_7]] {offsets = [0, 0, 2]
+//  CHECK-DAG:    %[[READ_LHS_7_3:.+]] = vector.extract_strided_slice
+// CHECK-SAME:      %[[READ_LHS_7]] {offsets = [0, 0, 3]
+
+//      CHECK:    %[[CONTRACT_0_0:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_0_0]], %[[READ_RHS_0]], %[[ACC_0]]
+//      CHECK:    %[[CONTRACT_0_1:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_0_1]], %[[READ_RHS_1]], %[[CONTRACT_0_0]]
+//      CHECK:    %[[CONTRACT_0_2:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_0_2]], %[[READ_RHS_2]], %[[CONTRACT_0_1]]
+//      CHECK:    %[[CONTRACT_0_3:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_0_3]], %[[READ_RHS_3]], %[[CONTRACT_0_2]]
+
+//      CHECK:    %[[CONTRACT_1_0:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_1_0]], %[[READ_RHS_0]], %[[ACC_1]]
+//      CHECK:    %[[CONTRACT_1_1:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_1_1]], %[[READ_RHS_1]], %[[CONTRACT_1_0]]
+//      CHECK:    %[[CONTRACT_1_2:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_1_2]], %[[READ_RHS_2]], %[[CONTRACT_1_1]]
+//      CHECK:    %[[CONTRACT_1_3:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_1_3]], %[[READ_RHS_3]], %[[CONTRACT_1_2]]
+
+//      CHECK:    %[[CONTRACT_2_0:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_2_0]], %[[READ_RHS_0]], %[[ACC_2]]
+//      CHECK:    %[[CONTRACT_2_1:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_2_1]], %[[READ_RHS_1]], %[[CONTRACT_2_0]]
+//      CHECK:    %[[CONTRACT_2_2:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_2_2]], %[[READ_RHS_2]], %[[CONTRACT_2_1]]
+//      CHECK:    %[[CONTRACT_2_3:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_2_3]], %[[READ_RHS_3]], %[[CONTRACT_2_2]]
+
+//      CHECK:    %[[CONTRACT_3_0:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_3_0]], %[[READ_RHS_0]], %[[ACC_3]]
+//      CHECK:    %[[CONTRACT_3_1:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_3_1]], %[[READ_RHS_1]], %[[CONTRACT_3_0]]
+//      CHECK:    %[[CONTRACT_3_2:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_3_2]], %[[READ_RHS_2]], %[[CONTRACT_3_1]]
+//      CHECK:    %[[CONTRACT_3_3:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_3_3]], %[[READ_RHS_3]], %[[CONTRACT_3_2]]
+
+//      CHECK:    %[[CONTRACT_4_0:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_4_0]], %[[READ_RHS_0]], %[[ACC_4]]
+//      CHECK:    %[[CONTRACT_4_1:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_4_1]], %[[READ_RHS_1]], %[[CONTRACT_4_0]]
+//      CHECK:    %[[CONTRACT_4_2:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_4_2]], %[[READ_RHS_2]], %[[CONTRACT_4_1]]
+//      CHECK:    %[[CONTRACT_4_3:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_4_3]], %[[READ_RHS_3]], %[[CONTRACT_4_2]]
+
+//      CHECK:    %[[CONTRACT_5_0:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_5_0]], %[[READ_RHS_0]], %[[ACC_5]]
+//      CHECK:    %[[CONTRACT_5_1:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_5_1]], %[[READ_RHS_1]], %[[CONTRACT_5_0]]
+//      CHECK:    %[[CONTRACT_5_2:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_5_2]], %[[READ_RHS_2]], %[[CONTRACT_5_1]]
+//      CHECK:    %[[CONTRACT_5_3:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_5_3]], %[[READ_RHS_3]], %[[CONTRACT_5_2]]
+
+//      CHECK:    %[[CONTRACT_6_0:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_6_0]], %[[READ_RHS_0]], %[[ACC_6]]
+//      CHECK:    %[[CONTRACT_6_1:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_6_1]], %[[READ_RHS_1]], %[[CONTRACT_6_0]]
+//      CHECK:    %[[CONTRACT_6_2:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_6_2]], %[[READ_RHS_2]], %[[CONTRACT_6_1]]
+//      CHECK:    %[[CONTRACT_6_3:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_6_3]], %[[READ_RHS_3]], %[[CONTRACT_6_2]]
+
+//      CHECK:    %[[CONTRACT_7_0:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_7_0]], %[[READ_RHS_0]], %[[ACC_7]]
+//      CHECK:    %[[CONTRACT_7_1:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_7_1]], %[[READ_RHS_1]], %[[CONTRACT_7_0]]
+//      CHECK:    %[[CONTRACT_7_2:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_7_2]], %[[READ_RHS_2]], %[[CONTRACT_7_1]]
+//      CHECK:    %[[CONTRACT_7_3:.+]] = vector.contract
+// CHECK-SAME:      %[[READ_LHS_7_3]], %[[READ_RHS_3]], %[[CONTRACT_7_2]]
+
+//      CHECK:  scf.yield %[[CONTRACT_0_3]], %[[CONTRACT_1_3]],
+// CHECK-SAME:    %[[CONTRACT_2_3]], %[[CONTRACT_3_3]], %[[CONTRACT_4_3]],
+// CHECK-SAME:    %[[CONTRACT_5_3]], %[[CONTRACT_6_3]], %[[CONTRACT_7_3]]
+
+//  CHECK-DAG:  vector.transfer_write %[[FOR_RES]]#0, %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C0]], %[[C0]]]
+//  CHECK-DAG:  vector.transfer_write %[[FOR_RES]]#1, %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C1]], %[[C0]]]
+//  CHECK-DAG:  vector.transfer_write %[[FOR_RES]]#2, %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C2]], %[[C0]]]
+//  CHECK-DAG:  vector.transfer_write %[[FOR_RES]]#3, %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C3]], %[[C0]]]
+//  CHECK-DAG:  vector.transfer_write %[[FOR_RES]]#4, %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C4]], %[[C0]]]
+//  CHECK-DAG:  vector.transfer_write %[[FOR_RES]]#5, %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C5]], %[[C0]]]
+//  CHECK-DAG:  vector.transfer_write %[[FOR_RES]]#6, %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C6]], %[[C0]]]
+//  CHECK-DAG:  vector.transfer_write %[[FOR_RES]]#7, %[[SUBVIEW_RESULT_2]][%[[C0]], %[[C7]], %[[C0]]]
+
+// -----
+
+module attributes {
+  spv.target_env =
+    #spv.target_env<#spv.vce<v1.3,
+      [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess,
+       StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess,
+       UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform,
+       GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot,
+       GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers,
+       VariablePointersStorageBuffer],
+      [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage,
+       SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+      ARM:IntegratedGPU,
+      {max_compute_shared_memory_size = 32768 : i32,
+       max_compute_workgroup_invocations = 512 : i32,
+       max_compute_workgroup_size = dense<512> : vector<3xi32>,
+       subgroup_size = 16 : i32}>} {
+  func @batch_matmul_fused_fillop()
+    attributes {vkspv.num_workgroups_fn = @batch_matmul_fused_fillop__num_workgroups__} {
+    %cst = constant 0.000000e+00 : f32
+    %arg0 = iree.placeholder for "interface buffer"
+      {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4x1024x1024xf32>
+    %arg1 = iree.placeholder for "interface buffer"
+      {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4x1024x1024xf32>
+    %ret0 = iree.placeholder for "interface buffer"
+      {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4x1024x1024xf32>
+    linalg.fill(%ret0, %cst) : memref<4x1024x1024xf32>, f32
+    linalg.batch_matmul ins(%arg0, %arg1 : memref<4x1024x1024xf32>, memref<4x1024x1024xf32>) outs(%ret0 : memref<4x1024x1024xf32>)
+    return
+  }
+  func @batch_matmul_fused_fillop__num_workgroups__
+    (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>,
+     !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index)
+    attributes {sym_visibility = "private"}
+  hal.interface @legacy_io attributes {sym_visibility = "private"} {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+  }
+}
+
+//    CHECK-LABEL: func @batch_matmul_fused_fillop
+//  CHECK-COUNT-8:   vector.transfer_write
+//  CHECK-COUNT-8:   vector.transfer_read
+//          CHECK:   %[[FOR_RES:.+]]:8 = scf.for
+// CHECK-COUNT-12:     vector.transfer_read
+// CHECK-COUNT-32:     vector.contract
+//      CHECK:         scf.yield
+//  CHECK-COUNT-8:    vector.transfer_write %[[FOR_RES]]
+//          CHECK:    return
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
index 593a693..75cd37d 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
@@ -1,38 +1,4 @@
-// RUN: iree-opt -split-input-file -iree-codegen-linalg-tile-and-fuse %s | IreeFileCheck %s
-
-// Test to check that convolution with padding is not tiled.
-module attributes {
-  spv.target_env =
-    #spv.target_env<#spv.vce<v1.3,
-    [Shader], [SPV_KHR_storage_buffer_storage_class]>,
-    {max_compute_workgroup_invocations = 128 : i32,
-     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
-  func @conv_padding() {
-    %0 = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg0} : memref<?x?x?x?xf32>
-    %1 = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg1} : memref<?x?x?x?xf32>
-    %2 = iree.placeholder for "interace buffer" {binding = @legacy_io::@ret0} : memref<?x?x?x?xf32>
-    linalg.conv(%0, %1, %2)
-      {dilations = [1, 1],
-       padding = dense<[[1, 1], [0, 1]]> : tensor<2x2xi64>, strides = [1, 1]} :
-      memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
-    return
-  }
-  hal.interface @legacy_io attributes {sym_visibility = "private"} {
-    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
-    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
-  }
-}
-//       CHECK: func @conv_padding()
-//   CHECK-DAG:   %[[ARG0:.+]] = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg0}
-//   CHECK-DAG:   %[[ARG1:.+]] = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg1}
-//   CHECK-DAG:   %[[RET0:.+]] = iree.placeholder for "interace buffer" {binding = @legacy_io::@ret0}
-//       CHECK:   linalg.conv
-//  CHECK-SAME:     %[[ARG0]]
-//  CHECK-SAME:     %[[ARG1]]
-//  CHECK-SAME:     %[[RET0]]
-
-// -----
+// RUN: iree-opt -split-input-file -iree-codegen-linalg-tile-and-fuse -iree-spirv-enable-vectorization -canonicalize -cse %s | IreeFileCheck %s
 
 module attributes {
   spv.target_env =
@@ -63,26 +29,24 @@
   }
 }
 //   CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 4)>
-//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 32)>
+//   CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 32)>
 //       CHECK: func @conv_no_padding()
 //  CHECK-SAME:   hal.num_workgroups_fn = @[[NUM_WORKGROUPS_FN:[a-zA-Z0-9_]+]]
 //  CHECK-SAME:   local_size = dense<[32, 4, 1]>
-//   CHECK-DAG:   %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
-//   CHECK-DAG:   %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
-//   CHECK-DAG:   %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
+//   CHECK-DAG:   %[[ARG0:.+]] = iree.placeholder {{.*}} @legacy_io::@arg0
+//   CHECK-DAG:   %[[ARG1:.+]] = iree.placeholder {{.*}} @legacy_io::@arg1
+//   CHECK-DAG:   %[[RET0:.+]] = iree.placeholder {{.*}} @legacy_io::@ret0
 //   CHECK-DAG:   %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
 //   CHECK-DAG:   %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
 //   CHECK-DAG:   %[[BIDZ:.+]] = "gpu.block_id"() {dimension = "z"}
 //       CHECK:   %[[LBY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-//       CHECK:   %[[LBX:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]]
-//       CHECK:   %[[VIEW1:.+]] = subview %[[ARG1]]
+//       CHECK:   %[[LBX:.+]] = affine.apply #[[MAP2]]()[%[[BIDX]]]
+//       CHECK:   %[[SV_ARG1:.+]] = subview %[[ARG1]]
 //  CHECK-SAME:     [%[[BIDZ]], %[[LBY]], %[[LBX]], 0]
-//       CHECK:   %[[LBY_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-//       CHECK:   %[[LBX_2:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]]
-//       CHECK:   %[[VIEW2:.+]] = subview %[[RET0]]
-//  CHECK-SAME:     [%[[BIDZ]], %[[LBY_2]], %[[LBX_2]], 0]
+//       CHECK:   %[[SV_RET0:.+]] = subview %[[RET0]]
+//  CHECK-SAME:     [%[[BIDZ]], %[[LBY]], %[[LBX]], 0]
 //       CHECK:   linalg.conv
-//  CHECK-SAME:     %[[ARG0]], %[[VIEW1]], %[[VIEW2]]
+//  CHECK-SAME:     %[[ARG0]], %[[SV_ARG1]], %[[SV_RET0]]
 //  CHECK-SAME:     "workgroup"
 //       CHECK: func private @[[NUM_WORKGROUPS_FN]]
 //   CHECK-DAG:   %[[C0:.+]] = constant 0
@@ -98,10 +62,10 @@
 //   CHECK-DAG:   %[[DIM0:.+]] = dim %[[ARG1]], %[[C0]]
 //   CHECK-DAG:   %[[DIM1:.+]] = dim %[[RET0]], %[[C1]]
 //   CHECK-DAG:   %[[DIM2:.+]] = dim %[[RET0]], %[[C2]]
-//       CHECK:   %[[T0:.+]] = addi %[[DIM2]], %[[C31]]
-//   CHECK-DAG:   %[[NBX:.+]] = divi_signed %[[T0]], %[[C32]]
 //       CHECK:   %[[T1:.+]] = addi %[[DIM1]], %[[C3]]
 //   CHECK-DAG:   %[[NBY:.+]] = divi_signed %[[T1]], %[[C4]]
+//       CHECK:   %[[T0:.+]] = addi %[[DIM2]], %[[C31]]
+//   CHECK-DAG:   %[[NBX:.+]] = divi_signed %[[T0]], %[[C32]]
 //       CHECK:   return %[[NBX]], %[[NBY]], %[[DIM0]]
 
 
@@ -140,24 +104,24 @@
 //       CHECK: func @matmul()
 //  CHECK-SAME:   hal.num_workgroups_fn = @[[NUM_WORKGROUPS_FN:[a-zA-Z0-9_]+]]
 //  CHECK-SAME:   local_size = dense<[16, 8, 1]>
-//   CHECK-DAG:   %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
-//   CHECK-DAG:   %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
-//   CHECK-DAG:   %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
+//   CHECK-DAG:   %[[ARG0:.+]] = iree.placeholder {{.*}} @legacy_io::@arg0
+//   CHECK-DAG:   %[[ARG1:.+]] = iree.placeholder {{.*}} @legacy_io::@arg1
+//   CHECK-DAG:   %[[RET0:.+]] = iree.placeholder {{.*}} @legacy_io::@ret0
 //   CHECK-DAG:   %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
 //   CHECK-DAG:   %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
 //   CHECK-NOT:   scf.parallel
 //   CHECK-NOT:   scf.for
 //       CHECK:   %[[LBY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-//       CHECK:   %[[VIEW0:.+]] = subview %[[ARG0]][%[[LBY]], 0]
+//       CHECK:   %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[LBY]], 0]
 //       CHECK:   %[[LBX:.+]] = affine.apply #[[MAP3]]()[%[[BIDX]]]
-//       CHECK:   %[[VIEW1:.+]] = subview %[[ARG1]][0, %[[LBX]]]
-//       CHECK:   %[[LBY_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-//       CHECK:   %[[LBX_2:.+]] = affine.apply #[[MAP3]]()[%[[BIDX]]]
-//       CHECK:   %[[VIEW2:.+]] = subview %[[RET0]][%[[LBY_2]], %[[LBX_2]]]
+//       CHECK:   %[[SV_ARG1:.+]] = subview %[[ARG1]][0, %[[LBX]]]
+//       CHECK:   %[[SV_RET0:.+]] = subview %[[RET0]][%[[LBY]], %[[LBX]]]
 //       CHECK:   linalg.matmul
 //  CHECK-SAME:     "workgroup"
-//  CHECK-SAME:     ins(%[[VIEW0]], %[[VIEW1]]
-//  CHECK-SAME:     outs(%[[VIEW2]]
+//  CHECK-SAME:     ins(%[[SV_ARG0]], %[[SV_ARG1]]
+//  CHECK-SAME:       )
+//  CHECK-SAME:     outs(%[[SV_RET0]]
+//  CHECK-SAME:       )
 //       CHECK: func private @[[NUM_WORKGROUPS_FN]]
 //   CHECK-DAG:   %[[C8:.+]] = constant 8 : index
 //   CHECK-DAG:   %[[C7:.+]] = constant 7 : index
@@ -165,12 +129,14 @@
 //   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
 //   CHECK-DAG:   %[[C16:.+]] = constant 16 : index
 //   CHECK-DAG:   %[[C15:.+]] = constant 15 : index
-//       CHECK:   %[[DIM0:.+]] = dim %{{.*}}, %[[C0]]
-//       CHECK:   %[[DIM1:.+]] = dim %{{.*}}, %[[C1]]
-//       CHECK:   %[[T0:.+]] = addi %[[DIM1]], %[[C15]]
-//       CHECK:   %[[T1:.+]] = divi_signed %[[T0]], %[[C16]]
+//   CHECK-DAG:   %[[ARG0:.+]] = iree.placeholder {{.+}} {binding = @legacy_io::@arg0
+//   CHECK-DAG:   %[[ARG1:.+]] = iree.placeholder {{.+}} {binding = @legacy_io::@arg1
+//       CHECK:   %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]]
+//       CHECK:   %[[DIM1:.+]] = dim %[[ARG1]], %[[C1]]
 //       CHECK:   %[[T2:.+]] = addi %[[DIM0]], %[[C7]]
 //       CHECK:   %[[T3:.+]] = divi_signed %[[T2]], %[[C8]]
+//       CHECK:   %[[T0:.+]] = addi %[[DIM1]], %[[C15]]
+//       CHECK:   %[[T1:.+]] = divi_signed %[[T0]], %[[C16]]
 //       CHECK:   return %[[T1]], %[[T3]], %[[C1]]
 
 // -----
@@ -215,12 +181,10 @@
 //   CHECK-DAG:   %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
 //       CHECK:   %[[LBY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
 //       CHECK:   %[[LBX:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]]
-//       CHECK:   %[[VIEW0:.+]] = subview %[[ARG0]][%[[LBY]], %[[LBX]]]
-//       CHECK:   %[[LBY2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-//       CHECK:   %[[LBX2:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]]
-//       CHECK:   %[[VIEW2:.+]] = subview %[[RET0]][%[[LBY2]], %[[LBX2]]]
+//       CHECK:   %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[LBY]], %[[LBX]]]
+//       CHECK:   %[[SV_RET0:.+]] = subview %[[RET0]][%[[LBY]], %[[LBX]]]
 //       CHECK:   linalg.pooling_max
-//  CHECK-SAME:     %[[VIEW0]], %[[ARG1]], %[[VIEW2]]
+//  CHECK-SAME:     %[[SV_ARG0]], %[[ARG1]], %[[SV_RET0]]
 //  CHECK-SAME:     "workgroup"
 //       CHECK: func private @[[NUM_WORKGROUPS_FN]]
 //   CHECK-DAG:   %[[C0:.+]] = constant 0
@@ -232,10 +196,10 @@
 //   CHECK-DAG:   %[[RET0:.+]] = iree.placeholder {{.+}} {binding = @legacy_io::@ret0
 //   CHECK-DAG:   %[[DIM0:.+]] = dim %[[RET0]], %[[C0]]
 //   CHECK-DAG:   %[[DIM1:.+]] = dim %[[RET0]], %[[C1]]
-//       CHECK:   %[[T0:.+]] = addi %[[DIM1]], %[[C31]]
-//   CHECK-DAG:   %[[NBX:.+]] = divi_signed %[[T0]], %[[C32]]
 //       CHECK:   %[[T1:.+]] = addi %[[DIM0]], %[[C3]]
 //   CHECK-DAG:   %[[NBY:.+]] = divi_signed %[[T1]], %[[C4]]
+//       CHECK:   %[[T0:.+]] = addi %[[DIM1]], %[[C31]]
+//   CHECK-DAG:   %[[NBX:.+]] = divi_signed %[[T0]], %[[C32]]
 //       CHECK:   return %[[NBX]], %[[NBY]], %[[C1]]
 
 // -----
@@ -281,12 +245,10 @@
 //   CHECK-DAG:   %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
 //       CHECK:   %[[LBY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
 //       CHECK:   %[[LBX:.+]] = affine.apply #[[MAP2]]()[%[[BIDX]]]
-//       CHECK:   %[[VIEW0:.+]] = subview %[[ARG0]][0, %[[LBY]], %[[LBX]], 0]
-//       CHECK:   %[[LBY2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-//       CHECK:   %[[LBX2:.+]] = affine.apply #[[MAP2]]()[%[[BIDX]]]
-//       CHECK:   %[[VIEW2:.+]] = subview %[[RET0]][0, %[[LBY2]], %[[LBX2]], 0]
+//       CHECK:   %[[SV_ARG0:.+]] = subview %[[ARG0]][0, %[[LBY]], %[[LBX]], 0]
+//       CHECK:   %[[SV_RET0:.+]] = subview %[[RET0]][0, %[[LBY]], %[[LBX]], 0]
 //       CHECK:   linalg.pooling_max
-//  CHECK-SAME:     %[[VIEW0]], %[[ARG1]], %[[VIEW2]]
+//  CHECK-SAME:     %[[SV_ARG0]], %[[ARG1]], %[[SV_RET0]]
 //  CHECK-SAME:     "workgroup"
 //       CHECK: func private @[[NUM_WORKGROUPS_FN]]
 //   CHECK-DAG:   %[[C1:.+]] = constant 1
@@ -297,14 +259,15 @@
 //   CHECK-DAG:   %[[RET0:.+]] = iree.placeholder {{.+}} {binding = @legacy_io::@ret0
 //   CHECK-DAG:   %[[DIM0:.+]] = dim %[[RET0]], %[[C1]]
 //   CHECK-DAG:   %[[DIM1:.+]] = dim %[[RET0]], %[[C2]]
-//       CHECK:   %[[T0:.+]] = addi %[[DIM1]], %[[C31]]
-//   CHECK-DAG:   %[[NBX:.+]] = divi_signed %[[T0]], %[[C32]]
 //       CHECK:   %[[T1:.+]] = addi %[[DIM0]], %[[C3]]
 //   CHECK-DAG:   %[[NBY:.+]] = divi_signed %[[T1]], %[[C4]]
+//       CHECK:   %[[T0:.+]] = addi %[[DIM1]], %[[C31]]
+//   CHECK-DAG:   %[[NBX:.+]] = divi_signed %[[T0]], %[[C32]]
 //       CHECK:   return %[[NBX]], %[[NBY]], %[[C1]]
 
 // -----
 
+
 module attributes {
   spv.target_env =
     #spv.target_env<#spv.vce<v1.3,
@@ -324,10 +287,9 @@
       outs(%2 : memref<?x?xf32>)
     return
   }
-  func @matmul_fusion__num_workgroups__
+  func private @matmul_fusion__num_workgroups__
     (!shapex.ranked_shape<[?,?]>, !shapex.ranked_shape<[?,?]>,
      !shapex.ranked_shape<[?,?]>) -> (index, index, index)
-    attributes {sym_visibility = "private"}
   hal.interface @legacy_io attributes {sym_visibility = "private"} {
     hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
     hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
@@ -338,28 +300,41 @@
 //   CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 8)>
 //   CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 16)>
 //       CHECK: func @matmul_fusion()
+//  CHECK-SAME:   hal.num_workgroups_fn = @[[NWGFN:[a-zA-Z0-9_]+]]
 //  CHECK-SAME:   local_size = dense<[16, 8, 1]>
-//   CHECK-DAG:   %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
-//   CHECK-DAG:   %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
-//   CHECK-DAG:   %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
+//   CHECK-DAG:   %[[ARG0:.+]] = iree.placeholder {{.+}} @legacy_io::@arg0
+//   CHECK-DAG:   %[[ARG1:.+]] = iree.placeholder {{.+}} @legacy_io::@arg1
+//   CHECK-DAG:   %[[RET0:.+]] = iree.placeholder {{.*}} @legacy_io::@ret0
 //   CHECK-DAG:   %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
 //   CHECK-DAG:   %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
 //   CHECK-NOT:   scf.parallel
 //   CHECK-NOT:   scf.for
 //       CHECK:   %[[LBY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-//       CHECK:   %[[VIEW0:.+]] = subview %[[ARG0]][%[[LBY]], 0]
+//       CHECK:   %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[LBY]], 0]
 //       CHECK:   %[[LBX:.+]] = affine.apply #[[MAP3]]()[%[[BIDX]]]
-//       CHECK:   %[[VIEW1:.+]] = subview %[[ARG1]][0, %[[LBX]]]
-//       CHECK:   %[[LBY_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-//       CHECK:   %[[LBX_2:.+]] = affine.apply #[[MAP3]]()[%[[BIDX]]]
-//       CHECK:   %[[VIEW2:.+]] = subview %[[RET0]][%[[LBY_2]], %[[LBX_2]]]
-//       CHECK:   %[[VIEW3:.+]] = subview %[[RET0]][%[[LBY]], %[[LBX]]]
-//       CHECK:   linalg.fill(%[[VIEW3]], %{{.+}})
+//       CHECK:   %[[SV_ARG1:.+]] = subview %[[ARG1]][0, %[[LBX]]]
+//       CHECK:   %[[SV_RET0_1:.+]] = subview %[[RET0]][%[[LBY]], %[[LBX]]]
+//       CHECK:   %[[SV_RET0_2:.+]] = subview %[[RET0]][%[[LBY]], %[[LBX]]]
+//       CHECK:   linalg.fill(%[[SV_RET0_2]], %{{.+}})
 //  CHECK-SAME:     "workgroup"
 //       CHECK:   linalg.matmul
 //  CHECK-SAME:     "workgroup"
-//  CHECK-SAME:     ins(%[[VIEW0]], %[[VIEW1]]
-//  CHECK-SAME:     outs(%[[VIEW2]]
+//  CHECK-SAME:     ins(%[[SV_ARG0]], %[[SV_ARG1]]
+//  CHECK-SAME:     outs(%[[SV_RET0_1]]
+
+//       CHECK: func private @[[NWGFN]]
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:   %[[ARG0:.+]] = iree.placeholder {{.+}} @legacy_io::@arg0
+//   CHECK-DAG:   %[[ARG1:.+]] = iree.placeholder {{.+}} @legacy_io::@arg1
+//   CHECK-DAG:   %[[RET0:.+]] = iree.placeholder {{.*}} @legacy_io::@ret0
+//       CHECK:   %[[M:.+]] = dim %[[ARG0]], %[[C0]]
+//       CHECK:   %[[N:.+]] = dim %[[ARG1]], %[[C1]]
+//       CHECK:   %[[WGY_N:.+]] = addi %[[M]], %{{.+}}
+//       CHECK:   %[[WGY:.+]] = divi_signed %[[WGY_N]], %{{.+}}
+//       CHECK:   %[[WGX_N:.+]] = addi %[[N]], %{{.+}}
+//       CHECK:   %[[WGX:.+]] = divi_signed %[[WGX_N]], %{{.+}}
+//       CHECK:   return %[[WGX]], %[[WGY]], %[[C1]]
 
 // -----
 
@@ -369,7 +344,9 @@
     [Shader], [SPV_KHR_storage_buffer_storage_class]>,
     {max_compute_workgroup_invocations = 128 : i32,
      max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
-  func @conv_no_padding_fusion() {
+  func @conv_no_padding_fusion()
+    attributes {
+      hal.num_workgroups_fn = @conv_no_padding_fusion__num_workgroups__} {
     %0 = iree.placeholder for "interace buffer"
       {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<?x?x?x?xf32>
     %1 = iree.placeholder for "interace buffer"
@@ -382,6 +359,9 @@
       memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
     return
   }
+  func private @conv_no_padding_fusion__num_workgroups__
+    (!shapex.ranked_shape<[?,?,?,?]>, !shapex.ranked_shape<[?,?,?,?]>,
+     !shapex.ranked_shape<[?,?,?,?]>) -> (index, index, index)
   hal.interface @legacy_io attributes {sym_visibility = "private"} {
     hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
     hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
@@ -391,25 +371,185 @@
 //   CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 4)>
 //   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 32)>
 //       CHECK: func @conv_no_padding_fusion()
+//  CHECK-SAME:   hal.num_workgroups_fn = @[[NWGFN:[a-zA-Z0-9_]+]]
 //  CHECK-SAME:   local_size = dense<[32, 4, 1]>
-//   CHECK-DAG:   %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
-//   CHECK-DAG:   %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
-//   CHECK-DAG:   %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
+//   CHECK-DAG:   %[[ARG0:.+]] = iree.placeholder {{.*}} @legacy_io::@arg0
+//   CHECK-DAG:   %[[ARG1:.+]] = iree.placeholder {{.*}} @legacy_io::@arg1
+//   CHECK-DAG:   %[[RET0:.+]] = iree.placeholder {{.*}} @legacy_io::@ret0
 //   CHECK-DAG:   %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
 //   CHECK-DAG:   %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
 //   CHECK-DAG:   %[[BIDZ:.+]] = "gpu.block_id"() {dimension = "z"}
 //       CHECK:   %[[LBY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
 //       CHECK:   %[[LBX:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]]
-//       CHECK:   %[[VIEW1:.+]] = subview %[[ARG1]]
+//       CHECK:   %[[SV_ARG1:.+]] = subview %[[ARG1]]
 //  CHECK-SAME:     [%[[BIDZ]], %[[LBY]], %[[LBX]], 0]
-//       CHECK:   %[[LBY_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-//       CHECK:   %[[LBX_2:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]]
-//       CHECK:   %[[VIEW2:.+]] = subview %[[RET0]]
-//  CHECK-SAME:     [%[[BIDZ]], %[[LBY_2]], %[[LBX_2]], 0]
-//       CHECK:   %[[VIEW3:.+]] = subview %[[RET0]]
-//  CHECK-SAME:     [%[[BIDZ]], %[[LBY_2]], %[[LBX_2]], 0]
-//       CHECK:   linalg.fill(%[[VIEW3]], %{{.*}})
+//       CHECK:   %[[SV_RET0:.+]] = subview %[[RET0]]
+//  CHECK-SAME:     [%[[BIDZ]], %[[LBY]], %[[LBX]], 0]
+//       CHECK:   linalg.fill(%[[SV_RET0]], %{{.*}})
 //  CHECK-SAME:     "workgroup"
 //       CHECK:   linalg.conv
-//  CHECK-SAME:     %[[ARG0]], %[[VIEW1]], %[[VIEW2]]
+//  CHECK-SAME:     %[[ARG0]], %[[SV_ARG1]], %[[SV_RET0]]
 //  CHECK-SAME:     "workgroup"
+
+//       CHECK:   func private @[[NWGFN]]
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:   %[[C2:.+]] = constant 2 : index
+//   CHECK-DAG:   %[[ARG0:.+]] = iree.placeholder {{.*}} @legacy_io::@arg0
+//   CHECK-DAG:   %[[ARG1:.+]] = iree.placeholder {{.*}} @legacy_io::@arg1
+//   CHECK-DAG:   %[[RET0:.+]] = iree.placeholder {{.*}} @legacy_io::@ret0
+//       CHECK:   %[[N:.+]] = dim %[[ARG1]], %[[C0]]
+//       CHECK:   %[[R:.+]] = dim %[[RET0]], %[[C1]]
+//       CHECK:   %[[S:.+]] = dim %[[RET0]], %[[C2]]
+//       CHECK:   %[[WGY_N:.+]] = addi %[[R]], %{{.+}}
+//       CHECK:   %[[WGY:.+]] = divi_signed %[[WGY_N]], %{{.+}}
+//       CHECK:   %[[WGX_N:.+]] = addi %[[S]], %{{.+}}
+//       CHECK:   %[[WGX:.+]] = divi_signed %[[WGX_N]], %{{.+}}
+//       CHECK:   return %[[WGX]], %[[WGY]], %[[N]]
+
+// -----
+
+module attributes {
+  spv.target_env =
+    #spv.target_env<#spv.vce<v1.3,
+    [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+  func @three_op_fusion()
+    attributes {
+      hal.num_workgroups_fn = @three_op_fusion__num_workgroups__} {
+    %cst = constant 0.000000e+00 : f32
+    %c0 = constant 0 : index
+    %c1 = constant 1 : index
+    %0 = iree.placeholder for "interface buffer"
+      {binding = @legacy_io::@arg0, operand_result_index = 0 : i32}
+      : memref<?x?xf32>
+    %1 = iree.placeholder for "interface buffer"
+      {binding = @legacy_io::@arg1, operand_result_index = 1 : i32}
+      : memref<?x?xf32>
+    %d0 = dim %0, %c0 : memref<?x?xf32>
+    %d1 = dim %1, %c1 : memref<?x?xf32>
+    %2 = alloc(%d0, %d1) : memref<?x?xf32>
+    %3 =  iree.placeholder for "interface buffer"
+      {binding = @legacy_io::@arg2, operand_result_index = 2 : i32}
+      : memref<?xf32>
+    %4 =  iree.placeholder for "interface buffer"
+      {binding = @legacy_io::@ret0, operand_result_index = 3 : i32}
+      : memref<?x?xf32>
+    linalg.fill(%2, %cst) : memref<?x?xf32>, f32
+    linalg.matmul ins(%0, %1 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%2 : memref<?x?xf32>)
+    linalg.generic
+      {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                        affine_map<(d0, d1) -> (d1)>,
+                        affine_map<(d0, d1) -> (d0, d1)>],
+       iterator_types = ["parallel", "parallel"]}
+      ins(%2, %3 : memref<?x?xf32>, memref<?xf32>)
+      outs(%4 : memref<?x?xf32>) {
+      ^bb0(%arg0 : f32, %arg1 : f32, %arg2 : f32) :
+        %5 = addf %arg0, %arg1 : f32
+        linalg.yield %5 : f32
+      }
+    return
+  }
+  func private @three_op_fusion__num_workgroups__
+    (!shapex.ranked_shape<[?,?]>, !shapex.ranked_shape<[?,?]>,
+     !shapex.ranked_shape<[?]>, !shapex.ranked_shape<[?,?]>)
+    -> (index, index, index)
+  hal.interface @legacy_io attributes {sym_visibility = "private"} {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg2, set=0, binding=2, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=3, type="StorageBuffer", access="Write"
+  }
+}
+
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 8)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (8, s1 - s0 * 8)>
+//   CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 16)>
+//   CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0, s1] -> (16, s1 - s0 * 16)>
+//       CHECK: func @three_op_fusion
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:   %[[ALLOC:.+]] = alloc() : memref<8x16xf32, 3>
+//   CHECK-DAG:   %[[ARG0:.+]] = iree.placeholder {{.*}} @legacy_io::@arg0
+//   CHECK-DAG:   %[[ARG1:.+]] = iree.placeholder {{.*}} @legacy_io::@arg1
+//   CHECK-DAG:   %[[M:.+]] = dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[N:.+]] = dim %[[ARG1]], %[[C1]]
+//   CHECK-DAG:   %[[ARG2:.+]] = iree.placeholder {{.*}} @legacy_io::@arg2
+//   CHECK-DAG:   %[[RET0:.+]] = iree.placeholder {{.*}} @legacy_io::@ret0
+//   CHECK-DAG:   %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
+//   CHECK-DAG:   %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
+//   CHECK-NOT:   scf.parallel
+//   CHECK-NOT:   scf.for
+//       CHECK:   %[[LBY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+//       CHECK:   %[[TILE_M:.+]] = affine.min #[[MAP1]]()[%[[BIDY]], %[[M]]]
+//       CHECK:   %[[LBX:.+]] = affine.apply #[[MAP2]]()[%[[BIDX]]]
+//       CHECK:   %[[TILE_N:.+]] = affine.min #[[MAP3]]()[%[[BIDX]], %[[N]]]
+//       CHECK:   %[[N_2:.+]] = dim %[[ARG2]], %[[C0]]
+//       CHECK:   %[[TILE_N_2:.+]] = affine.min #[[MAP3]]()[%[[BIDX]], %[[N_2]]]
+//       CHECK:   %[[SV_ARG2:.+]] = subview %[[ARG2]][%[[LBX]]] [%[[TILE_N_2]]]
+//       CHECK:   %[[M_2:.+]] = dim %[[RET0]], %[[C0]]
+//       CHECK:   %[[TILE_M_2:.+]] = affine.min #[[MAP1]]()[%[[BIDY]], %[[M_2]]]
+//       CHECK:   %[[N_3:.+]] = dim %[[RET0]], %[[C1]]
+//       CHECK:   %[[TILE_N_3:.+]] = affine.min #[[MAP3]]()[%[[BIDX]], %[[N_3]]]
+//       CHECK:   %[[SV_RET0:.+]] = subview %[[RET0]][%[[LBY]], %[[LBX]]
+//  CHECK-SAME:     [%[[TILE_M_2]], %[[TILE_N_3]]]
+//       CHECK:   %[[K:.+]] = dim %[[ARG0]], %[[C1]]
+//       CHECK:   %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[LBY]], 0]
+//  CHECK-SAME:     [%[[TILE_M]], %[[K]]]
+//       CHECK:   %[[SV_ARG1:.+]] = subview %[[ARG1]][0, %[[LBX]]]
+//  CHECK-SAME:     [%[[K]], %[[TILE_N]]]
+//       CHECK:   %[[SV_ALLOC:.+]] = subview %[[ALLOC]][0, 0]
+//  CHECK-SAME:     [%[[TILE_M]], %[[TILE_N]]]
+//       CHECK:   linalg.fill(%[[SV_ALLOC]], %{{.+}})
+//  CHECK-SAME:     "workgroup"
+//       CHECK:   linalg.matmul
+//  CHECK-SAME:     "workgroup"
+//  CHECK-SAME:     ins(%[[SV_ARG0]], %[[SV_ARG1]]
+//  CHECK-SAME:       )
+//  CHECK-SAME:     outs(%[[SV_ALLOC]]
+//  CHECK-SAME:       )
+//       CHECK:   linalg.generic
+//  CHECK-SAME:     ins(%[[SV_ALLOC]], %[[SV_ARG2]]
+//  CHECK-SAME:       )
+//  CHECK-SAME:     outs(%[[SV_RET0]]
+//  CHECK-SAME:       )
+
+// -----
+
+module attributes {
+  spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader, Float16, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, ARM:IntegratedGPU, {max_compute_shared_memory_size = 32768 : i32, max_compute_workgroup_invocations = 512 : i32, max_compute_workgroup_size = dense<512> : vector<3xi32>, subgroup_size = 16 : i32}>
+}  {
+  func @conv_tiled_and_vectorized() attributes {hal.num_workgroups_fn = @get_num_workgroups} {
+    %cst = constant 0.000000e+00 : f32
+    %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x112x112x32xf32>
+    %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x225x225x16xf32>
+    %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x16x32xf32>
+    linalg.fill(%0, %cst) : memref<1x112x112x32xf32>, f32
+    linalg.conv(%2, %1, %0) {dilations = [1, 1], strides = [2, 2]} : memref<3x3x16x32xf32>, memref<1x225x225x16xf32>, memref<1x112x112x32xf32>
+    return
+  }
+
+  func private @get_num_workgroups(!shapex.ranked_shape<[1,225,225,16]>, !shapex.ranked_shape<[3,3,16,32]>, !shapex.ranked_shape<[1,112,112,32]>) -> (index, index, index)
+
+  hal.interface @legacy_io attributes {sym_visibility = "private"} {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+  }
+}
+
+// CHECK-LABEL: func @conv_tiled_and_vectorized()
+
+// CHECK-COUNT-4: vector.transfer_read
+
+// check tiling loop along filter height/width and input channel
+// CHECK: scf.for %{{.*}} = %c0 to %c3 step %c1
+// CHECK:   scf.for %{{.*}} = %c0 to %c3 step %c1
+// CHECK:     scf.for %{{.*}} = %c0 to %c16 step %c4
+
+// CHECK-COUNT-16: vector.contract
+
+// CHECK-COUNT-3: scf.yield
+// CHECK-COUNT-4: vector.transfer_write
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir
index dc4f5c9..6c9a4e3 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir
@@ -46,4 +46,53 @@
 // CHECK-COUNT-32:   spv.FMul %{{.*}}, %{{.*}} : vector<4xf32>
 // CHECK-COUNT-8:   spv.Store "StorageBuffer" %{{.*}}, %{{.*}} : vector<4xf32>
 
+// -----
 
+module attributes {
+  spv.target_env =
+    #spv.target_env<#spv.vce<v1.3,
+      [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess,
+       StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess,
+       UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform,
+       GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot,
+       GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers,
+       VariablePointersStorageBuffer],
+      [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage,
+       SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+      ARM:IntegratedGPU,
+      {max_compute_shared_memory_size = 32768 : i32,
+       max_compute_workgroup_invocations = 512 : i32,
+       max_compute_workgroup_size = dense<512> : vector<3xi32>,
+       subgroup_size = 16 : i32}>} {
+  func @matmul_fill_fused()
+    attributes {vkspv.num_workgroups_fn = @matmul_fill_fused__num_workgroups__} {
+    %arg0 = iree.placeholder for "interface buffer"
+      {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4096x4096xf32>
+    %arg1 = iree.placeholder for "interface buffer"
+      {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4096x4096xf32>
+    %ret0 = iree.placeholder for "interface buffer"
+      {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4096x4096xf32>
+    %cst = constant 0.000000e+00 : f32
+    linalg.fill(%ret0, %cst) : memref<4096x4096xf32>, f32
+    linalg.matmul ins(%arg0, %arg1 : memref<4096x4096xf32>, memref<4096x4096xf32>)
+                 outs(%ret0 : memref<4096x4096xf32>)
+    return
+  }
+  func @matmul_fill_fused__num_workgroups__
+    (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>,
+     !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index)
+    attributes {sym_visibility = "private"}
+  hal.interface @legacy_io attributes {sym_visibility = "private"} {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+  }
+}
+
+//    CHECK-LABEL: spv.func @matmul_fill_fused
+//      CHECK-NOT:   spv.Store "StorageBuffer"
+//      CHECK-NOT:   spv.Load "StorageBuffer"
+//          CHECK:   spv.loop
+// CHECK-COUNT-12:   spv.Load "StorageBuffer" %{{.*}} : vector<4xf32>
+// CHECK-COUNT-32:   spv.FMul %{{.*}}, %{{.*}} : vector<4xf32>
+//  CHECK-COUNT-8:   spv.Store "StorageBuffer" %{{.*}}, %{{.*}} : vector<4xf32>
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
index 7f7ad16..52779a9 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
@@ -1,9 +1,11 @@
 // RUN: iree-opt -allow-unregistered-dialect -split-input-file -iree-codegen-split-dispatch-function -verify-diagnostics %s | IreeFileCheck %s
 
 module {
-  // CHECK: func @kernel_fusable_fill_conv_ops
-  // CHECK:   linalg.fill
-  // CHECK:   linalg.conv
+  //     CHECK: func @kernel_fusable_fill_conv_ops
+  //     CHECK:   linalg.fill
+  // CHECK-NOT:   return
+  //     CHECK:   linalg.conv
+  //     CHECK:   return
 
   func @kernel_fusable_fill_conv_ops()
   attributes {hal.num_workgroups_fn = @kernel_fusable_fill_conv_ops_num_workgroups__} {
@@ -20,10 +22,9 @@
     linalg.conv(%1, %ts1, %ts2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<?x2x2x512xf32>, memref<?x1x1x512xf32>
     return
   }
-  func @kernel_fill_conv_ops_num_workgroups__(!shapex.ranked_shape<[?,2,2,512]>,
-                                              !shapex.ranked_shape<[3,3,512,1]>,
-                                              !shapex.ranked_shape<[?,1,1,512]>)
-                                             -> (index, index, index)
+  func @kernel_fusable_fill_conv_ops_num_workgroups__
+    (!shapex.ranked_shape<[?,2,2,512]>, !shapex.ranked_shape<[3,3,512,1]>,
+     !shapex.ranked_shape<[?,1,1,512]>) -> (index, index, index)
   attributes {sym_visibility = "private"}
   hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} {
     hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
@@ -35,9 +36,11 @@
 // -----
 
 module {
-  // CHECK: func @kernel_fusable_fill_matmul_ops
-  // CHECK:   linalg.fill
-  // CHECK:   linalg.matmul
+  //     CHECK: func @kernel_fusable_fill_matmul_ops
+  //     CHECK:   linalg.fill
+  // CHECK-NOT:   return
+  //     CHECK:   linalg.matmul
+  //     CHECK:   return
 
   func @kernel_fusable_fill_matmul_ops()
   attributes {hal.num_workgroups_fn = @kernel_fusable_fill_matmul_ops_num_workgroups__} {
@@ -73,9 +76,11 @@
 // -----
 
 module {
-  // CHECK: func @kernel_fusable_pooling()
-  // CHECK:   linalg.fill
-  // CHECK:   linalg.pooling
+  //     CHECK: func @kernel_fusable_pooling()
+  //     CHECK:   linalg.fill
+  // CHECK-NOT:   return
+  //     CHECK:   linalg.pooling
+  //     CHECK:   return
   func @kernel_fusable_pooling() attributes {hal.num_workgroups_fn = @kernel_fusable_pooling__num_workgroups__} {
     %cst = constant 0.000000e+00 : f32
     %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x?xf32>
@@ -221,7 +226,6 @@
   }
 }
 
-
 // -----
 
 // Nothing to do if there is just one Linalg op.
@@ -246,6 +250,8 @@
   }
 }
 
+
+
 // -----
 
 // Do not split when Linalg and non-Linalg ops are interleaving each other.
@@ -270,7 +276,6 @@
 }
 
 // -----
-
 #map0 = affine_map<(d0, d1) -> (d0 * 12 + d1 + 53)>
 
 module {
@@ -417,3 +422,62 @@
 // CHECK-NEXT:   linalg.copy
 //  CHECK-NOT:   linalg
 //      CHECK:   return
+
+// -----
+
+module {
+  //     CHECK: func @kernel_fusable_fill_matmul_generic_ops
+  //     CHECK:   linalg.fill
+  // CHECK-NOT:   return
+  //     CHECK:   linalg.matmul
+  // CHECK-NOT:   return
+  //     CHECK:   linalg.generic
+  //     CHECK:   return
+
+  func @kernel_fusable_fill_matmul_generic_ops()
+  attributes {hal.num_workgroups_fn = @kernel_fusable_fill_matmul_generic_ops_num_workgroups__} {
+    %cst = constant 0.000000e+00 : f32
+    %dimM = hal.interface.load.constant offset = 0 : index
+    %dimN = hal.interface.load.constant offset = 1 : index
+    %shape1 = shapex.make_ranked_shape %dimM : (index) -> !shapex.ranked_shape<[?,512]>
+    %shape2 = shapex.make_ranked_shape %dimN : (index) -> !shapex.ranked_shape<[512,?]>
+    %shape3 = shapex.make_ranked_shape %dimM, %dimN : (index, index) -> !shapex.ranked_shape<[?,?]>
+    %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x512xf32>
+    %ts0 = shapex.tie_shape %0, %shape1 : memref<?x512xf32>, !shapex.ranked_shape<[?,512]>
+    %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<512x?xf32>
+    %ts1 = shapex.tie_shape %1, %shape2 : memref<512x?xf32>, !shapex.ranked_shape<[512, ?]>
+    %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg2} : memref<?x?xf32>
+    %ts2 = shapex.tie_shape %2, %shape3 : memref<?x?xf32>, !shapex.ranked_shape<[?, ?]>
+    %3 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x?xf32>
+    %ts3 = shapex.tie_shape %3, %shape3 : memref<?x?xf32>, !shapex.ranked_shape<[?,?]>
+    %4 = alloc(%dimM, %dimN) : memref<?x?xf32>
+    %ts4 = shapex.tie_shape %4, %shape3 : memref<?x?xf32>, !shapex.ranked_shape<[?,?]>
+    linalg.fill(%ts4, %cst) : memref<?x?xf32>, f32
+    linalg.matmul ins(%ts0, %ts1 : memref<?x512xf32>, memref<512x?xf32>)
+                  outs(%ts4 : memref<?x?xf32>)
+    linalg.generic
+      {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                        affine_map<(d0, d1) -> (d0, d1)>,
+                        affine_map<(d0, d1) -> (d0, d1)>],
+       iterator_types = ["parallel", "parallel"]}
+      ins(%ts2, %ts4 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%ts3 : memref<?x?xf32>) {
+      ^bb0(%arg0 : f32, %arg1 : f32, %arg2 : f32):
+        %5 = addf %arg0, %arg1 : f32
+        linalg.yield %5 : f32
+    }
+    return
+  }
+  func @kernel_fusable_matmul_ops_num_workgroups__(!shapex.ranked_shape<[?,512]>,
+                                                   !shapex.ranked_shape<[512,?]>,
+                                                   !shapex.ranked_shape<[?,?]>,
+                                                   !shapex.ranked_shape<[?,?]>)
+                                                  -> (index, index, index)
+  attributes {sym_visibility = "private"}
+  hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg2, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+  }
+}
diff --git a/iree/compiler/Conversion/LinalgToVector/BUILD b/iree/compiler/Conversion/LinalgToVector/BUILD
index 95676fb..53eaa8d 100644
--- a/iree/compiler/Conversion/LinalgToVector/BUILD
+++ b/iree/compiler/Conversion/LinalgToVector/BUILD
@@ -22,6 +22,7 @@
     name = "LinalgToVector",
     srcs = [
         "LoadStoreVectorization.cpp",
+        "VectorizeConv.cpp",
     ],
     hdrs = [
         "Passes.h",
@@ -32,6 +33,7 @@
         "//iree/compiler/Dialect/Shape/IR",
         "//iree/compiler/Dialect/Shape/Transforms",
         "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:DialectUtils",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LinalgOps",
         "@llvm-project//mlir:LinalgTransforms",
diff --git a/iree/compiler/Conversion/LinalgToVector/CMakeLists.txt b/iree/compiler/Conversion/LinalgToVector/CMakeLists.txt
index b292c28..1e56d8c 100644
--- a/iree/compiler/Conversion/LinalgToVector/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToVector/CMakeLists.txt
@@ -21,6 +21,7 @@
     "Passes.h"
   SRCS
     "LoadStoreVectorization.cpp"
+    "VectorizeConv.cpp"
   DEPS
     LLVMSupport
     MLIRIR
diff --git a/iree/compiler/Conversion/LinalgToVector/Passes.h b/iree/compiler/Conversion/LinalgToVector/Passes.h
index 98f07bd..4f7340b 100644
--- a/iree/compiler/Conversion/LinalgToVector/Passes.h
+++ b/iree/compiler/Conversion/LinalgToVector/Passes.h
@@ -23,6 +23,16 @@
 /// Creates a pass to vectorize Linalg operations.
 std::unique_ptr<Pass> createLoadStoreVectorizationPass();
 
+/// Creates a pass to vectorize a very specific form of linalg.conv ops.
+std::unique_ptr<Pass> createVectorizeLinalgConvPass();
+
+/// Populates `patterns` with a very specific pattern that vectorizes a
+/// linalg.conv op for a single thread. The linalg.conv should compute on
+/// static-sized subviews. To match, output shape must be 1x1xWoxCo, where Co
+/// Co is a multiple of 4, and filter shape must be 1x1x4xCo.
+void populateVectorizeLinalgConvPatterns(MLIRContext *context,
+                                         OwningRewritePatternList &patterns);
+
 }  // namespace iree_compiler
 }  // namespace mlir
 
diff --git a/iree/compiler/Conversion/LinalgToVector/VectorizeConv.cpp b/iree/compiler/Conversion/LinalgToVector/VectorizeConv.cpp
new file mode 100644
index 0000000..6d4bca5
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToVector/VectorizeConv.cpp
@@ -0,0 +1,234 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Conversion/LinalgToVector/Passes.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-vectorize-conv"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+/// Vectorizes linalg.conv for a single GPU invocation. Therefore, the
+/// linalg.conv op should have a very specific form; other patterns are
+/// expected to tile and distribute larger convolutions into this form for
+/// a single GPU invocation.
+///
+/// The linalg.conv op should follow:
+/// - Filter: HfWfCiCo format
+/// - Input : NHiWiCi format
+/// - Output: NHoWoCo format
+/// - For output:
+///   - N must be 1.
+///   - Ho must be 1.
+///   - Co must be a multiple of 4.
+/// - For filter:
+///   - Hf must be 1.
+///   - Hf must be 1.
+///   - Ci must be 4.
+/// - No dilation.
+/// - No padding.
+///
+/// Output channel is requried to be a multiple of 4 so that we can process
+/// them with load4/store4, which is native to GPUs. Similarly for the input
+/// channel size requirement.
+struct VectorizeLinalgConv : OpRewritePattern<linalg::ConvOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::ConvOp convOp,
+                                PatternRewriter &rewriter) const override {
+    LLVM_DEBUG(llvm::dbgs() << "inspecting " << convOp << "\n");
+
+    // This pattern does not handle convolutions with dilation.
+    if (auto dilations = convOp.dilations()) {
+      auto values = dilations->getAsValueRange<IntegerAttr>();
+      if (llvm::any_of(values, [](const APInt &value) {
+            return value.getSExtValue() != 1;
+          })) {
+        return failure();
+      }
+    }
+
+    auto filterViewOp = convOp.filter().getDefiningOp<SubViewOp>();
+    auto inputViewOp = convOp.input().getDefiningOp<SubViewOp>();
+    auto outputViewOp = convOp.output().getDefiningOp<SubViewOp>();
+    if (!filterViewOp || !inputViewOp || !outputViewOp) return failure();
+
+    // The filter/input/output view should have static sizes to vectorize.
+    if (!llvm::empty(filterViewOp.getDynamicSizes()) ||
+        !llvm::empty(inputViewOp.getDynamicSizes()) ||
+        !llvm::empty(outputViewOp.getDynamicSizes())) {
+      return failure();
+    }
+
+    // The output batch and height dimensions should be 1. If not, other
+    // patterns can generate parallel loops can distribute them.
+    if (outputViewOp.getStaticSize(0) != 1 ||
+        outputViewOp.getStaticSize(1) != 1) {
+      return failure();
+    }
+
+    // We addtionally expect the filter height/width dimensions are both 1 to
+    // simplify vectorization. Other patterns can generate loops to create 1x1
+    // filter subivews.
+    if (filterViewOp.getStaticSize(0) != 1 ||
+        filterViewOp.getStaticSize(1) != 1) {
+      return failure();
+    }
+
+    int64_t numInputChannels = filterViewOp.getStaticSize(2);
+    int64_t numOutputChannels = filterViewOp.getStaticSize(3);
+    if (numInputChannels != 4 || numOutputChannels % 4 != 0) return failure();
+
+    int64_t numOutputWidths = outputViewOp.getStaticSize(2);
+    int64_t widthStride = convOp.getStride(1);
+
+    // This invocation handles a batch of (numOutputWidths * numOutputChannels).
+    LLVM_DEBUG({
+      llvm::dbgs() << "# output width: " << numOutputWidths << "\n";
+      llvm::dbgs() << "# output channels: " << numOutputChannels << "\n";
+      llvm::dbgs() << "width stride: " << widthStride << "\n";
+    });
+
+    MLIRContext *context = convOp.getContext();
+    Location loc = convOp.getLoc();
+
+    Type elementType = filterViewOp.getType().getElementType();
+    auto filterVectorType =
+        VectorType::get({numInputChannels, numOutputChannels}, elementType);
+    auto vector1x4Type = VectorType::get({1, 4}, elementType);
+    Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
+
+    // Load the entire filter subview.
+    SmallVector<Value, 4> filterIndices(4, zero);
+    Value wholeFilter = rewriter.create<vector::TransferReadOp>(
+        loc, filterVectorType, filterViewOp, filterIndices);
+
+    // Get filter slices so that later we can use them for dot product with the
+    // input. Both the height and width dimensions are 1; so we just need to
+    // loop over input and output channel dimensions.
+    SmallVector<SmallVector<Value, 1>, 4> filterVectors(numInputChannels);
+    for (int ic = 0; ic < numInputChannels; ++ic) {
+      auto &thisInputChannel = filterVectors[ic];
+      thisInputChannel.reserve(numOutputChannels / 4);
+      for (int oc = 0; oc < numOutputChannels / 4; ++oc) {
+        Value slice = rewriter.create<vector::ExtractStridedSliceOp>(
+            loc, wholeFilter, /*offsets=*/ArrayRef<int64_t>({ic, oc * 4}),
+            /*sizes=*/ArrayRef<int64_t>({1, 4}),
+            /*strides=*/ArrayRef<int64_t>({1, 1}));
+        thisInputChannel.push_back(slice);
+      }
+    }
+
+    // Build indexing maps for a later vector contraction op.
+    AffineExpr dim0 = getAffineDimExpr(0, context);  // M
+    AffineExpr dim1 = getAffineDimExpr(1, context);  // N
+    AffineExpr dim2 = getAffineDimExpr(2, context);  // K
+    auto map02 = AffineMap::get(3, 0, {dim0, dim2}, context);
+    auto map21 = AffineMap::get(3, 0, {dim2, dim1}, context);
+    auto map01 = AffineMap::get(3, 0, {dim0, dim1}, context);
+    ArrayAttr indexingMaps =
+        rewriter.getAffineMapArrayAttr({map02, map21, map01});
+
+    // Also build iterator types for the vector contraction op.
+    ArrayAttr iterators = rewriter.getStrArrayAttr(
+        {getParallelIteratorTypeName(), getParallelIteratorTypeName(),
+         getReductionIteratorTypeName()});
+
+    // Compute the (numOutputWidths * numOutputChannels) batch. We only
+    // contribute numInputChannels accumulation along the reduction dimension.
+    // So read in the result from the output, compose a chain of
+    // numInputChannels vector dot operations, and then write out.
+    for (int ow = 0; ow < numOutputWidths; ++ow) {
+      // Read in the input vector for these 4 input channels a a batch. The
+      // input vector are used for computing all output channels so data can
+      // be reused.
+      SmallVector<Value, 4> inputIndices(4, zero);
+      inputIndices[2] = rewriter.create<ConstantIndexOp>(loc, ow * widthStride);
+      Value inputVector = rewriter.create<vector::TransferReadOp>(
+          loc, vector1x4Type, inputViewOp, inputIndices);
+
+      for (int oc = 0; oc < numOutputChannels / 4; ++oc) {
+        // Read in the initial value for this output vector.
+        SmallVector<Value, 4> outputIndices(4, zero);
+        outputIndices[2] = rewriter.create<ConstantIndexOp>(loc, ow);
+        outputIndices[3] = rewriter.create<ConstantIndexOp>(loc, oc * 4);
+        Value outputVector = rewriter.create<vector::TransferReadOp>(
+            loc, vector1x4Type, outputViewOp, outputIndices);
+
+        // Peform a chain of dot product and accumulation.
+        for (int i = 0; i < numInputChannels; ++i) {
+          auto inputSlice = rewriter.create<vector::ExtractStridedSliceOp>(
+              loc, inputVector, /*offsets=*/ArrayRef<int64_t>({0, i}),
+              /*sizes=*/ArrayRef<int64_t>({1, 1}),
+              /*strides=*/ArrayRef<int64_t>({1, 1}));
+          outputVector = rewriter.create<vector::ContractionOp>(
+              loc, inputSlice, filterVectors[i][oc], outputVector, indexingMaps,
+              iterators);
+        }
+
+        // Write out the output vector.
+        rewriter.create<vector::TransferWriteOp>(loc, outputVector,
+                                                 outputViewOp, outputIndices);
+      }
+    }
+
+    rewriter.eraseOp(convOp);
+    return success();
+  }
+};
+
+struct VectorizeLinalgConvPass
+    : public PassWrapper<VectorizeLinalgConvPass, OperationPass<FuncOp>> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<linalg::LinalgDialect, vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    OwningRewritePatternList patterns;
+    patterns.insert<VectorizeLinalgConv>(context);
+    applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
+}  // namespace
+
+void populateVectorizeLinalgConvPatterns(MLIRContext *context,
+                                         OwningRewritePatternList &patterns) {
+  patterns.insert<VectorizeLinalgConv>(context);
+}
+
+std::unique_ptr<Pass> createVectorizeLinalgConvPass() {
+  return std::make_unique<VectorizeLinalgConvPass>();
+}
+
+static PassRegistration<VectorizeLinalgConvPass> pass(
+    "iree-codegen-vectorize-linalg-conv",
+    "Vectorize a very specific form of linalg.conv");
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_conv.mlir b/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_conv.mlir
new file mode 100644
index 0000000..194607a
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_conv.mlir
@@ -0,0 +1,142 @@
+// RUN: iree-opt -split-input-file -iree-codegen-vectorize-linalg-conv -canonicalize -cse %s | IreeFileCheck %s
+
+func @vectorize_conv(%filter: memref<1x1x4x4xf32>, %input: memref<1x1x7x4xf32>, %output: memref<1x1x4x4xf32>) {
+  %0 = subview %filter[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1]  : memref<1x1x4x4xf32> to memref<1x1x4x4xf32>
+  %1 = subview %input[0, 0, 0, 0] [1, 1, 7, 4] [1, 1, 1, 1]  : memref<1x1x7x4xf32> to memref<1x1x7x4xf32>
+  %2 = subview %output[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1]  : memref<1x1x4x4xf32> to memref<1x1x4x4xf32>
+  linalg.conv(%0, %1, %2) {dilations = [1, 1], strides = [2, 2]} : memref<1x1x4x4xf32>, memref<1x1x7x4xf32>, memref<1x1x4x4xf32>
+  return
+}
+
+// CHECK: #map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK: func @vectorize_conv
+// CHECK-SAME: %[[FILTER_ARG:.+]]: memref<1x1x4x4xf32>,
+// CHECK-SAME: %[[INPUT_ARG:.+]]: memref<1x1x7x4xf32>,
+// CHECK-SAME: %[[OUTPUT_ARG:.+]]: memref<1x1x4x4xf32>
+
+// CHECK: %[[FLOAT_ZERO:.+]] = constant 0.000000e+00 : f32
+// CHECK: %[[FILTER:.+]] = subview %[[FILTER_ARG]]
+// CHECK: %[[INPUT:.+]] = subview %[[INPUT_ARG]]
+// CHECK: %[[OUTPUT:.+]] = subview %[[OUTPUT_ARG]]
+
+// Read in the filter and get slices
+// CHECK: %[[FILTER_VECTOR:.+]] = vector.transfer_read %[[FILTER]][%c0, %c0, %c0, %c0], %cst {masked = [false, false]} : memref<1x1x4x4xf32>, vector<4x4xf32>
+// CHECK: %[[FILTER_0:.+]] = vector.extract_strided_slice %[[FILTER_VECTOR]] {offsets = [0, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xf32> to vector<1x4xf32>
+// CHECK: %[[FILTER_1:.+]] = vector.extract_strided_slice %[[FILTER_VECTOR]] {offsets = [1, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xf32> to vector<1x4xf32>
+// CHECK: %[[FILTER_2:.+]] = vector.extract_strided_slice %[[FILTER_VECTOR]] {offsets = [2, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xf32> to vector<1x4xf32>
+// CHECK: %[[FILTER_3:.+]] = vector.extract_strided_slice %[[FILTER_VECTOR]] {offsets = [3, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xf32> to vector<1x4xf32>
+
+// Handle batch #0
+// CHECK: %[[INPUT_0:.+]] = vector.transfer_read %[[INPUT]][%c0, %c0, %c0, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x1x7x4xf32>, vector<1x4xf32>
+// CHECK: %[[OUTPUT_0:.+]] = vector.transfer_read %[[OUTPUT]][%c0, %c0, %c0, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x1x4x4xf32>, vector<1x4xf32>
+// CHECK: %[[INPUT_0_0:.+]] = vector.extract_strided_slice %[[INPUT_0]] {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_0_0]], %[[FILTER_0]], %[[OUTPUT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[INPUT_0_1:.+]] = vector.extract_strided_slice %[[INPUT_0]] {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_0_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[INPUT_0_2:.+]] = vector.extract_strided_slice %[[INPUT_0]] {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_0_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[INPUT_0_3:.+]] = vector.extract_strided_slice %[[INPUT_0]] {offsets = [0, 3], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_3:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_0_3]], %[[FILTER_3]], %[[DOT_2]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: vector.transfer_write %[[DOT_3]], %[[OUTPUT]][%c0, %c0, %c0, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x1x4x4xf32>
+
+// Handle batch #1
+// CHECK: %[[INPUT_1:.+]] = vector.transfer_read %[[INPUT]][%c0, %c0, %c2, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x1x7x4xf32>, vector<1x4xf32>
+// CHECK: %[[OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%c0, %c0, %c1, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x1x4x4xf32>, vector<1x4xf32>
+// CHECK: %[[INPUT_1_0:.+]] = vector.extract_strided_slice %[[INPUT_1]] {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_1_0]], %[[FILTER_0]], %[[OUTPUT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[INPUT_1_1:.+]] = vector.extract_strided_slice %[[INPUT_1]] {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_1_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[INPUT_1_2:.+]] = vector.extract_strided_slice %[[INPUT_1]] {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_1_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[INPUT_1_3:.+]] = vector.extract_strided_slice %[[INPUT_1]] {offsets = [0, 3], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_3:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_1_3]], %[[FILTER_3]], %[[DOT_2]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: vector.transfer_write %[[DOT_3]], %[[OUTPUT]][%c0, %c0, %c1, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x1x4x4xf32>
+
+// Handle batch #2
+// CHECK: %[[INPUT_2:.+]] = vector.transfer_read %[[INPUT]][%c0, %c0, %c4, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x1x7x4xf32>, vector<1x4xf32>
+// CHECK: %[[OUTPUT_2:.+]] = vector.transfer_read %[[OUTPUT]][%c0, %c0, %c2, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x1x4x4xf32>, vector<1x4xf32>
+// CHECK: %[[INPUT_2_0:.+]] = vector.extract_strided_slice %[[INPUT_2]] {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_2_0]], %[[FILTER_0]], %[[OUTPUT_2]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[INPUT_2_1:.+]] = vector.extract_strided_slice %[[INPUT_2]] {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_2_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[INPUT_2_2:.+]] = vector.extract_strided_slice %[[INPUT_2]] {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_2_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[INPUT_2_3:.+]] = vector.extract_strided_slice %[[INPUT_2]] {offsets = [0, 3], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_3:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_2_3]], %[[FILTER_3]], %[[DOT_2]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: vector.transfer_write %[[DOT_3]], %[[OUTPUT]][%c0, %c0, %c2, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x1x4x4xf32>
+
+// Handle batch #3
+// CHECK: %[[INPUT_3:.+]] = vector.transfer_read %[[INPUT]][%c0, %c0, %c6, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x1x7x4xf32>, vector<1x4xf32>
+// CHECK: %[[OUTPUT_3:.+]] = vector.transfer_read %[[OUTPUT]][%c0, %c0, %c3, %c0], %[[FLOAT_ZERO]] {masked = [false, false]} : memref<1x1x4x4xf32>, vector<1x4xf32>
+// CHECK: %[[INPUT_3_0:.+]] = vector.extract_strided_slice %[[INPUT_3]] {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_0:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_3_0]], %[[FILTER_0]], %[[OUTPUT_3]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[INPUT_3_1:.+]] = vector.extract_strided_slice %[[INPUT_3]] {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_1:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_3_1]], %[[FILTER_1]], %[[DOT_0]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[INPUT_3_2:.+]] = vector.extract_strided_slice %[[INPUT_3]] {offsets = [0, 2], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_2:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_3_2]], %[[FILTER_2]], %[[DOT_1]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[INPUT_3_3:.+]] = vector.extract_strided_slice %[[INPUT_3]] {offsets = [0, 3], sizes = [1, 1], strides = [1, 1]} : vector<1x4xf32> to vector<1x1xf32>
+// CHECK: %[[DOT_3:.+]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[INPUT_3_3]], %[[FILTER_3]], %[[DOT_2]] : vector<1x1xf32>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: vector.transfer_write %[[DOT_3]], %[[OUTPUT]][%c0, %c0, %c3, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<1x1x4x4xf32>
+
+// -----
+
+// CHECK-LABEL: func @do_not_vectorize_conv_with_non_1_batch
+func @do_not_vectorize_conv_with_non_1_batch(%filter: memref<1x1x4x4xf32>, %input: memref<2x1x7x4xf32>, %output: memref<2x1x4x4xf32>) {
+  %0 = subview %filter[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1]  : memref<1x1x4x4xf32> to memref<1x1x4x4xf32>
+  %1 = subview %input[0, 0, 0, 0] [2, 1, 7, 4] [1, 1, 1, 1]  : memref<2x1x7x4xf32> to memref<2x1x7x4xf32>
+  %2 = subview %output[0, 0, 0, 0] [2, 1, 4, 4] [1, 1, 1, 1]  : memref<2x1x4x4xf32> to memref<2x1x4x4xf32>
+  // CHECK: linalg.conv
+  linalg.conv(%0, %1, %2) {dilations = [1, 1], strides = [2, 2]} : memref<1x1x4x4xf32>, memref<2x1x7x4xf32>, memref<2x1x4x4xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @do_not_vectorize_conv_with_non_1_output_height
+func @do_not_vectorize_conv_with_non_1_output_height(%filter: memref<1x1x4x4xf32>, %input: memref<1x3x7x4xf32>, %output: memref<1x2x4x4xf32>) {
+  %0 = subview %filter[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1]  : memref<1x1x4x4xf32> to memref<1x1x4x4xf32>
+  %1 = subview %input[0, 0, 0, 0] [1, 3, 7, 4] [1, 1, 1, 1]  : memref<1x3x7x4xf32> to memref<1x3x7x4xf32>
+  %2 = subview %output[0, 0, 0, 0] [1, 2, 4, 4] [1, 1, 1, 1]  : memref<1x2x4x4xf32> to memref<1x2x4x4xf32>
+  // CHECK: linalg.conv
+  linalg.conv(%0, %1, %2) {dilations = [1, 1], strides = [2, 2]} : memref<1x1x4x4xf32>, memref<1x3x7x4xf32>, memref<1x2x4x4xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @do_not_vectorize_conv_with_non_1_filter_height
+func @do_not_vectorize_conv_with_non_1_filter_height(%filter: memref<2x1x4x4xf32>, %input: memref<1x2x7x4xf32>, %output: memref<1x1x4x4xf32>) {
+  %0 = subview %filter[0, 0, 0, 0] [2, 1, 4, 4] [1, 1, 1, 1]  : memref<2x1x4x4xf32> to memref<2x1x4x4xf32>
+  %1 = subview %input[0, 0, 0, 0] [1, 2, 7, 4] [1, 1, 1, 1]  : memref<1x2x7x4xf32> to memref<1x2x7x4xf32>
+  %2 = subview %output[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1]  : memref<1x1x4x4xf32> to memref<1x1x4x4xf32>
+  // CHECK: linalg.conv
+  linalg.conv(%0, %1, %2) {dilations = [1, 1], strides = [2, 2]} : memref<2x1x4x4xf32>, memref<1x2x7x4xf32>, memref<1x1x4x4xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @do_not_vectorize_conv_with_non_1_filter_width
+func @do_not_vectorize_conv_with_non_1_filter_width(%filter: memref<1x2x4x4xf32>, %input: memref<1x1x8x4xf32>, %output: memref<1x1x4x4xf32>) {
+  %0 = subview %filter[0, 0, 0, 0] [1, 2, 4, 4] [1, 1, 1, 1]  : memref<1x2x4x4xf32> to memref<1x2x4x4xf32>
+  %1 = subview %input[0, 0, 0, 0] [1, 1, 8, 4] [1, 1, 1, 1]  : memref<1x1x8x4xf32> to memref<1x1x8x4xf32>
+  %2 = subview %output[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1]  : memref<1x1x4x4xf32> to memref<1x1x4x4xf32>
+  // CHECK: linalg.conv
+  linalg.conv(%0, %1, %2) {dilations = [1, 1], strides = [2, 2]} : memref<1x2x4x4xf32>, memref<1x1x8x4xf32>, memref<1x1x4x4xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @do_not_vectorize_conv_with_non_1_dilation
+func @do_not_vectorize_conv_with_non_1_dilation(%filter: memref<1x1x4x4xf32>, %input: memref<1x1x7x4xf32>, %output: memref<1x1x4x4xf32>) {
+  %0 = subview %filter[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1]  : memref<1x1x4x4xf32> to memref<1x1x4x4xf32>
+  %1 = subview %input[0, 0, 0, 0] [1, 1, 7, 4] [1, 1, 1, 1]  : memref<1x1x7x4xf32> to memref<1x1x7x4xf32>
+  %2 = subview %output[0, 0, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1]  : memref<1x1x4x4xf32> to memref<1x1x4x4xf32>
+  // CHECK: linalg.conv
+  linalg.conv(%0, %1, %2) {dilations = [2, 1], strides = [2, 2]} : memref<1x1x4x4xf32>, memref<1x1x7x4xf32>, memref<1x1x4x4xf32>
+  return
+}
diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h
index 6d72146..d53c0ee 100644
--- a/iree/compiler/Conversion/init_conversions.h
+++ b/iree/compiler/Conversion/init_conversions.h
@@ -33,6 +33,7 @@
   createDecomposeHLOClampPass();
   createHLOToLinalgOnBuffersPass();
   createHLOToLinalgOnTensorsPass();
+  createDemoteF32ToF16Pass();
 }
 
 inline void registerLinalgToVectorPasses() {
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
index 049b5c1..4750793 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
@@ -17,12 +17,19 @@
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
 #include "iree/compiler/Dialect/Flow/Utils/DispatchUtils.h"
 #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 
 #define DEBUG_TYPE "iree-detail"
 
+static llvm::cl::opt<bool> clEnableMatmulFusion(
+    "iree-enable-matmul-fusion",
+    llvm::cl::desc("Flag to enable fusion of matmul with its consumers, "
+                   "experimental flag to evaluate fusion"),
+    llvm::cl::init(false));
+
 namespace mlir {
 namespace iree_compiler {
 namespace IREE {
@@ -104,7 +111,7 @@
 }
 
 int OpDispatchPolicy::getAnchorBenefit(Operation *op) {
-  if (isUnsupportedFusionOp(op)) {
+  if (isUnsupportedFusionOp(op) || isFusableWithConsumersOnly(op)) {
     return 100;
   }
 
@@ -142,6 +149,9 @@
   if (isUnsupportedFusionOp(anchorOp) || isUnsupportedFusionOp(inputOp)) {
     return FusionType::DISABLED;
   }
+  if (isFusableWithConsumersOnly(anchorOp) && !isa<mhlo::ReshapeOp>(inputOp)) {
+    return FusionType::DISABLED;
+  }
 
   // By default for operands, they are duplicated into the dispatch region.
   // Typically at the initial fusion stage, there is not a sufficient cost
@@ -167,6 +177,10 @@
   if (isUnsupportedFusionOp(anchorOp) || isUnsupportedFusionOp(outputOp)) {
     return FusionType::DISABLED;
   }
+  if (isFusableWithConsumersOnly(anchorOp) &&
+      !isFusableWithConsumersOnly(outputOp)) {
+    return FusionType::MOVE_INTO;
+  }
 
   // Generally, it is hard to reason locally about the legality of fusing an
   // output, since additional analysis may need to be done to determine
@@ -176,12 +190,16 @@
   return FusionType::DISABLED;
 }
 
+bool OpDispatchPolicy::isFusableWithConsumersOnly(Operation *op) {
+  return clEnableMatmulFusion && isa<mhlo::DotOp>(op);
+}
+
 // TODO(b/144530470): replace with tablegen attributes/interfaces.
 bool OpDispatchPolicy::isUnsupportedFusionOp(Operation *op) {
-  return isa<mhlo::ConcatenateOp, mhlo::ConvOp, mhlo::DotGeneralOp, mhlo::DotOp,
-             mhlo::PadOp, mhlo::ReduceOp, mhlo::ReduceWindowOp,
-             mhlo::TorchIndexSelectOp>(op) ||
-         isRootOnlyOp(op);
+  return isa<mhlo::ConcatenateOp, mhlo::ConvOp, mhlo::DotGeneralOp, mhlo::PadOp,
+             mhlo::ReduceOp, mhlo::ReduceWindowOp, mhlo::TorchIndexSelectOp>(
+             op) ||
+         (!clEnableMatmulFusion && isa<mhlo::DotOp>(op)) || isRootOnlyOp(op);
 }
 
 bool OpDispatchPolicy::isRootOnlyOp(Operation *op) {
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h
index ee9299f..36e8518 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h
@@ -41,6 +41,9 @@
   OpDispatchPolicy(Dispatchability &dispatchability)
       : dispatchability(dispatchability) {}
 
+  // Returns true if |op| is only fusable with its consumers.
+  static bool isFusableWithConsumersOnly(Operation *op);
+
   // Returns true if |op| is not able to fuse with either producer or consumer.
   static bool isUnsupportedFusionOp(Operation *op);
 
diff --git a/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
index 0982a12..2cbfaba 100644
--- a/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
@@ -202,7 +202,8 @@
   for (auto &block : regionOp.body().getBlocks()) {
     for (auto &op : block) {
       // A root only op is mergable.
-      if (OpDispatchPolicy::isUnsupportedFusionOp(&op) &&
+      if ((OpDispatchPolicy::isUnsupportedFusionOp(&op) ||
+           OpDispatchPolicy::isFusableWithConsumersOnly(&op)) &&
           !OpDispatchPolicy::isRootOnlyOp(&op)) {
         return false;
       }
diff --git a/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions2.cpp b/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions2.cpp
index 6da8bef..f604b39 100644
--- a/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions2.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions2.cpp
@@ -254,6 +254,44 @@
   return success();
 }
 
+// Inlining an op into a dispatch region makes the operands of the op the
+// operands of the dispatch region (if the operands arent already defined in the
+// dispatch region). The dispatch region has to be moved just after the last
+// defined operand for SSA value use to be valid.
+static LogicalResult moveDispatchOp(DispatchRegionOp dispatchRegionOp,
+                                    Operation *inlinedOp) {
+  // Check the operation that is the lexicographically first to produce an
+  // operand to the inlinedOp
+  Optional<Operation *> lastOperandDef = llvm::None;
+  for (Value operand : inlinedOp->getOperands()) {
+    if (Operation *definingOp = operand.getDefiningOp()) {
+      if (!lastOperandDef ||
+          lastOperandDef.getValue()->isBeforeInBlock(definingOp)) {
+        lastOperandDef = definingOp;
+      }
+    }
+  }
+  // If the last operand def is already before the dispatch region, there is
+  // nothing to do.
+  if (!lastOperandDef ||
+      lastOperandDef.getValue()->isBeforeInBlock(dispatchRegionOp)) {
+    return success();
+  }
+
+  // The dispatch region needs to be moved after the lastOperandDef, but before
+  // the first use.
+  Optional<Operation *> firstUse = llvm::None;
+  for (Operation *user : dispatchRegionOp.getOperation()->getUsers()) {
+    if (!firstUse || user->isBeforeInBlock(*firstUse)) {
+      firstUse = user;
+    }
+  }
+  if (firstUse && firstUse.getValue()->isBeforeInBlock(*lastOperandDef))
+    return failure();
+  dispatchRegionOp.getOperation()->moveAfter(lastOperandDef.getValue());
+  return success();
+}
+
 LogicalResult fuseOutputs(DispatchRegion &dispatchRegion,
                           OpDispatchPolicy &policy) {
   LLVM_DEBUG(llvm::dbgs() << "++ FUSING OUTPUT\n");
@@ -275,6 +313,11 @@
       return nextOp->emitError()
              << "cannot fuse output except with MOVE_INTO action";
     }
+    if (failed(moveDispatchOp(dispatchRegion.op, nextOp))) {
+      LLVM_DEBUG(llvm::dbgs() << "- SKIP Fusion due to SSA use-def violation "
+                              << *nextOp << "\n");
+      continue;
+    }
     LLVM_DEBUG(llvm::dbgs() << "- FUSABLE OUTPUT(" << static_cast<int>(action)
                             << "): " << *nextOp << "\n");
     // Since results will be redirected to the region results, need to scan
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_enable_matmul_fusion.mlir b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_enable_matmul_fusion.mlir
new file mode 100644
index 0000000..f8d9fa8
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/test/identify_dispatch_regions2_enable_matmul_fusion.mlir
@@ -0,0 +1,109 @@
+// RUN: iree-opt -split-input-file -iree-flow-dispatchability-analysis -iree-flow-identify-dispatch-regions2 -iree-enable-matmul-fusion %s | IreeFileCheck %s
+
+func @simpleDotAddMul
+  (%arg0 : tensor<16x32xf32>, %arg1 : tensor<32x48xf32>,
+   %arg2 : tensor<16x48xf32>, %arg3 : tensor<16x48xf32>) -> tensor<16x48xf32> {
+  %0 = "mhlo.dot"(%arg0, %arg1) :
+    (tensor<16x32xf32>, tensor<32x48xf32>) -> tensor<16x48xf32>
+  %1 = mhlo.add %0, %arg2 : tensor<16x48xf32>
+  %2 = mhlo.multiply %1, %arg3 : tensor<16x48xf32>
+  return %2 : tensor<16x48xf32>
+}
+// CHECK-LABEL: func @simpleDotAddMul
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<16x32xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<32x48xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<16x48xf32>
+//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: tensor<16x48xf32>
+//  CHECK-NEXT:   %[[WORKLOAD:.+]] = constant 768
+//  CHECK-NEXT:   %[[RESULT:.+]] = flow.dispatch.region[%[[WORKLOAD]] : index]
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG0]]
+//  CHECK-SAME:     %[[ARG5:[a-zA-Z0-9_]+]] = %[[ARG1]]
+//  CHECK-SAME:     %[[ARG6:[a-zA-Z0-9_]+]] = %[[ARG2]]
+//  CHECK-SAME:     %[[ARG7:[a-zA-Z0-9_]+]] = %[[ARG3]]
+//  CHECK-SAME:     {
+//  CHECK-NEXT:       %[[T1:.+]] = "mhlo.dot"(%[[ARG4]], %[[ARG5]])
+//  CHECK-NEXT:       %[[T2:.+]] = mhlo.add %[[T1]], %[[ARG6]]
+//  CHECK-NEXT:       %[[T3:.+]] = mhlo.multiply %[[T2]], %[[ARG7]]
+//  CHECK-NEXT:       flow.return %[[T3]]
+//  CHECK-NEXT:     }
+//  CHECK-NEXT:   return %[[RESULT]]
+
+// -----
+
+func @twoDots
+  (%arg0 : tensor<16x32xf32>, %arg1 : tensor<32x48xf32>,
+   %arg2 : tensor<16x48xf32>, %arg3 : tensor<16x64xf32>,
+   %arg4 : tensor<16x64xf32>) -> tensor<16x64xf32> {
+  %0 = "mhlo.dot"(%arg0, %arg1) :
+    (tensor<16x32xf32>, tensor<32x48xf32>) -> tensor<16x48xf32>
+  %1 = mhlo.add %0, %arg2 : tensor<16x48xf32>
+  %2 = "mhlo.dot"(%1, %arg3) :
+    (tensor<16x48xf32>, tensor<16x64xf32>) -> tensor<16x64xf32>
+  %3 = mhlo.multiply %2, %arg4 : tensor<16x64xf32>
+  return %3 : tensor<16x64xf32>
+}
+// CHECK-LABEL: func @twoDots
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<16x32xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<32x48xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<16x48xf32>
+//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: tensor<16x64xf32>
+//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: tensor<16x64xf32>
+//  CHECK-NEXT:   %[[WORKLOAD1:.+]] = constant 1024
+//  CHECK-NEXT:   %[[WORKLOAD2:.+]] = constant 768
+//  CHECK-NEXT:   %[[RESULT1:.+]] = flow.dispatch.region[%[[WORKLOAD2]] : index]
+//  CHECK-SAME:     %[[ARG5:[a-zA-Z0-9_]+]] = %[[ARG0]]
+//  CHECK-SAME:     %[[ARG6:[a-zA-Z0-9_]+]] = %[[ARG1]]
+//  CHECK-SAME:     %[[ARG7:[a-zA-Z0-9_]+]] = %[[ARG2]]
+//  CHECK-SAME:     {
+//  CHECK-NEXT:       %[[T1:.+]] = "mhlo.dot"(%[[ARG5]], %[[ARG6]])
+//  CHECK-NEXT:       %[[T2:.+]] = mhlo.add %[[T1]], %[[ARG7]]
+//  CHECK-NEXT:       flow.return %[[T2]]
+//  CHECK-NEXT:     }
+//  CHECK-NEXT:   %[[RESULT2:.+]] = flow.dispatch.region[%[[WORKLOAD1]] : index]
+//  CHECK-SAME:     %[[ARG5:[a-zA-Z0-9_]+]] = %[[RESULT1]]
+//  CHECK-SAME:     %[[ARG6:[a-zA-Z0-9_]+]] = %[[ARG3]]
+//  CHECK-SAME:     %[[ARG7:[a-zA-Z0-9_]+]] = %[[ARG4]]
+//  CHECK-SAME:     {
+//  CHECK-NEXT:       %[[T3:.+]] = "mhlo.dot"(%[[ARG5]], %[[ARG6]])
+//  CHECK-NEXT:       %[[T4:.+]] = mhlo.multiply %[[T3]], %[[ARG7]]
+//  CHECK-NEXT:       flow.return %[[T4]]
+//  CHECK-NEXT:     }
+//  CHECK-NEXT:   return %[[RESULT2]]
+
+// -----
+
+func @moveDispatchOp
+  (%arg0 : tensor<1x384x384xf32>, %arg1 : tensor<384x512xf32>,
+   %arg2 : tensor<512xf32>) -> tensor<1x384x512xf32> {
+  %0 = "mhlo.reshape"(%arg0) : (tensor<1x384x384xf32>) -> tensor<384x384xf32>
+  %1 = "mhlo.dot"(%0, %arg1) :
+    (tensor<384x384xf32>, tensor<384x512xf32>) -> tensor<384x512xf32>
+  %2 = "mhlo.broadcast_in_dim"(%arg2)
+    {broadcast_dimensions = dense<1> : tensor<1xi64>} :
+    (tensor<512xf32>) -> tensor<384x512xf32>
+  %3 = mhlo.add %1, %2 : tensor<384x512xf32>
+  %4 = "mhlo.reshape"(%3) : (tensor<384x512xf32>) -> tensor<1x384x512xf32>
+  return %4 : tensor<1x384x512xf32>
+}
+// CHECK-LABEL: func @moveDispatchOp
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x384x384xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<384x512xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<512xf32>
+//       CHECK:   %[[RESULT1:.+]] = flow.dispatch.region
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9_]+]] = %[[ARG2]]
+//  CHECK-SAME:     {
+//  CHECK-NEXT:       %[[T1:.+]] = "mhlo.broadcast_in_dim"(%[[ARG3]])
+//  CHECK-NEXT:       flow.return %[[T1]]
+//  CHECK-NEXT:     }
+//  CHECK-NEXT:   %[[RESULT2:.+]] = flow.dispatch.region
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9_]+]] = %[[ARG1]]
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9_]+]] = %[[RESULT1]]
+//  CHECK-SAME:     %[[ARG5:[a-zA-Z0-9_]+]] = %[[ARG0]]
+//  CHECK-SAME:     {
+//  CHECK-NEXT:       %[[T2:.+]] = "mhlo.reshape"(%[[ARG5]])
+//  CHECK-NEXT:       %[[T3:.+]] = "mhlo.dot"(%[[T2]], %[[ARG3]])
+//  CHECK-NEXT:       %[[T4:.+]] = mhlo.add %[[T3]], %[[ARG4]]
+//  CHECK-NEXT:       %[[T5:.+]] = "mhlo.reshape"(%[[T4]])
+//  CHECK-NEXT:       flow.return %[[T5]]
+//  CHECK-NEXT:     }
+//  CHECK-NEXT:   return %[[RESULT2]]
diff --git a/iree/compiler/Dialect/Modules/TensorList/Conversion/test/convert_hal_to_vm.mlir b/iree/compiler/Dialect/Modules/TensorList/Conversion/test/convert_hal_to_vm.mlir
index c5c8748..5974d45 100644
--- a/iree/compiler/Dialect/Modules/TensorList/Conversion/test/convert_hal_to_vm.mlir
+++ b/iree/compiler/Dialect/Modules/TensorList/Conversion/test/convert_hal_to_vm.mlir
@@ -3,7 +3,7 @@
 // CHECK-LABEL: @Reserve
 func @Reserve(%element_shape: !hal.buffer_view, %num_elements: !hal.buffer_view) -> !tensorlist.list {
   // CHECK: vm.call @tensorlist.reserve
-  %0 = "tensorlist.Reserve"(%element_shape, %num_elements) : (!hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
+  %0 = "tensorlist.Reserve"(%element_shape, %num_elements) { element_type = 50331680 : i32 } : (!hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
   return %0 : !tensorlist.list
 }
 // CHECK: vm.import @tensorlist.reserve
@@ -11,9 +11,9 @@
 // -----
 
 // CHECK-LABEL: @GetItem
-func @GetItem(%list: !tensorlist.list, %index: !hal.buffer_view, %element_shape: !hal.buffer_view) -> !hal.buffer_view {
+func @GetItem(%list: !tensorlist.list, %index: !hal.buffer_view) -> !hal.buffer_view {
   // CHECK: vm.call @tensorlist.get_item
-  %0 = "tensorlist.GetItem"(%list, %index, %element_shape) : (!tensorlist.list, !hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view
+  %0 = "tensorlist.GetItem"(%list, %index) : (!tensorlist.list, !hal.buffer_view) -> !hal.buffer_view
   return %0 : !hal.buffer_view
 }
 // CHECK: vm.import @tensorlist.get_item
@@ -31,9 +31,11 @@
 // -----
 
 // CHECK-LABEL: @Stack
-func @Stack(%list: !tensorlist.list, %element_shape: !hal.buffer_view, %num_elements: !hal.buffer_view) -> !hal.buffer_view {
+func @Stack(%list: !tensorlist.list, %num_elements: !hal.buffer_view) -> !hal.buffer_view {
+  %dev = hal.ex.shared_device : !hal.device
+  %allocator = hal.device.allocator %dev : !hal.allocator
   // CHECK: vm.call @tensorlist.stack
-  %0 = "tensorlist.Stack"(%list, %element_shape, %num_elements) : (!tensorlist.list, !hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view
+  %0 = "tensorlist.Stack"(%allocator, %list, %num_elements) : (!hal.allocator, !tensorlist.list, !hal.buffer_view) -> !hal.buffer_view
   return %0 : !hal.buffer_view
 }
 
@@ -41,7 +43,9 @@
 
 // CHECK-LABEL: @Concat
 func @Concat(%list: !tensorlist.list) -> !hal.buffer_view {
+  %dev = hal.ex.shared_device : !hal.device
+  %allocator = hal.device.allocator %dev : !hal.allocator
   // CHECK: vm.call @tensorlist.concat
-  %0 = "tensorlist.Concat"(%list) : (!tensorlist.list) -> !hal.buffer_view
+  %0 = "tensorlist.Concat"(%allocator, %list) : (!hal.allocator, !tensorlist.list) -> !hal.buffer_view
   return %0 : !hal.buffer_view
 }
diff --git a/iree/compiler/Dialect/Modules/TensorList/IR/TensorListOps.cpp b/iree/compiler/Dialect/Modules/TensorList/IR/TensorListOps.cpp
index 04ea528..296e199 100644
--- a/iree/compiler/Dialect/Modules/TensorList/IR/TensorListOps.cpp
+++ b/iree/compiler/Dialect/Modules/TensorList/IR/TensorListOps.cpp
@@ -15,6 +15,7 @@
 #include "iree/compiler/Dialect/Modules/TensorList/IR/TensorListOps.h"
 
 #include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "mlir/IR/Builders.h"
 
 #define GET_OP_CLASSES
 #include "iree/compiler/Dialect/Modules/TensorList/IR/TensorListOps.cpp.inc"
diff --git a/iree/compiler/Dialect/Modules/TensorList/IR/TensorListOps.td b/iree/compiler/Dialect/Modules/TensorList/IR/TensorListOps.td
index d7bc2c9..2da11b9 100644
--- a/iree/compiler/Dialect/Modules/TensorList/IR/TensorListOps.td
+++ b/iree/compiler/Dialect/Modules/TensorList/IR/TensorListOps.td
@@ -27,10 +27,10 @@
   }];
 
   let arguments = (ins
-    // TODO(silvasean): Do we need element_shape?
     HAL_BufferView:$element_shape,
     // TODO(silvasean): Convert to `I32:$count` instead.
-    HAL_BufferView:$count
+    HAL_BufferView:$count,
+    HAL_ElementTypeAttr:$element_type
   );
 
   let results = (outs
@@ -47,9 +47,7 @@
   let arguments = (ins
     TensorList_TensorList:$list,
     // TODO(silvasean): Convert to `I32:$index` instead.
-    HAL_BufferView:$index,
-    // TODO(silvasean): Do we need element_shape?
-    HAL_BufferView:$element_shape
+    HAL_BufferView:$index
   );
 
   let results = (outs
@@ -83,9 +81,7 @@
     a tensorlist `list` of length equal to `tensor`'s leading dimension.
   }];
   let arguments = (ins
-    HAL_BufferView:$tensor,
-    // TODO(silvasean): Do we need element_shape?
-    HAL_BufferView:$element_shape
+    HAL_BufferView:$tensor
   );
   let results = (outs
     TensorList_TensorList:$list
@@ -102,8 +98,8 @@
     Requires all tensors contained in `list` to be the same shape.
   }];
   let arguments = (ins
+    HAL_Allocator:$allocator,
     TensorList_TensorList:$list,
-    HAL_BufferView:$element_shape,
     HAL_BufferView:$num_elements
   );
   let results = (outs
@@ -122,6 +118,7 @@
     the non-leading axes.
   }];
   let arguments = (ins
+    HAL_Allocator:$allocator,
     TensorList_TensorList:$list
   );
   let results = (outs
diff --git a/iree/compiler/Dialect/Modules/TensorList/IR/test/ops.mlir b/iree/compiler/Dialect/Modules/TensorList/IR/test/ops.mlir
index 72cef98..688bc47 100644
--- a/iree/compiler/Dialect/Modules/TensorList/IR/test/ops.mlir
+++ b/iree/compiler/Dialect/Modules/TensorList/IR/test/ops.mlir
@@ -3,16 +3,16 @@
 // CHECK-LABEL: @Reserve
 func @Reserve(%element_shape: !hal.buffer_view, %num_elements: !hal.buffer_view) -> !tensorlist.list {
   // CHECK: tensorlist.Reserve
-  %0 = "tensorlist.Reserve"(%element_shape, %num_elements) : (!hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
+  %0 = "tensorlist.Reserve"(%element_shape, %num_elements) {element_type = 13 : i32 } : (!hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
   return %0 : !tensorlist.list
 }
 
 // -----
 
 // CHECK-LABEL: @GetItem
-func @GetItem(%list: !tensorlist.list, %index: !hal.buffer_view, %element_shape: !hal.buffer_view) -> !hal.buffer_view {
+func @GetItem(%list: !tensorlist.list, %index: !hal.buffer_view) -> !hal.buffer_view {
   // CHECK: tensorlist.GetItem
-  %0 = "tensorlist.GetItem"(%list, %index, %element_shape) : (!tensorlist.list, !hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view
+  %0 = "tensorlist.GetItem"(%list, %index) : (!tensorlist.list, !hal.buffer_view) -> !hal.buffer_view
   return %0 : !hal.buffer_view
 }
 
@@ -28,17 +28,17 @@
 // -----
 
 // CHECK-LABEL: @Stack
-func @Stack(%list: !tensorlist.list, %element_shape: !hal.buffer_view, %num_elements: !hal.buffer_view) -> !hal.buffer_view {
+func @Stack(%allocator: !hal.allocator, %list: !tensorlist.list, %num_elements: !hal.buffer_view) -> !hal.buffer_view {
   // CHECK: tensorlist.Stack
-  %0 = "tensorlist.Stack"(%list, %element_shape, %num_elements) : (!tensorlist.list, !hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view
+  %0 = "tensorlist.Stack"(%allocator, %list, %num_elements) : (!hal.allocator, !tensorlist.list, !hal.buffer_view) -> !hal.buffer_view
   return %0 : !hal.buffer_view
 }
 
 // -----
 
 // CHECK-LABEL: @Concat
-func @Concat(%list: !tensorlist.list) -> !hal.buffer_view {
+func @Concat(%allocator: !hal.allocator, %list: !tensorlist.list) -> !hal.buffer_view {
   // CHECK: tensorlist.Concat
-  %0 = "tensorlist.Concat"(%list) : (!tensorlist.list) -> !hal.buffer_view
+  %0 = "tensorlist.Concat"(%allocator, %list) : (!hal.allocator, !tensorlist.list) -> !hal.buffer_view
   return %0 : !hal.buffer_view
-}
\ No newline at end of file
+}
diff --git a/iree/compiler/Dialect/Modules/TensorList/tensorlist.imports.mlir b/iree/compiler/Dialect/Modules/TensorList/tensorlist.imports.mlir
index fa55ef5..6a16add 100644
--- a/iree/compiler/Dialect/Modules/TensorList/tensorlist.imports.mlir
+++ b/iree/compiler/Dialect/Modules/TensorList/tensorlist.imports.mlir
@@ -17,15 +17,15 @@
 // Maps to IREE::TensorList::Reserve.
 vm.import @reserve(
   %element_shape : !vm.ref<!hal.buffer_view>,
-  %num_elements : !vm.ref<!hal.buffer_view>
+  %num_elements : !vm.ref<!hal.buffer_view>,
+  %element_type : i32
 ) -> !vm.ref<!tensorlist.list>
 attributes {nosideeffects}
 
 // Maps to IREE::TensorList::GetItem.
 vm.import @get_item(
   %list : !vm.ref<!tensorlist.list>,
-  %index : !vm.ref<!hal.buffer_view>,
-  %element_shape: !vm.ref<!hal.buffer_view>
+  %index : !vm.ref<!hal.buffer_view>
 ) -> !vm.ref<!hal.buffer_view>
 attributes {nosideeffects}
 
@@ -39,21 +39,21 @@
 
 // Maps to IREE:TensorList::FromTensor
 vm.import @from_tensor(
-  %tensor : !vm.ref<!hal.buffer_view>,
-  %element_shape : !vm.ref<!hal.buffer_view>
+  %tensor : !vm.ref<!hal.buffer_view>
 ) -> !vm.ref<!tensorlist.list>
 attributes {nosideeffects}
 
 // Maps to IREE:TensorList::Concat
 vm.import @concat(
+  %allocator : !vm.ref<!hal.allocator>,
   %list : !vm.ref<!tensorlist.list>
 ) -> !vm.ref<!hal.buffer_view>
 attributes {nosideeffects}
 
 // Maps to IREE:TensorList::Stack
 vm.import @stack(
+  %allocator : !vm.ref<!hal.allocator>,
   %list : !vm.ref<!tensorlist.list>,
-  %element_shape : !vm.ref<!hal.buffer_view>,
   %num_elements : !vm.ref<!hal.buffer_view>
 ) -> !vm.ref<!hal.buffer_view>
 attributes {nosideeffects}
diff --git a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/CMakeLists.txt b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/CMakeLists.txt
index f4f29c9..4dafc73 100644
--- a/iree/compiler/Dialect/VM/Conversion/VMToEmitC/CMakeLists.txt
+++ b/iree/compiler/Dialect/VM/Conversion/VMToEmitC/CMakeLists.txt
@@ -28,6 +28,9 @@
       MLIREmitC
       MLIRTransforms
       iree::compiler::Dialect::VM::IR
+    INCLUDES
+      "${PROJECT_SOURCE_DIR}/third_party/mlir-emitc/include"
+      "${PROJECT_BINARY_DIR}/third_party/mlir-emitc/include"
     PUBLIC
   )
 endif()
diff --git a/iree/compiler/Dialect/VM/Target/C/CMakeLists.txt b/iree/compiler/Dialect/VM/Target/C/CMakeLists.txt
index 0b0b4da..691c6d3 100644
--- a/iree/compiler/Dialect/VM/Target/C/CMakeLists.txt
+++ b/iree/compiler/Dialect/VM/Target/C/CMakeLists.txt
@@ -32,6 +32,8 @@
       MLIRSupport
       iree::compiler::Dialect::VM::IR
       iree::compiler::Dialect::VM::Conversion::VMToEmitC
+    INCLUDES
+      "${PROJECT_SOURCE_DIR}/third_party/mlir-emitc/include"
     PUBLIC
   )
 endif()
diff --git a/iree/hal/vulkan/descriptor_set_arena.cc b/iree/hal/vulkan/descriptor_set_arena.cc
index 238384f..ee60e6b 100644
--- a/iree/hal/vulkan/descriptor_set_arena.cc
+++ b/iree/hal/vulkan/descriptor_set_arena.cc
@@ -131,10 +131,11 @@
 
   // Pick a bucket based on the number of descriptors required.
   // NOTE: right now we are 1:1 with bindings.
-  int required_descriptor_count = static_cast<int>(bindings.size() * 1);
-  int max_descriptor_count =
-      std::max(8, RoundUpToNearestPow2(required_descriptor_count));
-  int bucket = TrailingZeros(max_descriptor_count >> 3);
+  uint32_t required_descriptor_count = static_cast<int>(bindings.size() * 1);
+  uint32_t max_descriptor_count =
+      std::max(8u, iree_math_round_up_to_pow2_u32(required_descriptor_count));
+  uint32_t bucket =
+      iree_math_count_trailing_zeros_u32(max_descriptor_count >> 3);
   if (bucket >= descriptor_pool_buckets_.size()) {
     return OutOfRangeErrorBuilder(IREE_LOC)
            << "Too many descriptors required: " << required_descriptor_count
diff --git a/iree/hal/vulkan/internal_vk_mem_alloc.h b/iree/hal/vulkan/internal_vk_mem_alloc.h
index 4541e9f..d2fa59e 100644
--- a/iree/hal/vulkan/internal_vk_mem_alloc.h
+++ b/iree/hal/vulkan/internal_vk_mem_alloc.h
@@ -27,5 +27,6 @@
 // to be omitted and not have VMA poking around where it shouldn't.
 #define VMA_DYNAMIC_VULKAN_FUNCTIONS 0
 
-#include "vk_mem_alloc.h"
+#include <vk_mem_alloc.h>
+
 #endif  // IREE_HAL_VULKAN_INTERNAL_VK_MEM_ALLOC_H_
diff --git a/iree/hal/vulkan/vulkan_device.cc b/iree/hal/vulkan/vulkan_device.cc
index a5e2cd0..582a8c3 100644
--- a/iree/hal/vulkan/vulkan_device.cc
+++ b/iree/hal/vulkan/vulkan_device.cc
@@ -171,7 +171,8 @@
     const ref_ptr<DynamicSymbols>& syms) {
   absl::InlinedVector<std::unique_ptr<CommandQueue>, 4> command_queues;
 
-  uint64_t compute_queue_count = CountOnes64(compute_queue_set.queue_indices);
+  uint64_t compute_queue_count =
+      iree_math_count_ones_u64(compute_queue_set.queue_indices);
   for (uint32_t i = 0; i < compute_queue_count; ++i) {
     if (!(compute_queue_set.queue_indices & (1ull << i))) continue;
 
@@ -193,7 +194,8 @@
     }
   }
 
-  uint64_t transfer_queue_count = CountOnes64(transfer_queue_set.queue_indices);
+  uint64_t transfer_queue_count =
+      iree_math_count_ones_u64(transfer_queue_set.queue_indices);
   for (uint32_t i = 0; i < transfer_queue_count; ++i) {
     if (!(transfer_queue_set.queue_indices & (1ull << i))) continue;
 
@@ -411,8 +413,10 @@
     const ref_ptr<DynamicSymbols>& syms) {
   IREE_TRACE_SCOPE0("VulkanDevice::Wrap");
 
-  uint64_t compute_queue_count = CountOnes64(compute_queue_set.queue_indices);
-  uint64_t transfer_queue_count = CountOnes64(transfer_queue_set.queue_indices);
+  uint64_t compute_queue_count =
+      iree_math_count_ones_u64(compute_queue_set.queue_indices);
+  uint64_t transfer_queue_count =
+      iree_math_count_ones_u64(transfer_queue_set.queue_indices);
 
   if (compute_queue_count == 0) {
     return InvalidArgumentErrorBuilder(IREE_LOC)
diff --git a/iree/modules/tensorlist/native_module.cc b/iree/modules/tensorlist/native_module.cc
index b23db9c..18d4770 100644
--- a/iree/modules/tensorlist/native_module.cc
+++ b/iree/modules/tensorlist/native_module.cc
@@ -38,6 +38,14 @@
 namespace {
 class TensorList final : public RefObject<TensorList> {
  public:
+  TensorList(absl::Span<const int32_t> shape, iree_hal_element_type_t dtype)
+      : shape_(shape.begin(), shape.end()), dtype_(dtype) {}
+
+  TensorList(const vm::ref<TensorList>& other)
+      : shape_(other->shape_), dtype_(other->dtype_) {
+    CopyFrom(other);
+  }
+
   void Resize(int32_t num_elements) { list_.resize(num_elements); }
   // Copy from another iree_tensorlist.
   // vm::ref has deleted copy operator=, so we can't use vector's operator=.
@@ -62,6 +70,9 @@
     }
   }
   size_t Size() { return list_.size(); }
+  absl::Span<int32_t> Shape() {
+    return absl::Span<int32_t>(shape_.data(), shape_.size());
+  }
 
   static StatusOr<vm::ref<TensorList>> FromTensor(
       vm::ref<iree_hal_buffer_view_t> tensor) {
@@ -74,15 +85,21 @@
     IREE_RETURN_IF_ERROR(
         iree_hal_buffer_view_shape(tensor.get(), rank, shape.data(), nullptr));
 
-    TensorList* list = new TensorList;
-    list->Resize(shape[0]);
+    auto element_type = iree_hal_buffer_view_element_type(tensor.get());
+
+    int32_t list_elements = shape[0];
+    absl::Span<int32_t> element_shape(shape.data() + 1, shape.size() - 1);
+
+    TensorList* list = new TensorList(element_shape, element_type);
+    list->Resize(list_elements);
+
     // The python pseudocode for this is:
     // for i in range(t.shape[0]):
     //   list[i] = t[i,...]
     absl::InlinedVector<int32_t, 6> start_indices(shape.size());
     absl::InlinedVector<int32_t, 6> lengths = shape;
     lengths[0] = 1;
-    for (int i = 0, e = shape[0]; i < e; i++) {
+    for (int i = 0, e = list_elements; i < e; i++) {
       start_indices[0] = i;
       iree_device_size_t start_offset = 0;
       iree_device_size_t subview_length = 0;
@@ -96,7 +113,7 @@
 
       iree_hal_buffer_view_t* slice = nullptr;
       IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create(
-          subview_buffer.get(), shape.data() + 1, shape.size() - 1,
+          subview_buffer.get(), element_shape.data(), element_shape.size(),
           iree_hal_buffer_view_element_type(tensor.get()),
           iree_allocator_system(), &slice));
       list->SetItem(i, slice);
@@ -104,35 +121,29 @@
     return list;
   }
 
-  StatusOr<vm::ref<iree_hal_buffer_view_t>> Stack() {
+  StatusOr<vm::ref<iree_hal_buffer_view_t>> Stack(
+      vm::ref<iree_hal_allocator_t> hal_allocator) {
     size_t num_tensors = Size();
     if (num_tensors == 0) {
       return InvalidArgumentErrorBuilder(IREE_LOC) << "expected non-empty list";
     }
-    for (size_t i = 0; i < num_tensors; i++) {
-      if (!GetItem(i).get()) {
-        return InvalidArgumentErrorBuilder(IREE_LOC)
-               << "uninitialized element in list";
-      }
-    }
 
-    size_t rank = iree_hal_buffer_view_shape_rank(GetItem(0).get());
-    iree_hal_element_type_t type =
-        iree_hal_buffer_view_element_type(GetItem(0).get());
-    absl::InlinedVector<int32_t, 6> shape(rank);
-    IREE_RETURN_IF_ERROR(iree_hal_buffer_view_shape(GetItem(0).get(), rank,
-                                                    shape.data(), nullptr));
+    // Validate that all buffers are of the right shape/type.
+    absl::Span<int32_t> shape(shape_);
+    iree_hal_element_type_t type(dtype_);
     for (size_t i = 0; i < num_tensors; i++) {
-      size_t element_rank = iree_hal_buffer_view_shape_rank(GetItem(i).get());
+      auto item = GetItem(i).get();
+      if (!item) continue;
+      size_t element_rank = iree_hal_buffer_view_shape_rank(item);
       absl::InlinedVector<int32_t, 6> element_shape(element_rank);
       IREE_RETURN_IF_ERROR(iree_hal_buffer_view_shape(
-          GetItem(i).get(), element_rank, element_shape.data(), nullptr));
+          item, element_rank, element_shape.data(), nullptr));
       if (absl::MakeSpan(shape) != absl::MakeSpan(element_shape) ||
-          iree_hal_buffer_view_element_type(GetItem(i).get()) != type) {
+          iree_hal_buffer_view_element_type(item) != type) {
         return InvalidArgumentErrorBuilder(IREE_LOC)
                << "stacking list with elements of different shapes or element "
-                  "types. Mismatch between element 0 and element "
-               << i;
+               << "types. Mismatch between element 0 and element " << i;
+        ;
       }
     }
 
@@ -141,13 +152,12 @@
     for (int32_t dim : shape) {
       num_elements_per_tensor *= dim;
     }
-    size_t element_size = iree_hal_buffer_view_element_size(GetItem(0).get());
+
+    size_t element_size = iree_hal_element_byte_count(type);
     size_t num_result_elements = num_elements_per_tensor * num_tensors;
     size_t result_byte_size = num_result_elements * element_size;
-    iree_hal_allocator_t* hal_allocator = iree_hal_buffer_allocator(
-        iree_hal_buffer_view_buffer(GetItem(0).get()));
     IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer(
-        hal_allocator,
+        hal_allocator.get(),
         static_cast<iree_hal_memory_type_t>(
             IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
             IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE),
@@ -167,27 +177,28 @@
     return std::move(result_view);
   }
 
-  StatusOr<vm::ref<iree_hal_buffer_view_t>> Concat() {
+  StatusOr<vm::ref<iree_hal_buffer_view_t>> Concat(
+      vm::ref<iree_hal_allocator_t> hal_allocator) {
     size_t num_tensors = Size();
     if (num_tensors == 0) {
       return InvalidArgumentErrorBuilder(IREE_LOC) << "expected non-empty list";
     }
-    for (size_t i = 0; i < num_tensors; i++) {
-      if (!GetItem(i).get()) {
-        return InvalidArgumentErrorBuilder(IREE_LOC)
-               << "uninitialized element in list";
-      }
+
+    if (shape_.empty()) {
+      return InvalidArgumentErrorBuilder(IREE_LOC)
+             << "stacking rank must be greater than zero.";
     }
 
     size_t rank = iree_hal_buffer_view_shape_rank(GetItem(0).get());
-    iree_hal_element_type_t type =
-        iree_hal_buffer_view_element_type(GetItem(0).get());
+    iree_hal_element_type_t type = dtype_;
     absl::InlinedVector<int32_t, 6> shape(rank);
     IREE_RETURN_IF_ERROR(iree_hal_buffer_view_shape(GetItem(0).get(), rank,
                                                     shape.data(), nullptr));
-    size_t num_rows = 0;
+    const size_t num_rows = num_tensors * shape[0];
     for (size_t i = 0; i < num_tensors; i++) {
-      size_t element_rank = iree_hal_buffer_view_shape_rank(GetItem(i).get());
+      auto item = GetItem(i).get();
+      if (!item) continue;
+      size_t element_rank = iree_hal_buffer_view_shape_rank(item);
       if (element_rank < 1) {
         return InvalidArgumentErrorBuilder(IREE_LOC)
                << "stacking rank must be greater than zero." << i;
@@ -196,7 +207,6 @@
       absl::InlinedVector<int32_t, 6> element_shape(element_rank);
       IREE_RETURN_IF_ERROR(iree_hal_buffer_view_shape(
           GetItem(i).get(), element_rank, element_shape.data(), nullptr));
-      num_rows += element_shape.front();
 
       if (absl::MakeSpan(shape).subspan(1) !=
               absl::MakeSpan(element_shape).subspan(1) ||
@@ -216,10 +226,8 @@
     size_t element_size = iree_hal_buffer_view_element_size(GetItem(0).get());
     size_t num_result_elements = num_elements_per_row * num_rows;
     size_t result_byte_size = num_result_elements * element_size;
-    iree_hal_allocator_t* hal_allocator = iree_hal_buffer_allocator(
-        iree_hal_buffer_view_buffer(GetItem(0).get()));
     IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer(
-        hal_allocator,
+        hal_allocator.get(),
         static_cast<iree_hal_memory_type_t>(
             IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
             IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE),
@@ -261,9 +269,19 @@
     // in the compiler at which point there will be no "stack" function inside
     // this module at all.
     size_t num_tensors = Size();
-    size_t offset = 0;
+    size_t tensor_byte_size = iree_hal_element_byte_count(dtype_);
+    for (auto dim : shape_) tensor_byte_size *= dim;
     for (size_t i = 0; i < num_tensors; i++) {
       iree_hal_buffer_view_t* tensor = GetItem(i).get();
+
+      auto block_begin = result_mapping.contents.data + i * tensor_byte_size;
+      auto block_size = tensor_byte_size;
+
+      if (!tensor) {
+        memset(block_begin, 0, block_size);
+        continue;
+      }
+
       iree_hal_buffer_t* tensor_buffer = iree_hal_buffer_view_buffer(tensor);
       iree_hal_mapped_memory_t tensor_mapping;
       iree_device_size_t tensor_byte_size =
@@ -273,9 +291,7 @@
           iree_hal_buffer_map(tensor_buffer, IREE_HAL_MEMORY_ACCESS_READ, 0,
                               tensor_byte_size, &tensor_mapping));
 
-      memcpy(result_mapping.contents.data + offset,
-             tensor_mapping.contents.data, tensor_byte_size);
-      offset += tensor_byte_size;
+      memcpy(block_begin, tensor_mapping.contents.data, block_size);
 
       IREE_RETURN_IF_ERROR(
           iree_hal_buffer_unmap(tensor_buffer, &tensor_mapping));
@@ -285,6 +301,8 @@
   }
 
   std::vector<vm::ref<iree_hal_buffer_view_t>> list_;
+  std::vector<int32_t> shape_;
+  iree_hal_element_type_t dtype_;
 };
 }  // namespace
 
@@ -339,6 +357,35 @@
   return scalar;
 }
 
+static StatusOr<std::vector<int32_t>> ReadInt32VectorFromBufferView(
+    iree_hal_buffer_view_t* buffer_view) {
+  if (iree_hal_buffer_view_element_type(buffer_view) !=
+      IREE_HAL_ELEMENT_TYPE_SINT_32) {
+    return InvalidArgumentErrorBuilder(IREE_LOC) << "expected i32 buffer view";
+  }
+  if (iree_hal_buffer_view_shape_rank(buffer_view) != 1) {
+    return InvalidArgumentErrorBuilder(IREE_LOC)
+           << "expected rank-1 buffer view";
+  }
+
+  int32_t length;
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_view_shape(
+      buffer_view, /*rank_capacity=*/1, &length, nullptr));
+
+  iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(buffer_view);
+  iree_hal_mapped_memory_t mapped_memory;
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_map(buffer, IREE_HAL_MEMORY_ACCESS_READ,
+                                           0, length * sizeof(int32_t),
+                                           &mapped_memory));
+
+  std::vector<int32_t> contents(
+      reinterpret_cast<int32_t*>(mapped_memory.contents.data),
+      reinterpret_cast<int32_t*>(mapped_memory.contents.data) + length);
+
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_unmap(buffer, &mapped_memory));
+  return contents;
+}
+
 namespace {
 class TensorListModuleState final {
  public:
@@ -348,10 +395,11 @@
   // tensorlist.reserve(%element_shape, %num_elements) -> %list
   StatusOr<vm::ref<TensorList>> Reserve(
       vm::ref<iree_hal_buffer_view_t> element_shape,
-      vm::ref<iree_hal_buffer_view_t> num_elements_buf) {
-    // TODO(silvasean): Emulate element shape and dtype tracking in TensorList.
-    (void)element_shape;
-    TensorList* tensorlist = new TensorList;
+      vm::ref<iree_hal_buffer_view_t> num_elements_buf,
+      iree_hal_element_type_t element_type) {
+    IREE_ASSIGN_OR_RETURN(std::vector<int32_t> shape,
+                          ReadInt32VectorFromBufferView(element_shape.get()));
+    TensorList* tensorlist = new TensorList(shape, element_type);
     IREE_ASSIGN_OR_RETURN(int32_t num_elements, ReadInt32FromScalarBufferView(
                                                     num_elements_buf.get()));
     tensorlist->Resize(num_elements);
@@ -360,10 +408,8 @@
 
   // tensorlist.get_item(%list, %index, %element_shape) -> %item
   StatusOr<vm::ref<iree_hal_buffer_view_t>> GetItem(
-      vm::ref<TensorList> tensorlist, vm::ref<iree_hal_buffer_view_t> index_buf,
-      vm::ref<iree_hal_buffer_view_t> element_shape) {
-    // TODO(silvasean): Emulate element shape and dtype tracking in TensorList.
-    (void)element_shape;
+      vm::ref<TensorList> tensorlist,
+      vm::ref<iree_hal_buffer_view_t> index_buf) {
     IREE_ASSIGN_OR_RETURN(int32_t index,
                           ReadInt32FromScalarBufferView(index_buf.get()));
     return vm::retain_ref(tensorlist->GetItem(index).get());
@@ -373,35 +419,29 @@
   StatusOr<vm::ref<TensorList>> SetItem(
       vm::ref<TensorList> list, vm::ref<iree_hal_buffer_view_t> index_buf,
       vm::ref<iree_hal_buffer_view_t> item) {
-    TensorList* new_list = new TensorList;
     IREE_ASSIGN_OR_RETURN(int32_t index,
                           ReadInt32FromScalarBufferView(index_buf.get()));
-    new_list->CopyFrom(list);
+    TensorList* new_list = new TensorList(list);
     new_list->SetItem(index, vm::retain_ref(item));
     return new_list;
   }
 
   // tensorlist.from_tensor(%tensor, %element_shape) -> %list
   StatusOr<vm::ref<TensorList>> FromTensor(
-      vm::ref<iree_hal_buffer_view_t> tensor,
-      vm::ref<iree_hal_buffer_view_t> element_shape) {
-    // TODO(silvasean): Emulate element shape and dtype tracking in TensorList.
-    (void)element_shape;
+      vm::ref<iree_hal_buffer_view_t> tensor) {
     return TensorList::FromTensor(tensor);
   }
 
   // tensorlist.concat(%list) -> %list
-  StatusOr<vm::ref<iree_hal_buffer_view_t>> Concat(vm::ref<TensorList> list) {
-    return list->Concat();
+  StatusOr<vm::ref<iree_hal_buffer_view_t>> Concat(
+      vm::ref<iree_hal_allocator_t> allocator, vm::ref<TensorList> list) {
+    return list->Concat(allocator);
   }
 
   // tensorlist.stack(%list, %element_shape, %num_elements) -> %list
   StatusOr<vm::ref<iree_hal_buffer_view_t>> Stack(
-      vm::ref<TensorList> list,
-      vm::ref<iree_hal_buffer_view_t> element_shape_buffer_view,
+      vm::ref<iree_hal_allocator_t> allocator, vm::ref<TensorList> list,
       vm::ref<iree_hal_buffer_view_t> num_elements_buffer_view) {
-    // TODO(silvasean): Emulate element shape and dtype tracking in TensorList.
-    (void)element_shape_buffer_view;
     IREE_ASSIGN_OR_RETURN(
         int32_t num_elements,
         ReadInt32FromScalarBufferView(num_elements_buffer_view.get()));
@@ -410,7 +450,7 @@
              << "num_elements arg to tesorlist.stack doesn't match the list "
                 "size";
     }
-    return list->Stack();
+    return list->Stack(allocator);
   }
 };
 }  // namespace
diff --git a/iree/modules/tensorlist/tensorlist_test.cc b/iree/modules/tensorlist/tensorlist_test.cc
index 3293d57..8eba8e9 100644
--- a/iree/modules/tensorlist/tensorlist_test.cc
+++ b/iree/modules/tensorlist/tensorlist_test.cc
@@ -96,6 +96,91 @@
     return function;
   }
 
+  void Invoke(absl::string_view function_name,
+              absl::Span<const float> input_values,
+              absl::Span<const int32_t> input_shape,
+              absl::Span<const float> expected_values,
+              absl::Span<const int32_t> expected_shape) {
+    vm::ref<iree_hal_buffer_view_t> input_buffer_view;
+    CreateBufferView(input_values, input_shape, device_, &input_buffer_view);
+
+    // Pass in the tensor as a HAL buffer view.
+    vm::ref<iree_vm_list_t> inputs;
+    IREE_ASSERT_OK(iree_vm_list_create(/*element_type=*/nullptr, 1,
+                                       iree_allocator_system(), &inputs));
+    iree_vm_ref_t input_buffer_view_ref =
+        iree_hal_buffer_view_move_ref(input_buffer_view.get());
+    IREE_ASSERT_OK(
+        iree_vm_list_push_ref_retain(inputs.get(), &input_buffer_view_ref));
+
+    // Prepare outputs list to accept the results from the invocation.
+    vm::ref<iree_vm_list_t> outputs;
+    IREE_ASSERT_OK(iree_vm_list_create(/*element_type=*/nullptr, 1,
+                                       iree_allocator_system(), &outputs));
+
+    // Synchronously invoke the function.
+    IREE_ASSERT_OK(iree_vm_invoke(context_, LookupFunction(function_name),
+                                  /*policy=*/nullptr, inputs.get(),
+                                  outputs.get(), iree_allocator_system()));
+
+    auto* returned_buffer_view =
+        reinterpret_cast<iree_hal_buffer_view_t*>(iree_vm_list_get_ref_deref(
+            outputs.get(), 0, iree_hal_buffer_view_get_descriptor()));
+
+    absl::InlinedVector<int32_t, 5> returned_shape(
+        iree_hal_buffer_view_shape_rank(returned_buffer_view));
+    iree_hal_buffer_view_shape(returned_buffer_view, returned_shape.size(),
+                               returned_shape.data(), nullptr);
+
+    EXPECT_EQ(returned_shape, expected_shape);
+
+    iree_hal_buffer_t* returned_buffer =
+        iree_hal_buffer_view_buffer(returned_buffer_view);
+    ASSERT_NE(returned_buffer, nullptr);
+
+    iree_hal_mapped_memory_t mapped_memory;
+    IREE_ASSERT_OK(iree_hal_buffer_map(returned_buffer,
+                                       IREE_HAL_MEMORY_ACCESS_READ, 0,
+                                       IREE_WHOLE_BUFFER, &mapped_memory));
+    for (int i = 0; i < expected_values.size(); i++) {
+      EXPECT_EQ(reinterpret_cast<float*>(mapped_memory.contents.data)[i],
+                expected_values[i]);
+    }
+
+    IREE_ASSERT_OK(iree_hal_buffer_unmap(returned_buffer, &mapped_memory));
+  }
+
+  void CreateBufferView(absl::Span<const float> contents,
+                        absl::Span<const int32_t> shape,
+                        iree_hal_device_t* device,
+                        iree_hal_buffer_view_t** out_buffer_view) {
+    size_t num_elements = 1;
+    for (int32_t dim : shape) {
+      num_elements *= dim;
+    }
+    ASSERT_EQ(contents.size(), num_elements);
+    vm::ref<iree_hal_buffer_t> buffer;
+    iree_hal_allocator_t* allocator = iree_hal_device_allocator(device);
+    IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer(
+        allocator,
+        static_cast<iree_hal_memory_type_t>(
+            IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
+            IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE),
+        IREE_HAL_BUFFER_USAGE_ALL, contents.size() * sizeof(float), &buffer));
+    iree_hal_mapped_memory_t mapped_memory;
+    IREE_ASSERT_OK(iree_hal_buffer_map(buffer.get(),
+                                       IREE_HAL_MEMORY_ACCESS_WRITE, 0,
+                                       IREE_WHOLE_BUFFER, &mapped_memory));
+    memcpy(mapped_memory.contents.data,
+           static_cast<const void*>(contents.data()),
+           mapped_memory.contents.data_length);
+    IREE_ASSERT_OK(iree_hal_buffer_unmap(buffer.get(), &mapped_memory));
+    IREE_ASSERT_OK(iree_hal_buffer_view_create(
+        buffer.get(), shape.data(), shape.size(),
+        IREE_HAL_ELEMENT_TYPE_FLOAT_32, iree_allocator_system(),
+        &*out_buffer_view));
+  }
+
   iree_hal_device_t* device_ = nullptr;
   iree_vm_instance_t* instance_ = nullptr;
   iree_vm_context_t* context_ = nullptr;
@@ -104,176 +189,53 @@
   iree_vm_module_t* hal_module_ = nullptr;
 };
 
-void CreateBufferView(absl::Span<float> contents,
-                      absl::Span<const int32_t> shape,
-                      iree_hal_device_t* device,
-                      iree_hal_buffer_view_t** out_buffer_view) {
-  size_t num_elements = 1;
-  for (int32_t dim : shape) {
-    num_elements *= dim;
-  }
-  ASSERT_EQ(contents.size(), num_elements);
-  vm::ref<iree_hal_buffer_t> buffer;
-  iree_hal_allocator_t* allocator = iree_hal_device_allocator(device);
-  IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer(
-      allocator,
-      static_cast<iree_hal_memory_type_t>(IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
-                                          IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE),
-      IREE_HAL_BUFFER_USAGE_ALL, contents.size() * sizeof(float), &buffer));
-  iree_hal_mapped_memory_t mapped_memory;
-  IREE_ASSERT_OK(iree_hal_buffer_map(buffer.get(), IREE_HAL_MEMORY_ACCESS_WRITE,
-                                     0, IREE_WHOLE_BUFFER, &mapped_memory));
-  memcpy(mapped_memory.contents.data, static_cast<void*>(contents.data()),
-         mapped_memory.contents.data_length);
-  IREE_ASSERT_OK(iree_hal_buffer_unmap(buffer.get(), &mapped_memory));
-  IREE_ASSERT_OK(iree_hal_buffer_view_create(
-      buffer.get(), shape.data(), shape.size(), IREE_HAL_ELEMENT_TYPE_FLOAT_32,
-      iree_allocator_system(), &*out_buffer_view));
-}
-
 TEST_F(TensorListModulesTest, IdentityThroughSetItemGetItem) {
   // Allocate the buffer we'll be passing through.
-  static float kBufferContents[1] = {42.0f};
-  absl::InlinedVector<int32_t, 4> shape;
-  vm::ref<iree_hal_buffer_view_t> input_buffer_view;
-  CreateBufferView(kBufferContents, shape, device_, &input_buffer_view);
+  std::vector<float> input = {42.0f};
+  std::vector<int32_t> input_shape = {};
+  Invoke("identity_through_set_item_get_item", input, input_shape, input,
+         input_shape);
+}
 
-  // Pass in the tensor as a HAL buffer view.
-  vm::ref<iree_vm_list_t> inputs;
-  IREE_ASSERT_OK(iree_vm_list_create(/*element_type=*/nullptr, 1,
-                                     iree_allocator_system(), &inputs));
-  iree_vm_ref_t input_buffer_view_ref =
-      iree_hal_buffer_view_move_ref(input_buffer_view.get());
-  IREE_ASSERT_OK(
-      iree_vm_list_push_ref_retain(inputs.get(), &input_buffer_view_ref));
-
-  // Prepare outputs list to accept the results from the invocation.
-  vm::ref<iree_vm_list_t> outputs;
-  IREE_ASSERT_OK(iree_vm_list_create(/*element_type=*/nullptr, 1,
-                                     iree_allocator_system(), &outputs));
-
-  // Synchronously invoke the function.
-  IREE_ASSERT_OK(iree_vm_invoke(
-      context_, LookupFunction("identity_through_set_item_get_item"),
-      /*policy=*/nullptr, inputs.get(), outputs.get(),
-      iree_allocator_system()));
-
-  auto* returned_buffer_view =
-      reinterpret_cast<iree_hal_buffer_view_t*>(iree_vm_list_get_ref_deref(
-          outputs.get(), 0, iree_hal_buffer_view_get_descriptor()));
-  ASSERT_NE(nullptr, returned_buffer_view);
-  iree_hal_buffer_t* returned_buffer =
-      iree_hal_buffer_view_buffer(returned_buffer_view);
-  ASSERT_NE(nullptr, returned_buffer);
-
-  iree_hal_mapped_memory_t mapped_memory;
-  IREE_ASSERT_OK(iree_hal_buffer_map(returned_buffer,
-                                     IREE_HAL_MEMORY_ACCESS_READ, 0,
-                                     IREE_WHOLE_BUFFER, &mapped_memory));
-  EXPECT_EQ(reinterpret_cast<float*>(mapped_memory.contents.data)[0],
-            kBufferContents[0]);
-  IREE_ASSERT_OK(iree_hal_buffer_unmap(returned_buffer, &mapped_memory));
+TEST_F(TensorListModulesTest, IdentityThroughSetItemGetItem2D) {
+  // Allocate the buffer we'll be passing through.
+  std::vector<float> input = {42.0f};
+  std::vector<int32_t> input_shape = {1, 1};
+  Invoke("identity_through_set_item_get_item", input, input_shape, input,
+         input_shape);
 }
 
 TEST_F(TensorListModulesTest, IdentityThroughConcat) {
   // Allocate the buffer we'll be passing through.
-  static float kBufferContents[4] = {42.0f, 43.0f, 44.0f, 45.0f};
-  absl::InlinedVector<int32_t, 4> shape = {4, 1};
-  vm::ref<iree_hal_buffer_view_t> input_buffer_view;
-  CreateBufferView(kBufferContents, shape, device_, &input_buffer_view);
+  std::vector<float> input = {42.0f, 43.0f, 44.0f, 45.0f};
+  absl::InlinedVector<int32_t, 4> input_shape = {4, 1};
+  absl::InlinedVector<int32_t, 4> expected_shape = {4};
+  Invoke("identity_through_concat", input, input_shape, input, expected_shape);
+}
 
-  // Pass in the tensor as a HAL buffer view.
-  vm::ref<iree_vm_list_t> inputs;
-  IREE_ASSERT_OK(iree_vm_list_create(/*element_type=*/nullptr, 1,
-                                     iree_allocator_system(), &inputs));
-  iree_vm_ref_t input_buffer_view_ref =
-      iree_hal_buffer_view_move_ref(input_buffer_view.get());
-  IREE_ASSERT_OK(
-      iree_vm_list_push_ref_retain(inputs.get(), &input_buffer_view_ref));
-
-  // Prepare outputs list to accept the results from the invocation.
-  vm::ref<iree_vm_list_t> outputs;
-  IREE_ASSERT_OK(iree_vm_list_create(/*element_type=*/nullptr, 1,
-                                     iree_allocator_system(), &outputs));
-
-  // Synchronously invoke the function.
-  IREE_ASSERT_OK(iree_vm_invoke(context_,
-                                LookupFunction("identity_through_concat"),
-                                /*policy=*/nullptr, inputs.get(), outputs.get(),
-                                iree_allocator_system()));
-
-  auto* returned_buffer_view =
-      reinterpret_cast<iree_hal_buffer_view_t*>(iree_vm_list_get_ref_deref(
-          outputs.get(), 0, iree_hal_buffer_view_get_descriptor()));
-  ASSERT_NE(nullptr, returned_buffer_view);
-  iree_hal_buffer_t* returned_buffer =
-      iree_hal_buffer_view_buffer(returned_buffer_view);
-  ASSERT_NE(nullptr, returned_buffer);
-
-  // Dimemsionality is reduced by 1.
-  auto returned_rank = iree_hal_buffer_view_shape_rank(returned_buffer_view);
-  ASSERT_EQ(returned_rank, shape.size() - 1);
-
-  iree_hal_dim_t returned_shape[1];
-  iree_hal_buffer_view_shape(returned_buffer_view, 1, returned_shape,
-                             &returned_rank);
-  EXPECT_EQ(returned_shape[0], shape[0] * shape[1]);
-
-  iree_hal_mapped_memory_t mapped_memory;
-  IREE_ASSERT_OK(iree_hal_buffer_map(returned_buffer,
-                                     IREE_HAL_MEMORY_ACCESS_READ, 0,
-                                     IREE_WHOLE_BUFFER, &mapped_memory));
-  EXPECT_EQ(std::memcmp(mapped_memory.contents.data,
-                        static_cast<void*>(&kBufferContents[0]),
-                        mapped_memory.contents.data_length),
-            0);
-  IREE_ASSERT_OK(iree_hal_buffer_unmap(returned_buffer, &mapped_memory));
+TEST_F(TensorListModulesTest, ConcatAppendsEmpty) {
+  // Allocate the buffer we'll be passing through.
+  std::vector<float> input = {42.0f};
+  absl::InlinedVector<int32_t, 4> input_shape = {1};
+  std::vector<float> expected = {42.0f, 0.0f};
+  absl::InlinedVector<int32_t, 4> expected_shape = {2};
+  Invoke("concat_appends_empty", input, input_shape, expected, expected_shape);
 }
 
 TEST_F(TensorListModulesTest, IdentityThroughStack) {
   // Allocate the buffer we'll be passing through.
-  static float kBufferContents[2] = {42.0f, 43.0f};
-  absl::InlinedVector<int32_t, 4> shape = {2, 1};
-  vm::ref<iree_hal_buffer_view_t> input_buffer_view;
-  CreateBufferView(kBufferContents, shape, device_, &input_buffer_view);
+  std::vector<float> input = {42.0f, 43.0f};
+  absl::InlinedVector<int32_t, 4> input_shape = {2, 1};
+  Invoke("identity_through_stack", input, input_shape, input, input_shape);
+}
 
-  // Pass in the tensor as a HAL buffer view.
-  vm::ref<iree_vm_list_t> inputs;
-  IREE_ASSERT_OK(iree_vm_list_create(/*element_type=*/nullptr, 1,
-                                     iree_allocator_system(), &inputs));
-  iree_vm_ref_t input_buffer_view_ref =
-      iree_hal_buffer_view_move_ref(input_buffer_view.get());
-  IREE_ASSERT_OK(
-      iree_vm_list_push_ref_retain(inputs.get(), &input_buffer_view_ref));
-
-  // Prepare outputs list to accept the results from the invocation.
-  vm::ref<iree_vm_list_t> outputs;
-  IREE_ASSERT_OK(iree_vm_list_create(/*element_type=*/nullptr, 1,
-                                     iree_allocator_system(), &outputs));
-
-  // Synchronously invoke the function.
-  IREE_ASSERT_OK(iree_vm_invoke(context_,
-                                LookupFunction("identity_through_stack"),
-                                /*policy=*/nullptr, inputs.get(), outputs.get(),
-                                iree_allocator_system()));
-
-  auto* returned_buffer_view =
-      reinterpret_cast<iree_hal_buffer_view_t*>(iree_vm_list_get_ref_deref(
-          outputs.get(), 0, iree_hal_buffer_view_get_descriptor()));
-  ASSERT_NE(nullptr, returned_buffer_view);
-  iree_hal_buffer_t* returned_buffer =
-      iree_hal_buffer_view_buffer(returned_buffer_view);
-  ASSERT_NE(nullptr, returned_buffer);
-
-  iree_hal_mapped_memory_t mapped_memory;
-  IREE_ASSERT_OK(iree_hal_buffer_map(returned_buffer,
-                                     IREE_HAL_MEMORY_ACCESS_READ, 0,
-                                     IREE_WHOLE_BUFFER, &mapped_memory));
-  EXPECT_EQ(std::memcmp(mapped_memory.contents.data,
-                        static_cast<void*>(&kBufferContents[0]),
-                        mapped_memory.contents.data_length),
-            0);
-  IREE_ASSERT_OK(iree_hal_buffer_unmap(returned_buffer, &mapped_memory));
+TEST_F(TensorListModulesTest, StackAppendsEmpty) {
+  // Allocate the buffer we'll be passing through.
+  std::vector<float> input = {42.0f};
+  absl::InlinedVector<int32_t, 4> input_shape = {};
+  std::vector<float> expected = {42.0f, 0.0f};
+  absl::InlinedVector<int32_t, 4> expected_shape = {2};
+  Invoke("stack_appends_empty", input, input_shape, expected, expected_shape);
 }
 
 }  // namespace
diff --git a/iree/modules/tensorlist/tensorlist_test.mlir b/iree/modules/tensorlist/tensorlist_test.mlir
index dfe10a4..6425b1b 100644
--- a/iree/modules/tensorlist/tensorlist_test.mlir
+++ b/iree/modules/tensorlist/tensorlist_test.mlir
@@ -4,27 +4,62 @@
   %0 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<1> : tensor<i32>
   %1 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<[]> : tensor<0xi32>
   %2 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<0> : tensor<i32>
-  %3 = "tensorlist.Reserve"(%1, %0) : (!hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
+  %3 = "tensorlist.Reserve"(%1, %0) { element_type = 50331680 : i32} : (!hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
   %4 = "tensorlist.SetItem"(%3, %2, %arg0) : (!tensorlist.list, !hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
-  %5 = "tensorlist.GetItem"(%4, %2, %1) : (!tensorlist.list, !hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view
+  %5 = "tensorlist.GetItem"(%4, %2) : (!tensorlist.list, !hal.buffer_view) -> !hal.buffer_view
   return %5 : !hal.buffer_view
 }
 
+func @identity_through_set_item_get_item_2D(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.module.export, iree.abi.none} {
+  %dev = hal.ex.shared_device : !hal.device
+  %allocator = hal.device.allocator %dev : !hal.allocator
+  %0 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<1> : tensor<i32>
+  %1 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<[1, 1]> : tensor<2xi32>
+  %2 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<0> : tensor<i32>
+  %3 = "tensorlist.Reserve"(%1, %0) { element_type = 50331680 : i32} : (!hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
+  %4 = "tensorlist.SetItem"(%3, %2, %arg0) : (!tensorlist.list, !hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
+  %stacked = "tensorlist.Stack"(%allocator, %4, %0) : (!hal.allocator, !tensorlist.list, !hal.buffer_view) -> !hal.buffer_view
+  return %stacked : !hal.buffer_view
+}
+
 func @identity_through_concat(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.module.export, iree.abi.none} {
   %dev = hal.ex.shared_device : !hal.device
   %allocator = hal.device.allocator %dev : !hal.allocator
   %element_shape = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<[]> : tensor<0xi32>
-  %list = "tensorlist.FromTensor"(%arg0, %element_shape) : (!hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
-  %concat = "tensorlist.Concat"(%list) : (!tensorlist.list) -> !hal.buffer_view
+  %list = "tensorlist.FromTensor"(%arg0) : (!hal.buffer_view) -> !tensorlist.list
+  %concat = "tensorlist.Concat"(%allocator, %list) : (!hal.allocator, !tensorlist.list) -> !hal.buffer_view
+  return %concat : !hal.buffer_view
+}
+
+func @concat_appends_empty(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.module.export, iree.abi.none} {
+  %dev = hal.ex.shared_device : !hal.device
+  %allocator = hal.device.allocator %dev : !hal.allocator
+  %0 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<2> : tensor<i32>
+  %1 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<[1]> : tensor<1xi32>
+  %2 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<0> : tensor<i32>
+  %3 = "tensorlist.Reserve"(%1, %0) { element_type = 50331680 : i32} : (!hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
+  %4 = "tensorlist.SetItem"(%3, %2, %arg0) : (!tensorlist.list, !hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
+  %concat = "tensorlist.Concat"(%allocator, %4) : (!hal.allocator, !tensorlist.list) -> !hal.buffer_view
   return %concat : !hal.buffer_view
 }
 
 func @identity_through_stack(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.module.export, iree.abi.none} {
   %dev = hal.ex.shared_device : !hal.device
   %allocator = hal.device.allocator %dev : !hal.allocator
-  %element_shape = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<[]> : tensor<0xi32>
   %num_elements = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<2> : tensor<i32>
-  %list = "tensorlist.FromTensor"(%arg0, %element_shape) : (!hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
-  %stacked = "tensorlist.Stack"(%list, %element_shape, %num_elements) : (!tensorlist.list, !hal.buffer_view, !hal.buffer_view) -> !hal.buffer_view
+  %list = "tensorlist.FromTensor"(%arg0) : (!hal.buffer_view) -> !tensorlist.list
+  %stacked = "tensorlist.Stack"(%allocator, %list, %num_elements) : (!hal.allocator, !tensorlist.list, !hal.buffer_view) -> !hal.buffer_view
+  return %stacked : !hal.buffer_view
+}
+
+func @stack_appends_empty(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.module.export, iree.abi.none} {
+  %dev = hal.ex.shared_device : !hal.device
+  %allocator = hal.device.allocator %dev : !hal.allocator
+  %0 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<2> : tensor<i32>
+  %1 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<[]> : tensor<0xi32>
+  %2 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<0> : tensor<i32>
+  %3 = "tensorlist.Reserve"(%1, %0) { element_type = 50331680 : i32} : (!hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
+  %4 = "tensorlist.SetItem"(%3, %2, %arg0) : (!tensorlist.list, !hal.buffer_view, !hal.buffer_view) -> !tensorlist.list
+  %stacked = "tensorlist.Stack"(%allocator, %4, %0) : (!hal.allocator, !tensorlist.list, !hal.buffer_view) -> !hal.buffer_view
   return %stacked : !hal.buffer_view
 }
diff --git a/iree/test/e2e/models/fullyconnected.mlir b/iree/test/e2e/models/fullyconnected.mlir
index a7f4cd4..81a5dee 100644
--- a/iree/test/e2e/models/fullyconnected.mlir
+++ b/iree/test/e2e/models/fullyconnected.mlir
@@ -1,5 +1,5 @@
 // RUN: iree-run-mlir -export-all %s -iree-hal-target-backends=vmla -function-input="1x5xf32=1,-2,-3,4,-5" -function-input="1x5x3x1xf32=15,14,13,12,11,10,9,8,7,6,5,4,3,2,1" | IreeFileCheck %s
-// RUN: [[ $IREE_LLVMJIT_DISABLE == 1 ]] || (iree-run-mlir -export-all %s -iree-hal-target-backends=llvm-ir -function-input="1x5xf32=1,-2,-3,4,-5" -function-input="1x5x3x1xf32=15,14,13,12,11,10,9,8,7,6,5,4,3,2,1" | IreeFileCheck %s)
+// RUN: [[ $IREE_LLVMJIT_DISABLE == 1 ]] || (iree-run-mlir -export-all %s -iree-hal-target-backends=llvm-ir -function-input="1x5xf32=1,-2,-3,4,-5" -function-input="1x5x3x1xf32=15,14,13,12,11,10,9,8,7,6,5,4,3,2,1" -iree-enable-matmul-fusion | IreeFileCheck %s)
 // RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -export-all %s -iree-hal-target-backends=vulkan-spirv -function-input="1x5xf32=1,-2,-3,4,-5" -function-input="1x5x3x1xf32=15,14,13,12,11,10,9,8,7,6,5,4,3,2,1" | IreeFileCheck %s)
 
 // CHECK-LABEL: EXEC @main
diff --git a/iree/test/e2e/regression/matmul_add.mlir b/iree/test/e2e/regression/matmul_add.mlir
new file mode 100644
index 0000000..1584692
--- /dev/null
+++ b/iree/test/e2e/regression/matmul_add.mlir
@@ -0,0 +1,26 @@
+// RUN: [[ $IREE_LLVMJIT_DISABLE == 1 ]] || (iree-run-mlir -export-all -iree-hal-target-backends=llvm-ir -iree-enable-matmul-fusion %s | IreeFileCheck %s)
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -export-all -iree-hal-target-backends=vulkan-spirv -iree-enable-matmul-fusion %s | IreeFileCheck %s)
+
+func @matmul_add() -> tensor<2x4xf32> {
+  %0 = iree.unfoldable_constant dense<[
+    [1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0],
+    [9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]]> : tensor<2x8xf32>
+  %1 = iree.unfoldable_constant dense<[
+    [ 1.0,  2.0,  3.0,  4.0],
+    [ 5.0,  6.0,  7.0,  8.0],
+    [ 9.0, 10.0, 11.0, 12.0],
+    [13.0, 14.0, 15.0, 16.0],
+    [17.0, 18.0, 19.0, 20.0],
+    [21.0, 22.0, 23.0, 24.0],
+    [25.0, 26.0, 27.0, 28.0],
+    [29.0, 30.0, 31.0, 32.0]]> : tensor<8x4xf32>
+  %2 = iree.unfoldable_constant dense<[
+    [1.0, 2.0, 3.0, 4.0],
+    [5.0, 6.0, 7.0, 8.0]]> : tensor<2x4xf32>
+  %3 = "mhlo.dot"(%0, %1) : (tensor<2x8xf32>, tensor<8x4xf32>) -> tensor<2x4xf32>
+  %4 = mhlo.add %3, %2 : tensor<2x4xf32>
+  return %4 : tensor<2x4xf32>
+}
+
+// CHECK: EXEC @matmul_add
+// CHECK: 2x4xf32=[709 746 783 820][1673 1774 1875 1976]
diff --git a/iree/test/e2e/vulkan_specific/BUILD b/iree/test/e2e/vulkan_specific/BUILD
index adde1a1..b340804 100644
--- a/iree/test/e2e/vulkan_specific/BUILD
+++ b/iree/test/e2e/vulkan_specific/BUILD
@@ -57,3 +57,16 @@
     driver = "vulkan",
     target_backend = "vulkan-spirv",
 )
+
+iree_check_single_backend_test_suite(
+    name = "check_vulkan-spirv_vulkan_vectorized_conv",
+    srcs = [
+        "vectorized_conv.mlir",
+    ],
+    compiler_flags = [
+        "-iree-spirv-enable-vectorization",
+        "-iree-vulkan-target-triple=valhall-g77-unknown-android10",
+    ],
+    driver = "vulkan",
+    target_backend = "vulkan-spirv",
+)
diff --git a/iree/test/e2e/vulkan_specific/CMakeLists.txt b/iree/test/e2e/vulkan_specific/CMakeLists.txt
index a6126da..0cc1b30 100644
--- a/iree/test/e2e/vulkan_specific/CMakeLists.txt
+++ b/iree/test/e2e/vulkan_specific/CMakeLists.txt
@@ -56,3 +56,17 @@
   COMPILER_FLAGS
     "-iree-spirv-enable-memref-vectorization"
 )
+
+iree_check_single_backend_test_suite(
+  NAME
+    check_vulkan-spirv_vulkan_vectorized_conv
+  SRCS
+    "vectorized_conv.mlir"
+  TARGET_BACKEND
+    vulkan-spirv
+  DRIVER
+    vulkan
+  COMPILER_FLAGS
+    "-iree-spirv-enable-vectorization"
+    "-iree-vulkan-target-triple=valhall-g77-unknown-android10"
+)
diff --git a/iree/test/e2e/vulkan_specific/vectorized_conv.mlir b/iree/test/e2e/vulkan_specific/vectorized_conv.mlir
new file mode 100644
index 0000000..39ad667
--- /dev/null
+++ b/iree/test/e2e/vulkan_specific/vectorized_conv.mlir
@@ -0,0 +1,70 @@
+func @conv() attributes { iree.module.export } {
+  %input = iree.unfoldable_constant dense<
+     [[[[6.0, 7.5, 0.0, 1.5],
+        [1.5, 3.5, 4.5, 2.0],
+        [3.0, 6.0, 0.5, 3.0]],
+       [[3.5, 7.0, 2.5, 6.5],
+        [4.0, 4.5, 8.0, 2.5],
+        [7.5, 7.5, 0.0, 1.5]],
+       [[7.0, 3.5, 0.0, 0.5],
+        [4.5, 0.0, 5.0, 1.5],
+        [5.5, 1.0, 0.0, 0.0]]]]>
+    : tensor<1x3x3x4xf32>
+  %filter = iree.unfoldable_constant dense<
+      [[[[2.0, 2.5, 2.5, 3.0, 4.0, 2.0, 0.5, 2.0, 4.5, 5.0, 5.0, 4.0, 0.5, 0.5, 3.5, 4.5,
+          4.5, 1.5, 3.0, 3.5, 1.0, 0.0, 1.5, 2.5, 4.5, 5.0, 2.0, 2.0, 3.0, 2.0, 2.0, 1.5],
+         [2.0, 2.0, 4.0, 2.0, 1.5, 5.0, 3.5, 2.5, 2.5, 0.0, 0.5, 2.5, 4.5, 1.5, 0.0, 2.5,
+          0.0, 0.5, 1.0, 2.0, 1.0, 0.0, 1.5, 1.0, 5.0, 0.0, 3.5, 2.5, 4.5, 0.0, 5.0, 1.0],
+         [5.0, 3.5, 1.0, 4.5, 1.0, 1.5, 1.5, 1.0, 1.5, 2.0, 0.5, 1.0, 4.5, 5.0, 0.5, 2.0,
+          5.0, 3.0, 4.0, 1.0, 1.5, 0.0, 0.0, 3.0, 0.0, 3.0, 1.5, 5.0, 1.5, 4.0, 4.0, 4.0],
+         [1.0, 1.5, 1.0, 0.0, 4.0, 4.0, 1.5, 4.0, 5.0, 1.0, 4.0, 2.0, 1.5, 0.0, 2.0, 1.5,
+          3.0, 4.5, 4.0, 0.0, 4.0, 2.5, 4.5, 0.0, 4.5, 3.0, 2.5, 1.5, 0.5, 4.0, 0.0, 2.0]],
+        [[4.5, 3.0, 2.5, 3.5, 4.0, 4.0, 4.5, 1.0, 4.0, 3.0, 3.0, 4.5, 0.5, 3.0, 4.0, 4.0,
+          1.5, 1.0, 1.5, 5.0, 3.0, 1.5, 3.0, 2.5, 3.5, 0.0, 4.0, 2.0, 5.0, 3.0, 2.5, 4.0],
+         [1.0, 1.5, 4.5, 3.5, 2.5, 1.5, 2.0, 2.5, 1.5, 1.5, 3.5, 4.5, 4.5, 4.5, 3.5, 1.5,
+          5.0, 1.0, 1.5, 4.5, 5.0, 3.5, 3.5, 2.5, 0.5, 1.0, 1.0, 4.0, 0.5, 2.5, 4.0, 2.0],
+         [0.0, 1.0, 2.5, 2.5, 0.0, 4.0, 0.5, 0.5, 0.0, 1.5, 4.0, 4.0, 2.0, 2.0, 0.0, 4.5,
+          1.5, 3.5, 1.5, 1.0, 0.5, 0.5, 1.0, 0.5, 2.0, 1.0, 2.5, 2.5, 2.5, 1.0, 2.5, 3.5],
+         [3.5, 3.0, 0.5, 3.0, 3.5, 1.0, 1.5, 0.5, 4.5, 2.5, 4.5, 4.5, 1.0, 0.0, 4.5, 0.5,
+          4.5, 5.0, 0.0, 3.0, 0.0, 5.0, 2.0, 4.0, 2.0, 1.5, 1.5, 4.0, 4.0, 3.5, 0.0, 1.5]]],
+       [[[4.0, 3.5, 3.5, 5.0, 0.5, 4.0, 2.0, 3.5, 0.0, 2.0, 4.5, 0.0, 5.0, 3.0, 2.0, 1.0,
+          2.0, 3.0, 1.5, 5.0, 1.5, 3.5, 4.0, 2.5, 0.0, 4.0, 2.5, 2.0, 3.5, 5.0, 5.0, 2.0],
+         [0.5, 1.5, 1.5, 4.5, 1.0, 2.5, 1.0, 1.5, 2.5, 5.0, 3.5, 1.0, 3.5, 0.5, 3.0, 5.0,
+          2.5, 0.0, 0.0, 5.0, 1.5, 5.0, 0.5, 5.0, 4.5, 4.5, 3.0, 3.0, 3.5, 4.0, 4.0, 3.5],
+         [0.0, 4.0, 3.0, 4.0, 4.5, 4.0, 1.5, 3.0, 0.5, 3.5, 2.0, 4.5, 1.0, 0.0, 4.0, 1.0,
+          3.5, 4.0, 2.0, 2.0, 0.5, 3.5, 3.0, 4.5, 2.0, 0.5, 2.5, 4.5, 3.5, 0.5, 1.5, 2.5],
+         [3.5, 1.5, 3.0, 3.0, 3.5, 4.5, 0.5, 4.5, 3.0, 0.0, 1.5, 4.0, 2.0, 0.5, 2.0, 2.5,
+          0.0, 1.5, 5.0, 0.5, 2.0, 2.0, 2.0, 0.0, 0.0, 5.0, 4.0, 2.0, 3.0, 4.5, 1.5, 1.5]],
+        [[1.0, 0.5, 5.0, 1.0, 0.5, 1.5, 2.0, 5.0, 0.5, 0.5, 0.0, 3.5, 4.0, 5.0, 2.0, 1.5,
+          2.5, 3.0, 1.5, 1.0, 4.5, 4.0, 0.5, 2.0, 5.0, 0.0, 4.0, 1.5, 4.5, 2.5, 2.5, 0.5],
+         [3.5, 4.0, 3.0, 2.0, 3.5, 1.5, 2.5, 1.5, 3.0, 2.0, 3.5, 1.5, 0.0, 2.5, 4.5, 1.5,
+          3.5, 2.5, 2.5, 4.0, 0.0, 4.0, 1.5, 3.0, 4.5, 5.0, 1.5, 1.0, 3.5, 0.0, 1.5, 5.0],
+         [0.0, 1.5, 3.0, 0.5, 4.5, 1.0, 4.5, 2.0, 4.5, 0.5, 1.5, 1.0, 2.0, 4.5, 3.5, 2.0,
+          4.5, 2.0, 0.5, 1.0, 3.5, 1.0, 1.5, 4.5, 5.0, 3.5, 5.0, 3.0, 3.0, 1.0, 5.0, 1.5],
+         [3.0, 0.0, 5.0, 4.0, 0.0, 5.0, 3.5, 3.0, 2.5, 4.5, 3.0, 2.5, 1.0, 3.5, 0.5, 4.5,
+          1.0, 1.0, 2.5, 3.0, 2.0, 1.0, 1.0, 0.5, 0.0, 4.5, 0.0, 1.0, 4.0, 1.5, 5.0, 0.0]]]]>
+    : tensor<2x2x4x32xf32>
+
+    %0 = "mhlo.convolution"(%input, %filter) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<0> : tensor<2x2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<1x3x3x4xf32>, tensor<2x2x4x32xf32>) -> tensor<1x2x2x32xf32>
+
+   check.expect_almost_eq_const(%0, dense<
+     [[[[113.25, 127.0, 198.0, 173.25, 159.5, 190.75, 135.5, 160.0,
+         169.5, 130.0, 173.75, 174.5, 158.5, 136.75, 159.75, 177.75,
+         164.5, 122.25, 116.0, 168.0, 124.75, 144.0, 113.5, 159.0,
+         208.0, 186.5, 190.5, 158.5, 213.75, 140.5, 206.75, 135.25],
+        [129.75, 147.25, 181.25, 181.75, 142.5, 161.75, 117.75, 153.25,
+         119.5, 128.75, 149.25, 171.0, 152.5, 142.5, 166.0, 122.25,
+         177.75, 142.75, 116.5, 170.0, 117.5, 176.75, 116.75, 162.25,
+         161.25, 135.0, 145.5, 163.25, 190.5, 138.25, 162.5, 146.75]],
+       [[111.75, 115.75, 173.5, 158.25, 122.5, 187.25, 129.0, 142.5,
+         142.25, 109.0, 175.75, 158.5, 172.75, 146.25, 122.25, 157.25,
+         157.5, 141.25, 104.25, 151.25, 136.25, 122.0, 127.75, 125.75,
+         180.5, 131.25, 168.75, 151.5, 180.75, 152.75, 193.5, 128.75],
+        [138.25, 133.75, 157.5, 168.5, 131.0, 149.75, 115.25, 130.75,
+         114.5, 107.25, 127.75, 163.75, 153.5, 149.25, 133.5, 114.0,
+         164.75, 120.75, 116.0, 149.5, 127.5, 113.5, 116.0, 129.75,
+         126.75, 94.25, 135.0, 157.75, 158.75, 142.0, 158.75, 126.25]]]]>
+        : tensor<1x2x2x32xf32>) : tensor<1x2x2x32xf32>
+   return
+}
+
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index de1c9b9..0f63c8c 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -14,13 +14,9 @@
 
 # bazel_to_cmake: DO NOT EDIT (Special logic is used throughout this file)
 
-# This need to come first so targets in the android/ directory can depend on it.
-# TODO(#3317): this seems to indicate an issue somewhere regarding dynamic
-# library dependency management.
-add_subdirectory(utils)
-
 add_subdirectory(android)
 add_subdirectory(test)
+add_subdirectory(utils)
 
 # Enable compiler targets based on options.
 set(IREE_COMPILER_TARGETS "")
@@ -53,6 +49,9 @@
     MLIREmitC
     MLIRTargetCpp
   )
+  set(IREE_EMITC_INCLUDES
+    "${PROJECT_SOURCE_DIR}/third_party/mlir-emitc/include"
+  )
 endif()
 
 iree_cc_binary(
@@ -208,6 +207,8 @@
       iree::compiler::Conversion::init_conversions
       iree::compiler::Conversion::HLOToLinalg
       iree::compiler::Dialect::HAL::Conversion::Passes
+    INCLUDES
+      "${IREE_EMITC_INCLUDES}"
     PUBLIC
   )
 
@@ -295,6 +296,8 @@
       iree::compiler::Dialect::VM::Target::init_targets
       iree::compiler::Translation::IREEVM
       ${IREE_TRANSLATE_CONDITIONAL_DEPS}
+    INCLUDES
+      "${IREE_EMITC_INCLUDES}"
     PUBLIC
   )
 
diff --git a/iree/vm/bytecode_dispatch.c b/iree/vm/bytecode_dispatch.c
index bb5241f..2a9ef1b 100644
--- a/iree/vm/bytecode_dispatch.c
+++ b/iree/vm/bytecode_dispatch.c
@@ -14,27 +14,12 @@
 
 #include <string.h>
 
+#include "iree/base/math.h"
 #include "iree/base/tracing.h"
 #include "iree/vm/api.h"
 #include "iree/vm/bytecode_dispatch_util.h"
 
 //===----------------------------------------------------------------------===//
-// Math utilities, kept here to limit dependencies
-//===----------------------------------------------------------------------===//
-
-// Rounds up the value to the nearest power of 2 (if not already a power of 2).
-static inline uint32_t iree_math_round_up_to_pow2_u32(uint32_t n) {
-  n--;
-  n |= n >> 1;
-  n |= n >> 2;
-  n |= n >> 4;
-  n |= n >> 8;
-  n |= n >> 16;
-  n++;
-  return n;
-}
-
-//===----------------------------------------------------------------------===//
 // Register remapping utilities
 //===----------------------------------------------------------------------===//
 
diff --git a/scripts/get_e2e_artifacts.py b/scripts/get_e2e_artifacts.py
index 2440ab7..21d9259 100755
--- a/scripts/get_e2e_artifacts.py
+++ b/scripts/get_e2e_artifacts.py
@@ -41,8 +41,6 @@
         '//integrations/tensorflow/e2e:e2e_tests',
     'mobile_bert_squad_tests':
         '//integrations/tensorflow/e2e:mobile_bert_squad_tests',
-    'keras_tests':
-        '//integrations/tensorflow/e2e/keras:keras_tests',
     'layers_tests':
         '//integrations/tensorflow/e2e/keras/layers:layers_tests',
     'layers_dynamic_batch_tests':
@@ -54,7 +52,7 @@
     'keyword_spotting_internal_streaming_tests':
         '//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests',
     'imagenet_non_hermetic_tests':
-        '//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests',
+        '//integrations/tensorflow/e2e/keras/applications:imagenet_non_hermetic_tests',
     'slim_vision_tests':
         '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests',
 }
diff --git a/scripts/update_e2e_coverage.py b/scripts/update_e2e_coverage.py
index 3b5509c..fba79b2 100755
--- a/scripts/update_e2e_coverage.py
+++ b/scripts/update_e2e_coverage.py
@@ -60,7 +60,7 @@
         '//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests',
     ],
     'vision_coverage': [
-        '//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests',
+        '//integrations/tensorflow/e2e/keras/applications:imagenet_non_hermetic_tests',
         '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests',
     ],
 }
@@ -116,7 +116,7 @@
     '//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
         f'End to end tests of {KWS_LINK} models in internal streaming mode',
     # vision_coverage
-    '//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests':
+    '//integrations/tensorflow/e2e/keras/applications:imagenet_non_hermetic_tests':
         'End to end tests of tf.keras.applications vision models on Imagenet',
     '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
         'End to end tests of TensorFlow slim vision models',
@@ -162,7 +162,7 @@
     '//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
         'model',
     # vision_coverage
-    '//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests':
+    '//integrations/tensorflow/e2e/keras/applications:imagenet_non_hermetic_tests':
         'model',
     '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
         'model',
@@ -193,7 +193,7 @@
     '//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
         'keyword_spotting_streaming_test',
     # vision_coverage
-    '//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests':
+    '//integrations/tensorflow/e2e/keras/applications:imagenet_non_hermetic_tests':
         'vision_model_test',
     '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
         'slim_vision_model_test',
diff --git a/third_party/googletest b/third_party/googletest
index f2fb48c..b1fbd33 160000
--- a/third_party/googletest
+++ b/third_party/googletest
@@ -1 +1 @@
-Subproject commit f2fb48c3b3d79a75a88a99fba6576b25d42ec528
+Subproject commit b1fbd33c06cdb0024c67733c6fdec2009d17b384
diff --git a/third_party/tracy b/third_party/tracy
index d7059ec..9c3dac3 160000
--- a/third_party/tracy
+++ b/third_party/tracy
@@ -1 +1 @@
-Subproject commit d7059eca6351546d1f51e248fc75e49dfeee709e
+Subproject commit 9c3dac3ed2bd647b8d63f197fed058fee97a7e1e