Synchronize submodules with LLVM at llvm/llvm-project@6a7a2ee8161d

Updates LLVM dependencies to match
[6a7a2ee8161d](https://github.com/llvm/llvm-project/commit/6a7a2ee8161d).
- TensorFlow to
  [4f8db85aa9ab](https://github.com/tensorflow/tensorflow/commit/4f8db85aa9ab)
- MLIR-HLO to
  [c65dc7a455dc](https://github.com/tensorflow/mlir-hlo/commit/c65dc7a455dc)

`./scripts/git/update_to_llvm_syncpoint.py `

PiperOrigin-RevId: 387433820
diff --git a/.gitignore b/.gitignore
index 4f52c93..6fe9bf2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -58,3 +58,6 @@
 # Generated documentation files
 mkdocs/site/
 docs/website/site/
+
+# Temporary files
+iree/compiler/Dialect/HAL/Target/LLVM/librt/bin/librt.ll
diff --git a/CMakeLists.txt b/CMakeLists.txt
index c466088..e4df9d5 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -322,6 +322,23 @@
 endif()
 
 #-------------------------------------------------------------------------------
+# Check if git submodules have been initialized.
+# This will only run if python3 is available.
+#-------------------------------------------------------------------------------
+
+find_package(Python3 COMPONENTS Interpreter QUIET)
+if(Python3_FOUND)
+  execute_process(
+    COMMAND ${Python3_EXECUTABLE} scripts/git/check_submodule_init.py
+    WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
+    RESULT_VARIABLE ret
+  )
+  if(NOT ret EQUAL "0")
+    message(FATAL_ERROR "check_submodule_init.py failed, see the logs above")
+  endif()
+endif()
+
+#-------------------------------------------------------------------------------
 # MLIR/LLVM Dependency
 # We treat the LLVM dependency specially because we support several different
 # ways to use it:
diff --git a/bindings/python/iree/compiler/xla.py b/bindings/python/iree/compiler/xla.py
index eb04231..a0b0578 100644
--- a/bindings/python/iree/compiler/xla.py
+++ b/bindings/python/iree/compiler/xla.py
@@ -74,6 +74,7 @@
                import_format: Union[ImportFormat,
                                     str] = ImportFormat.BINARY_PROTO,
                import_extra_args: Sequence[str] = (),
+               save_temp_mhlo_input: Optional[str] = None,
                save_temp_iree_input: Optional[str] = None,
                **kwargs):
     """Initialize options from keywords.
@@ -87,6 +88,7 @@
     self.import_only = import_only
     self.import_format = ImportFormat.parse(import_format)
     self.import_extra_args = import_extra_args
+    self.save_temp_mhlo_input = save_temp_mhlo_input
     self.save_temp_iree_input = save_temp_iree_input
 
 
@@ -121,6 +123,10 @@
     cl.append("--mlir-print-op-generic")
 
   # Save temps flags.
+  save_mhlo_input = tfs.alloc_optional("tf-mhlo.mlir",
+                                       export_as=options.save_temp_mhlo_input)
+  if save_mhlo_input:
+    cl.append(f"--save-temp-mhlo-input={save_mhlo_input}")
   iree_input = tfs.alloc_optional("xla-iree-input.mlir",
                                   export_as=options.save_temp_iree_input)
   if iree_input:
diff --git a/build_tools/buildkite/cmake/build_configurations.yml b/build_tools/buildkite/cmake/build_configurations.yml
index fb79ee9..426e5fb 100644
--- a/build_tools/buildkite/cmake/build_configurations.yml
+++ b/build_tools/buildkite/cmake/build_configurations.yml
@@ -5,7 +5,7 @@
 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 steps:
-  - label: "Build with tracing enabled"
+  - label: ":zap: Build with tracing enabled"
     commands:
       - "./scripts/git/submodule_versions.py init"
       - "docker run --user=$(id -u):$(id -g) --volume=\\$PWD:\\$IREE_DOCKER_WORKDIR --workdir=\\$IREE_DOCKER_WORKDIR --rm gcr.io/iree-oss/cmake@sha256:9d9953acf5ca0cf1ff3e8de32f10f24dfab1c4e8ec5d1fc047f556024ee4bed6 ./build_tools/cmake/build_tracing.sh"
@@ -14,7 +14,7 @@
     agents:
       - "queue=build"
 
-  - label: "Build the runtime only"
+  - label: ":hammer_and_wrench: Build the runtime only"
     commands:
       - "./scripts/git/submodule_versions.py init"
       - "docker run --user=$(id -u):$(id -g) --volume=\\$PWD:\\$IREE_DOCKER_WORKDIR --workdir=\\$IREE_DOCKER_WORKDIR --rm gcr.io/iree-oss/cmake@sha256:9d9953acf5ca0cf1ff3e8de32f10f24dfab1c4e8ec5d1fc047f556024ee4bed6 ./build_tools/cmake/build_runtime.sh"
@@ -22,3 +22,27 @@
       IREE_DOCKER_WORKDIR: "/usr/src/github/iree"
     agents:
       - "queue=build"
+
+  - label: ":linux: Build host install"
+    key: "build-host-install"
+    commands:
+      - "./scripts/git/submodule_versions.py init"
+      - "docker run --user=$(id -u):$(id -g) --volume=\\$PWD:\\$IREE_DOCKER_WORKDIR --workdir=\\$IREE_DOCKER_WORKDIR --rm gcr.io/iree-oss/cmake@sha256:9d9953acf5ca0cf1ff3e8de32f10f24dfab1c4e8ec5d1fc047f556024ee4bed6 ./build_tools/cmake/build_host_install.sh"
+      - "tar -czvf build-artifacts.tgz build-host/install"
+    artifact_paths: "build-artifacts.tgz"
+    env:
+      IREE_DOCKER_WORKDIR: "/usr/src/github/iree"
+    agents:
+      - "queue=build"
+
+  - label: ":webassembly: Build WebAssembly runtime with Emscripten"
+    depends_on: "build-host-install"
+    commands:
+      - "buildkite-agent artifact download --step build-host-install build-artifacts.tgz ./"
+      - "tar xzf build-artifacts.tgz"
+      - "./scripts/git/submodule_versions.py init"
+      - "docker run --user=$(id -u):$(id -g) --volume=\\$PWD:\\$IREE_DOCKER_WORKDIR --workdir=\\$IREE_DOCKER_WORKDIR --rm gcr.io/iree-oss/cmake-emscripten@sha256:8acad361d23cb586187c2ea29df3a1ab301b5283c3648beb328681d69ecd0ab0 ./build_tools/cmake/build_runtime_emscripten.sh"
+    env:
+      IREE_DOCKER_WORKDIR: "/usr/src/github/iree"
+    agents:
+      - "queue=build"
diff --git a/build_tools/cmake/build_host_install.sh b/build_tools/cmake/build_host_install.sh
new file mode 100755
index 0000000..b5f8130
--- /dev/null
+++ b/build_tools/cmake/build_host_install.sh
@@ -0,0 +1,36 @@
+#!/bin/bash
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Builds host binaries (compiler tools and other utilities) and installs them
+# in build-host/install. Designed for CI, but can be run manually. This uses
+# previously cached build results and does not clear build directories.
+
+set -x
+set -e
+
+CMAKE_BIN=${CMAKE_BIN:-$(which cmake)}
+"${CMAKE_BIN?}" --version
+ninja --version
+
+ROOT_DIR=$(git rev-parse --show-toplevel)
+cd ${ROOT_DIR?}
+
+if [ -d "build-host" ]
+then
+  echo "build-host directory already exists. Will use cached results there."
+else
+  echo "build-host directory does not already exist. Creating a new one."
+  mkdir build-host
+fi
+cd build-host
+
+# Configure, build, install.
+"${CMAKE_BIN?}" -G Ninja .. \
+  -DCMAKE_INSTALL_PREFIX=./install \
+  -DIREE_BUILD_TESTS=OFF \
+  -DIREE_BUILD_SAMPLES=OFF
+"${CMAKE_BIN?}" --build . --target install
diff --git a/build_tools/cmake/build_runtime_emscripten.sh b/build_tools/cmake/build_runtime_emscripten.sh
new file mode 100755
index 0000000..ca2abc3
--- /dev/null
+++ b/build_tools/cmake/build_runtime_emscripten.sh
@@ -0,0 +1,51 @@
+#!/bin/bash
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Cross-compile IREE's runtime through Emscripten to WebAssembly with CMake.
+# Designed for CI, but can be run manually. This uses previously cached build
+# results and does not clear build directories.
+#
+# Host binaries (e.g. compiler tools) should already be built at
+# ./build-host/install. Emscripten binaries (e.g. .wasm and .js files) will be
+# built in ./build-emscripten/.
+
+set -x
+set -e
+
+if ! command -v emcmake &> /dev/null
+then
+    echo "'emcmake' not found, setup environment according to https://emscripten.org/docs/getting_started/downloads.html"
+    exit
+fi
+
+CMAKE_BIN=${CMAKE_BIN:-$(which cmake)}
+"${CMAKE_BIN?}" --version
+ninja --version
+
+ROOT_DIR=$(git rev-parse --show-toplevel)
+cd ${ROOT_DIR?}
+
+if [ -d "build-emscripten" ]
+then
+  echo "build-emscripten directory already exists. Will use cached results there."
+else
+  echo "build-emscripten directory does not already exist. Creating a new one."
+  mkdir build-emscripten
+fi
+cd build-emscripten
+
+# Configure using Emscripten's CMake wrapper, then build.
+emcmake "${CMAKE_BIN?}" -G Ninja .. \
+  -DIREE_HOST_BINARY_ROOT=$PWD/../build-host/install \
+  -DIREE_HAL_DRIVERS_TO_BUILD=VMVX\;DyLib \
+  -DIREE_BUILD_COMPILER=OFF \
+  -DIREE_ENABLE_MLIR=OFF \
+  -DIREE_BUILD_TESTS=OFF \
+  -DIREE_BUILD_SAMPLES=ON
+
+# TODO(scotttodd): expand this list of targets
+"${CMAKE_BIN?}" --build . --target iree_samples_simple_embedding_simple_embedding_vmvx_sync
diff --git a/build_tools/docker/cmake-emscripten/Dockerfile b/build_tools/docker/cmake-emscripten/Dockerfile
new file mode 100644
index 0000000..4f3f600
--- /dev/null
+++ b/build_tools/docker/cmake-emscripten/Dockerfile
@@ -0,0 +1,44 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# An image for building IREE through Emscripten using CMake.
+
+FROM gcr.io/iree-oss/cmake@sha256:9d9953acf5ca0cf1ff3e8de32f10f24dfab1c4e8ec5d1fc047f556024ee4bed6 AS final
+
+# See also
+#   * https://github.com/emscripten-core/emsdk/blob/main/docker/Dockerfile
+#   * https://hub.docker.com/r/emscripten/emsdk
+
+ARG EMSDK_COMMIT=5c0e31a03a72136ccaa761c408baf5a640f942ec
+ARG SDK_VERSION=2.0.25
+
+# Follow https://emscripten.org/docs/getting_started/downloads.html.
+RUN git clone https://github.com/emscripten-core/emsdk
+RUN cd emsdk && git checkout "${EMSDK_COMMIT?}" && \
+    ./emsdk install ${SDK_VERSION?} && \
+    ./emsdk activate ${SDK_VERSION?}
+
+# Set some environment variables for Emscripten to use.
+ENV EMSDK=/emsdk
+ENV EM_DATA=${EMSDK}/.data
+ENV EM_CONFIG=${EMSDK}/.emscripten
+ENV EM_CACHE=${EM_DATA}/cache
+ENV EM_PORTS=${EM_DATA}/ports
+# Emscripten writes into its cache location (outside of the CMake build
+# directory).
+# We can either
+#   (A) Grant broad write permissions to the cache directory to be able to run
+#       our scripts under different users.
+#   (B) Mount a user home directory when using the image.
+# Since (A) requires less configuration, we'll do that. If multiple tools would
+# want a user directory (like Bazel), we should switch to (B).
+# See https://github.com/emscripten-core/emsdk/issues/535
+RUN mkdir -p ${EM_CACHE} && chmod -R 777 ${EM_CACHE}
+
+# Normally we'd run `source emsdk_env.sh`, but that doesn't integrate with
+# Docker's environment properties model. Instead, we directly extend the path
+# to include the directories suggested by `emsdk activate`.
+ENV PATH="$PWD/emsdk:$PWD/emsdk/node/14.15.5_64bit/bin:$PWD/emsdk/upstream/emscripten:$PATH"
diff --git a/build_tools/docker/manage_images.py b/build_tools/docker/manage_images.py
index 9e9bfd8..ce61e2f 100755
--- a/build_tools/docker/manage_images.py
+++ b/build_tools/docker/manage_images.py
@@ -46,6 +46,7 @@
     'bazel': ['base', 'util'],
     'cmake': ['base', 'util'],
     'cmake-android': ['cmake-python', 'util'],
+    'cmake-emscripten': ['cmake'],
     'cmake-python': ['cmake'],
     'cmake-python-vulkan': ['cmake-python', 'vulkan'],
     'cmake-python-swiftshader': ['cmake-python-vulkan', 'swiftshader'],
diff --git a/build_tools/docker/prod_digests.txt b/build_tools/docker/prod_digests.txt
index f5e1f59..dd0a5b7 100644
--- a/build_tools/docker/prod_digests.txt
+++ b/build_tools/docker/prod_digests.txt
@@ -17,3 +17,4 @@
 gcr.io/iree-oss/cmake-riscv@sha256:95489593bc9b0cd325ce9c1a32b47389c01b174a5b8190a16d937d2e8828d384
 gcr.io/iree-oss/cmake-bazel-frontends-android@sha256:1392e3a27cddbdc597817168fb61e125bbdcbfd9076eff9d70bd8012b0a0c5ba
 gcr.io/iree-oss/samples@sha256:be5465585706b620d6c722caa6237eafdfaa8dd11ce20db0981b979f2d3387b3
+gcr.io/iree-oss/cmake-emscripten@sha256:8acad361d23cb586187c2ea29df3a1ab301b5283c3648beb328681d69ecd0ab0
diff --git a/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-swiftshader-asan/build.sh b/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-swiftshader-asan/build.sh
index bd48c3f..1f5b793 100755
--- a/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-swiftshader-asan/build.sh
+++ b/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-swiftshader-asan/build.sh
@@ -88,6 +88,7 @@
 fi
 if [[ "${IREE_CUDA_DISABLE?}" == 1 ]]; then
   label_exclude_args+=("^driver=cuda$")
+  label_exclude_args+=("^uses_cuda_runtime$")
 fi
 if [[ "${IREE_VULKAN_F16_DISABLE?}" == 1 ]]; then
   label_exclude_args+=("^vulkan_uses_vk_khr_shader_float16_int8$")
diff --git a/docs/developers/best_practices.md b/docs/developers/best_practices.md
new file mode 100644
index 0000000..d3d9c91
--- /dev/null
+++ b/docs/developers/best_practices.md
@@ -0,0 +1,63 @@
+# IREE Best Practices
+
+This page contains a list of best practices for getting the most out of IREE,
+spanning model authoring, ahead-of-time compilation, and runtime use. Treat
+these as a collection of ideas to consider or areas to start benchmarking when
+working on your own applications.
+
+## Introduction
+
+Common themes include:
+
+* Give the compiler as much information as possible
+* Give the compiler opportunities to batch work together or defer computation
+* Keep compute devices saturated with work through pipelining
+* Use dense math where possible, particularly for inner loop bodies
+* Limit synchronization points between devices like CPUs and GPUs
+* Profile early and often, using the right tools for each level of granularity
+
+## Practices for model authoring
+
+### Track state within your model when possible
+
+If your model is stateful prefer to store that state directly within your
+program rather than externalizing it through arguments and return values. By
+keeping state inside your program the compiler is better able to reason about
+it and function calls will have lower overhead.
+
+If you do externalize state, try to pack that state into a limited number of
+arguments.
+
+See the
+[variables and state](https://github.com/google/iree/tree/main/iree/samples/variables_and_state)
+sample for further guidance on tracking and using state.
+
+### Limit uses of dynamic shapes
+
+While IREE aims to support general dynamic shapes use, it is better able to
+optimize parts of programs where shapes are static. Slow varying dimensions
+like batch index or timestamp are safer uses of dynamic shapes than faster
+varying dimensions like the x/y/channel dimensions of images.
+
+See the
+[dynamic shapes](https://github.com/google/iree/tree/main/iree/samples/dynamic_shapes)
+sample for further guidance on using dynamic shapes.
+
+## Practices for compilation settings
+
+TODO: mention parameters to tune
+
+TODO: which compiler targets to use (try both CUDA and Vulkan?)
+
+TODO: use the most specific LLVM target triple you can?
+
+## Practices for runtime use
+
+### Do the minimum amount of work: cache queries and reuse buffers
+
+Try to front-load queries, particularly queries using strings that look up into
+maps like `iree_runtime_session_call_by_name`, so that hot sections of code are
+doing the minimum amount of work: routing inputs through buffers, scheduling
+runtime calls, and routing outputs through other buffers.
+
+TODO: sample code, profile numbers
diff --git a/docs/developers/get_started/getting_started_emscripten.md b/docs/developers/get_started/getting_started_emscripten.md
new file mode 100644
index 0000000..c5aefb7
--- /dev/null
+++ b/docs/developers/get_started/getting_started_emscripten.md
@@ -0,0 +1,67 @@
+# Getting Started With Emscripten
+
+[Emscripten](https://emscripten.org/index.html) is a complete compiler
+toolchain to WebAssembly, using LLVM, with a special focus on speed, size, and
+the Web platform. Emscripten can be used to compile parts of IREE to
+[WebAssembly](https://webassembly.org/) for execution within web browsers or
+other Wasm runtimes.
+
+## Status
+
+IREE's _runtime_ can be compiled through Emscripten in some limited
+configurations. More of the runtime will be supported over time.
+
+IREE's _compiler_ can be compiled through Emscripten with local changes. More
+work is needed for this to be generally supported.
+
+## Prerequisites
+
+Read https://emscripten.org/docs/getting_started/downloads.html and run
+
+```
+./emsdk install latest
+./emsdk activate latest
+source ./emsdk_env.sh
+```
+
+## Building IREE's Runtime with Emscripten
+
+### Host Configuration
+
+Build and install at least the compiler tools on your host machine, or install
+them from a binary distribution:
+
+```shell
+$ cmake -G Ninja -B ../iree-build-host/ \
+    -DCMAKE_C_COMPILER=clang \
+    -DCMAKE_CXX_COMPILER=clang++ \
+    -DCMAKE_INSTALL_PREFIX=../iree-build-host/install \
+    .
+$ cmake --build ../iree-build-host/ --target install
+```
+
+### Target Configuration
+
+```shell
+$ emcmake cmake -G Ninja -B ../iree-build-emscripten/ \
+  -DCMake_BUILD_TYPE=Release \
+  -DIREE_HOST_BINARY_ROOT=$(realpath ../iree-build-host/install) \
+  -DIREE_BUILD_TESTS=OFF \
+  -DIREE_BUILD_COMPILER=OFF \
+  -DIREE_ENABLE_MLIR=OFF \
+  .
+```
+
+Build:
+
+```
+cmake --build ../iree-build-emscripten/ \
+  --target iree_samples_simple_embedding_simple_embedding_vmvx_sync
+```
+
+### Load into a WebAssembly Environment
+
+Copy the outputs from the build process (e.g. `simple_embedding_vmvx_sync.js`
+and `simple_embedding_vmvx_sync.wasm`) into your application and follow
+instructions at either https://webassembly.org/getting-started/developers-guide/
+or https://developer.mozilla.org/en-US/docs/WebAssembly/Loading_and_running.
diff --git a/docs/website/docs/bindings/python.md b/docs/website/docs/bindings/python.md
index b6faade..f279ad6 100644
--- a/docs/website/docs/bindings/python.md
+++ b/docs/website/docs/bindings/python.md
@@ -10,11 +10,16 @@
 | `iree-tools-tf-snapshot`     | Tools for importing from [TensorFlow](https://www.tensorflow.org/)          |
 | `iree-tools-tflite-snapshot` | Tools for importing from [TensorFlow Lite](https://www.tensorflow.org/lite) |
 | `iree-tools-xla-snapshot`    | Tools for importing from [XLA](https://www.tensorflow.org/xla)              |
+| `iree-jax-snapshot`          | Tools for importing from [JAX](https://github.com/google/jax)               |
 
 Collectively, these packages allow for importing from frontends, compiling
 towards various targets, and executing compiled code on IREE's backends.
 
-<!-- TODO(??): Which package for JAX? -->
+!!! warning
+    The TensorFlow, TensorFlow Lite, and XLA packages are currently only
+    available on Linux and macOS. They are not available on Windows yet (see
+    [this issue](https://github.com/google/iree/issues/6417)).
+
 <!-- TODO(??): API references for packages/modules -->
 <!-- TODO(??): at least link to source code and sample Colab notebooks for now -->
 <!-- TODO(??): link to frontend docs -->
diff --git a/docs/website/docs/blog/2021-07-19-tflite-tosa-compilation-diagram.png b/docs/website/docs/blog/2021-07-19-tflite-tosa-compilation-diagram.png
new file mode 100644
index 0000000..11b5986
--- /dev/null
+++ b/docs/website/docs/blog/2021-07-19-tflite-tosa-compilation-diagram.png
Binary files differ
diff --git a/docs/website/docs/blog/2021-07-19-tflite-tosa.md b/docs/website/docs/blog/2021-07-19-tflite-tosa.md
new file mode 100644
index 0000000..90929b6
--- /dev/null
+++ b/docs/website/docs/blog/2021-07-19-tflite-tosa.md
@@ -0,0 +1,44 @@
+ Monday, July 19, 2021<br>
+ By Rob Suderman and Jenni Kilduff
+
+## TFLite Support via TOSA
+
+IREE can now execute [TensorFlow Lite](https://www.tensorflow.org/lite)
+(TFLite) models through the use of
+[TOSA](https://developer.mlplatform.org/w/tosa/), an open standard of common
+tensor operations, and a part of [MLIR](https://mlir.llvm.org/) core. TOSA’s
+high-level representation of tensor operations provides a common front-end for
+ingesting models from different frameworks. In this case we ingest a TFLite
+flatbuffer and compile it to TOSA IR, which IREE takes as an input format to
+compile to its various backends.
+
+![Compilation diagram](./2021-07-19-tflite-tosa-compilation-diagram.png){ align=left }
+
+Using TFLite as a frontend for IREE provides an alternative ingestion method for
+already existing models that could benefit from IREE’s design. This enables
+models already designed for on-device inference to have an alternative path for
+execution without requiring any additional porting, while benefiting from
+IREE’s improvements in buffer management, work dispatch system, and compact
+binary format. With continued improvements to IREE/MLIR’s compilation
+performance, more optimized versions can be compiled and distributed to target
+devices without an update to the clientside environment.
+
+Today, we have validated floating point support for a variety of models,
+including
+[mobilenet](https://tfhub.dev/s?deployment-format=lite&network-architecture=mobilenet,mobilenet-v2,mobilenet-v3,mobilenet-v1&q=mobilenet)
+(v1, v2, and v3) and
+[mobilebert](https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1).
+More work is in progress to support fully quantized models, and TFLite’s hybrid
+quantization, along with dynamic shape support.
+
+## Examples
+
+TFLite with IREE is available in Python and Java.  We have a
+[colab notebook](https://colab.sandbox.google.com/github/google/iree/blob/main/colab/tflite_text_classification.ipynb)
+that shows how to use IREE’s python bindings and TFLite compiler tools to
+compile a pre-trained TFLite model from a flatbuffer and run using IREE.  We
+also have an
+[Android Java app](https://github.com/not-jenni/iree-android-tflite-demo) that
+was forked from an existing TFLite demo app, swapping out the TFLite library
+for our own AAR.  More information on IREE’s TFLite frontend is available
+[here](../ml-frameworks/tensorflow-lite.md).
diff --git a/docs/website/docs/deployment-configurations/cpu-dylib.md b/docs/website/docs/deployment-configurations/cpu-dylib.md
index 706d039..8bad3be 100644
--- a/docs/website/docs/deployment-configurations/cpu-dylib.md
+++ b/docs/website/docs/deployment-configurations/cpu-dylib.md
@@ -131,11 +131,11 @@
 
 <!-- TODO(??): troubleshooting -->
 
-[android-cc]: /building-from-source/android/
-[get-started]: /building-from-source/getting-started/
+[android-cc]: ../building-from-source/android.md
+[get-started]: ../building-from-source/getting-started.md
 [iree-releases]: https://github.com/google/iree/releases/
 [llvm]: https://llvm.org/
 [mlir]: https://mlir.llvm.org/
 [tf-hub-mobilenetv2]: https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification
-[tf-import]: /frontends/tensorflow/
-[tflite-import]: /frontends/tensorflow-lite/
+[tf-import]: ../ml-frameworks/tensorflow.md
+[tflite-import]: ../ml-frameworks/tensorflow-lite.md
diff --git a/docs/website/docs/deployment-configurations/gpu-vulkan.md b/docs/website/docs/deployment-configurations/gpu-vulkan.md
index d86de5c..a973bc3 100644
--- a/docs/website/docs/deployment-configurations/gpu-vulkan.md
+++ b/docs/website/docs/deployment-configurations/gpu-vulkan.md
@@ -177,13 +177,13 @@
 
 <!-- TODO(??): troubleshooting -->
 
-[android-cc]: /building-from-source/android/
-[get-started]: /building-from-source/getting-started/
+[android-cc]: ../building-from-source/android.md
+[get-started]: ../building-from-source/getting-started.md
 [iree-releases]: https://github.com/google/iree/releases/
 [mlir]: https://mlir.llvm.org/
 [spirv]: https://www.khronos.org/registry/spir-v/
 [tf-hub-mobilenetv2]: https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification
-[tf-import]: /frontends/tensorflow/
-[tflite-import]: /frontends/tensorflow-lite/
+[tf-import]: ../ml-frameworks/tensorflow.md
+[tflite-import]: ../ml-frameworks/tensorflow-lite.md
 [vulkan]: https://www.khronos.org/vulkan/
 [vulkan-sdk]: https://vulkan.lunarg.com/sdk/home/
diff --git a/docs/website/docs/ml-frameworks/tensorflow-lite.md b/docs/website/docs/ml-frameworks/tensorflow-lite.md
index b91ac3f..e807555 100644
--- a/docs/website/docs/ml-frameworks/tensorflow-lite.md
+++ b/docs/website/docs/ml-frameworks/tensorflow-lite.md
@@ -24,6 +24,11 @@
   -f https://github.com/google/iree/releases
 ```
 
+!!! warning
+    The TensorFlow Lite package is currently only available on Linux and macOS.
+    It is not available on Windows yet (see
+    [this issue](https://github.com/google/iree/issues/6417)).
+
 ## Importing models
 
 Fist, import the TFLite model to TOSA MLIR:
diff --git a/docs/website/docs/ml-frameworks/tensorflow.md b/docs/website/docs/ml-frameworks/tensorflow.md
index 837f8ec..9e5df96 100644
--- a/docs/website/docs/ml-frameworks/tensorflow.md
+++ b/docs/website/docs/ml-frameworks/tensorflow.md
@@ -27,6 +27,11 @@
   -f https://github.com/google/iree/releases
 ```
 
+!!! warning
+    The TensorFlow package is currently only available on Linux and macOS. It
+    is not available on Windows yet (see
+    [this issue](https://github.com/google/iree/issues/6417)).
+
 ## Importing models
 
 IREE compilers transform a model into its final deployable format in several
diff --git a/docs/website/mkdocs.yml b/docs/website/mkdocs.yml
index a6dca81..c9253f3 100644
--- a/docs/website/mkdocs.yml
+++ b/docs/website/mkdocs.yml
@@ -73,6 +73,7 @@
 markdown_extensions:
   - abbr
   - admonition
+  - attr_list
   - footnotes
   - meta
   - pymdownx.details
@@ -115,3 +116,5 @@
       - TensorFlow Lite: 'bindings/tensorflow-lite.md'
   - 'Community':
       - Projects: 'community/projects.md'
+  - 'Blog':
+      - TFLite Support via TOSA: 'blog/2021-07-19-tflite-tosa.md'
diff --git a/experimental/rocm/descriptor_set_layout.c b/experimental/rocm/descriptor_set_layout.c
index 69fb4af..c9e985e 100644
--- a/experimental/rocm/descriptor_set_layout.c
+++ b/experimental/rocm/descriptor_set_layout.c
@@ -13,49 +13,58 @@
 
 typedef struct iree_hal_rocm_descriptor_set_layout_t {
   iree_hal_resource_t resource;
-  iree_hal_rocm_context_wrapper_t *context;
+  iree_hal_rocm_context_wrapper_t* context;
+  iree_host_size_t binding_count;
 } iree_hal_rocm_descriptor_set_layout_t;
 
 extern const iree_hal_descriptor_set_layout_vtable_t
     iree_hal_rocm_descriptor_set_layout_vtable;
 
-static iree_hal_rocm_descriptor_set_layout_t *
+static iree_hal_rocm_descriptor_set_layout_t*
 iree_hal_rocm_descriptor_set_layout_cast(
-    iree_hal_descriptor_set_layout_t *base_value) {
+    iree_hal_descriptor_set_layout_t* base_value) {
   IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_descriptor_set_layout_vtable);
-  return (iree_hal_rocm_descriptor_set_layout_t *)base_value;
+  return (iree_hal_rocm_descriptor_set_layout_t*)base_value;
 }
 
 iree_status_t iree_hal_rocm_descriptor_set_layout_create(
-    iree_hal_rocm_context_wrapper_t *context,
+    iree_hal_rocm_context_wrapper_t* context,
     iree_hal_descriptor_set_layout_usage_type_t usage_type,
     iree_host_size_t binding_count,
-    const iree_hal_descriptor_set_layout_binding_t *bindings,
-    iree_hal_descriptor_set_layout_t **out_descriptor_set_layout) {
+    const iree_hal_descriptor_set_layout_binding_t* bindings,
+    iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) {
   IREE_ASSERT_ARGUMENT(context);
   IREE_ASSERT_ARGUMENT(!binding_count || bindings);
   IREE_ASSERT_ARGUMENT(out_descriptor_set_layout);
   *out_descriptor_set_layout = NULL;
   IREE_TRACE_ZONE_BEGIN(z0);
 
-  iree_hal_rocm_descriptor_set_layout_t *descriptor_set_layout = NULL;
+  iree_hal_rocm_descriptor_set_layout_t* descriptor_set_layout = NULL;
   iree_status_t status = iree_allocator_malloc(context->host_allocator,
                                                sizeof(*descriptor_set_layout),
-                                               (void **)&descriptor_set_layout);
+                                               (void**)&descriptor_set_layout);
   if (iree_status_is_ok(status)) {
     iree_hal_resource_initialize(&iree_hal_rocm_descriptor_set_layout_vtable,
                                  &descriptor_set_layout->resource);
     descriptor_set_layout->context = context;
+    descriptor_set_layout->binding_count = binding_count;
     *out_descriptor_set_layout =
-        (iree_hal_descriptor_set_layout_t *)descriptor_set_layout;
+        (iree_hal_descriptor_set_layout_t*)descriptor_set_layout;
   }
   IREE_TRACE_ZONE_END(z0);
   return status;
 }
 
+iree_host_size_t iree_hal_rocm_descriptor_set_layout_binding_count(
+    iree_hal_descriptor_set_layout_t* base_descriptor_set_layout) {
+  iree_hal_rocm_descriptor_set_layout_t* descriptor_set_layout =
+      iree_hal_rocm_descriptor_set_layout_cast(base_descriptor_set_layout);
+  return descriptor_set_layout->binding_count;
+}
+
 static void iree_hal_rocm_descriptor_set_layout_destroy(
-    iree_hal_descriptor_set_layout_t *base_descriptor_set_layout) {
-  iree_hal_rocm_descriptor_set_layout_t *descriptor_set_layout =
+    iree_hal_descriptor_set_layout_t* base_descriptor_set_layout) {
+  iree_hal_rocm_descriptor_set_layout_t* descriptor_set_layout =
       iree_hal_rocm_descriptor_set_layout_cast(base_descriptor_set_layout);
   iree_allocator_t host_allocator =
       descriptor_set_layout->context->host_allocator;
diff --git a/experimental/rocm/descriptor_set_layout.h b/experimental/rocm/descriptor_set_layout.h
index fb07b76..9d58acd 100644
--- a/experimental/rocm/descriptor_set_layout.h
+++ b/experimental/rocm/descriptor_set_layout.h
@@ -16,11 +16,15 @@
 #endif  // __cplusplus
 
 iree_status_t iree_hal_rocm_descriptor_set_layout_create(
-    iree_hal_rocm_context_wrapper_t *context,
+    iree_hal_rocm_context_wrapper_t* context,
     iree_hal_descriptor_set_layout_usage_type_t usage_type,
     iree_host_size_t binding_count,
-    const iree_hal_descriptor_set_layout_binding_t *bindings,
-    iree_hal_descriptor_set_layout_t **out_descriptor_set_layout);
+    const iree_hal_descriptor_set_layout_binding_t* bindings,
+    iree_hal_descriptor_set_layout_t** out_descriptor_set_layout);
+
+// Return the binding count for the given descriptor set layout.
+iree_host_size_t iree_hal_rocm_descriptor_set_layout_binding_count(
+    iree_hal_descriptor_set_layout_t* descriptor_set_layout);
 
 #ifdef __cplusplus
 }  // extern "C"
diff --git a/experimental/rocm/direct_command_buffer.c b/experimental/rocm/direct_command_buffer.c
index 3d8706a..db28d91 100644
--- a/experimental/rocm/direct_command_buffer.c
+++ b/experimental/rocm/direct_command_buffer.c
@@ -11,6 +11,7 @@
 #include <stdint.h>
 
 #include "experimental/rocm/dynamic_symbols.h"
+#include "experimental/rocm/executable_layout.h"
 #include "experimental/rocm/native_executable.h"
 #include "experimental/rocm/rocm_buffer.h"
 #include "experimental/rocm/status_util.h"
@@ -44,7 +45,7 @@
   return (iree_hal_rocm_direct_command_buffer_t*)base_value;
 }
 
-iree_status_t iree_hal_rocm_direct_command_buffer_allocate(
+iree_status_t iree_hal_rocm_direct_command_buffer_create(
     iree_hal_rocm_context_wrapper_t* context,
     iree_hal_command_buffer_mode_t mode,
     iree_hal_command_category_t command_categories,
@@ -283,6 +284,8 @@
     const iree_hal_descriptor_set_binding_t* bindings) {
   iree_hal_rocm_direct_command_buffer_t* command_buffer =
       iree_hal_rocm_direct_command_buffer_cast(base_command_buffer);
+  iree_host_size_t base_binding =
+      iree_hal_rocm_base_binding_index(executable_layout, set);
   // Convention with the compiler side. We map bindings to kernel argument.
   // We compact the bindings to get a dense set of arguments and keep them order
   // based on the binding index.
@@ -303,7 +306,8 @@
         iree_hal_rocm_buffer_device_pointer(
             iree_hal_buffer_allocated_buffer(binding.buffer)) +
         iree_hal_buffer_byte_offset(binding.buffer) + binding.offset;
-    *((hipDeviceptr_t*)command_buffer->current_descriptor[i]) = device_ptr;
+    *((hipDeviceptr_t*)command_buffer->current_descriptor[i + base_binding]) =
+        device_ptr;
   }
   return iree_ok_status();
 }
diff --git a/experimental/rocm/direct_command_buffer.h b/experimental/rocm/direct_command_buffer.h
index 0145b15..bd665bf 100644
--- a/experimental/rocm/direct_command_buffer.h
+++ b/experimental/rocm/direct_command_buffer.h
@@ -26,16 +26,16 @@
   unsigned int blockDimX;
   unsigned int blockDimY;
   unsigned int blockDimZ;
-  void **kernelParams;
+  void** kernelParams;
 } hip_launch_params;
 
 // Creates a rocm direct command buffer.
-iree_status_t iree_hal_rocm_direct_command_buffer_allocate(
-    iree_hal_rocm_context_wrapper_t *context,
+iree_status_t iree_hal_rocm_direct_command_buffer_create(
+    iree_hal_rocm_context_wrapper_t* context,
     iree_hal_command_buffer_mode_t mode,
     iree_hal_command_category_t command_categories,
     iree_hal_queue_affinity_t queue_affinity,
-    iree_hal_command_buffer_t **out_command_buffer);
+    iree_hal_command_buffer_t** out_command_buffer);
 
 #ifdef __cplusplus
 }  // extern "C"
diff --git a/experimental/rocm/dynamic_symbol_tables.h b/experimental/rocm/dynamic_symbol_tables.h
index fe6cfd5..aa78c5a 100644
--- a/experimental/rocm/dynamic_symbol_tables.h
+++ b/experimental/rocm/dynamic_symbol_tables.h
@@ -30,6 +30,7 @@
 RC_PFN_DECL(hipMemcpyAsync, void *, const void *, size_t, hipMemcpyKind,
             hipStream_t)
 RC_PFN_DECL(hipMalloc, void **, size_t)
+RC_PFN_DECL(hipMallocManaged, hipDeviceptr_t *, size_t, unsigned int)
 RC_PFN_DECL(hipFree, void *)
 RC_PFN_DECL(hipHostFree, void *)
 RC_PFN_DECL(hipMemAllocHost, void **, size_t, unsigned int)
diff --git a/experimental/rocm/dynamic_symbols.c b/experimental/rocm/dynamic_symbols.c
index d676838..ad7ca6d 100644
--- a/experimental/rocm/dynamic_symbols.c
+++ b/experimental/rocm/dynamic_symbols.c
@@ -12,7 +12,7 @@
 #include "iree/base/target_platform.h"
 #include "iree/base/tracing.h"
 
-static const char *kROCMLoaderSearchNames[] = {
+static const char* kROCMLoaderSearchNames[] = {
 #if defined(IREE_PLATFORM_WINDOWS)
     "amdhip64.dll",
 #else
@@ -21,12 +21,12 @@
 };
 
 static iree_status_t iree_hal_rocm_dynamic_symbols_resolve_all(
-    iree_hal_rocm_dynamic_symbols_t *syms) {
-#define RC_PFN_DECL(rocmSymbolName, ...)                               \
-  {                                                                    \
-    static const char *kName = #rocmSymbolName;                        \
-    IREE_RETURN_IF_ERROR(iree_dynamic_library_lookup_symbol(           \
-        syms->loader_library, kName, (void **)&syms->rocmSymbolName)); \
+    iree_hal_rocm_dynamic_symbols_t* syms) {
+#define RC_PFN_DECL(rocmSymbolName, ...)                              \
+  {                                                                   \
+    static const char* kName = #rocmSymbolName;                       \
+    IREE_RETURN_IF_ERROR(iree_dynamic_library_lookup_symbol(          \
+        syms->loader_library, kName, (void**)&syms->rocmSymbolName)); \
   }
 #define RC_PFN_STR_DECL(rocmSymbolName, ...) RC_PFN_DECL(rocmSymbolName, ...)
 #include "experimental/rocm/dynamic_symbol_tables.h"  // IWYU pragma: keep
@@ -36,7 +36,7 @@
 }
 
 iree_status_t iree_hal_rocm_dynamic_symbols_initialize(
-    iree_allocator_t allocator, iree_hal_rocm_dynamic_symbols_t *out_syms) {
+    iree_allocator_t allocator, iree_hal_rocm_dynamic_symbols_t* out_syms) {
   IREE_TRACE_ZONE_BEGIN(z0);
   memset(out_syms, 0, sizeof(*out_syms));
   iree_status_t status = iree_dynamic_library_load_from_files(
@@ -59,7 +59,7 @@
 }
 
 void iree_hal_rocm_dynamic_symbols_deinitialize(
-    iree_hal_rocm_dynamic_symbols_t *syms) {
+    iree_hal_rocm_dynamic_symbols_t* syms) {
   IREE_TRACE_ZONE_BEGIN(z0);
   iree_dynamic_library_release(syms->loader_library);
   memset(syms, 0, sizeof(*syms));
diff --git a/experimental/rocm/dynamic_symbols.h b/experimental/rocm/dynamic_symbols.h
index 00a58b0..9cbd774 100644
--- a/experimental/rocm/dynamic_symbols.h
+++ b/experimental/rocm/dynamic_symbols.h
@@ -20,12 +20,12 @@
 // any of the symbol is not available. The functions signatures are matching
 // the declarations in `hipruntime.h`.
 typedef struct iree_hal_rocm_dynamic_symbols_t {
-  iree_dynamic_library_t *loader_library;
+  iree_dynamic_library_t* loader_library;
 
 #define RC_PFN_DECL(rocmSymbolName, ...) \
   hipError_t (*rocmSymbolName)(__VA_ARGS__);
 #define RC_PFN_STR_DECL(rocmSymbolName, ...) \
-  const char *(*rocmSymbolName)(__VA_ARGS__);
+  const char* (*rocmSymbolName)(__VA_ARGS__);
 #include "experimental/rocm/dynamic_symbol_tables.h"  // IWYU pragma: export
 #undef RC_PFN_DECL
 #undef RC_PFN_STR_DECL
@@ -35,13 +35,13 @@
 // iree_hal_rocm_dynamic_symbols_deinitialize must be used to release the
 // library resources.
 iree_status_t iree_hal_rocm_dynamic_symbols_initialize(
-    iree_allocator_t allocator, iree_hal_rocm_dynamic_symbols_t *out_syms);
+    iree_allocator_t allocator, iree_hal_rocm_dynamic_symbols_t* out_syms);
 
 // Deinitializes |syms| by unloading the backing library. All function pointers
 // will be invalidated. They _may_ still work if there are other reasons the
 // library remains loaded so be careful.
 void iree_hal_rocm_dynamic_symbols_deinitialize(
-    iree_hal_rocm_dynamic_symbols_t *syms);
+    iree_hal_rocm_dynamic_symbols_t* syms);
 
 #ifdef __cplusplus
 }  // extern "C"
diff --git a/experimental/rocm/event_semaphore.c b/experimental/rocm/event_semaphore.c
index 7fe6860..d3085fd 100644
--- a/experimental/rocm/event_semaphore.c
+++ b/experimental/rocm/event_semaphore.c
@@ -13,34 +13,34 @@
 
 typedef struct iree_hal_rocm_semaphore_t {
   iree_hal_resource_t resource;
-  iree_hal_rocm_context_wrapper_t *context;
+  iree_hal_rocm_context_wrapper_t* context;
   uint64_t initial_value;
 } iree_hal_rocm_semaphore_t;
 
 extern const iree_hal_semaphore_vtable_t iree_hal_rocm_semaphore_vtable;
 
-static iree_hal_rocm_semaphore_t *iree_hal_rocm_semaphore_cast(
-    iree_hal_semaphore_t *base_value) {
+static iree_hal_rocm_semaphore_t* iree_hal_rocm_semaphore_cast(
+    iree_hal_semaphore_t* base_value) {
   IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_semaphore_vtable);
-  return (iree_hal_rocm_semaphore_t *)base_value;
+  return (iree_hal_rocm_semaphore_t*)base_value;
 }
 
 iree_status_t iree_hal_rocm_semaphore_create(
-    iree_hal_rocm_context_wrapper_t *context, uint64_t initial_value,
-    iree_hal_semaphore_t **out_semaphore) {
+    iree_hal_rocm_context_wrapper_t* context, uint64_t initial_value,
+    iree_hal_semaphore_t** out_semaphore) {
   IREE_ASSERT_ARGUMENT(context);
   IREE_ASSERT_ARGUMENT(out_semaphore);
   IREE_TRACE_ZONE_BEGIN(z0);
 
-  iree_hal_rocm_semaphore_t *semaphore = NULL;
+  iree_hal_rocm_semaphore_t* semaphore = NULL;
   iree_status_t status = iree_allocator_malloc(
-      context->host_allocator, sizeof(*semaphore), (void **)&semaphore);
+      context->host_allocator, sizeof(*semaphore), (void**)&semaphore);
   if (iree_status_is_ok(status)) {
     iree_hal_resource_initialize(&iree_hal_rocm_semaphore_vtable,
                                  &semaphore->resource);
     semaphore->context = context;
     semaphore->initial_value = initial_value;
-    *out_semaphore = (iree_hal_semaphore_t *)semaphore;
+    *out_semaphore = (iree_hal_semaphore_t*)semaphore;
   }
 
   IREE_TRACE_ZONE_END(z0);
@@ -48,8 +48,8 @@
 }
 
 static void iree_hal_rocm_semaphore_destroy(
-    iree_hal_semaphore_t *base_semaphore) {
-  iree_hal_rocm_semaphore_t *semaphore =
+    iree_hal_semaphore_t* base_semaphore) {
+  iree_hal_rocm_semaphore_t* semaphore =
       iree_hal_rocm_semaphore_cast(base_semaphore);
   iree_allocator_t host_allocator = semaphore->context->host_allocator;
   IREE_TRACE_ZONE_BEGIN(z0);
@@ -60,24 +60,24 @@
 }
 
 static iree_status_t iree_hal_rocm_semaphore_query(
-    iree_hal_semaphore_t *base_semaphore, uint64_t *out_value) {
+    iree_hal_semaphore_t* base_semaphore, uint64_t* out_value) {
   // TODO: Support semaphores completely.
   *out_value = 0;
   return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "Not impemented on rocm");
 }
 
 static iree_status_t iree_hal_rocm_semaphore_signal(
-    iree_hal_semaphore_t *base_semaphore, uint64_t new_value) {
+    iree_hal_semaphore_t* base_semaphore, uint64_t new_value) {
   // TODO: Support semaphores completely. Return OK currently as everything is
   // synchronized for each submit to allow things to run.
   return iree_ok_status();
 }
 
-static void iree_hal_rocm_semaphore_fail(iree_hal_semaphore_t *base_semaphore,
+static void iree_hal_rocm_semaphore_fail(iree_hal_semaphore_t* base_semaphore,
                                          iree_status_t status) {}
 
 static iree_status_t iree_hal_rocm_semaphore_wait(
-    iree_hal_semaphore_t *base_semaphore, uint64_t value,
+    iree_hal_semaphore_t* base_semaphore, uint64_t value,
     iree_timeout_t timeout) {
   // TODO: Support semaphores completely. Return OK currently as everything is
   // synchronized for each submit to allow things to run.
diff --git a/experimental/rocm/event_semaphore.h b/experimental/rocm/event_semaphore.h
index 5f32492..9e79aa8 100644
--- a/experimental/rocm/event_semaphore.h
+++ b/experimental/rocm/event_semaphore.h
@@ -20,8 +20,8 @@
 
 // Create a rocm allocator.
 iree_status_t iree_hal_rocm_semaphore_create(
-    iree_hal_rocm_context_wrapper_t *context, uint64_t initial_value,
-    iree_hal_semaphore_t **out_semaphore);
+    iree_hal_rocm_context_wrapper_t* context, uint64_t initial_value,
+    iree_hal_semaphore_t** out_semaphore);
 
 #ifdef __cplusplus
 }  // extern "C"
diff --git a/experimental/rocm/executable_layout.c b/experimental/rocm/executable_layout.c
index 6960cd2..f293921 100644
--- a/experimental/rocm/executable_layout.c
+++ b/experimental/rocm/executable_layout.c
@@ -8,30 +8,31 @@
 
 #include <stddef.h>
 
+#include "experimental/rocm/descriptor_set_layout.h"
 #include "iree/base/api.h"
 #include "iree/base/tracing.h"
 
 typedef struct iree_hal_rocm_executable_layout_t {
   iree_hal_resource_t resource;
-  iree_hal_rocm_context_wrapper_t *context;
+  iree_hal_rocm_context_wrapper_t* context;
   iree_host_size_t set_layout_count;
-  iree_hal_descriptor_set_layout_t *set_layouts[];
+  iree_hal_descriptor_set_layout_t* set_layouts[];
 } iree_hal_rocm_executable_layout_t;
 
 extern const iree_hal_executable_layout_vtable_t
     iree_hal_rocm_executable_layout_vtable;
 
-static iree_hal_rocm_executable_layout_t *iree_hal_rocm_executable_layout_cast(
-    iree_hal_executable_layout_t *base_value) {
+static iree_hal_rocm_executable_layout_t* iree_hal_rocm_executable_layout_cast(
+    iree_hal_executable_layout_t* base_value) {
   IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_executable_layout_vtable);
-  return (iree_hal_rocm_executable_layout_t *)base_value;
+  return (iree_hal_rocm_executable_layout_t*)base_value;
 }
 
 iree_status_t iree_hal_rocm_executable_layout_create(
-    iree_hal_rocm_context_wrapper_t *context, iree_host_size_t set_layout_count,
-    iree_hal_descriptor_set_layout_t **set_layouts,
+    iree_hal_rocm_context_wrapper_t* context, iree_host_size_t set_layout_count,
+    iree_hal_descriptor_set_layout_t** set_layouts,
     iree_host_size_t push_constant_count,
-    iree_hal_executable_layout_t **out_executable_layout) {
+    iree_hal_executable_layout_t** out_executable_layout) {
   IREE_ASSERT_ARGUMENT(context);
   IREE_ASSERT_ARGUMENT(!set_layout_count || set_layouts);
   IREE_ASSERT_ARGUMENT(out_executable_layout);
@@ -40,12 +41,12 @@
   // Currently the executable layout doesn't do anything.
   // TODO: Handle creating the argument layout at that time hadling both push
   // constant and buffers.
-  iree_hal_rocm_executable_layout_t *executable_layout = NULL;
+  iree_hal_rocm_executable_layout_t* executable_layout = NULL;
   iree_host_size_t total_size =
       sizeof(*executable_layout) +
       set_layout_count * sizeof(*executable_layout->set_layouts);
   iree_status_t status = iree_allocator_malloc(
-      context->host_allocator, total_size, (void **)&executable_layout);
+      context->host_allocator, total_size, (void**)&executable_layout);
   if (iree_status_is_ok(status)) {
     iree_hal_resource_initialize(&iree_hal_rocm_executable_layout_vtable,
                                  &executable_layout->resource);
@@ -55,15 +56,15 @@
       executable_layout->set_layouts[i] = set_layouts[i];
       iree_hal_descriptor_set_layout_retain(set_layouts[i]);
     }
-    *out_executable_layout = (iree_hal_executable_layout_t *)executable_layout;
+    *out_executable_layout = (iree_hal_executable_layout_t*)executable_layout;
   }
   IREE_TRACE_ZONE_END(z0);
   return status;
 }
 
 static void iree_hal_rocm_executable_layout_destroy(
-    iree_hal_executable_layout_t *base_executable_layout) {
-  iree_hal_rocm_executable_layout_t *executable_layout =
+    iree_hal_executable_layout_t* base_executable_layout) {
+  iree_hal_rocm_executable_layout_t* executable_layout =
       iree_hal_rocm_executable_layout_cast(base_executable_layout);
   iree_allocator_t host_allocator = executable_layout->context->host_allocator;
   IREE_TRACE_ZONE_BEGIN(z0);
@@ -76,6 +77,20 @@
   IREE_TRACE_ZONE_END(z0);
 }
 
+iree_host_size_t iree_hal_rocm_base_binding_index(
+    iree_hal_executable_layout_t* base_executable_layout, uint32_t set) {
+  iree_hal_rocm_executable_layout_t* executable_layout =
+      iree_hal_rocm_executable_layout_cast(base_executable_layout);
+  iree_host_size_t base_binding = 0;
+  for (iree_host_size_t i = 0; i < set; ++i) {
+    iree_host_size_t binding_count =
+        iree_hal_rocm_descriptor_set_layout_binding_count(
+            executable_layout->set_layouts[i]);
+    base_binding += binding_count;
+  }
+  return base_binding;
+}
+
 const iree_hal_executable_layout_vtable_t
     iree_hal_rocm_executable_layout_vtable = {
         .destroy = iree_hal_rocm_executable_layout_destroy,
diff --git a/experimental/rocm/executable_layout.h b/experimental/rocm/executable_layout.h
index cd0fccf..8287cb8 100644
--- a/experimental/rocm/executable_layout.h
+++ b/experimental/rocm/executable_layout.h
@@ -17,10 +17,14 @@
 
 // Creates the kernel arguments.
 iree_status_t iree_hal_rocm_executable_layout_create(
-    iree_hal_rocm_context_wrapper_t *context, iree_host_size_t set_layout_count,
-    iree_hal_descriptor_set_layout_t **set_layouts,
+    iree_hal_rocm_context_wrapper_t* context, iree_host_size_t set_layout_count,
+    iree_hal_descriptor_set_layout_t** set_layouts,
     iree_host_size_t push_constant_count,
-    iree_hal_executable_layout_t **out_executable_layout);
+    iree_hal_executable_layout_t** out_executable_layout);
+
+// Return the base binding index for the given set.
+iree_host_size_t iree_hal_rocm_base_binding_index(
+    iree_hal_executable_layout_t* executable_layout, uint32_t set);
 
 #ifdef __cplusplus
 }  // extern "C"
diff --git a/experimental/rocm/native_executable.c b/experimental/rocm/native_executable.c
index f72bd4f..680f07f 100644
--- a/experimental/rocm/native_executable.c
+++ b/experimental/rocm/native_executable.c
@@ -27,7 +27,7 @@
 
 typedef struct iree_hal_rocm_native_executable_t {
   iree_hal_resource_t resource;
-  iree_hal_rocm_context_wrapper_t *context;
+  iree_hal_rocm_context_wrapper_t* context;
   iree_host_size_t entry_count;
   hipModule_t module;
   iree_hal_rocm_native_executable_function_t entry_functions[];
@@ -36,23 +36,23 @@
 extern const iree_hal_executable_vtable_t
     iree_hal_rocm_native_executable_vtable;
 
-static iree_hal_rocm_native_executable_t *iree_hal_rocm_native_executable_cast(
-    iree_hal_executable_t *base_value) {
+static iree_hal_rocm_native_executable_t* iree_hal_rocm_native_executable_cast(
+    iree_hal_executable_t* base_value) {
   IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_native_executable_vtable);
-  return (iree_hal_rocm_native_executable_t *)base_value;
+  return (iree_hal_rocm_native_executable_t*)base_value;
 }
 
 iree_status_t iree_hal_rocm_native_executable_create(
-    iree_hal_rocm_context_wrapper_t *context,
-    const iree_hal_executable_spec_t *executable_spec,
-    iree_hal_executable_t **out_executable) {
+    iree_hal_rocm_context_wrapper_t* context,
+    const iree_hal_executable_spec_t* executable_spec,
+    iree_hal_executable_t** out_executable) {
   IREE_ASSERT_ARGUMENT(context);
   IREE_ASSERT_ARGUMENT(executable_spec);
   IREE_ASSERT_ARGUMENT(out_executable);
   *out_executable = NULL;
   IREE_TRACE_ZONE_BEGIN(z0);
 
-  iree_hal_rocm_native_executable_t *executable = NULL;
+  iree_hal_rocm_native_executable_t* executable = NULL;
 
   // TODO: Verify the flat buffer.
   iree_ROCMExecutableDef_table_t executable_def =
@@ -69,8 +69,8 @@
   iree_host_size_t total_size =
       sizeof(*executable) +
       entry_count * sizeof(iree_hal_rocm_native_executable_function_t);
-  iree_status_t status = iree_allocator_malloc(
-      context->host_allocator, total_size, (void **)&executable);
+  iree_status_t status = iree_allocator_malloc(context->host_allocator,
+                                               total_size, (void**)&executable);
   hipModule_t module = NULL;
   ROCM_RETURN_IF_ERROR(context->syms,
                        hipModuleLoadDataEx(&module, hsaco_image, 0, NULL, NULL),
@@ -78,7 +78,7 @@
 
   for (iree_host_size_t i = 0; i < entry_count; i++) {
     hipFunction_t function = NULL;
-    const char *entry_name = flatbuffers_string_vec_at(entry_points_vec, i);
+    const char* entry_name = flatbuffers_string_vec_at(entry_points_vec, i);
     ROCM_RETURN_IF_ERROR(context->syms,
                          hipModuleGetFunction(&function, module, entry_name),
                          "hipModuleGetFunction");
@@ -92,22 +92,22 @@
                                &executable->resource);
   executable->module = module;
   executable->context = context;
-  *out_executable = (iree_hal_executable_t *)executable;
+  *out_executable = (iree_hal_executable_t*)executable;
   IREE_TRACE_ZONE_END(z0);
   return iree_ok_status();
 }
 
 hipFunction_t iree_hal_rocm_native_executable_for_entry_point(
-    iree_hal_executable_t *base_executable, int32_t entry_point) {
-  iree_hal_rocm_native_executable_t *executable =
+    iree_hal_executable_t* base_executable, int32_t entry_point) {
+  iree_hal_rocm_native_executable_t* executable =
       iree_hal_rocm_native_executable_cast(base_executable);
   return executable->entry_functions[entry_point].rocm_function;
 }
 
 iree_status_t iree_hal_rocm_native_executable_block_size(
-    iree_hal_executable_t *base_executable, int32_t entry_point, uint32_t *x,
-    uint32_t *y, uint32_t *z) {
-  iree_hal_rocm_native_executable_t *executable =
+    iree_hal_executable_t* base_executable, int32_t entry_point, uint32_t* x,
+    uint32_t* y, uint32_t* z) {
+  iree_hal_rocm_native_executable_t* executable =
       iree_hal_rocm_native_executable_cast(base_executable);
   *x = executable->entry_functions[entry_point].block_size_x;
   *y = executable->entry_functions[entry_point].block_size_y;
@@ -116,8 +116,8 @@
 }
 
 static void iree_hal_rocm_native_executable_destroy(
-    iree_hal_executable_t *base_executable) {
-  iree_hal_rocm_native_executable_t *executable =
+    iree_hal_executable_t* base_executable) {
+  iree_hal_rocm_native_executable_t* executable =
       iree_hal_rocm_native_executable_cast(base_executable);
   iree_allocator_t host_allocator = executable->context->host_allocator;
   IREE_TRACE_ZONE_BEGIN(z0);
diff --git a/experimental/rocm/native_executable.h b/experimental/rocm/native_executable.h
index 7a9229a..bc671bf 100644
--- a/experimental/rocm/native_executable.h
+++ b/experimental/rocm/native_executable.h
@@ -21,17 +21,17 @@
 // Creates an executable from a HSACO module. The module may contain several
 // kernels that can be extracted along with the associated block size.
 iree_status_t iree_hal_rocm_native_executable_create(
-    iree_hal_rocm_context_wrapper_t *context,
-    const iree_hal_executable_spec_t *executable_spec,
-    iree_hal_executable_t **out_executable);
+    iree_hal_rocm_context_wrapper_t* context,
+    const iree_hal_executable_spec_t* executable_spec,
+    iree_hal_executable_t** out_executable);
 
 hipFunction_t iree_hal_rocm_native_executable_for_entry_point(
-    iree_hal_executable_t *executable, int32_t entry_point);
+    iree_hal_executable_t* executable, int32_t entry_point);
 
 // Return the block size of the given |entry_point| within the executable.
 iree_status_t iree_hal_rocm_native_executable_block_size(
-    iree_hal_executable_t *executable, int32_t entry_point, uint32_t *x,
-    uint32_t *y, uint32_t *z);
+    iree_hal_executable_t* executable, int32_t entry_point, uint32_t* x,
+    uint32_t* y, uint32_t* z);
 
 #ifdef __cplusplus
 }  // extern "C"
diff --git a/experimental/rocm/nop_executable_cache.c b/experimental/rocm/nop_executable_cache.c
index 6592727..54bde49 100644
--- a/experimental/rocm/nop_executable_cache.c
+++ b/experimental/rocm/nop_executable_cache.c
@@ -15,44 +15,44 @@
 
 typedef struct iree_hal_rocm_nop_executable_cache_t {
   iree_hal_resource_t resource;
-  iree_hal_rocm_context_wrapper_t *context;
+  iree_hal_rocm_context_wrapper_t* context;
 } iree_hal_rocm_nop_executable_cache_t;
 
 extern const iree_hal_executable_cache_vtable_t
     iree_hal_rocm_nop_executable_cache_vtable;
 
-static iree_hal_rocm_nop_executable_cache_t *
+static iree_hal_rocm_nop_executable_cache_t*
 iree_hal_rocm_nop_executable_cache_cast(
-    iree_hal_executable_cache_t *base_value) {
+    iree_hal_executable_cache_t* base_value) {
   IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_nop_executable_cache_vtable);
-  return (iree_hal_rocm_nop_executable_cache_t *)base_value;
+  return (iree_hal_rocm_nop_executable_cache_t*)base_value;
 }
 
 iree_status_t iree_hal_rocm_nop_executable_cache_create(
-    iree_hal_rocm_context_wrapper_t *context, iree_string_view_t identifier,
-    iree_hal_executable_cache_t **out_executable_cache) {
+    iree_hal_rocm_context_wrapper_t* context, iree_string_view_t identifier,
+    iree_hal_executable_cache_t** out_executable_cache) {
   IREE_ASSERT_ARGUMENT(out_executable_cache);
   *out_executable_cache = NULL;
   IREE_TRACE_ZONE_BEGIN(z0);
 
-  iree_hal_rocm_nop_executable_cache_t *executable_cache = NULL;
+  iree_hal_rocm_nop_executable_cache_t* executable_cache = NULL;
   iree_status_t status =
       iree_allocator_malloc(context->host_allocator, sizeof(*executable_cache),
-                            (void **)&executable_cache);
+                            (void**)&executable_cache);
   if (iree_status_is_ok(status)) {
     iree_hal_resource_initialize(&iree_hal_rocm_nop_executable_cache_vtable,
                                  &executable_cache->resource);
     executable_cache->context = context;
 
-    *out_executable_cache = (iree_hal_executable_cache_t *)executable_cache;
+    *out_executable_cache = (iree_hal_executable_cache_t*)executable_cache;
   }
   IREE_TRACE_ZONE_END(z0);
   return status;
 }
 
 static void iree_hal_rocm_nop_executable_cache_destroy(
-    iree_hal_executable_cache_t *base_executable_cache) {
-  iree_hal_rocm_nop_executable_cache_t *executable_cache =
+    iree_hal_executable_cache_t* base_executable_cache) {
+  iree_hal_rocm_nop_executable_cache_t* executable_cache =
       iree_hal_rocm_nop_executable_cache_cast(base_executable_cache);
   iree_allocator_t host_allocator = executable_cache->context->host_allocator;
   IREE_TRACE_ZONE_BEGIN(z0);
@@ -63,7 +63,7 @@
 }
 
 static bool iree_hal_rocm_nop_executable_cache_can_prepare_format(
-    iree_hal_executable_cache_t *base_executable_cache,
+    iree_hal_executable_cache_t* base_executable_cache,
     iree_hal_executable_caching_mode_t caching_mode,
     iree_string_view_t executable_format) {
   return iree_string_view_equal(executable_format,
@@ -71,10 +71,10 @@
 }
 
 static iree_status_t iree_hal_rocm_nop_executable_cache_prepare_executable(
-    iree_hal_executable_cache_t *base_executable_cache,
-    const iree_hal_executable_spec_t *executable_spec,
-    iree_hal_executable_t **out_executable) {
-  iree_hal_rocm_nop_executable_cache_t *executable_cache =
+    iree_hal_executable_cache_t* base_executable_cache,
+    const iree_hal_executable_spec_t* executable_spec,
+    iree_hal_executable_t** out_executable) {
+  iree_hal_rocm_nop_executable_cache_t* executable_cache =
       iree_hal_rocm_nop_executable_cache_cast(base_executable_cache);
   return iree_hal_rocm_native_executable_create(
       executable_cache->context, executable_spec, out_executable);
diff --git a/experimental/rocm/nop_executable_cache.h b/experimental/rocm/nop_executable_cache.h
index d1b2fc1..a057466 100644
--- a/experimental/rocm/nop_executable_cache.h
+++ b/experimental/rocm/nop_executable_cache.h
@@ -19,8 +19,8 @@
 // This is useful to isolate pipeline caching behavior and verify compilation
 // behavior.
 iree_status_t iree_hal_rocm_nop_executable_cache_create(
-    iree_hal_rocm_context_wrapper_t *context, iree_string_view_t identifier,
-    iree_hal_executable_cache_t **out_executable_cache);
+    iree_hal_rocm_context_wrapper_t* context, iree_string_view_t identifier,
+    iree_hal_executable_cache_t** out_executable_cache);
 
 #ifdef __cplusplus
 }  // extern "C"
diff --git a/experimental/rocm/rocm_allocator.c b/experimental/rocm/rocm_allocator.c
index cccdc9f..4f9742f 100644
--- a/experimental/rocm/rocm_allocator.c
+++ b/experimental/rocm/rocm_allocator.c
@@ -16,30 +16,30 @@
 
 typedef struct iree_hal_rocm_allocator_t {
   iree_hal_resource_t resource;
-  iree_hal_rocm_context_wrapper_t *context;
+  iree_hal_rocm_context_wrapper_t* context;
 } iree_hal_rocm_allocator_t;
 
 extern const iree_hal_allocator_vtable_t iree_hal_rocm_allocator_vtable;
 
-static iree_hal_rocm_allocator_t *iree_hal_rocm_allocator_cast(
-    iree_hal_allocator_t *base_value) {
+static iree_hal_rocm_allocator_t* iree_hal_rocm_allocator_cast(
+    iree_hal_allocator_t* base_value) {
   IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_allocator_vtable);
-  return (iree_hal_rocm_allocator_t *)base_value;
+  return (iree_hal_rocm_allocator_t*)base_value;
 }
 
 iree_status_t iree_hal_rocm_allocator_create(
-    iree_hal_rocm_context_wrapper_t *context,
-    iree_hal_allocator_t **out_allocator) {
+    iree_hal_rocm_context_wrapper_t* context,
+    iree_hal_allocator_t** out_allocator) {
   IREE_ASSERT_ARGUMENT(context);
   IREE_TRACE_ZONE_BEGIN(z0);
-  iree_hal_rocm_allocator_t *allocator = NULL;
+  iree_hal_rocm_allocator_t* allocator = NULL;
   iree_status_t status = iree_allocator_malloc(
-      context->host_allocator, sizeof(*allocator), (void **)&allocator);
+      context->host_allocator, sizeof(*allocator), (void**)&allocator);
   if (iree_status_is_ok(status)) {
     iree_hal_resource_initialize(&iree_hal_rocm_allocator_vtable,
                                  &allocator->resource);
     allocator->context = context;
-    *out_allocator = (iree_hal_allocator_t *)allocator;
+    *out_allocator = (iree_hal_allocator_t*)allocator;
   }
 
   IREE_TRACE_ZONE_END(z0);
@@ -47,8 +47,8 @@
 }
 
 static void iree_hal_rocm_allocator_destroy(
-    iree_hal_allocator_t *base_allocator) {
-  iree_hal_rocm_allocator_t *allocator =
+    iree_hal_allocator_t* base_allocator) {
+  iree_hal_rocm_allocator_t* allocator =
       iree_hal_rocm_allocator_cast(base_allocator);
   iree_allocator_t host_allocator = allocator->context->host_allocator;
   IREE_TRACE_ZONE_BEGIN(z0);
@@ -59,15 +59,15 @@
 }
 
 static iree_allocator_t iree_hal_rocm_allocator_host_allocator(
-    const iree_hal_allocator_t *base_allocator) {
-  iree_hal_rocm_allocator_t *allocator =
-      (iree_hal_rocm_allocator_t *)base_allocator;
+    const iree_hal_allocator_t* base_allocator) {
+  iree_hal_rocm_allocator_t* allocator =
+      (iree_hal_rocm_allocator_t*)base_allocator;
   return allocator->context->host_allocator;
 }
 
 static iree_hal_buffer_compatibility_t
 iree_hal_rocm_allocator_query_buffer_compatibility(
-    iree_hal_allocator_t *base_allocator, iree_hal_memory_type_t memory_type,
+    iree_hal_allocator_t* base_allocator, iree_hal_memory_type_t memory_type,
     iree_hal_buffer_usage_t allowed_usage,
     iree_hal_buffer_usage_t intended_usage,
     iree_device_size_t allocation_size) {
@@ -96,19 +96,31 @@
 }
 
 static iree_status_t iree_hal_rocm_allocator_allocate_buffer(
-    iree_hal_allocator_t *base_allocator, iree_hal_memory_type_t memory_type,
+    iree_hal_allocator_t* base_allocator, iree_hal_memory_type_t memory_type,
     iree_hal_buffer_usage_t allowed_usage, iree_host_size_t allocation_size,
-    iree_hal_buffer_t **out_buffer) {
-  iree_hal_rocm_allocator_t *allocator =
+    iree_hal_buffer_t** out_buffer) {
+  iree_hal_rocm_allocator_t* allocator =
       iree_hal_rocm_allocator_cast(base_allocator);
   // Guard against the corner case where the requested buffer size is 0. The
   // application is unlikely to do anything when requesting a 0-byte buffer; but
   // it can happen in real world use cases. So we should at least not crash.
   if (allocation_size == 0) allocation_size = 4;
   iree_status_t status;
-  void *host_ptr = NULL;
+  void* host_ptr = NULL;
   hipDeviceptr_t device_ptr = 0;
-  if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
+  if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)) {
+    // Device local case.
+    if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
+      status = ROCM_RESULT_TO_STATUS(
+          allocator->context->syms,
+          hipMallocManaged(&device_ptr, allocation_size, hipMemAttachGlobal));
+      host_ptr = (void*)device_ptr;
+    } else {
+      // Device only.
+      status = ROCM_RESULT_TO_STATUS(allocator->context->syms,
+                                     hipMalloc(&device_ptr, allocation_size));
+    }
+  } else {
     unsigned int flags = hipHostMallocMapped;
     if (!iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_HOST_CACHED)) {
       flags |= hipHostMallocWriteCombined;
@@ -121,14 +133,11 @@
           allocator->context->syms,
           hipHostGetDevicePointer(&device_ptr, host_ptr, /*flags=*/0));
     }
-  } else {
-    status = ROCM_RESULT_TO_STATUS(allocator->context->syms,
-                                   hipMalloc(&device_ptr, allocation_size));
   }
 
   if (iree_status_is_ok(status)) {
     status = iree_hal_rocm_buffer_wrap(
-        (iree_hal_allocator_t *)allocator, memory_type,
+        (iree_hal_allocator_t*)allocator, memory_type,
         IREE_HAL_MEMORY_ACCESS_ALL, allowed_usage, allocation_size,
         /*byte_offset=*/0,
         /*byte_length=*/allocation_size, device_ptr, host_ptr, out_buffer);
@@ -140,23 +149,24 @@
   return status;
 }
 
-void iree_hal_rocm_allocator_free(iree_hal_allocator_t *base_allocator,
-                                  hipDeviceptr_t device_ptr, void *host_ptr,
+void iree_hal_rocm_allocator_free(iree_hal_allocator_t* base_allocator,
+                                  hipDeviceptr_t device_ptr, void* host_ptr,
                                   iree_hal_memory_type_t memory_type) {
-  iree_hal_rocm_allocator_t *allocator =
+  iree_hal_rocm_allocator_t* allocator =
       iree_hal_rocm_allocator_cast(base_allocator);
-  if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
-    ROCM_IGNORE_ERROR(allocator->context->syms, hipHostFree(host_ptr));
-  } else {
+  if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)) {
     ROCM_IGNORE_ERROR(allocator->context->syms, hipFree(device_ptr));
+  } else {
+    // Host local.
+    ROCM_IGNORE_ERROR(allocator->context->syms, hipHostFree(host_ptr));
   }
 }
 
 static iree_status_t iree_hal_rocm_allocator_wrap_buffer(
-    iree_hal_allocator_t *base_allocator, iree_hal_memory_type_t memory_type,
+    iree_hal_allocator_t* base_allocator, iree_hal_memory_type_t memory_type,
     iree_hal_memory_access_t allowed_access,
     iree_hal_buffer_usage_t allowed_usage, iree_byte_span_t data,
-    iree_allocator_t data_allocator, iree_hal_buffer_t **out_buffer) {
+    iree_allocator_t data_allocator, iree_hal_buffer_t** out_buffer) {
   return iree_make_status(IREE_STATUS_UNAVAILABLE,
                           "wrapping of external buffers not supported");
 }
diff --git a/experimental/rocm/rocm_allocator.h b/experimental/rocm/rocm_allocator.h
index a47480c..37d6710 100644
--- a/experimental/rocm/rocm_allocator.h
+++ b/experimental/rocm/rocm_allocator.h
@@ -18,12 +18,12 @@
 
 // Create a ROCM allocator.
 iree_status_t iree_hal_rocm_allocator_create(
-    iree_hal_rocm_context_wrapper_t *context,
-    iree_hal_allocator_t **out_allocator);
+    iree_hal_rocm_context_wrapper_t* context,
+    iree_hal_allocator_t** out_allocator);
 
 // Free an allocation represent by the given device or host pointer.
-void iree_hal_rocm_allocator_free(iree_hal_allocator_t *allocator,
-                                  hipDeviceptr_t device_ptr, void *host_ptr,
+void iree_hal_rocm_allocator_free(iree_hal_allocator_t* allocator,
+                                  hipDeviceptr_t device_ptr, void* host_ptr,
                                   iree_hal_memory_type_t memory_type);
 
 #ifdef __cplusplus
diff --git a/experimental/rocm/rocm_buffer.c b/experimental/rocm/rocm_buffer.c
index 195e897..8822cb5 100644
--- a/experimental/rocm/rocm_buffer.c
+++ b/experimental/rocm/rocm_buffer.c
@@ -16,32 +16,32 @@
 
 typedef struct iree_hal_rocm_buffer_t {
   iree_hal_buffer_t base;
-  void *host_ptr;
+  void* host_ptr;
   hipDeviceptr_t device_ptr;
 } iree_hal_rocm_buffer_t;
 
 extern const iree_hal_buffer_vtable_t iree_hal_rocm_buffer_vtable;
 
-static iree_hal_rocm_buffer_t *iree_hal_rocm_buffer_cast(
-    iree_hal_buffer_t *base_value) {
+static iree_hal_rocm_buffer_t* iree_hal_rocm_buffer_cast(
+    iree_hal_buffer_t* base_value) {
   IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_buffer_vtable);
-  return (iree_hal_rocm_buffer_t *)base_value;
+  return (iree_hal_rocm_buffer_t*)base_value;
 }
 
 iree_status_t iree_hal_rocm_buffer_wrap(
-    iree_hal_allocator_t *allocator, iree_hal_memory_type_t memory_type,
+    iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type,
     iree_hal_memory_access_t allowed_access,
     iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size,
     iree_device_size_t byte_offset, iree_device_size_t byte_length,
-    hipDeviceptr_t device_ptr, void *host_ptr, iree_hal_buffer_t **out_buffer) {
+    hipDeviceptr_t device_ptr, void* host_ptr, iree_hal_buffer_t** out_buffer) {
   IREE_ASSERT_ARGUMENT(allocator);
   IREE_ASSERT_ARGUMENT(out_buffer);
   IREE_TRACE_ZONE_BEGIN(z0);
 
-  iree_hal_rocm_buffer_t *buffer = NULL;
+  iree_hal_rocm_buffer_t* buffer = NULL;
   iree_status_t status =
       iree_allocator_malloc(iree_hal_allocator_host_allocator(allocator),
-                            sizeof(*buffer), (void **)&buffer);
+                            sizeof(*buffer), (void**)&buffer);
   if (iree_status_is_ok(status)) {
     iree_hal_resource_initialize(&iree_hal_rocm_buffer_vtable,
                                  &buffer->base.resource);
@@ -59,11 +59,11 @@
   }
 
   IREE_TRACE_ZONE_END(z0);
-  return iree_ok_status();
+  return status;
 }
 
-static void iree_hal_rocm_buffer_destroy(iree_hal_buffer_t *base_buffer) {
-  iree_hal_rocm_buffer_t *buffer = iree_hal_rocm_buffer_cast(base_buffer);
+static void iree_hal_rocm_buffer_destroy(iree_hal_buffer_t* base_buffer) {
+  iree_hal_rocm_buffer_t* buffer = iree_hal_rocm_buffer_cast(base_buffer);
   iree_allocator_t host_allocator =
       iree_hal_allocator_host_allocator(iree_hal_buffer_allocator(base_buffer));
   IREE_TRACE_ZONE_BEGIN(z0);
@@ -76,11 +76,11 @@
 }
 
 static iree_status_t iree_hal_rocm_buffer_map_range(
-    iree_hal_buffer_t *base_buffer, iree_hal_mapping_mode_t mapping_mode,
+    iree_hal_buffer_t* base_buffer, iree_hal_mapping_mode_t mapping_mode,
     iree_hal_memory_access_t memory_access,
     iree_device_size_t local_byte_offset, iree_device_size_t local_byte_length,
-    void **out_data_ptr) {
-  iree_hal_rocm_buffer_t *buffer = iree_hal_rocm_buffer_cast(base_buffer);
+    void** out_data_ptr) {
+  iree_hal_rocm_buffer_t* buffer = iree_hal_rocm_buffer_cast(base_buffer);
 
   if (!iree_all_bits_set(buffer->base.memory_type,
                          IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
@@ -88,14 +88,14 @@
                             "trying to map memory not host visible");
   }
 
-  uint8_t *data_ptr = (uint8_t *)(buffer->host_ptr) + local_byte_offset;
+  uint8_t* data_ptr = (uint8_t*)(buffer->host_ptr) + local_byte_offset;
   // If we mapped for discard scribble over the bytes. This is not a mandated
   // behavior but it will make debugging issues easier. Alternatively for
   // heap buffers we could reallocate them such that ASAN yells, but that
   // would only work if the entire buffer was discarded.
 #ifndef NDEBUG
   if (iree_any_bit_set(memory_access, IREE_HAL_MEMORY_ACCESS_DISCARD)) {
-    memset(data_ptr + local_byte_offset, 0xCD, local_byte_length);
+    memset(data_ptr, 0xCD, local_byte_length);
   }
 #endif  // !NDEBUG
   *out_data_ptr = data_ptr;
@@ -103,28 +103,28 @@
 }
 
 static void iree_hal_rocm_buffer_unmap_range(
-    iree_hal_buffer_t *base_buffer, iree_device_size_t local_byte_offset,
-    iree_device_size_t local_byte_length, void *data_ptr) {
+    iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset,
+    iree_device_size_t local_byte_length, void* data_ptr) {
   // nothing to do.
 }
 
 static iree_status_t iree_hal_rocm_buffer_invalidate_range(
-    iree_hal_buffer_t *base_buffer, iree_device_size_t local_byte_offset,
+    iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset,
     iree_device_size_t local_byte_length) {
   // Nothing to do.
   return iree_ok_status();
 }
 
 static iree_status_t iree_hal_rocm_buffer_flush_range(
-    iree_hal_buffer_t *base_buffer, iree_device_size_t local_byte_offset,
+    iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset,
     iree_device_size_t local_byte_length) {
   // Nothing to do.
   return iree_ok_status();
 }
 
 hipDeviceptr_t iree_hal_rocm_buffer_device_pointer(
-    iree_hal_buffer_t *base_buffer) {
-  iree_hal_rocm_buffer_t *buffer = iree_hal_rocm_buffer_cast(base_buffer);
+    iree_hal_buffer_t* base_buffer) {
+  iree_hal_rocm_buffer_t* buffer = iree_hal_rocm_buffer_cast(base_buffer);
   return buffer->device_ptr;
 }
 
diff --git a/experimental/rocm/rocm_buffer.h b/experimental/rocm/rocm_buffer.h
index e898630..85cf294 100644
--- a/experimental/rocm/rocm_buffer.h
+++ b/experimental/rocm/rocm_buffer.h
@@ -17,16 +17,16 @@
 
 // Wraps a rocm allocation in an iree_hal_buffer_t.
 iree_status_t iree_hal_rocm_buffer_wrap(
-    iree_hal_allocator_t *allocator, iree_hal_memory_type_t memory_type,
+    iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type,
     iree_hal_memory_access_t allowed_access,
     iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size,
     iree_device_size_t byte_offset, iree_device_size_t byte_length,
-    hipDeviceptr_t device_ptr, void *host_ptr, iree_hal_buffer_t **out_buffer);
+    hipDeviceptr_t device_ptr, void* host_ptr, iree_hal_buffer_t** out_buffer);
 
 // Returns the rocm base pointer for the given |buffer|.
 // This is the entire allocated_buffer and must be offset by the buffer
 // byte_offset and byte_length when used.
-hipDeviceptr_t iree_hal_rocm_buffer_device_pointer(iree_hal_buffer_t *buffer);
+hipDeviceptr_t iree_hal_rocm_buffer_device_pointer(iree_hal_buffer_t* buffer);
 
 #ifdef __cplusplus
 }  // extern "C"
diff --git a/experimental/rocm/rocm_device.c b/experimental/rocm/rocm_device.c
index cff4d89..fa5a75e 100644
--- a/experimental/rocm/rocm_device.c
+++ b/experimental/rocm/rocm_device.c
@@ -32,27 +32,27 @@
 
   // Optional driver that owns the ROCM symbols. We retain it for our lifetime
   // to ensure the symbols remains valid.
-  iree_hal_driver_t *driver;
+  iree_hal_driver_t* driver;
 
   hipDevice_t device;
 
   // TODO: support multiple streams.
   hipStream_t stream;
   iree_hal_rocm_context_wrapper_t context_wrapper;
-  iree_hal_allocator_t *device_allocator;
+  iree_hal_allocator_t* device_allocator;
 
 } iree_hal_rocm_device_t;
 
 extern const iree_hal_device_vtable_t iree_hal_rocm_device_vtable;
 
-static iree_hal_rocm_device_t *iree_hal_rocm_device_cast(
-    iree_hal_device_t *base_value) {
+static iree_hal_rocm_device_t* iree_hal_rocm_device_cast(
+    iree_hal_device_t* base_value) {
   IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_device_vtable);
-  return (iree_hal_rocm_device_t *)base_value;
+  return (iree_hal_rocm_device_t*)base_value;
 }
 
-static void iree_hal_rocm_device_destroy(iree_hal_device_t *base_device) {
-  iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device);
+static void iree_hal_rocm_device_destroy(iree_hal_device_t* base_device) {
+  iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device);
   iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device);
   IREE_TRACE_ZONE_BEGIN(z0);
 
@@ -70,21 +70,21 @@
 }
 
 static iree_status_t iree_hal_rocm_device_create_internal(
-    iree_hal_driver_t *driver, iree_string_view_t identifier,
+    iree_hal_driver_t* driver, iree_string_view_t identifier,
     hipDevice_t rocm_device, hipStream_t stream, hipCtx_t context,
-    iree_hal_rocm_dynamic_symbols_t *syms, iree_allocator_t host_allocator,
-    iree_hal_device_t **out_device) {
-  iree_hal_rocm_device_t *device = NULL;
+    iree_hal_rocm_dynamic_symbols_t* syms, iree_allocator_t host_allocator,
+    iree_hal_device_t** out_device) {
+  iree_hal_rocm_device_t* device = NULL;
   iree_host_size_t total_size = sizeof(*device) + identifier.size;
   IREE_RETURN_IF_ERROR(
-      iree_allocator_malloc(host_allocator, total_size, (void **)&device));
+      iree_allocator_malloc(host_allocator, total_size, (void**)&device));
   memset(device, 0, total_size);
   iree_hal_resource_initialize(&iree_hal_rocm_device_vtable, &device->resource);
   device->driver = driver;
   iree_hal_driver_retain(device->driver);
-  uint8_t *buffer_ptr = (uint8_t *)device + sizeof(*device);
+  uint8_t* buffer_ptr = (uint8_t*)device + sizeof(*device);
   buffer_ptr += iree_string_view_append_to_buffer(
-      identifier, &device->identifier, (char *)buffer_ptr);
+      identifier, &device->identifier, (char*)buffer_ptr);
   device->device = rocm_device;
   device->stream = stream;
   device->context_wrapper.rocm_context = context;
@@ -93,19 +93,19 @@
   iree_status_t status = iree_hal_rocm_allocator_create(
       &device->context_wrapper, &device->device_allocator);
   if (iree_status_is_ok(status)) {
-    *out_device = (iree_hal_device_t *)device;
+    *out_device = (iree_hal_device_t*)device;
   } else {
-    iree_hal_device_release((iree_hal_device_t *)device);
+    iree_hal_device_release((iree_hal_device_t*)device);
   }
   return status;
 }
 
-iree_status_t iree_hal_rocm_device_create(iree_hal_driver_t *driver,
+iree_status_t iree_hal_rocm_device_create(iree_hal_driver_t* driver,
                                           iree_string_view_t identifier,
-                                          iree_hal_rocm_dynamic_symbols_t *syms,
+                                          iree_hal_rocm_dynamic_symbols_t* syms,
                                           hipDevice_t device,
                                           iree_allocator_t host_allocator,
-                                          iree_hal_device_t **out_device) {
+                                          iree_hal_device_t** out_device) {
   IREE_TRACE_ZONE_BEGIN(z0);
   hipCtx_t context;
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
@@ -130,26 +130,26 @@
 }
 
 static iree_string_view_t iree_hal_rocm_device_id(
-    iree_hal_device_t *base_device) {
-  iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device);
+    iree_hal_device_t* base_device) {
+  iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device);
   return device->identifier;
 }
 
 static iree_allocator_t iree_hal_rocm_device_host_allocator(
-    iree_hal_device_t *base_device) {
-  iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device);
+    iree_hal_device_t* base_device) {
+  iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device);
   return device->context_wrapper.host_allocator;
 }
 
-static iree_hal_allocator_t *iree_hal_rocm_device_allocator(
-    iree_hal_device_t *base_device) {
-  iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device);
+static iree_hal_allocator_t* iree_hal_rocm_device_allocator(
+    iree_hal_device_t* base_device) {
+  iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device);
   return device->device_allocator;
 }
 
 static iree_status_t iree_hal_rocm_device_query_i32(
-    iree_hal_device_t *base_device, iree_string_view_t category,
-    iree_string_view_t key, int32_t *out_value) {
+    iree_hal_device_t* base_device, iree_string_view_t category,
+    iree_string_view_t key, int32_t* out_value) {
   // iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device);
   *out_value = 0;
 
@@ -169,77 +169,77 @@
 }
 
 static iree_status_t iree_hal_rocm_device_create_command_buffer(
-    iree_hal_device_t *base_device, iree_hal_command_buffer_mode_t mode,
+    iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode,
     iree_hal_command_category_t command_categories,
     iree_hal_queue_affinity_t queue_affinity,
-    iree_hal_command_buffer_t **out_command_buffer) {
-  iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device);
-  return iree_hal_rocm_direct_command_buffer_allocate(
+    iree_hal_command_buffer_t** out_command_buffer) {
+  iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device);
+  return iree_hal_rocm_direct_command_buffer_create(
       &device->context_wrapper, mode, command_categories, queue_affinity,
       out_command_buffer);
 }
 
 static iree_status_t iree_hal_rocm_device_create_descriptor_set(
-    iree_hal_device_t *base_device,
-    iree_hal_descriptor_set_layout_t *set_layout,
+    iree_hal_device_t* base_device,
+    iree_hal_descriptor_set_layout_t* set_layout,
     iree_host_size_t binding_count,
-    const iree_hal_descriptor_set_binding_t *bindings,
-    iree_hal_descriptor_set_t **out_descriptor_set) {
+    const iree_hal_descriptor_set_binding_t* bindings,
+    iree_hal_descriptor_set_t** out_descriptor_set) {
   return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
                           "non-push descriptor sets still need work");
 }
 
 static iree_status_t iree_hal_rocm_device_create_descriptor_set_layout(
-    iree_hal_device_t *base_device,
+    iree_hal_device_t* base_device,
     iree_hal_descriptor_set_layout_usage_type_t usage_type,
     iree_host_size_t binding_count,
-    const iree_hal_descriptor_set_layout_binding_t *bindings,
-    iree_hal_descriptor_set_layout_t **out_descriptor_set_layout) {
-  iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device);
+    const iree_hal_descriptor_set_layout_binding_t* bindings,
+    iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) {
+  iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device);
   return iree_hal_rocm_descriptor_set_layout_create(
       &device->context_wrapper, usage_type, binding_count, bindings,
       out_descriptor_set_layout);
 }
 
 static iree_status_t iree_hal_rocm_device_create_event(
-    iree_hal_device_t *base_device, iree_hal_event_t **out_event) {
-  iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device);
+    iree_hal_device_t* base_device, iree_hal_event_t** out_event) {
+  iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device);
   return iree_hal_rocm_event_create(&device->context_wrapper, out_event);
 }
 
 static iree_status_t iree_hal_rocm_device_create_executable_cache(
-    iree_hal_device_t *base_device, iree_string_view_t identifier,
-    iree_hal_executable_cache_t **out_executable_cache) {
-  iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device);
+    iree_hal_device_t* base_device, iree_string_view_t identifier,
+    iree_hal_executable_cache_t** out_executable_cache) {
+  iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device);
   return iree_hal_rocm_nop_executable_cache_create(
       &device->context_wrapper, identifier, out_executable_cache);
 }
 
 static iree_status_t iree_hal_rocm_device_create_executable_layout(
-    iree_hal_device_t *base_device, iree_host_size_t push_constants,
+    iree_hal_device_t* base_device, iree_host_size_t push_constants,
     iree_host_size_t set_layout_count,
-    iree_hal_descriptor_set_layout_t **set_layouts,
-    iree_hal_executable_layout_t **out_executable_layout) {
-  iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device);
+    iree_hal_descriptor_set_layout_t** set_layouts,
+    iree_hal_executable_layout_t** out_executable_layout) {
+  iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device);
   return iree_hal_rocm_executable_layout_create(
       &device->context_wrapper, set_layout_count, set_layouts, push_constants,
       out_executable_layout);
 }
 
 static iree_status_t iree_hal_rocm_device_create_semaphore(
-    iree_hal_device_t *base_device, uint64_t initial_value,
-    iree_hal_semaphore_t **out_semaphore) {
-  iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device);
+    iree_hal_device_t* base_device, uint64_t initial_value,
+    iree_hal_semaphore_t** out_semaphore) {
+  iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device);
   return iree_hal_rocm_semaphore_create(&device->context_wrapper, initial_value,
                                         out_semaphore);
 }
 
 static iree_status_t iree_hal_rocm_device_queue_submit(
-    iree_hal_device_t *base_device,
+    iree_hal_device_t* base_device,
     iree_hal_command_category_t command_categories,
     iree_hal_queue_affinity_t queue_affinity, iree_host_size_t batch_count,
-    const iree_hal_submission_batch_t *batches) {
-  iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device);
+    const iree_hal_submission_batch_t* batches) {
+  iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device);
   // TODO(raikonenfnu): Once semaphore is implemented wait for semaphores
   // TODO(thomasraoux): Conservatively syncronize after every submit until we
   // support semaphores.
@@ -251,11 +251,11 @@
 }
 
 static iree_status_t iree_hal_rocm_device_submit_and_wait(
-    iree_hal_device_t *base_device,
+    iree_hal_device_t* base_device,
     iree_hal_command_category_t command_categories,
     iree_hal_queue_affinity_t queue_affinity, iree_host_size_t batch_count,
-    const iree_hal_submission_batch_t *batches,
-    iree_hal_semaphore_t *wait_semaphore, uint64_t wait_value,
+    const iree_hal_submission_batch_t* batches,
+    iree_hal_semaphore_t* wait_semaphore, uint64_t wait_value,
     iree_timeout_t timeout) {
   // Submit...
   IREE_RETURN_IF_ERROR(iree_hal_rocm_device_queue_submit(
@@ -266,15 +266,15 @@
 }
 
 static iree_status_t iree_hal_rocm_device_wait_semaphores(
-    iree_hal_device_t *base_device, iree_hal_wait_mode_t wait_mode,
-    const iree_hal_semaphore_list_t *semaphore_list, iree_timeout_t timeout) {
+    iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode,
+    const iree_hal_semaphore_list_t* semaphore_list, iree_timeout_t timeout) {
   return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
                           "semaphore not implemented");
 }
 
 static iree_status_t iree_hal_rocm_device_wait_idle(
-    iree_hal_device_t *base_device, iree_timeout_t timeout) {
-  iree_hal_rocm_device_t *device = iree_hal_rocm_device_cast(base_device);
+    iree_hal_device_t* base_device, iree_timeout_t timeout) {
+  iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device);
   // Wait until the stream is done.
   // TODO(thomasraoux): HIP doesn't support a deadline for wait, figure out how
   // to handle it better.
diff --git a/experimental/rocm/rocm_device.h b/experimental/rocm/rocm_device.h
index 67e2f9b..083f4c7 100644
--- a/experimental/rocm/rocm_device.h
+++ b/experimental/rocm/rocm_device.h
@@ -17,12 +17,12 @@
 #endif  // __cplusplus
 
 // Creates a device that owns and manages its own hipContext.
-iree_status_t iree_hal_rocm_device_create(iree_hal_driver_t *driver,
+iree_status_t iree_hal_rocm_device_create(iree_hal_driver_t* driver,
                                           iree_string_view_t identifier,
-                                          iree_hal_rocm_dynamic_symbols_t *syms,
+                                          iree_hal_rocm_dynamic_symbols_t* syms,
                                           hipDevice_t device,
                                           iree_allocator_t host_allocator,
-                                          iree_hal_device_t **out_device);
+                                          iree_hal_device_t** out_device);
 
 #ifdef __cplusplus
 }  // extern "C"
diff --git a/experimental/rocm/rocm_driver.c b/experimental/rocm/rocm_driver.c
index 8bd8ae6..808642f 100644
--- a/experimental/rocm/rocm_driver.c
+++ b/experimental/rocm/rocm_driver.c
@@ -32,44 +32,44 @@
 
 extern const iree_hal_driver_vtable_t iree_hal_rocm_driver_vtable;
 
-static iree_hal_rocm_driver_t *iree_hal_rocm_driver_cast(
-    iree_hal_driver_t *base_value) {
+static iree_hal_rocm_driver_t* iree_hal_rocm_driver_cast(
+    iree_hal_driver_t* base_value) {
   IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_driver_vtable);
-  return (iree_hal_rocm_driver_t *)base_value;
+  return (iree_hal_rocm_driver_t*)base_value;
 }
 
 IREE_API_EXPORT void iree_hal_rocm_driver_options_initialize(
-    iree_hal_rocm_driver_options_t *out_options) {
+    iree_hal_rocm_driver_options_t* out_options) {
   memset(out_options, 0, sizeof(*out_options));
   out_options->default_device_index = 0;
 }
 
 static iree_status_t iree_hal_rocm_driver_create_internal(
     iree_string_view_t identifier,
-    const iree_hal_rocm_driver_options_t *options,
-    iree_allocator_t host_allocator, iree_hal_driver_t **out_driver) {
-  iree_hal_rocm_driver_t *driver = NULL;
+    const iree_hal_rocm_driver_options_t* options,
+    iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) {
+  iree_hal_rocm_driver_t* driver = NULL;
   iree_host_size_t total_size = sizeof(*driver) + identifier.size;
   IREE_RETURN_IF_ERROR(
-      iree_allocator_malloc(host_allocator, total_size, (void **)&driver));
+      iree_allocator_malloc(host_allocator, total_size, (void**)&driver));
   iree_hal_resource_initialize(&iree_hal_rocm_driver_vtable, &driver->resource);
   driver->host_allocator = host_allocator;
   iree_string_view_append_to_buffer(
       identifier, &driver->identifier,
-      (char *)driver + total_size - identifier.size);
+      (char*)driver + total_size - identifier.size);
   driver->default_device_index = options->default_device_index;
   iree_status_t status =
       iree_hal_rocm_dynamic_symbols_initialize(host_allocator, &driver->syms);
   if (iree_status_is_ok(status)) {
-    *out_driver = (iree_hal_driver_t *)driver;
+    *out_driver = (iree_hal_driver_t*)driver;
   } else {
-    iree_hal_driver_release((iree_hal_driver_t *)driver);
+    iree_hal_driver_release((iree_hal_driver_t*)driver);
   }
   return status;
 }
 
-static void iree_hal_rocm_driver_destroy(iree_hal_driver_t *base_driver) {
-  iree_hal_rocm_driver_t *driver = iree_hal_rocm_driver_cast(base_driver);
+static void iree_hal_rocm_driver_destroy(iree_hal_driver_t* base_driver) {
+  iree_hal_rocm_driver_t* driver = iree_hal_rocm_driver_cast(base_driver);
   iree_allocator_t host_allocator = driver->host_allocator;
   IREE_TRACE_ZONE_BEGIN(z0);
 
@@ -81,8 +81,8 @@
 
 IREE_API_EXPORT iree_status_t iree_hal_rocm_driver_create(
     iree_string_view_t identifier,
-    const iree_hal_rocm_driver_options_t *options,
-    iree_allocator_t host_allocator, iree_hal_driver_t **out_driver) {
+    const iree_hal_rocm_driver_options_t* options,
+    iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) {
   IREE_ASSERT_ARGUMENT(options);
   IREE_ASSERT_ARGUMENT(out_driver);
   IREE_TRACE_ZONE_BEGIN(z0);
@@ -97,9 +97,9 @@
 // Populates device information from the given ROCM physical device handle.
 // |out_device_info| must point to valid memory and additional data will be
 // appended to |buffer_ptr| and the new pointer is returned.
-static uint8_t *iree_hal_rocm_populate_device_info(
-    hipDevice_t device, iree_hal_rocm_dynamic_symbols_t *syms,
-    uint8_t *buffer_ptr, iree_hal_device_info_t *out_device_info) {
+static uint8_t* iree_hal_rocm_populate_device_info(
+    hipDevice_t device, iree_hal_rocm_dynamic_symbols_t* syms,
+    uint8_t* buffer_ptr, iree_hal_device_info_t* out_device_info) {
   char device_name[IREE_MAX_ROCM_DEVICE_NAME_LENGTH];
   ROCM_IGNORE_ERROR(syms,
                     hipDeviceGetName(device_name, sizeof(device_name), device));
@@ -109,31 +109,31 @@
   iree_string_view_t device_name_string =
       iree_make_string_view(device_name, strlen(device_name));
   buffer_ptr += iree_string_view_append_to_buffer(
-      device_name_string, &out_device_info->name, (char *)buffer_ptr);
+      device_name_string, &out_device_info->name, (char*)buffer_ptr);
   return buffer_ptr;
 }
 
 static iree_status_t iree_hal_rocm_driver_query_available_devices(
-    iree_hal_driver_t *base_driver, iree_allocator_t host_allocator,
-    iree_hal_device_info_t **out_device_infos,
-    iree_host_size_t *out_device_info_count) {
-  iree_hal_rocm_driver_t *driver = iree_hal_rocm_driver_cast(base_driver);
+    iree_hal_driver_t* base_driver, iree_allocator_t host_allocator,
+    iree_hal_device_info_t** out_device_infos,
+    iree_host_size_t* out_device_info_count) {
+  iree_hal_rocm_driver_t* driver = iree_hal_rocm_driver_cast(base_driver);
   // Query the number of available ROCM devices.
   int device_count = 0;
   ROCM_RETURN_IF_ERROR(&driver->syms, hipGetDeviceCount(&device_count),
                        "hipGetDeviceCount");
 
   // Allocate the return infos and populate with the devices.
-  iree_hal_device_info_t *device_infos = NULL;
+  iree_hal_device_info_t* device_infos = NULL;
   iree_host_size_t total_size = device_count * sizeof(iree_hal_device_info_t);
   for (iree_host_size_t i = 0; i < device_count; ++i) {
     total_size += IREE_MAX_ROCM_DEVICE_NAME_LENGTH * sizeof(char);
   }
   iree_status_t status =
-      iree_allocator_malloc(host_allocator, total_size, (void **)&device_infos);
+      iree_allocator_malloc(host_allocator, total_size, (void**)&device_infos);
   if (iree_status_is_ok(status)) {
-    uint8_t *buffer_ptr =
-        (uint8_t *)device_infos + device_count * sizeof(iree_hal_device_info_t);
+    uint8_t* buffer_ptr =
+        (uint8_t*)device_infos + device_count * sizeof(iree_hal_device_info_t);
     for (iree_host_size_t i = 0; i < device_count; ++i) {
       hipDevice_t device;
       iree_status_t status = ROCM_RESULT_TO_STATUS(
@@ -153,8 +153,8 @@
 }
 
 static iree_status_t iree_hal_rocm_driver_select_default_device(
-    iree_hal_rocm_dynamic_symbols_t *syms, int default_device_index,
-    iree_allocator_t host_allocator, hipDevice_t *out_device) {
+    iree_hal_rocm_dynamic_symbols_t* syms, int default_device_index,
+    iree_allocator_t host_allocator, hipDevice_t* out_device) {
   int device_count = 0;
   ROCM_RETURN_IF_ERROR(syms, hipGetDeviceCount(&device_count),
                        "hipGetDeviceCount");
@@ -173,9 +173,9 @@
 }
 
 static iree_status_t iree_hal_rocm_driver_create_device(
-    iree_hal_driver_t *base_driver, iree_hal_device_id_t device_id,
-    iree_allocator_t host_allocator, iree_hal_device_t **out_device) {
-  iree_hal_rocm_driver_t *driver = iree_hal_rocm_driver_cast(base_driver);
+    iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id,
+    iree_allocator_t host_allocator, iree_hal_device_t** out_device) {
+  iree_hal_rocm_driver_t* driver = iree_hal_rocm_driver_cast(base_driver);
   IREE_TRACE_ZONE_BEGIN(z0);
 
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
diff --git a/experimental/rocm/rocm_event.c b/experimental/rocm/rocm_event.c
index 48b20c3..945a5f7 100644
--- a/experimental/rocm/rocm_event.c
+++ b/experimental/rocm/rocm_event.c
@@ -14,40 +14,40 @@
 // Dummy events for now, don't do anything.
 typedef struct iree_hal_rocm_event_t {
   iree_hal_resource_t resource;
-  iree_hal_rocm_context_wrapper_t *context_wrapper;
+  iree_hal_rocm_context_wrapper_t* context_wrapper;
 } iree_hal_rocm_event_t;
 
 extern const iree_hal_event_vtable_t iree_hal_rocm_event_vtable;
 
-static iree_hal_rocm_event_t *iree_hal_rocm_event_cast(
-    iree_hal_event_t *base_value) {
+static iree_hal_rocm_event_t* iree_hal_rocm_event_cast(
+    iree_hal_event_t* base_value) {
   IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_rocm_event_vtable);
-  return (iree_hal_rocm_event_t *)base_value;
+  return (iree_hal_rocm_event_t*)base_value;
 }
 
 iree_status_t iree_hal_rocm_event_create(
-    iree_hal_rocm_context_wrapper_t *context_wrapper,
-    iree_hal_event_t **out_event) {
+    iree_hal_rocm_context_wrapper_t* context_wrapper,
+    iree_hal_event_t** out_event) {
   IREE_ASSERT_ARGUMENT(context_wrapper);
   IREE_ASSERT_ARGUMENT(out_event);
   *out_event = NULL;
   IREE_TRACE_ZONE_BEGIN(z0);
 
-  iree_hal_rocm_event_t *event = NULL;
+  iree_hal_rocm_event_t* event = NULL;
   iree_status_t status = iree_allocator_malloc(context_wrapper->host_allocator,
-                                               sizeof(*event), (void **)&event);
+                                               sizeof(*event), (void**)&event);
   if (iree_status_is_ok(status)) {
     iree_hal_resource_initialize(&iree_hal_rocm_event_vtable, &event->resource);
     event->context_wrapper = context_wrapper;
-    *out_event = (iree_hal_event_t *)event;
+    *out_event = (iree_hal_event_t*)event;
   }
 
   IREE_TRACE_ZONE_END(z0);
   return status;
 }
 
-static void iree_hal_rocm_event_destroy(iree_hal_event_t *base_event) {
-  iree_hal_rocm_event_t *event = iree_hal_rocm_event_cast(base_event);
+static void iree_hal_rocm_event_destroy(iree_hal_event_t* base_event) {
+  iree_hal_rocm_event_t* event = iree_hal_rocm_event_cast(base_event);
   iree_allocator_t host_allocator = event->context_wrapper->host_allocator;
   IREE_TRACE_ZONE_BEGIN(z0);
 
diff --git a/experimental/rocm/rocm_event.h b/experimental/rocm/rocm_event.h
index 73041a4..0bac1a2 100644
--- a/experimental/rocm/rocm_event.h
+++ b/experimental/rocm/rocm_event.h
@@ -21,8 +21,8 @@
 // command buffer we will add the appropriate edges to enforce the right
 // synchronization.
 iree_status_t iree_hal_rocm_event_create(
-    iree_hal_rocm_context_wrapper_t *context_wrapper,
-    iree_hal_event_t **out_event);
+    iree_hal_rocm_context_wrapper_t* context_wrapper,
+    iree_hal_event_t** out_event);
 
 #ifdef __cplusplus
 }  // extern "C"
diff --git a/experimental/rocm/status_util.h b/experimental/rocm/status_util.h
index b459591..0f6fcc5 100644
--- a/experimental/rocm/status_util.h
+++ b/experimental/rocm/status_util.h
@@ -44,7 +44,7 @@
 
 // Converts a hipError_t to a Status object.
 iree_status_t iree_hal_rocm_result_to_status(
-    iree_hal_rocm_dynamic_symbols_t *syms, hipError_t result, const char *file,
+    iree_hal_rocm_dynamic_symbols_t* syms, hipError_t result, const char* file,
     uint32_t line);
 
 #ifdef __cplusplus
diff --git a/integrations/tensorflow/iree_tf_compiler/iree-import-tf-main.cpp b/integrations/tensorflow/iree_tf_compiler/iree-import-tf-main.cpp
index feec3ff..ee95f1e 100644
--- a/integrations/tensorflow/iree_tf_compiler/iree-import-tf-main.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/iree-import-tf-main.cpp
@@ -224,8 +224,8 @@
 
     iree_integrations::TF::buildTFImportPassPipeline(pm);
     if (failed(pm.run(*module))) {
-      llvm::errs()
-          << "Running iree-import-tf pass pipeline failed (see diagnostics)\n";
+      llvm::errs() << "Running iree-import-tf TF import pass pipeline failed "
+                      "(see diagnostics)\n";
       return 2;
     }
     if (!saveTempMidLevelImport.empty()) {
@@ -237,8 +237,8 @@
     applyPassManagerCLOptions(pm);
     iree_integrations::MHLO::buildMHLOImportPassPipeline(pm);
     if (failed(pm.run(*module))) {
-      llvm::errs()
-          << "Running iree-import-tf pass pipeline failed (see diagnostics)\n";
+      llvm::errs() << "Running iree-import-tf MHLO Import pass pipeline failed "
+                      "(see diagnostics)\n";
       return 2;
     }
   }
diff --git a/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp b/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
index 2e11d7f..08f270a 100644
--- a/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
@@ -100,6 +100,10 @@
   static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"),
                                              cl::value_desc("filename"),
                                              cl::init("-"));
+  static llvm::cl::opt<std::string> saveTempMhloInput(
+      "save-temp-mhlo-input",
+      llvm::cl::desc("Save the MHLO pipeline input IR to this file"),
+      llvm::cl::init(""));
   static llvm::cl::opt<std::string> saveTempIreeImport(
       "save-temp-iree-input",
       llvm::cl::desc("Save the resultant IR to this file (useful for saving an "
@@ -251,6 +255,11 @@
     return success();
   };
 
+  // Save temp output.
+  if (!saveTempMhloInput.empty()) {
+    if (failed(saveToFile(saveTempMhloInput))) return 10;
+  }
+
   // Run passes.
   PassManager pm(&context, PassManager::Nesting::Implicit);
   applyPassManagerCLOptions(pm);
@@ -264,8 +273,8 @@
       iree_integrations::MHLO::createEmitDefaultIREEABIPass());
 
   if (failed(pm.run(*module))) {
-    llvm::errs()
-        << "Running iree-xla-import pass pipeline failed (see diagnostics)\n";
+    llvm::errs() << "Running iree-xla-import MHLO import pass pipeline failed "
+                    "(see diagnostics)\n";
     return 2;
   }
 
diff --git a/iree/base/alignment.h b/iree/base/alignment.h
index 0f5953a..5229d97 100644
--- a/iree/base/alignment.h
+++ b/iree/base/alignment.h
@@ -11,6 +11,8 @@
 #define IREE_BASE_ALIGNMENT_H_
 
 #include <stddef.h>
+#include <stdint.h>
+#include <string.h>
 
 #include "iree/base/config.h"
 #include "iree/base/target_platform.h"
@@ -55,6 +57,178 @@
   return (value + (alignment - 1)) & ~(alignment - 1);
 }
 
+// Returns the size of a struct padded out to iree_max_align_t.
+// This must be used when performing manual trailing allocation packing to
+// ensure the alignment requirements of the trailing data are satisified.
+//
+// NOTE: do not use this if using VLAs (`struct { int trailing[]; }`) - those
+// must precisely follow the normal sizeof(t) as the compiler does the padding
+// for you.
+//
+// Example:
+//  some_buffer_ptr_t* p = NULL;
+//  iree_host_size_t total_size = iree_sizeof_struct(*buffer) + extra_data_size;
+//  IREE_CHECK_OK(iree_allocator_malloc(allocator, total_size, (void**)&p));
+#define iree_sizeof_struct(t) iree_host_align(sizeof(t), iree_max_align_t)
+
+//===----------------------------------------------------------------------===//
+// Alignment-safe memory accesses
+//===----------------------------------------------------------------------===//
+
+// Map little-endian byte indices in memory to the host memory order indices.
+#if defined(IREE_ENDIANNESS_LITTLE)
+#define IREE_LE_IDX_1(i) (i)
+#define IREE_LE_IDX_2(i) (i)
+#define IREE_LE_IDX_4(i) (i)
+#define IREE_LE_IDX_8(i) (i)
+#else
+#define IREE_LE_IDX_1(i) (i)
+#define IREE_LE_IDX_2(i) (1 - (i))
+#define IREE_LE_IDX_4(i) (3 - (i))
+#define IREE_LE_IDX_8(i) (7 - (i))
+#endif  // IREE_ENDIANNESS_*
+
+#if IREE_MEMORY_ACCESS_ALIGNMENT_REQUIRED
+
+static inline uint8_t iree_unaligned_load_le_u8(const uint8_t* ptr) {
+  return *ptr;
+}
+static inline uint16_t iree_unaligned_load_le_u16(const uint16_t* ptr) {
+  const uint8_t* p = (const uint8_t*)ptr;
+  return ((uint16_t)p[IREE_LE_IDX_2(0)]) | ((uint16_t)p[IREE_LE_IDX_2(1)] << 8);
+}
+static inline uint32_t iree_unaligned_load_le_u32(const uint32_t* ptr) {
+  const uint8_t* p = (const uint8_t*)ptr;
+  return ((uint32_t)p[IREE_LE_IDX_4(0)]) |
+         ((uint32_t)p[IREE_LE_IDX_4(1)] << 8) |
+         ((uint32_t)p[IREE_LE_IDX_4(2)] << 16) |
+         ((uint32_t)p[IREE_LE_IDX_4(3)] << 24);
+}
+static inline uint64_t iree_unaligned_load_le_u64(const uint64_t* ptr) {
+  const uint8_t* p = (const uint8_t*)ptr;
+  return ((uint64_t)p[IREE_LE_IDX_8(0)]) |
+         ((uint64_t)p[IREE_LE_IDX_8(1)] << 8) |
+         ((uint64_t)p[IREE_LE_IDX_8(2)] << 16) |
+         ((uint64_t)p[IREE_LE_IDX_8(3)] << 24) |
+         ((uint64_t)p[IREE_LE_IDX_8(4)] << 32) |
+         ((uint64_t)p[IREE_LE_IDX_8(5)] << 40) |
+         ((uint64_t)p[IREE_LE_IDX_8(6)] << 48) |
+         ((uint64_t)p[IREE_LE_IDX_8(7)] << 56);
+}
+static inline float iree_unaligned_load_le_f32(const float* ptr) {
+  uint32_t uint_value = iree_unaligned_load_le_u32((const uint32_t*)ptr);
+  float value;
+  memcpy(&value, &uint_value, sizeof(value));
+  return value;
+}
+static inline double iree_unaligned_load_le_f64(const double* ptr) {
+  uint64_t uint_value = iree_unaligned_load_le_u64((const uint64_t*)ptr);
+  double value;
+  memcpy(&value, &uint_value, sizeof(value));
+  return value;
+}
+
+static inline void iree_unaligned_store_le_u8(uint8_t* ptr, uint8_t value) {
+  *ptr = value;
+}
+static inline void iree_unaligned_store_le_u16(uint16_t* ptr, uint16_t value) {
+  uint8_t* p = (uint8_t*)ptr;
+  p[IREE_LE_IDX_2(0)] = value;
+  p[IREE_LE_IDX_2(1)] = value >> 8;
+}
+static inline void iree_unaligned_store_le_u32(uint32_t* ptr, uint32_t value) {
+  uint8_t* p = (uint8_t*)ptr;
+  p[IREE_LE_IDX_4(0)] = value;
+  p[IREE_LE_IDX_4(1)] = value >> 8;
+  p[IREE_LE_IDX_4(2)] = value >> 16;
+  p[IREE_LE_IDX_4(3)] = value >> 24;
+}
+static inline void iree_unaligned_store_le_u64(uint64_t* ptr, uint64_t value) {
+  uint8_t* p = (uint8_t*)ptr;
+  p[IREE_LE_IDX_8(0)] = value;
+  p[IREE_LE_IDX_8(1)] = value >> 8;
+  p[IREE_LE_IDX_8(2)] = value >> 16;
+  p[IREE_LE_IDX_8(3)] = value >> 24;
+  p[IREE_LE_IDX_8(4)] = value >> 32;
+  p[IREE_LE_IDX_8(5)] = value >> 40;
+  p[IREE_LE_IDX_8(6)] = value >> 48;
+  p[IREE_LE_IDX_8(7)] = value >> 56;
+}
+static inline void iree_unaligned_store_le_f32(float* ptr, float value) {
+  uint32_t uint_value;
+  memcpy(&uint_value, &value, sizeof(value));
+  iree_unaligned_store_le_u32((uint32_t*)ptr, uint_value);
+}
+static inline void iree_unaligned_store_le_f64(double* ptr, double value) {
+  uint64_t uint_value;
+  memcpy(&uint_value, &value, sizeof(value));
+  iree_unaligned_store_le_u64((uint64_t*)ptr, uint_value);
+}
+
+#else
+
+#if defined(IREE_ENDIANNESS_LITTLE)
+
+#define iree_unaligned_load_le_u8(ptr) *(ptr)
+#define iree_unaligned_load_le_u16(ptr) *(ptr)
+#define iree_unaligned_load_le_u32(ptr) *(ptr)
+#define iree_unaligned_load_le_u64(ptr) *(ptr)
+#define iree_unaligned_load_le_f32(ptr) *(ptr)
+#define iree_unaligned_load_le_f64(ptr) *(ptr)
+
+#define iree_unaligned_store_le_u8(ptr, value) *(ptr) = (value)
+#define iree_unaligned_store_le_u16(ptr, value) *(ptr) = (value)
+#define iree_unaligned_store_le_u32(ptr, value) *(ptr) = (value)
+#define iree_unaligned_store_le_u64(ptr, value) *(ptr) = (value)
+#define iree_unaligned_store_le_f32(ptr, value) *(ptr) = (value)
+#define iree_unaligned_store_le_f64(ptr, value) *(ptr) = (value)
+
+#else
+
+#error "TODO(benvanik): little-endian load/store for big-endian archs"
+
+#endif  // IREE_ENDIANNESS_*
+
+#endif  // IREE_MEMORY_ACCESS_ALIGNMENT_REQUIRED
+
+// clang-format off
+
+// Dereferences |ptr| and returns the value.
+// Automatically handles unaligned accesses on architectures that may not
+// support them natively (or efficiently). Memory is treated as little-endian.
+#define iree_unaligned_load_le(ptr)                                               \
+  _Generic((ptr),                                                              \
+        int8_t*: iree_unaligned_load_le_u8((const uint8_t*)(ptr)),             \
+       uint8_t*: iree_unaligned_load_le_u8((const uint8_t*)(ptr)),             \
+       int16_t*: iree_unaligned_load_le_u16((const uint16_t*)(ptr)),           \
+      uint16_t*: iree_unaligned_load_le_u16((const uint16_t*)(ptr)),           \
+       int32_t*: iree_unaligned_load_le_u32((const uint32_t*)(ptr)),           \
+      uint32_t*: iree_unaligned_load_le_u32((const uint32_t*)(ptr)),           \
+       int64_t*: iree_unaligned_load_le_u64((const uint64_t*)(ptr)),           \
+      uint64_t*: iree_unaligned_load_le_u64((const uint64_t*)(ptr)),           \
+         float*: iree_unaligned_load_le_f32((const float*)(ptr)),              \
+        double*: iree_unaligned_load_le_f64((const double*)(ptr))              \
+  )
+
+// Dereferences |ptr| and writes the given |value|.
+// Automatically handles unaligned accesses on architectures that may not
+// support them natively (or efficiently). Memory is treated as little-endian.
+#define iree_unaligned_store(ptr, value)                                       \
+  _Generic((ptr),                                                              \
+        int8_t*: iree_unaligned_store_le_u8((uint8_t*)(ptr), value),           \
+       uint8_t*: iree_unaligned_store_le_u8((uint8_t*)(ptr), value),           \
+       int16_t*: iree_unaligned_store_le_u16((uint16_t*)(ptr), value),         \
+      uint16_t*: iree_unaligned_store_le_u16((uint16_t*)(ptr), value),         \
+       int32_t*: iree_unaligned_store_le_u32((uint32_t*)(ptr), value),         \
+      uint32_t*: iree_unaligned_store_le_u32((uint32_t*)(ptr), value),         \
+       int64_t*: iree_unaligned_store_le_u64((uint64_t*)(ptr), value),         \
+      uint64_t*: iree_unaligned_store_le_u64((uint64_t*)(ptr), value),         \
+         float*: iree_unaligned_store_le_f32((float*)(ptr), value),            \
+        double*: iree_unaligned_store_le_f64((double*)(ptr), value)            \
+  )
+
+// clang-format on
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif
diff --git a/iree/base/internal/threading_darwin.c b/iree/base/internal/threading_darwin.c
index 0d9f4fd..7f3bd00 100644
--- a/iree/base/internal/threading_darwin.c
+++ b/iree/base/internal/threading_darwin.c
@@ -82,7 +82,7 @@
   // (including the user-specified entry_arg).
   iree_thread_t* thread = NULL;
   iree_status_t status =
-      iree_allocator_malloc(allocator, sizeof(iree_thread_t), (void**)&thread);
+      iree_allocator_malloc(allocator, sizeof(*thread), (void**)&thread);
   if (!iree_status_is_ok(status)) {
     IREE_TRACE_ZONE_END(z0);
     return status;
diff --git a/iree/base/internal/threading_pthreads.c b/iree/base/internal/threading_pthreads.c
index ed7c59f..f3d66a8 100644
--- a/iree/base/internal/threading_pthreads.c
+++ b/iree/base/internal/threading_pthreads.c
@@ -124,7 +124,7 @@
   // (including the user-specified entry_arg).
   iree_thread_t* thread = NULL;
   iree_status_t status =
-      iree_allocator_malloc(allocator, sizeof(iree_thread_t), (void**)&thread);
+      iree_allocator_malloc(allocator, sizeof(*thread), (void**)&thread);
   if (!iree_status_is_ok(status)) {
     IREE_TRACE_ZONE_END(z0);
     return status;
diff --git a/iree/base/internal/threading_win32.c b/iree/base/internal/threading_win32.c
index f6e5d24..6e550e3 100644
--- a/iree/base/internal/threading_win32.c
+++ b/iree/base/internal/threading_win32.c
@@ -139,7 +139,7 @@
   // (including the user-specified entry_arg).
   iree_thread_t* thread = NULL;
   iree_status_t status =
-      iree_allocator_malloc(allocator, sizeof(iree_thread_t), (void**)&thread);
+      iree_allocator_malloc(allocator, sizeof(*thread), (void**)&thread);
   if (!iree_status_is_ok(status)) {
     IREE_TRACE_ZONE_END(z0);
     return status;
@@ -200,7 +200,11 @@
 
   iree_thread_resume(thread);
 
-  WaitForSingleObject(thread->handle, INFINITE);
+  if (thread->id != GetCurrentThreadId()) {
+    // Join with the thread. Since threads can delete themselves we must ensure
+    // they don't try to join with themselves and deadlock.
+    WaitForSingleObject(thread->handle, INFINITE);
+  }
   CloseHandle(thread->handle);
   iree_thread_override_list_deinitialize(&thread->qos_override_list);
   iree_allocator_free(thread->allocator, thread);
diff --git a/iree/base/internal/wait_handle_poll.c b/iree/base/internal/wait_handle_poll.c
index 805a19d..5dd51fc 100644
--- a/iree/base/internal/wait_handle_poll.c
+++ b/iree/base/internal/wait_handle_poll.c
@@ -148,10 +148,10 @@
   }
 
   iree_host_size_t user_handle_list_size =
-      capacity * sizeof(iree_wait_handle_t);
+      capacity * iree_sizeof_struct(iree_wait_handle_t);
   iree_host_size_t poll_fd_list_size = capacity * sizeof(struct pollfd);
-  iree_host_size_t total_size =
-      sizeof(iree_wait_set_t) + user_handle_list_size + poll_fd_list_size;
+  iree_host_size_t total_size = iree_sizeof_struct(iree_wait_set_t) +
+                                user_handle_list_size + poll_fd_list_size;
 
   iree_wait_set_t* set = NULL;
   IREE_RETURN_IF_ERROR(
@@ -161,7 +161,8 @@
   iree_wait_set_clear(set);
 
   set->user_handles =
-      (iree_wait_handle_t*)((uint8_t*)set + sizeof(iree_wait_set_t));
+      (iree_wait_handle_t*)((uint8_t*)set +
+                            iree_sizeof_struct(iree_wait_set_t));
   set->poll_fds =
       (struct pollfd*)((uint8_t*)set->user_handles + user_handle_list_size);
 
diff --git a/iree/base/internal/wait_handle_win32.c b/iree/base/internal/wait_handle_win32.c
index 0a5bb56..af17e4a 100644
--- a/iree/base/internal/wait_handle_win32.c
+++ b/iree/base/internal/wait_handle_win32.c
@@ -147,8 +147,8 @@
   iree_host_size_t user_handle_list_size =
       capacity * sizeof(iree_wait_handle_t);
   iree_host_size_t native_handle_list_size = capacity * sizeof(HANDLE);
-  iree_host_size_t total_size =
-      sizeof(iree_wait_set_t) + user_handle_list_size + native_handle_list_size;
+  iree_host_size_t total_size = iree_sizeof_struct(iree_wait_set_t) +
+                                user_handle_list_size + native_handle_list_size;
 
   iree_wait_set_t* set = NULL;
   IREE_RETURN_IF_ERROR(
@@ -158,7 +158,8 @@
   iree_wait_set_clear(set);
 
   set->user_handles =
-      (iree_wait_handle_t*)((uint8_t*)set + sizeof(iree_wait_set_t));
+      (iree_wait_handle_t*)((uint8_t*)set +
+                            iree_sizeof_struct(iree_wait_set_t));
   set->native_handles =
       (HANDLE*)((uint8_t*)set->user_handles + user_handle_list_size);
 
diff --git a/iree/base/target_platform.h b/iree/base/target_platform.h
index 4277d71..8052967 100644
--- a/iree/base/target_platform.h
+++ b/iree/base/target_platform.h
@@ -30,6 +30,8 @@
 // IREE_ENDIANNESS_LITTLE
 // IREE_ENDIANNESS_BIG
 //
+// IREE_MEMORY_ACCESS_ALIGNMENT_REQUIRED (0/1)
+//
 // IREE_COMPILER_CLANG
 // IREE_COMPILER_GCC
 // IREE_COMPILER_GCC_COMPAT
@@ -138,6 +140,39 @@
 #endif  // __BYTE_ORDER__
 
 //==============================================================================
+// IREE_MEMORY_ACCESS_*
+//==============================================================================
+// Certain architectures have specific memory access requirements that require
+// user-mode code changes to work at all or work at reasonable performance.
+
+#if !defined(IREE_MEMORY_ACCESS_ALIGNMENT_REQUIRED)
+
+#if defined(IREE_ARCH_ARM_32) || defined(IREE_ARCH_ARM_64)
+
+// Armv6‑M and Armv8-M (w/o the main extension) do not support unaligned access.
+// The -munaligned-access and -mno-unaligned-access flags control this.
+// https://www.keil.com/support/man/docs/armclang_ref/armclang_ref_sam1444138667173.htm
+#if !defined(__ARM_FEATURE_UNALIGNED)
+#define IREE_MEMORY_ACCESS_ALIGNMENT_REQUIRED 1
+#else
+#define IREE_MEMORY_ACCESS_ALIGNMENT_REQUIRED 0
+#endif  // !__ARM_FEATURE_UNALIGNED
+
+#elif defined(IREE_ARCH_RISCV_32) || defined(IREE_ARCH_RISCV_64)
+
+// Though unaligned access is part of the base spec it is allowed to be
+// implemented with trap handlers. Bare-metal systems likely won't have these
+// handlers and even on systems that do (linux) we don't want to be trapping for
+// every load/store.
+#define IREE_MEMORY_ACCESS_ALIGNMENT_REQUIRED 1
+
+#endif  // IREE_ARCH_*
+
+#else
+#define IREE_MEMORY_ACCESS_ALIGNMENT_REQUIRED 0
+#endif  // !IREE_MEMORY_ACCESS_ALIGNMENT_REQUIRED
+
+//==============================================================================
 // IREE_COMPILER_*
 //==============================================================================
 
diff --git a/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp b/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
index b79eac9..7c3cbb7 100644
--- a/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
+++ b/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
@@ -13,6 +13,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include <tuple>
+
 #include "iree/compiler/Codegen/PassDetail.h"
 #include "iree/compiler/Codegen/Passes.h"
 #include "iree/compiler/Codegen/Utils/MarkerUtils.h"
@@ -51,15 +53,19 @@
 // Resource utilities
 //===----------------------------------------------------------------------===//
 
-/// Inserts a resource evariable of the given `type` at the beginning of
+/// Map from hal.interface.binding.subspan ops to their corresponding
+/// spv.GlobalVariable ops.
+using InterfaceResourceMap =
+    llvm::DenseMap<Operation *, spirv::GlobalVariableOp>;
+
+/// Creates a resource evariable of the given `type` at the beginning of
 /// `moduleOp`'s block via `symbolTable` and bind it to `set` and `binding`.
-spirv::GlobalVariableOp insertResourceVariable(Location loc, Type type,
+spirv::GlobalVariableOp createResourceVariable(Location loc, Type type,
                                                unsigned set, unsigned binding,
                                                bool alias, ModuleOp moduleOp,
-                                               SymbolTable *symbolTable,
-                                               OpBuilder::Listener *listener) {
-  OpBuilder builder(moduleOp.getContext(), listener);
+                                               SymbolTable *symbolTable) {
   std::string name = llvm::formatv("__resource_var_{0}_{1}_", set, binding);
+  OpBuilder builder(moduleOp.getContext());
   auto variable =
       builder.create<spirv::GlobalVariableOp>(loc, type, name, set, binding);
   if (alias) variable->setAttr("aliased", builder.getUnitAttr());
@@ -74,35 +80,66 @@
   return {bindingOp.set().getSExtValue(), bindingOp.binding().getSExtValue()};
 }
 
-/// Returns the set of resources that should be marked as aliased in SPIR-V.
-llvm::DenseSet<Operation *> getAliasedResources(ModuleOp module) {
-  llvm::DenseSet<Operation *> aliasedResources;
+/// Scans all hal.interface.binding.subspan ops in `module`, creates their
+/// corresponding spv.GlobalVariables when needed, and returns the map.
+/// The created variables need to have their types fixed later.
+InterfaceResourceMap createResourceVariables(mlir::ModuleOp module) {
+  SymbolTable symbolTable(module);
+  InterfaceResourceMap interfaceToResourceVars;
 
-  for (FuncOp func : module.getOps<FuncOp>()) {
+  auto fns = llvm::to_vector<1>(module.getOps<FuncOp>());
+  for (FuncOp func : llvm::reverse(fns)) {
     // Collect all interface ops and their (set, binding) pairs in this
-    // function.
-    SmallVector<Operation *, 4> interfaceOps;
-    SmallVector<std::pair<uint32_t, uint32_t>, 4> setBindings;
-    llvm::DenseMap<std::pair<uint32_t, uint32_t>, unsigned> setBindingCount;
+    // function. Use SmallVector here for a deterministic order.
+    SmallVector<IREE::HAL::InterfaceBindingSubspanOp, 8> interfaceOps;
+    SmallVector<std::pair<uint32_t, uint32_t>, 8> setBindings;
+
+    // Use a map to see if we have different types for one (set, binding) pair,
+    // which will require creating multiple SPIR-V global variables.
+    llvm::DenseMap<std::pair<uint32_t, uint32_t>, llvm::DenseSet<Type>>
+        setBindingTypes;
+
     func.walk([&](Operation *op) {
-      if (isa<IREE::HAL::InterfaceBindingSubspanOp>(op)) {
-        interfaceOps.emplace_back(op);
-        setBindings.emplace_back(getInterfaceSetAndBinding(op));
-        ++setBindingCount[setBindings.back()];
-      }
+      auto interfaceOp = dyn_cast<IREE::HAL::InterfaceBindingSubspanOp>(op);
+      if (!interfaceOp || interfaceOp.use_empty()) return;
+      interfaceOps.emplace_back(interfaceOp);
+      setBindings.emplace_back(getInterfaceSetAndBinding(interfaceOp));
+      setBindingTypes[setBindings.back()].insert(interfaceOp.getType());
     });
 
-    // Perform analysis to determine whether we need to mark the resource as
-    // alias. This should happen when we have multiple resources binding to the
-    // same (set, binding) pair and they are used in the same function.
-    for (unsigned i = 0; i < interfaceOps.size(); ++i) {
-      if (setBindingCount[setBindings[i]] > 1) {
-        aliasedResources.insert(interfaceOps[i]);
+    // Keep track of created SPIR-V global variables. This allows us to
+    // deduplicate when possible to reduce generated SPIR-V blob size.
+    llvm::DenseMap<std::tuple<uint32_t, uint32_t, Type>,
+                   spirv::GlobalVariableOp>
+        resourceVars;
+
+    for (int i = interfaceOps.size() - 1; i >= 0; --i) {
+      auto interfaceOp = interfaceOps[i];
+      const auto &setBinding = setBindings[i];
+
+      auto key = std::make_tuple(setBinding.first, setBinding.second,
+                                 interfaceOp.getType());
+      auto var = resourceVars.lookup(key);
+      if (!var) {
+        // If we have multiple SPIR-V global variables bound to the same (set,
+        // binding) pair and they are used in the same function, those variables
+        // need to have alias decoration.
+        bool alias = setBindingTypes[setBindings[i]].size() > 1;
+
+        // We are using the interface op's type for creating the global
+        // variable. It's fine. The correctness boundary is the pass.
+        // We will fix it up during conversion so it won't leak.
+        var = createResourceVariable(
+            interfaceOp.getLoc(), interfaceOp.getType(), setBinding.first,
+            setBinding.second, alias, module, &symbolTable);
+        resourceVars[key] = var;
       }
+
+      interfaceToResourceVars[interfaceOp] = var;
     }
   }
 
-  return aliasedResources;
+  return interfaceToResourceVars;
 }
 
 }  // namespace
@@ -169,17 +206,19 @@
     : public OpConversionPattern<IREE::HAL::InterfaceBindingSubspanOp> {
   HALInterfaceBindingSubspanConverter(
       TypeConverter &typeConverter, MLIRContext *context,
-      const llvm::DenseSet<Operation *> &aliasedResources,
-      SymbolTable *symbolTable, PatternBenefit benefit = 1)
+      const InterfaceResourceMap &interfaceToResourceVars,
+      PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
-        aliasedResources(aliasedResources),
-        symbolTable(symbolTable) {}
+        interfaceToResourceVars(interfaceToResourceVars) {}
 
   LogicalResult matchAndRewrite(
       IREE::HAL::InterfaceBindingSubspanOp interfaceOp,
       ArrayRef<Value> operands,
       ConversionPatternRewriter &rewriter) const override {
-    auto moduleOp = interfaceOp->template getParentOfType<ModuleOp>();
+    if (interfaceOp.use_empty()) {
+      rewriter.eraseOp(interfaceOp);
+      return success();
+    }
 
     Type resultType = interfaceOp.getOperation()->getResult(0).getType();
     Type convertedType = this->getTypeConverter()->convertType(resultType);
@@ -187,22 +226,18 @@
       return interfaceOp.emitError()
              << "failed to convert SPIR-V type: " << resultType;
     }
-    auto setAndBinding = getInterfaceSetAndBinding(interfaceOp.getOperation());
 
-    // We always create a new resource variable for the interface.
-    spirv::GlobalVariableOp varOp = insertResourceVariable(
-        interfaceOp.getLoc(), convertedType, setAndBinding.first,
-        setAndBinding.second,
-        aliasedResources.contains(interfaceOp.getOperation()), moduleOp,
-        symbolTable, rewriter.getListener());
+    auto varOp = interfaceToResourceVars.lookup(interfaceOp);
+    // Fix up the variable's type.
+    varOp.typeAttr(TypeAttr::get(convertedType));
 
     rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(interfaceOp, varOp);
+
     return success();
   }
 
  private:
-  const llvm::DenseSet<Operation *> &aliasedResources;
-  SymbolTable *symbolTable;
+  const InterfaceResourceMap &interfaceToResourceVars;
 };
 
 /// Pattern to lower operations that become a no-ops at this level.
@@ -307,18 +342,12 @@
           IREE::HAL::InterfaceWorkgroupCountOp, spirv::BuiltIn::NumWorkgroups>>(
       typeConverter, context);
 
-  // Create a symbol table for the current module op. This must be done before
-  // starting conversion. The conversion framework has a deferred nature.
-  // Actions like replacing/deleting ops are just recorded instead of directly
-  // applied. So for function conversion, we will see both the original function
-  // and the new replacement function during conversion process. That causes an
-  // issue for creating symbol tables, which scans all symbols and errors out
-  // when duplicates are found. We should be fine here tough becuase we only use
-  // this symbol table to insert new SPIR-V global variables.
-  SymbolTable symbolTable(moduleOp);
-  auto aliasedResources = getAliasedResources(moduleOp);
-  patterns.insert<HALInterfaceBindingSubspanConverter>(
-      typeConverter, context, aliasedResources, &symbolTable);
+  // Performs a prelimiary step to analyze all hal.interface.binding.subspan ops
+  // and create spv.GlobalVariables.
+  auto interfaceToResourceVars = createResourceVariables(moduleOp);
+  // For using use them in conversion.
+  patterns.insert<HALInterfaceBindingSubspanConverter>(typeConverter, context,
+                                                       interfaceToResourceVars);
 
   /// Fold certain operations as no-ops:
   /// - linalg.reshape becomes a no-op since all memrefs are linearized in
diff --git a/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir b/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir
index abc817f..a7d78af 100644
--- a/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir
@@ -24,20 +24,43 @@
 
 module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
   // CHECK-LABEL: spv.module
-  // CHECK: spv.GlobalVariable @__resource_var_3_4_ bind(3, 4) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
-  // CHECK: spv.GlobalVariable @__resource_var_1_2__0 bind(1, 2) {aliased} : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
-  // CHECK: spv.GlobalVariable @__resource_var_1_2_ bind(1, 2) {aliased} : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+  // CHECK: spv.GlobalVariable @[[ARG0:.+]] bind(1, 2) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+  // CHECK: spv.GlobalVariable @[[ARG1_0:.+]] bind(1, 3) {aliased} : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+  // CHECK: spv.GlobalVariable @[[ARG1_1:.+]] bind(1, 3) {aliased} : !spv.ptr<!spv.struct<(!spv.array<4 x vector<4xf32>, stride=16> [0])>, StorageBuffer>
+  // CHECK: spv.GlobalVariable @[[RET0:.+]] bind(3, 4) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
   // CHECK: spv.func @resource_bindings_in_same_entry_func()
   func @resource_bindings_in_same_entry_func() {
     %c0 = constant 0 : index
+
+    // Same type
+    // CHECK: spv.mlir.addressof @[[ARG0]]
+    // CHECK: spv.mlir.addressof @[[ARG0]]
     %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<4x4xf32>
     %1 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<4x4xf32>
-    %2 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<4x4xf32>
+
+    // Different type
+    // CHECK: spv.mlir.addressof @[[ARG1_0]]
+    // CHECK: spv.mlir.addressof @[[ARG1_1]]
+    %2 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<4x4xf32>
+    %3 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<4xvector<4xf32>>
+
+    // CHECK: spv.mlir.addressof @[[RET0]]
+    %4 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<4x4xf32>
+
+    %5 = memref.load %0[%c0, %c0] : memref<4x4xf32>
+    %6 = memref.load %1[%c0, %c0] : memref<4x4xf32>
+
+    %7 = memref.load %2[%c0, %c0] : memref<4x4xf32>
+    %8 = memref.load %3[%c0] : memref<4xvector<4xf32>>
+
+    %9 = memref.load %4[%c0, %c0] : memref<4x4xf32>
+
     return
   }
 
   hal.interface @io attributes {push_constants = 5 : index, sym_visibility = "private"} {
     hal.interface.binding @arg0, set=1, binding=2, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=1, binding=3, type="StorageBuffer", access="Read"
     hal.interface.binding @ret0, set=3, binding=4, type="StorageBuffer", access="Write"
   }
 }
@@ -46,18 +69,22 @@
 
 module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
   // CHECK-LABEL: spv.module
-  // CHECK: spv.GlobalVariable @[[FUNC2_RET:.+]] bind(3, 4) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
-  // CHECK: spv.GlobalVariable @[[FUNC2_ARG:.+]] bind(1, 2) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
-  // CHECK: spv.GlobalVariable @[[FUNC1_RET:.+]] bind(3, 4) : !spv.ptr<!spv.struct<(!spv.array<4 x vector<4xf32>, stride=16> [0])>, StorageBuffer>
   // CHECK: spv.GlobalVariable @[[FUNC1_ARG:.+]] bind(1, 2) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+  // CHECK: spv.GlobalVariable @[[FUNC1_RET:.+]] bind(3, 4) : !spv.ptr<!spv.struct<(!spv.array<4 x vector<4xf32>, stride=16> [0])>, StorageBuffer>
+  // CHECK: spv.GlobalVariable @[[FUNC2_ARG:.+]] bind(1, 2) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+  // CHECK: spv.GlobalVariable @[[FUNC2_RET:.+]] bind(3, 4) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
 
   // CHECK: spv.func @resource_bindings_in_entry_func1()
   func @resource_bindings_in_entry_func1() {
-    // CHECK: spv.mlir.addressof @[[FUNC1_ARG:.+]]
-    // CHECK: spv.mlir.addressof @[[FUNC1_RET:.+]]
+    // CHECK: spv.mlir.addressof @[[FUNC1_ARG]]
+    // CHECK: spv.mlir.addressof @[[FUNC1_RET]]
     %c0 = constant 0 : index
     %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<4x4xf32>
     %1 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<4xvector<4xf32>>
+
+    %2 = memref.load %0[%c0, %c0] : memref<4x4xf32>
+    %3 = memref.load %1[%c0] : memref<4xvector<4xf32>>
+
     return
   }
 
@@ -66,8 +93,12 @@
     // CHECK: spv.mlir.addressof @[[FUNC2_ARG]]
     // CHECK: spv.mlir.addressof @[[FUNC2_RET]]
     %c0 = constant 0 : index
-    %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<4x4xf32>
-    %1 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<4x4xf32>
+    %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<4x4xf32> // Same type as previous function
+    %1 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<4x4xf32> // Different type as previous function
+
+    %2 = memref.load %0[%c0, %c0] : memref<4x4xf32>
+    %3 = memref.load %1[%c0, %c0] : memref<4x4xf32>
+
     return
   }
 
@@ -85,6 +116,11 @@
     %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<8x5xf32>
     %1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<5xf32>
     %2 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<8x5xf32>
+
+    %3 = memref.load %0[%c0, %c0] : memref<8x5xf32>
+    %4 = memref.load %1[%c0] : memref<5xf32>
+    %5 = memref.load %2[%c0, %c0] : memref<8x5xf32>
+
     return
   }
   hal.interface @io attributes {sym_visibility = "private"} {
@@ -93,14 +129,17 @@
     hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
   }
 }
+
+// Explicitly check the variable symbols
+
 // CHECK-LABEL: spv.module
-//   CHECK-DAG:   spv.GlobalVariable @[[RET0:.+]] bind(0, 2)
-//   CHECK-DAG:   spv.GlobalVariable @[[ARG1:.+]] bind(0, 1)
-//   CHECK-DAG:   spv.GlobalVariable @[[ARG0:.+]] bind(0, 0)
+//       CHECK:   spv.GlobalVariable @__resource_var_0_0_ bind(0, 0)
+//       CHECK:   spv.GlobalVariable @__resource_var_0_1_ bind(0, 1)
+//       CHECK:   spv.GlobalVariable @__resource_var_0_2_ bind(0, 2)
 //       CHECK:   spv.func
-//   CHECK-DAG:   %{{.+}} = spv.mlir.addressof @[[RET0]]
-//   CHECK-DAG:   %{{.+}} = spv.mlir.addressof @[[ARG0]]
-//   CHECK-DAG:   %{{.+}} = spv.mlir.addressof @[[ARG1]]
+//       CHECK:   %{{.+}} = spv.mlir.addressof @__resource_var_0_0_
+//       CHECK:   %{{.+}} = spv.mlir.addressof @__resource_var_0_1_
+//       CHECK:   %{{.+}} = spv.mlir.addressof @__resource_var_0_2_
 
 // -----
 
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index 8bf33ff..ce2f8cb 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -695,6 +695,43 @@
   }
 };
 
+// Replace `flow.tensor.splat`-`flow.tensor.load` op-pairs by the input
+// primitive value for the splat op.
+struct FoldSplatLoadIntoPrimitive : public OpRewritePattern<TensorLoadOp> {
+  using OpRewritePattern<TensorLoadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TensorLoadOp loadOp,
+                                PatternRewriter &rewriter) const override {
+    auto sourceOp =
+        dyn_cast_or_null<TensorSplatOp>(loadOp.source().getDefiningOp());
+
+    if (!sourceOp) return failure();
+
+    rewriter.replaceOp(loadOp, sourceOp.value());
+    return success();
+  }
+};
+
+struct FoldSplatReshapeIntoSplat : public OpRewritePattern<TensorSplatOp> {
+  using OpRewritePattern<TensorSplatOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TensorSplatOp splatOp,
+                                PatternRewriter &rewriter) const override {
+    if (!splatOp.result().hasOneUse()) return failure();
+
+    auto reshapeOp = dyn_cast_or_null<TensorReshapeOp>(
+        splatOp.result().use_begin()->getOwner());
+    if (!reshapeOp) return failure();
+
+    rewriter.replaceOpWithNewOp<TensorSplatOp>(
+        reshapeOp, reshapeOp.result().getType(), splatOp.value(),
+        reshapeOp.result_dims());
+    rewriter.eraseOp(splatOp);
+
+    return success();
+  }
+};
+
 }  // namespace
 
 void TensorReshapeOp::getCanonicalizationPatterns(
@@ -702,6 +739,11 @@
   results.insert<FlattenTensorReshapeChain>(context);
 }
 
+void TensorLoadOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<FoldSplatLoadIntoPrimitive>(context);
+}
+
 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute> operands) {
   if (auto source = operands[0].dyn_cast_or_null<ElementsAttr>()) {
     // Load directly from the constant source tensor.
@@ -739,6 +781,12 @@
   return {};
 }
 
+void TensorSplatOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  // TODO(benvanik): canonicalize splat+slice to smaller splat.
+  results.insert<FoldSplatReshapeIntoSplat>(context);
+}
+
 OpFoldResult TensorSplatOp::fold(ArrayRef<Attribute> operands) {
   if (operands.size() == 1 && operands.front()) {
     // Splat value is constant and we can fold the operation.
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.td b/iree/compiler/Dialect/Flow/IR/FlowOps.td
index 3b14102..cd1b2f4 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -85,7 +85,7 @@
   }];
 
   let arguments = (ins
-    FLOW_VariableRefAttr:$variable
+    Arg<FLOW_VariableRefAttr, "", [MemRead]>:$variable
   );
   let results = (outs
     AnyType:$result
@@ -135,7 +135,7 @@
 
   let arguments = (ins
     AnyType:$value,
-    FLOW_VariableRefAttr:$variable
+    Arg<FLOW_VariableRefAttr, "", [MemWrite]>:$variable
   );
 
   let assemblyFormat = "$value `,` $variable attr-dict `:` type($value)";
@@ -822,6 +822,7 @@
 
   // TODO(benvanik): canonicalize to slice+load if dims are known.
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 def FLOW_TensorStoreOp : FLOW_PureOp<"tensor.store", [
@@ -899,7 +900,7 @@
     bool isTransfer() { return true; }
   }];
 
-  // TODO(benvanik): canonicalize splat+slice to smaller splat.
+  let hasCanonicalizer = 1;
   let hasFolder = 1;
 }
 
diff --git a/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
index 62cab58..724b8e5 100644
--- a/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
+++ b/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
@@ -411,3 +411,34 @@
   // CHECK: return %[[RESULT]]
   return %1 : tensor<?x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @foldSplatLoadIntoPrimitive
+//  CHECK-SAME: (%[[arg0:.+]]: f32, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
+func @foldSplatLoadIntoPrimitive(%arg0 : f32, %arg1 : index, %arg2 : index) -> f32 {
+  // CHECK-NEXT: return %[[arg0]] : f32
+  %0 = flow.tensor.splat %arg0 : tensor<4x4xf32>
+  %1 = flow.tensor.load %0[%arg1, %arg2] : tensor<4x4xf32>
+  return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @foldSplatReshapeIntoSplat
+func @foldSplatReshapeIntoSplat(%arg0 : f32) -> tensor<16xf32> {
+  // CHECK-NEXT: %0 = flow.tensor.splat %arg0 : tensor<16xf32>
+  // CHECK-NEXT: return %0 : tensor<16xf32>
+  %0 = flow.tensor.splat %arg0 : tensor<4x4xf32>
+  %1 = flow.tensor.reshape %0 : tensor<4x4xf32> -> tensor<16xf32>
+  return %1 : tensor<16xf32>
+}
+
+// CHECK-LABEL: @foldSplatReshapeIntoSplatDynamic
+func @foldSplatReshapeIntoSplatDynamic(%arg0 : f32, %arg1 : index, %arg2 : index, %arg3 : index) -> tensor<?x?xf32> {
+  // CHECK-NEXT: %0 = flow.tensor.splat %arg0 : tensor<?x?xf32>{%arg2, %arg3}
+  // CHECK-NEXT: return %0 : tensor<?x?xf32>
+  %0 = flow.tensor.splat %arg0 : tensor<?x4xf32>{%arg1}
+  %1 = flow.tensor.reshape %0 : tensor<?x4xf32>{%arg1} -> tensor<?x?xf32>{%arg2, %arg3}
+  return %1 : tensor<?x?xf32>
+}
diff --git a/iree/compiler/Dialect/Flow/Transforms/BUILD b/iree/compiler/Dialect/Flow/Transforms/BUILD
index 167a506..8b26a4b 100644
--- a/iree/compiler/Dialect/Flow/Transforms/BUILD
+++ b/iree/compiler/Dialect/Flow/Transforms/BUILD
@@ -53,6 +53,7 @@
         "Passes.cpp",
         "PromoteI1ToI8Pass.cpp",
         "PromoteTensorLoads.cpp",
+        "SimplifyVariableAccesses.cpp",
         "StripAndSplatConstantVariables.cpp",
         "TypeConverter.cpp",
     ],
diff --git a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index 57484a1..e8c00f3 100644
--- a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -50,6 +50,7 @@
     "Passes.cpp"
     "PromoteI1ToI8Pass.cpp"
     "PromoteTensorLoads.cpp"
+    "SimplifyVariableAccesses.cpp"
     "StripAndSplatConstantVariables.cpp"
     "TypeConverter.cpp"
   DEPS
diff --git a/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp b/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp
index 69249fc..81483a4 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp
@@ -14,6 +14,7 @@
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
 #include "iree/compiler/Dialect/HAL/IR/HALOps.h"
 #include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
@@ -55,17 +56,28 @@
 };
 
 // TODO(nicolasvasilache): Use some interface instead of op names directly.
-static bool hasDestructiveUpdateSubTensorUses(
-    BlockArgument arg, SpecialTerminatorOpCapture &capture) {
+static bool hasDestructiveUpdateUses(BlockArgument arg,
+                                     SpecialTerminatorOpCapture &capture) {
   SmallVector<Operation *> reads;
-  SmallVector<tensor::InsertSliceOp> writes;
+  SmallVector<Operation *> writes;
   for (OpOperand &u : arg.getUses()) {
-    if (auto subTensorInsertOp =
-            dyn_cast<tensor::InsertSliceOp>(u.getOwner())) {
-      writes.push_back(subTensorInsertOp);
-    } else {
-      reads.push_back(u.getOwner());
-    }
+    TypeSwitch<Operation *, void>(u.getOwner())
+        .Case<linalg::LinalgOp, linalg_ext::LinalgExtOp>(
+            [&](auto linalgLikeOp) {
+              if (linalgLikeOp.isOutputTensor(&u)) {
+                writes.push_back(linalgLikeOp);
+              } else {
+                reads.push_back(linalgLikeOp);
+              }
+            })
+        .Case<tensor::InsertSliceOp>([&](tensor::InsertSliceOp sliceOp) {
+          if (sliceOp.dest() == u.get()) {
+            writes.push_back(sliceOp);
+          } else {
+            reads.push_back(sliceOp);
+          }
+        })
+        .Default([&](Operation *op) { reads.push_back(op); });
   }
   // For now, only allow exactly a single SubTensorInsertOp that must be
   // dominated by all SubTensorOp.
@@ -77,7 +89,7 @@
     if (!domInfo.properlyDominates(read, writes.front())) {
       LLVM_DEBUG(llvm::dbgs() << "non-destructive use-def: " << *read
                               << " does not properly dominate "
-                              << *(writes.front().getOperation()) << "\n");
+                              << *(writes.front()) << "\n");
       return false;
     }
   }
@@ -133,8 +145,7 @@
     // Case 2: multiple uses from an scf::ForOp then this must be used only by
     // SubTensorOp / SubTensorInsertOp with proper dominance.
     if (!regionArg.hasOneUse()) {
-      if (!hasDestructiveUpdateSubTensorUses(regionArg, capture))
-        return nullptr;
+      if (!hasDestructiveUpdateUses(regionArg, capture)) return nullptr;
       return returnValue;
     }
 
@@ -145,8 +156,7 @@
     // Case 3a: Single use which is not an scf::ForOp, it may still be a
     // single SubTensor / SubTensorInsertOp.
     if (!innerForOp) {
-      if (!hasDestructiveUpdateSubTensorUses(regionArg, capture))
-        return nullptr;
+      if (!hasDestructiveUpdateUses(regionArg, capture)) return nullptr;
       return returnValue;
     }
 
@@ -216,31 +226,67 @@
   return success();
 }
 
-static LogicalResult rewriteSubTensorInsertInPlace(OpBuilder &b,
-                                                   tensor::InsertSliceOp op,
-                                                   Value target) {
-  LLVM_DEBUG(llvm::dbgs() << "RewriteSubTensorInsertInPlace: "
-                          << *(op.getOperation()) << "\n");
-  if (!op.getResult().hasOneUse()) {
-    return op.emitError("not a single use operation");
+template <typename OpTy>
+static LogicalResult rewriteDestructiveUpdateInPlace(OpBuilder &b,
+                                                     OpTy linalgLikeOp,
+                                                     Value target) {
+  LLVM_DEBUG(llvm::dbgs() << "RewriteDestructiveUpdateInPlace: "
+                          << *linalgLikeOp.getOperation() << "\n");
+  if (!linalgLikeOp->hasOneUse()) {
+    return linalgLikeOp.emitError("not a single use operation");
   }
 
-  Operation *user = *(op.getResult().getUsers().begin());
-  if (isa<scf::YieldOp>(user)) {
-    auto dest = op.dest();
-    if (!dest.isa<BlockArgument>()) {
-      return op.emitError("dest is not a argument to the loop");
+  OpOperand &use = *(linalgLikeOp->use_begin());
+  if (isa<scf::YieldOp>(use.getOwner())) {
+    OpResult usedResult = use.get().cast<OpResult>();
+    Value dest =
+        linalgLikeOp.getOutputOperand(usedResult.getResultNumber())->get();
+    if (!dest || !dest.isa<BlockArgument>()) {
+      return linalgLikeOp.emitError("dest is not a argument to the loop");
     }
     OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(op);
+    b.setInsertionPointAfter(linalgLikeOp);
 
     // Kills the SSA use-def chain.
-    op.replaceAllUsesWith(dest);
+    usedResult.replaceAllUsesWith(dest);
+
+    b.create<IREE::Flow::DispatchTensorStoreOp>(linalgLikeOp.getLoc(),
+                                                usedResult, target);
+
+    return success();
+  }
+  return failure();
+}
+
+/// Rewrites destructive in-place updates with the update operation being
+/// tensor.insert_slice.
+template <>
+LogicalResult rewriteDestructiveUpdateInPlace<tensor::InsertSliceOp>(
+    OpBuilder &b, tensor::InsertSliceOp insertSliceOp, Value target) {
+  LLVM_DEBUG(llvm::dbgs() << "RewriteDestructiveUpdateInPlace: "
+                          << *insertSliceOp.getOperation() << "\n");
+  if (!insertSliceOp->hasOneUse()) {
+    return insertSliceOp.emitError("not a single use operation");
+  }
+
+  OpOperand &use = *(insertSliceOp->use_begin());
+  if (isa<scf::YieldOp>(use.getOwner())) {
+    OpResult usedResult = use.get().cast<OpResult>();
+    Value dest = insertSliceOp.dest();
+    if (!dest || !dest.isa<BlockArgument>()) {
+      return insertSliceOp.emitError("dest is not a argument to the loop");
+    }
+    OpBuilder::InsertionGuard g(b);
+    b.setInsertionPoint(insertSliceOp);
+
+    // Kills the SSA use-def chain.
+    usedResult.replaceAllUsesWith(dest);
 
     b.create<IREE::Flow::DispatchTensorStoreOp>(
-        op.getLoc(), op.source(), target, op.offsets(), op.sizes(),
-        op.strides(), op.static_offsets(), op.static_sizes(),
-        op.static_strides());
+        insertSliceOp->getLoc(), insertSliceOp.source(), target,
+        insertSliceOp.offsets(), insertSliceOp.sizes(), insertSliceOp.strides(),
+        insertSliceOp.static_offsets(), insertSliceOp.static_sizes(),
+        insertSliceOp.static_strides());
 
     return success();
   }
@@ -366,11 +412,17 @@
                           << "\n");
 
   // Try to rewrite inplace.
-  if (failed(rewriteSubTensorInsertInPlace(
-          b, cast<tensor::InsertSliceOp>(capture.rootDestructiveUpdate),
-          target))) {
-    return failure();
-  }
+  auto status =
+      TypeSwitch<Operation *, LogicalResult>(capture.rootDestructiveUpdate)
+          .Case<linalg::LinalgOp, linalg_ext::LinalgExtOp,
+                tensor::InsertSliceOp>([&](auto op) {
+            if (failed(rewriteDestructiveUpdateInPlace(b, op, target))) {
+              return failure();
+            }
+            return success();
+          })
+          .Default([&](Operation *) { return failure(); });
+  if (failed(status)) return failure();
 
   if (scf::ForOp loopOp = dyn_cast<scf::ForOp>(outermostProducingOp))
     loopOp.walk(
@@ -415,9 +467,9 @@
                     capture.initValue = op.value();
                     Value sourceValue =
                         isADestructiveUpdatePattern(capture.initValue, capture);
-                    if (!sourceValue || !isa_and_nonnull<tensor::InsertSliceOp>(
-                                            capture.rootDestructiveUpdate))
+                    if (!sourceValue) {
                       return WalkResult::advance();
+                    }
                     if (failed(rewriteDestructiveUpdateInPlace(b, capture,
                                                                op.target()))) {
                       return WalkResult::interrupt();
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index 6185eac..5fca973 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -65,9 +65,14 @@
 namespace Flow {
 
 void buildFlowTransformPassPipeline(OpPassManager &passManager) {
-  // Perform initial cleanup.
-  // NOTE: There is no principled reason to be doing this here. But also ensures
-  // some consistency at the tool boundary.
+  // Simplify flow.variable accesses early on; this can help with dispatch
+  // region formation as redundant store-loads are removed.
+  passManager.addNestedPass<FuncOp>(
+      IREE::Flow::createSimplifyVariableAccessesPass());
+
+  // Perform cleanup after variable simplification as more canonicalizers may be
+  // able to kick in.
+  passManager.addNestedPass<FuncOp>(mlir::createCanonicalizerPass());
   passManager.addNestedPass<FuncOp>(mlir::createCSEPass());
 
   // Replaces variables with !shapex.ranked_shape types with individual
@@ -169,9 +174,11 @@
   // Reorder blocks to increase the grouping of streamable ops.
   passManager.addNestedPass<FuncOp>(
       IREE::Flow::createHoistUnstreamableOpsPass());
+
   // The hoisting pass does some reordering. Canonicalize to avoid unnecessary
   // arbitrary ordering.
   passManager.addNestedPass<FuncOp>(mlir::createCanonicalizerPass());
+  passManager.addNestedPass<FuncOp>(mlir::createCSEPass());
 
   // Clone constants that escape basic blocks until we have better analysis.
   passManager.addNestedPass<FuncOp>(
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.h b/iree/compiler/Dialect/Flow/Transforms/Passes.h
index 5f4403a..445190e 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -131,6 +131,10 @@
 // Reorders blocks to hoist ops that cannot be put into streams.
 std::unique_ptr<OperationPass<FuncOp>> createHoistUnstreamableOpsPass();
 
+// Hoists loads and sinks stores to variables to decrease data dependency
+// regions.
+std::unique_ptr<OperationPass<FuncOp>> createSimplifyVariableAccessesPass();
+
 // TODO(benvanik): cross-function stream flows.
 
 // Inserts clones of constant values where they may be required.
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.td b/iree/compiler/Dialect/Flow/Transforms/Passes.td
index c8fa51c..4dfc814 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -57,7 +57,7 @@
 
 def FormStreams :
     Pass<"iree-flow-form-streams", "FuncOp"> {
-  let summary = "Identifies dispatches that can be grouped into streams within functions";
+  let summary = "Identifies dispatches that can be grouped into streams within functions.";
   let constructor = "mlir::iree_compiler::IREE::Flow::createFormStreamsPass()";
 }
 
@@ -75,7 +75,7 @@
 
 def InjectDispatchTracing :
     Pass<"iree-flow-inject-dispatch-tracing", "FuncOp"> {
-  let summary = "Injects dispatch region tracing";
+  let summary = "Injects dispatch region tracing.";
   let constructor = "mlir::iree_compiler::IREE::Flow::createInjectDispatchTracingPass()";
 }
 
@@ -128,6 +128,12 @@
   let constructor = "mlir::iree_compiler::IREE::Flow::createPromoteTensorLoadsPass()";
 }
 
+def SimplifyVariableAccesses :
+    Pass<"iree-flow-simplify-variable-accesses", "FuncOp"> {
+  let summary = "Hoist loads and sinks stores to variables to decrease data dependency regions.";
+  let constructor = "mlir::iree_compiler::IREE::Flow::createSimplifyVariableAccessesPass()";
+}
+
 def StripAndSplatConstantVariables :
     Pass<"iree-flow-strip-and-splat-constant-variables", "ModuleOp"> {
   let summary = "Strips constant flow.variables and replaces them with splats.";
diff --git a/iree/compiler/Dialect/Flow/Transforms/SimplifyVariableAccesses.cpp b/iree/compiler/Dialect/Flow/Transforms/SimplifyVariableAccesses.cpp
new file mode 100644
index 0000000..f190956
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/SimplifyVariableAccesses.cpp
@@ -0,0 +1,270 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <algorithm>
+#include <iterator>
+
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+
+#define DEBUG_TYPE "iree-flow-simplify-variable-accesses"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Flow {
+
+// Builds symbol ref set for all immutable variables in |moduleOp|.
+static DenseSet<StringRef> gatherImmutableVariables(ModuleOp moduleOp) {
+  DenseSet<StringRef> set;
+  moduleOp.walk([&](IREE::Flow::VariableOp variableOp) {
+    if (!variableOp.is_mutable()) {
+      set.insert(variableOp.sym_name());
+    }
+  });
+  return set;
+}
+
+// Hoists all loads of immutable variables in |funcOp| to the entry block.
+// |immutableVariables| is used for lookups of which variables are immutable.
+static void hoistImmutableLoads(FuncOp funcOp,
+                                DenseSet<StringRef> &immutableVariables) {
+  // Since CSE of loads isn't a thing yet we perform a basic deduping here by
+  // folding all subsequent loads into the first one found. This works only for
+  // immutable variables as otherwise we'd have to ensure stores and
+  // side-effects were properly observed.
+  DenseMap<Attribute, Operation *> loadOps;
+  auto *entryBlock = &funcOp.getBlocks().front();
+  Operation *lastEntryOp = nullptr;
+  for (auto &block : funcOp) {
+    for (auto op : llvm::make_early_inc_range(
+             block.getOps<IREE::Flow::VariableLoadOp>())) {
+      if (immutableVariables.contains(op.variable())) {
+        auto variableRef = op.variableAttr().cast<Attribute>();
+        auto it = loadOps.find(variableRef);
+        if (it == loadOps.end()) {
+          // Move to entry block; even if it's already there (so loads are
+          // hoisted at the same time).
+          LLVM_DEBUG(llvm::dbgs()
+                     << "moving immutable variable " << op.variable()
+                     << " load to the entry block\n");
+          if (lastEntryOp) {
+            op->moveAfter(lastEntryOp);
+          } else {
+            op->moveBefore(entryBlock, entryBlock->begin());
+          }
+          loadOps[variableRef] = op;
+          lastEntryOp = op;
+        } else {
+          LLVM_DEBUG(llvm::dbgs()
+                     << "CSE'ing immutable variable " << op.variable() << "\n");
+          op->replaceAllUsesWith(it->getSecond());
+          op->erase();
+        }
+      }
+    }
+  }
+}
+
+static bool doesOpBlockMotion(Operation *op) {
+  return isa<mlir::CallOpInterface>(op) ||
+         op->hasTrait<OpTrait::IREE::YieldPoint>();
+}
+
+static void moveOpUpInBlock(Block &block, Operation *op) {
+  while (op->getPrevNode()) {
+    if (doesOpBlockMotion(op->getPrevNode())) break;
+    op->moveBefore(op->getPrevNode());
+  }
+}
+
+static void moveOpDownInBlock(Block &block, Operation *op) {
+  while (op->getNextNode() != block.getTerminator()) {
+    if (doesOpBlockMotion(op->getNextNode())) break;
+    op->moveAfter(op->getNextNode());
+  }
+}
+
+// Optimizes the load/store ops for each given bucket.
+// Returns true if any op was removed.
+static bool optimizeBuckets(
+    Block &block, std::map<StringRef, SmallVector<Operation *>> &buckets) {
+  bool didRemoveAny = false;
+  for (auto &bucket : buckets) {
+    // First perform basic load-store forwarding and such.
+    auto &ops = bucket.second;
+    for (int i = ops.size() - 1; i >= 1; --i) {
+      auto previous = ops[i - 1];
+      auto current = ops[i];
+      if (isa<IREE::Flow::VariableStoreOp>(previous) &&
+          isa<IREE::Flow::VariableLoadOp>(current)) {
+        // RAW - forward the stored variable to the following use.
+        auto storedValue = previous->getOperand(0);
+        LLVM_DEBUG({
+          llvm::dbgs() << "RAW: replacing load with previous store value:\n";
+          current->dump();
+          llvm::dbgs() << "->\n";
+          storedValue.dump();
+        });
+        current->replaceAllUsesWith(ValueRange{storedValue});
+        ops.erase(ops.begin() + i);
+        current->erase();
+        didRemoveAny = true;
+      } else if (isa<IREE::Flow::VariableLoadOp>(previous) &&
+                 isa<IREE::Flow::VariableLoadOp>(current)) {
+        // RAR - forward the loaded variable to the following use.
+        LLVM_DEBUG({
+          llvm::dbgs() << "RAR: replacing subsequent load with op:\n";
+          current->dump();
+          llvm::dbgs() << "->\n";
+          previous->dump();
+        });
+        current->replaceAllUsesWith(previous);
+        ops.erase(ops.begin() + i);
+        current->erase();
+        didRemoveAny = true;
+      } else if (isa<IREE::Flow::VariableStoreOp>(previous) &&
+                 isa<IREE::Flow::VariableStoreOp>(current)) {
+        // WAW - remove the first store.
+        LLVM_DEBUG({
+          llvm::dbgs() << "WAW: erasing source op:\n";
+          previous->dump();
+          llvm::dbgs() << "\nand keeping subsequent op:\n";
+          current->dump();
+        });
+        ops.erase(ops.begin() + i - 1);
+        previous->erase();
+        didRemoveAny = true;
+      }
+    }
+    assert(!ops.empty());
+
+    if (auto loadOp = dyn_cast<IREE::Flow::VariableLoadOp>(ops.front())) {
+      // If the head op is a load we can move that to the top of the block.
+      LLVM_DEBUG(llvm::dbgs() << "moving mutable variable " << loadOp.variable()
+                              << " load upward\n");
+      moveOpUpInBlock(block, ops.front());
+    }
+    if (auto storeOp = dyn_cast<IREE::Flow::VariableStoreOp>(ops.back())) {
+      // If the tail op is a store we can move that to the bottom of the block.
+      LLVM_DEBUG(llvm::dbgs() << "moving mutable variable "
+                              << storeOp.variable() << " store downward\n");
+      moveOpDownInBlock(block, ops.back());
+    }
+  }
+  return didRemoveAny;
+}
+
+// Hoists loads and sinks stores to the boundary of |block| when safe.
+// |immutableVariables| is used for lookups of which variables are immutable.
+//
+// Basic algorithm (repeat until no op removals):
+//   for each op:
+//     if immutable: skip
+//     add to load/store buckets (sorted vector)
+//   for each bucket (symbol):
+//     walk ops in reverse:
+//       if (prev == store && this == load)  // RAW
+//         replace load with store source
+//       if (prev == load && this == load)  // RAR
+//         replace with first load
+//       if (prev == store && this == store) // WAW
+//         remove first store
+//     if (head == load) move load to front
+//     if (tail == store) move store to back
+//
+// Returns true if there were any removals and the block should be reprocessed.
+static bool rearrangeBlockVariableAccesses(
+    Block &block, DenseSet<StringRef> &immutableVariables) {
+  // Produce [symbol_name, [op, op, op, ...]] buckets.
+  // NOTE: we use a map here so that we are deterministically ordered. This may
+  // not be needed but the variable count is low and it's nice to not care about
+  // op order issues.
+  //
+  // Because there may be ops that we can't optimize across (calls/etc) we
+  // handle flushing buckets on demand.
+  std::map<StringRef, SmallVector<Operation *>> buckets;
+  bool didRemoveAny = false;
+  for (auto &op : block) {
+    if (auto loadOp = dyn_cast<IREE::Flow::VariableLoadOp>(op)) {
+      if (immutableVariables.contains(loadOp.variable())) continue;
+      buckets[loadOp.variable()].push_back(&op);
+    } else if (auto storeOp = dyn_cast<IREE::Flow::VariableStoreOp>(op)) {
+      buckets[storeOp.variable()].push_back(&op);
+    } else if (doesOpBlockMotion(&op)) {
+      // Split point - all accesses after this point must not assume anything
+      // about accesses before it.
+      didRemoveAny = optimizeBuckets(block, buckets) || didRemoveAny;
+      buckets.clear();
+    }
+  }
+  didRemoveAny = optimizeBuckets(block, buckets) || didRemoveAny;
+  return didRemoveAny;
+}
+
+namespace {
+
+class SimplifyVariableAccessesPass
+    : public SimplifyVariableAccessesBase<SimplifyVariableAccessesPass> {
+ public:
+  void runOnOperation() override {
+    auto funcOp = getOperation();
+    if (funcOp.empty()) return;
+
+    auto moduleOp = funcOp->getParentOfType<mlir::ModuleOp>();
+    assert(moduleOp && "func not in a module");
+
+    // Build a set of all immutable variables for fast lookup.
+    auto immutableVariables = gatherImmutableVariables(moduleOp);
+
+    // Hoist immutable variables first. These have no hazards and don't care
+    // about control flow - like `constant` - so getting them handled first
+    // avoids the need for us to do the full analysis.
+    hoistImmutableLoads(funcOp, immutableVariables);
+
+    // We can't optimize the function if there are indirect loads/stores.
+    // Note that constant loads are still ok above.
+    for (auto &block : funcOp) {
+      for (auto &op : block) {
+        if (isa<IREE::Flow::VariableLoadIndirectOp>(op) ||
+            isa<IREE::Flow::VariableStoreIndirectOp>(op)) {
+          LLVM_DEBUG(llvm::dbgs()
+                     << "bailing on variable access simplification: indirect "
+                        "accesses present in function\n");
+          return;
+        }
+      }
+    }
+
+    // For each block in the function hoist loads and sink stores.
+    // This does no cross-block movement, though it really should. Maybe when a
+    // real compiler engineer sees this they'll be inspired to do this properly.
+    for (auto &block : funcOp) {
+      LLVM_DEBUG(llvm::dbgs() << "==== REARRANGING BLOCK ACCESSES ====\n");
+      while (rearrangeBlockVariableAccesses(block, immutableVariables)) {
+        // NOTE: block is processed until no more ops are removed. Will always
+        // end in a fixed amount of time as ops are only removed from the block.
+      }
+    }
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<OperationPass<FuncOp>> createSimplifyVariableAccessesPass() {
+  return std::make_unique<SimplifyVariableAccessesPass>();
+}
+
+}  // namespace Flow
+}  // namespace IREE
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/BUILD b/iree/compiler/Dialect/Flow/Transforms/test/BUILD
index b8e2e4f..b16ee25 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/BUILD
+++ b/iree/compiler/Dialect/Flow/Transforms/test/BUILD
@@ -38,6 +38,7 @@
             "pad_tensor_to_tensor.mlir",
             "promote_i1_to_i8.mlir",
             "promote_tensor_loads.mlir",
+            "simplify_variable_accesses.mlir",
             "strip_and_splat_constant_variables.mlir",
             "transformation.mlir",
         ],
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
index 1d3a4d1..01a4294 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
@@ -35,6 +35,7 @@
     "pad_tensor_to_tensor.mlir"
     "promote_i1_to_i8.mlir"
     "promote_tensor_loads.mlir"
+    "simplify_variable_accesses.mlir"
     "strip_and_splat_constant_variables.mlir"
     "transformation.mlir"
   DATA
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index 5e30bed..3a7999d 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -910,7 +910,7 @@
 // CHECK-SAME:             {__internal_linalg_transform__ = "workgroup"}
 // CHECK-SAME:             ins(%[[UPDATE_TILE]], %[[INDICES_TILE]] : tensor<?x?xf32>, tensor<?x1xi32>)
 // CHECK-SAME:             outs(%[[ORIGINAL]] : tensor<?x?xf32>)
-//      CHECK:         flow.dispatch.tensor.store %[[RESULT_TILE]], %[[ARG5]], offsets = [0, 0]
+//      CHECK:         flow.dispatch.tensor.store %[[RESULT_TILE]], %[[ARG5]], offsets = [], sizes = [], strides = []
 //      CHECK:   return %[[RESULT]] : tensor<?x?xf32>
 
 // -----
@@ -1007,3 +1007,33 @@
 //      CHECK:     flow.return
 //      CHECK:   }
 //      CHECK:   return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func @scatter_static(%arg0 : tensor<4xi32>, %arg1 : tensor<4x1xi32>, %arg2 : tensor<8xi32>)
+    -> tensor<8xi32>{
+  %cst = constant dense<[0, 9, 0, 10, 11, 0, 0, 12]> : tensor<8xi32>
+  %cst_0 = constant dense<[9, 10, 11, 12]> : tensor<4xi32>
+  %cst_1 = constant dense<[[1], [3], [4], [7]]> : tensor<4x1xi32>
+  %cst_2 = constant dense<0> : tensor<8xi32>
+  %0 = linalg_ext.scatter
+      ins(%arg0, %arg1 : tensor<4xi32>, tensor<4x1xi32>)
+      outs(%arg2 : tensor<8xi32>)  {
+    ^bb0(%arg3: i32, %arg4: i32):  // no predecessors
+      linalg_ext.yield %arg3 : i32
+    } -> tensor<8xi32>
+  return %0 : tensor<8xi32>
+}
+//      CHECK: func @scatter_static
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4xi32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<4x1xi32>
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<8xi32>
+//      CHECK:   %[[RESULT:.+]] = flow.dispatch.workgroups
+// CHECK-NEXT:     %[[ARG3:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:4xi32>
+// CHECK-SAME:     %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:4x1xi32>
+// CHECK-SAME:     %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readwrite:8xi32>
+//      CHECK:     scf.for %[[IV:.+]] = %{{.+}} to %{{.+}} step %{{.+}} {
+//      CHECK:       %[[SCATTER_TILE:.+]] = linalg_ext.scatter
+//      CHECK:       flow.dispatch.tensor.store %[[SCATTER_TILE]], %[[ARG5]], offsets = [], sizes = [], strides = []
+// CHECK-NEXT:     }
+//      CHECK:  return %[[RESULT]]
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/simplify_variable_accesses.mlir b/iree/compiler/Dialect/Flow/Transforms/test/simplify_variable_accesses.mlir
new file mode 100644
index 0000000..927d193
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/test/simplify_variable_accesses.mlir
@@ -0,0 +1,136 @@
+// RUN: iree-opt -split-input-file -iree-flow-simplify-variable-accesses %s | IreeFileCheck %s
+
+flow.variable @varA dense<1> : tensor<2xi32> attributes {sym_visibility = "private"}
+flow.variable @varB dense<3> : tensor<2x4xi32> attributes {sym_visibility = "private"}
+
+// CHECK-LABEL: @constants()
+func @constants() {
+  // CHECK-NEXT: %[[VAR_A:.+]] = flow.variable.load @varA : tensor<2xi32>
+  // CHECK-NEXT: %[[VAR_B:.+]] = flow.variable.load @varB : tensor<2x4xi32>
+  // CHECK-NEXT: constant 10
+  %w = constant 10 : index
+  %varA = flow.variable.load @varA : tensor<2xi32>
+  // CHECK-NEXT: %[[T:.+]] = flow.dispatch @ex::@dispatch0{{.+}}(%[[VAR_A]])
+  %d0 = flow.dispatch @ex::@dispatch0[%w](%varA) : (tensor<2xi32>) -> tensor<2xi32>
+  %varB = flow.variable.load @varB : tensor<2x4xi32>
+  // CHECK-NEXT: flow.dispatch @ex::@dispatch1{{.+}}(%[[T]], %[[VAR_B]])
+  %d1 = flow.dispatch @ex::@dispatch1[%w](%d0, %varB) : (tensor<2xi32>, tensor<2x4xi32>) -> tensor<2xi32>
+  return
+}
+
+// -----
+
+flow.variable @varA 1 : i32 attributes {sym_visibility = "private"}
+flow.variable @varB 2 : i32 attributes {sym_visibility = "private"}
+
+// CHECK-LABEL: @constants_in_cfg
+func @constants_in_cfg(%start: i32, %bound: i32) -> i32 {
+  // CHECK-NEXT: %[[VAR_A:.+]] = flow.variable.load @varA : i32
+  // CHECK-NEXT: %[[VAR_B:.+]] = flow.variable.load @varB : i32
+  // CHECK-NEXT: br ^bb1
+  br ^bb1(%start : i32)
+// CHECK: ^bb1(%[[BB1_ARG:.+]]: i32):
+^bb1(%2: i32):
+  %cmp = cmpi slt, %2, %bound : i32
+  cond_br %cmp, ^bb2(%2 : i32), ^bb3(%2 : i32)
+// CHECK: ^bb2(%[[BB2_ARG:.+]]: i32):
+^bb2(%5: i32):
+  %6 = flow.variable.load @varA : i32
+  // CHECK-NEXT: = addi %[[BB2_ARG]], %[[VAR_A]] : i32
+  %7 = addi %5, %6 : i32
+  br ^bb1(%7 : i32)
+// CHECK: ^bb3(%[[BB3_ARG:.+]]: i32):
+^bb3(%8: i32):
+  %9 = flow.variable.load @varA : i32
+  // CHECK-NEXT: %[[T0:.+]] = muli %[[BB3_ARG]], %[[VAR_A]] : i32
+  %10 = muli %8, %9 : i32
+  %11 = flow.variable.load @varB : i32
+  // CHECK-NEXT: %[[T1:.+]] = subi %[[T0]], %[[VAR_B]]
+  %12 = subi %10, %11 : i32
+  // CHECK-NEXT: return %[[T1]]
+  return %12 : i32
+}
+
+// -----
+
+flow.variable @varA mutable dense<1> : tensor<2xi32> attributes {sym_visibility = "private"}
+flow.variable @varB dense<3> : tensor<2x4xi32> attributes {sym_visibility = "private"}
+
+// CHECK-LABEL: @mixed_mutability
+func @mixed_mutability() {
+  // CHECK-DAG: %[[VAR_A:.+]] = flow.variable.load @varA : tensor<2xi32>
+  // CHECK-DAG: %[[VAR_B:.+]] = flow.variable.load @varB : tensor<2x4xi32>
+  // CHECK-NEXT: constant 10
+  %w = constant 10 : index
+  %varA = flow.variable.load @varA : tensor<2xi32>
+  // CHECK-NEXT: %[[T0:.+]] = flow.dispatch @ex::@dispatch0{{.+}}(%[[VAR_A]])
+  %d0 = flow.dispatch @ex::@dispatch0[%w](%varA) : (tensor<2xi32>) -> tensor<2xi32>
+  %varB = flow.variable.load @varB : tensor<2x4xi32>
+  // CHECK-NEXT: %[[T1:.+]] = flow.dispatch @ex::@dispatch1{{.+}}(%[[T0]], %[[VAR_B]])
+  %d1 = flow.dispatch @ex::@dispatch1[%w](%d0, %varB) : (tensor<2xi32>, tensor<2x4xi32>) -> tensor<2xi32>
+  // CHECK-NEXT: flow.variable.store %[[T1]], @varA : tensor<2xi32>
+  flow.variable.store %d1, @varA : tensor<2xi32>
+  return
+}
+
+// -----
+
+flow.variable @varA mutable dense<1> : tensor<2xi32> attributes {sym_visibility = "private"}
+
+// CHECK-LABEL: @raw
+func @raw() {
+  // CHECK: %[[T:.+]] = flow.variable.load @varA {id = 0
+  %varA_0 = flow.variable.load @varA {id = 0} : tensor<2xi32>
+  flow.variable.store %varA_0, @varA {id = 0} : tensor<2xi32>
+  %varA_1 = flow.variable.load @varA {id = 1} : tensor<2xi32>
+  // CHECK-NEXT: flow.variable.store %[[T]], @varA {id = 1
+  flow.variable.store %varA_1, @varA {id = 1} : tensor<2xi32>
+  return
+}
+
+// -----
+
+flow.variable @varA mutable dense<1> : tensor<2xi32> attributes {sym_visibility = "private"}
+
+// CHECK-LABEL: @rar
+func @rar() -> (tensor<2xi32>, tensor<2xi32>) {
+  // CHECK: %[[T:.+]] = flow.variable.load @varA {id = 0
+  %varA_0 = flow.variable.load @varA {id = 0} : tensor<2xi32>
+  %varA_1 = flow.variable.load @varA {id = 1} : tensor<2xi32>
+  // CHECK-NEXT: return %[[T]], %[[T]]
+  return %varA_0, %varA_1 : tensor<2xi32>, tensor<2xi32>
+}
+
+// -----
+
+flow.variable @varA mutable dense<1> : tensor<2xi32> attributes {sym_visibility = "private"}
+
+// CHECK-LABEL: @waw
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<2xi32>, %[[ARG1:.+]]: tensor<2xi32>)
+func @waw(%varA_0: tensor<2xi32>, %varA_1: tensor<2xi32>) {
+  flow.variable.store %varA_0, @varA : tensor<2xi32>
+  // CHECK-NEXT: flow.variable.store %[[ARG1]], @varA
+  flow.variable.store %varA_1, @varA : tensor<2xi32>
+  return
+}
+
+// -----
+
+flow.variable @varA mutable dense<1> : tensor<2xi32> attributes {sym_visibility = "private"}
+
+// CHECK-LABEL: @side_effects(
+func @side_effects() {
+  // CHECK-NEXT: %[[T0:.+]] = flow.variable.load @varA
+  %varA_0 = flow.variable.load @varA : tensor<2xi32>
+  // CHECK-NEXT: flow.variable.store %[[T0]], @varA
+  flow.variable.store %varA_0, @varA : tensor<2xi32>
+  // CHECK-NEXT: call @other_fn()
+  call @other_fn() : () -> ()
+  // CHECK-NEXT: %[[T1:.+]] = flow.variable.load @varA
+  %varA_1 = flow.variable.load @varA : tensor<2xi32>
+  // CHECK-NEXT: flow.variable.store %[[T1]], @varA
+  flow.variable.store %varA_1, @varA : tensor<2xi32>
+  return
+}
+
+func private @other_fn()
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/BUILD b/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
index 01c1c13..3359b98 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
@@ -41,12 +41,14 @@
         "//iree/compiler/Codegen/LLVMCPU",
         "//iree/compiler/Codegen/Utils",
         "//iree/compiler/Dialect/HAL/Target",
+        "//iree/compiler/Dialect/HAL/Target/LLVM/librt",
         "//iree/compiler/Utils",
         "//iree/schemas:dylib_executable_def_c_fbs",
         "@llvm-project//llvm:AArch64AsmParser",
         "@llvm-project//llvm:AArch64CodeGen",
         "@llvm-project//llvm:ARMAsmParser",
         "@llvm-project//llvm:ARMCodeGen",
+        "@llvm-project//llvm:BitReader",
         "@llvm-project//llvm:BitWriter",
         "@llvm-project//llvm:Core",
         "@llvm-project//llvm:RISCVAsmParser",
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
index f12301b..29c9514 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
@@ -32,6 +32,7 @@
     LLVMAArch64CodeGen
     LLVMARMAsmParser
     LLVMARMCodeGen
+    LLVMBitReader
     LLVMBitWriter
     LLVMCore
     LLVMRISCVAsmParser
@@ -50,6 +51,7 @@
     iree::compiler::Codegen::PassHeaders
     iree::compiler::Codegen::Utils
     iree::compiler::Dialect::HAL::Target
+    iree::compiler::Dialect::HAL::Target::LLVM::librt
     iree::compiler::Utils
     iree::schemas::dylib_executable_def_c_fbs
   PUBLIC
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
index f80e47c..fdefab2 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
@@ -13,9 +13,11 @@
 #include "iree/compiler/Dialect/HAL/Target/LLVM/LibraryBuilder.h"
 #include "iree/compiler/Dialect/HAL/Target/LLVM/LinkerTool.h"
 #include "iree/compiler/Dialect/HAL/Target/LLVM/StaticLibraryGenerator.h"
+#include "iree/compiler/Dialect/HAL/Target/LLVM/librt/librt.h"
 #include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
 #include "iree/compiler/Utils/FlatbufferUtils.h"
 #include "iree/schemas/dylib_executable_def_builder.h"
+#include "llvm/Bitcode/BitcodeReader.h"
 #include "llvm/Bitcode/BitcodeWriter.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
@@ -30,11 +32,10 @@
 namespace IREE {
 namespace HAL {
 
-namespace {
+static constexpr char kQueryFunctionName[] =
+    "iree_hal_executable_library_query";
 
-constexpr char kQueryFunctionName[] = "iree_hal_executable_library_query";
-
-llvm::Optional<FileLineColLoc> findFirstFileLoc(Location baseLoc) {
+static llvm::Optional<FileLineColLoc> findFirstFileLoc(Location baseLoc) {
   if (auto loc = baseLoc.dyn_cast<FusedLoc>()) {
     for (auto &childLoc : loc.getLocations()) {
       auto childResult = findFirstFileLoc(childLoc);
@@ -46,7 +47,7 @@
   return llvm::None;
 }
 
-std::string guessModuleName(mlir::ModuleOp moduleOp) {
+static std::string guessModuleName(mlir::ModuleOp moduleOp) {
   std::string moduleName =
       moduleOp.getName().hasValue() ? moduleOp.getName().getValue().str() : "";
   if (!moduleName.empty()) return moduleName;
@@ -58,8 +59,6 @@
   }
 }
 
-}  // namespace
-
 class LLVMAOTTargetBackend final : public TargetBackend {
  public:
   explicit LLVMAOTTargetBackend(LLVMTargetOptions options)
@@ -284,8 +283,10 @@
              << options_.targetTriple << "'";
     }
 
-    // Emit object files.
-    SmallVector<Artifact, 4> objectFiles;
+    SmallVector<Artifact> objectFiles;
+
+    // Emit the base object file containing the bulk of our code.
+    // This must come first such that we have the proper library linking order.
     {
       // NOTE: today we just use a single object file, however if we wanted to
       // scale code generation and linking we'd want to generate one per
@@ -306,6 +307,16 @@
       objectFiles.push_back(std::move(objectFile));
     }
 
+    // Optionally append additional object files that provide functionality that
+    // may otherwise have been runtime-dynamic (like libc/libm calls).
+    // For now we only do this for embedded uses.
+    if (options_.linkEmbedded) {
+      if (failed(buildLibraryObjects(variantOp.getLoc(), targetMachine.get(),
+                                     objectFiles, context))) {
+        return variantOp.emitError() << "failed generating library objects";
+      }
+    }
+
     // If we are keeping artifacts then let's also add the bitcode for easier
     // debugging (vs just the binary object file).
     if (options_.keepLinkerArtifacts) {
@@ -329,7 +340,6 @@
       // Copy the static object file to the specified output along with
       // generated header file.
       const std::string &libraryPath = options_.staticLibraryOutput;
-      const auto library_name = objectFiles[0].path;
       if (!outputStaticLibrary(libraryName, queryFunctionName, libraryPath,
                                objectFiles[0].path)) {
         return variantOp.emitError() << "static library generation failed";
@@ -506,6 +516,76 @@
     return IREE::HAL::ExecutableTargetAttr::get(context, "llvm", format);
   }
 
+  static void overridePlatformGlobal(llvm::Module &module, StringRef globalName,
+                                     uint32_t newValue) {
+    // NOTE: the global will not be defined if it is not used in the module.
+    auto *globalValue = module.getNamedGlobal(globalName);
+    if (!globalValue) return;
+    globalValue->setLinkage(llvm::GlobalValue::PrivateLinkage);
+    globalValue->setDSOLocal(true);
+    globalValue->setConstant(true);
+    globalValue->setInitializer(llvm::ConstantInt::get(
+        globalValue->getValueType(), APInt(32, newValue)));
+  }
+
+  // Builds an object file for the librt embedded runtime library.
+  // This is done per link operation so that we can match the precise target
+  // configuration. Since we (mostly) link once per user-level compilation
+  // this is fine today. If in the future we invoke the compiler for thousands
+  // of modules we'd want to (carefully) cache this.
+  LogicalResult buildLibraryObjects(Location loc,
+                                    llvm::TargetMachine *targetMachine,
+                                    SmallVector<Artifact> &objectFiles,
+                                    llvm::LLVMContext &context) {
+    assert(!objectFiles.empty() && "libraries must come after the base object");
+
+    // Load the generic bitcode file contents.
+    llvm::MemoryBufferRef bitcodeBufferRef(
+        llvm::StringRef(iree_compiler_librt_create()->data,
+                        iree_compiler_librt_create()->size),
+        "librt.bc");
+    auto bitcodeModuleValue = llvm::parseBitcodeFile(bitcodeBufferRef, context);
+    if (!bitcodeModuleValue) {
+      return mlir::emitError(loc)
+             << "failed to parse librt bitcode: "
+             << llvm::toString(bitcodeModuleValue.takeError());
+    }
+    auto bitcodeModule = std::move(bitcodeModuleValue.get());
+    bitcodeModule->setDataLayout(targetMachine->createDataLayout());
+    bitcodeModule->setTargetTriple(targetMachine->getTargetTriple().str());
+
+    // Inject target-specific flags.
+    // TODO(benvanik): move this entire function to another file that can do
+    // more complex logic cleanly. This is just an example.
+    overridePlatformGlobal(*bitcodeModule, "librt_platform_example_flag", 0u);
+
+    // Run the LLVM passes to optimize it for the current target.
+    if (failed(runLLVMIRPasses(options_, targetMachine, bitcodeModule.get()))) {
+      return mlir::emitError(loc)
+             << "failed to run librt LLVM-IR opt passes targeting '"
+             << options_.targetTriple << "'";
+    }
+
+    // Emit an object file we can pass to the linker.
+    std::string objectData;
+    if (failed(runEmitObjFilePasses(targetMachine, bitcodeModule.get(),
+                                    &objectData))) {
+      return mlir::emitError(loc)
+             << "failed to compile librt LLVM-IR module to an object file";
+    }
+
+    // Write the object file to disk with a similar name to the base file.
+    auto objectFile =
+        Artifact::createVariant(objectFiles.front().path, ".librt.o");
+    auto &os = objectFile.outputFile->os();
+    os << objectData;
+    os.flush();
+    os.close();
+    objectFiles.push_back(std::move(objectFile));
+
+    return success();
+  }
+
   LLVMTargetOptions options_;
 };
 
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/internal/EmbeddedLinkerTool.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/internal/EmbeddedLinkerTool.cpp
index f2847c2..cda219e 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/internal/EmbeddedLinkerTool.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/internal/EmbeddedLinkerTool.cpp
@@ -80,7 +80,7 @@
 
     // Create the shared object name; if we only have a single input object we
     // can just reuse that.
-    if (objectFiles.size() == 1) {
+    if (!objectFiles.empty()) {
       artifacts.libraryFile =
           Artifact::createVariant(objectFiles.front().path, "so");
     } else {
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/librt/BUILD b/iree/compiler/Dialect/HAL/Target/LLVM/librt/BUILD
new file mode 100644
index 0000000..f9f6feb
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/librt/BUILD
@@ -0,0 +1,31 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/embed_data:build_defs.bzl", "c_embed_data")
+load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content")
+
+package(
+    default_visibility = ["//visibility:public"],
+    features = ["layering_check"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+iree_cmake_extra_content(
+    content = """
+if(NOT "${IREE_TARGET_BACKEND_DYLIB-LLVM-AOT}" AND NOT "${IREE_TARGET_BACKEND_WASM-LLVM-AOT}")
+  return()
+endif()
+""",
+)
+
+c_embed_data(
+    name = "librt",
+    srcs = ["bin/librt.bc"],
+    c_file_output = "librt.c",
+    flatten = True,
+    h_file_output = "librt.h",
+    identifier = "iree_compiler_librt",
+)
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/librt/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/LLVM/librt/CMakeLists.txt
new file mode 100644
index 0000000..c6dabe0
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/librt/CMakeLists.txt
@@ -0,0 +1,32 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from           #
+# iree/compiler/Dialect/HAL/Target/LLVM/librt/BUILD                            #
+#                                                                              #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary   #
+# CMake-only content.                                                          #
+#                                                                              #
+# To disable autogeneration for this file entirely, delete this header.        #
+################################################################################
+
+if(NOT "${IREE_TARGET_BACKEND_DYLIB-LLVM-AOT}" AND NOT "${IREE_TARGET_BACKEND_WASM-LLVM-AOT}")
+  return()
+endif()
+
+iree_add_all_subdirs()
+
+iree_c_embed_data(
+  NAME
+    librt
+  SRCS
+    "bin/librt.bc"
+  C_FILE_OUTPUT
+    "librt.c"
+  H_FILE_OUTPUT
+    "librt.h"
+  IDENTIFIER
+    "iree_compiler_librt"
+  FLATTEN
+  PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/librt/bin/librt.bc b/iree/compiler/Dialect/HAL/Target/LLVM/librt/bin/librt.bc
new file mode 100644
index 0000000..a484f31
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/librt/bin/librt.bc
Binary files differ
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/librt/build.sh b/iree/compiler/Dialect/HAL/Target/LLVM/librt/build.sh
new file mode 100644
index 0000000..c671e4d
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/librt/build.sh
@@ -0,0 +1,45 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+set -e
+
+SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
+OUT="${SCRIPT_DIR}/bin"
+SRC="${SCRIPT_DIR}/src"
+LL_FILE="${OUT}/librt.ll"
+BC_FILE="${OUT}/librt.bc"
+
+# Generate an LLVM IR assembly listing so we can easily read the file.
+# This is not checked in or used by the compiler.
+clang \
+    -target wasm32 \
+    -std=c17 \
+    -O2 \
+    -Xclang -disable-llvm-passes \
+    -fno-ident \
+    -fvisibility=hidden \
+    -nostdinc \
+    -g0 \
+    -S \
+    -emit-llvm \
+    -fno-verbose-asm \
+    -fdiscard-value-names \
+    -o "${LL_FILE}" \
+    -c \
+    "${SRC}/libm.c"
+
+# Clang adds a bunch of bad attributes and host-specific information that we
+# don't want (so we get at least somewhat deterministic builds).
+sed -i 's/^;.*$//' ${LL_FILE}
+sed -i 's/^source_filename.*$//' ${LL_FILE}
+sed -i 's/^target datalayout.*$//' ${LL_FILE}
+sed -i 's/^target triple.*$//' ${LL_FILE}
+sed -i 's/^\(attributes #[0-9]* = {\).*$/\1 inlinehint }/' ${LL_FILE}
+
+# Generate a binary bitcode file embedded into the compiler binary.
+# NOTE: we do this from stdin so that the filename on the user's system is not
+# embedded in the bitcode file (making it non-deterministic).
+cat ${LL_FILE} | llvm-as -o=${BC_FILE}
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/librt/src/libm.c b/iree/compiler/Dialect/HAL/Target/LLVM/librt/src/libm.c
new file mode 100644
index 0000000..bdbe0ae
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/librt/src/libm.c
@@ -0,0 +1,13 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "libm.h"
+
+// https://en.cppreference.com/w/c/numeric/math/fma
+LIBRT_EXPORT float fmaf(float x, float y, float z) {
+  // TODO(*): a real implementation :)
+  return (x * y) + z;
+}
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/librt/src/libm.h b/iree/compiler/Dialect/HAL/Target/LLVM/librt/src/libm.h
new file mode 100644
index 0000000..10dfe76
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/librt/src/libm.h
@@ -0,0 +1,15 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_HAL_TARGET_LLVM_LIBRT_SRC_LIBM_H_
+#define IREE_COMPILER_DIALECT_HAL_TARGET_LLVM_LIBRT_SRC_LIBM_H_
+
+#include "librt.h"
+
+// https://en.cppreference.com/w/c/numeric/math/fma
+LIBRT_EXPORT float fmaf(float x, float y, float z);
+
+#endif  // IREE_COMPILER_DIALECT_HAL_TARGET_LLVM_LIBRT_SRC_LIBM_H_
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/librt/src/librt.h b/iree/compiler/Dialect/HAL/Target/LLVM/librt/src/librt.h
new file mode 100644
index 0000000..0af8b21
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/librt/src/librt.h
@@ -0,0 +1,75 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+//===----------------------------------------------------------------------===//
+// A simplified libc/libm-alike that is designed to compile to portable LLVM IR.
+//===----------------------------------------------------------------------===//
+// This library is focused on supporting the subset of LLVM's RuntimeLibcalls
+// that we need in our embedded executable binaries. This means that things like
+// printf, malloc, etc are excluded.
+//
+// See the full list of possible functions here:
+// third_party/llvm-project/llvm/include/llvm/IR/RuntimeLibcalls.def
+//
+// Code here must not use any system headers - as almost all pull in bits/ and
+// various other target-dependent definitions that make the resulting IR
+// non-portable. This means there is no size_t, etc. Any definitions that may
+// come from an std* file must be redefined here with care.
+//
+// Code must also not use any mutable global or thread-local state ala
+// errno/rounding modes/etc. Each of the functions in the library will be called
+// concurrently from multiple threads and from multiple source modules. There
+// must be no mutable static values anywhere.
+//
+// Avoid #ifdef entirely: they indicate a leakage of host build configuration
+// into what is supposed to be a portable module. Anything that requires
+// target-specific conditional logic must be implemented via an extern that
+// can be substituted by the IREE compiler when producing the final
+// target-specific module.
+
+//===----------------------------------------------------------------------===//
+// Attributes and metadata
+//===----------------------------------------------------------------------===//
+
+// Tagged on functions that are part of the public API.
+#define LIBRT_EXPORT __attribute__((visibility("hidden")))
+
+//===----------------------------------------------------------------------===//
+// stdint.h
+//===----------------------------------------------------------------------===//
+// https://pubs.opengroup.org/onlinepubs/009604599/basedefs/stdint.h.html
+// NOTE: no size_t/ptrdiff_t/etc (as they are target dependent).
+
+typedef signed char int8_t;
+typedef short int16_t;
+typedef int int32_t;
+typedef long long int64_t;
+typedef unsigned char uint8_t;
+typedef unsigned short uint16_t;
+typedef unsigned int uint32_t;
+typedef unsigned long long uint64_t;
+
+#define INT8_MIN (-127i8 - 1)
+#define INT16_MIN (-32767i16 - 1)
+#define INT32_MIN (-2147483647i32 - 1)
+#define INT64_MIN (-9223372036854775807i64 - 1)
+#define INT8_MAX 127i8
+#define INT16_MAX 32767i16
+#define INT32_MAX 2147483647i32
+#define INT64_MAX 9223372036854775807i64
+#define UINT8_MAX 0xffui8
+#define UINT16_MAX 0xffffui16
+#define UINT32_MAX 0xffffffffui32
+#define UINT64_MAX 0xffffffffffffffffui64
+
+//===----------------------------------------------------------------------===//
+// Target-specific queries
+//===----------------------------------------------------------------------===//
+// These are substituted with values from the compiler and must not be specified
+// here in C before we generate the IR.
+
+// Do not use: here as an example. Remove once we have any other flag.
+extern int librt_platform_example_flag;
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index c7d61c8..3e47a35 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -203,11 +203,11 @@
   return {Range{zero, ub, one}};
 }
 
-Operation *ScatterOp::getTiledImplementation(
-    OpBuilder &builder, ValueRange outputs, ArrayRef<OpFoldResult> offsets,
-    ArrayRef<OpFoldResult> sizes,
-    SmallVectorImpl<SmallVector<OpFoldResult, 4>> &resultOffsets,
-    SmallVectorImpl<SmallVector<OpFoldResult, 4>> &resultSizes) {
+Operation *ScatterOp::getTiledImplementation(OpBuilder &builder,
+                                             ValueRange outputs,
+                                             ArrayRef<OpFoldResult> offsets,
+                                             ArrayRef<OpFoldResult> sizes,
+                                             SmallVectorImpl<Value> &results) {
   assert(outputs.size() == 1 && offsets.size() == 1 && sizes.size() == 1);
   Location loc = getLoc();
   auto zeroAttr = builder.getI64IntegerAttr(0);
@@ -241,20 +241,18 @@
                                 indicesSizes, indicesStrides);
   assert(tiledIndices && "failed to get slice of indices");
 
-  resultOffsets.resize(1);
-  resultOffsets[0].resize(updateRank, zeroAttr);
-  resultSizes.resize(1);
-  resultSizes[0].resize(updateRank);
-  for (auto dim : llvm::seq<int64_t>(0, updateRank)) {
-    resultSizes[0][dim] = getDim(builder, loc, original(), dim);
-  }
   SmallVector<Type> resultTypes;
   if (getNumResults()) {
     resultTypes.push_back(getResultTypes()[0]);
   }
-  return cast<LinalgExtOp>(getOperation())
-      .clone(builder, loc, resultTypes,
-             ValueRange{tiledUpdate, tiledIndices, outputs[0]});
+  Operation *tiledScatterOp =
+      cast<LinalgExtOp>(getOperation())
+          .clone(builder, loc, resultTypes,
+                 ValueRange{tiledUpdate, tiledIndices, outputs[0]});
+  if (getNumResults()) {
+    results.push_back(tiledScatterOp->getResult(0));
+  }
+  return tiledScatterOp;
 }
 
 void ScatterOp::generateScalarUpdateLoopBody(OpBuilder &b, Location loc,
@@ -466,11 +464,11 @@
   return partitionableLoops;
 }
 
-Operation *SortOp::getTiledImplementation(
-    OpBuilder &builder, ValueRange outputs, ArrayRef<OpFoldResult> offsets,
-    ArrayRef<OpFoldResult> sizes,
-    SmallVectorImpl<SmallVector<OpFoldResult, 4>> &resultOffsets,
-    SmallVectorImpl<SmallVector<OpFoldResult, 4>> &resultSizes) {
+Operation *SortOp::getTiledImplementation(OpBuilder &builder,
+                                          ValueRange outputs,
+                                          ArrayRef<OpFoldResult> offsets,
+                                          ArrayRef<OpFoldResult> sizes,
+                                          SmallVectorImpl<Value> &results) {
   assert(outputs.size() == this->outputs().size());
   int64_t rank = getOperandRank();
   assert(offsets.size() == static_cast<size_t>(rank) &&
@@ -479,22 +477,24 @@
   SmallVector<OpFoldResult> strides(rank, oneAttr);
   Location loc = getLoc();
   SmallVector<Value> tiledOperands(outputs.size());
-  resultOffsets.resize(outputs.size());
-  resultSizes.resize(outputs.size());
   for (auto en : llvm::enumerate(outputs)) {
     tiledOperands[en.index()] =
         getSlice(builder, getLoc(), en.value(), offsets, sizes, strides);
     assert(tiledOperands[en.index()] && "failed to get slice of operand");
-    resultOffsets[en.index()].assign(offsets.begin(), offsets.end());
-    resultSizes[en.index()].assign(sizes.begin(), sizes.end());
   }
   SmallVector<Type, 4> resultTypes;
   if (getNumResults()) {
     resultTypes = llvm::to_vector<4>(
         llvm::map_range(tiledOperands, [&](Value v) { return v.getType(); }));
   }
-  return cast<LinalgExtOp>(getOperation())
-      .clone(builder, loc, resultTypes, tiledOperands);
+  Operation *tiledSortOp = cast<LinalgExtOp>(getOperation())
+                               .clone(builder, loc, resultTypes, tiledOperands);
+  for (auto result : llvm::enumerate(tiledSortOp->getResults())) {
+    auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
+        loc, result.value(), outputs[result.index()], offsets, sizes, strides);
+    results.push_back(insertSliceOp.getResult());
+  }
+  return tiledSortOp;
 }
 
 LogicalResult SortOp::generateScalarImplementation(OpBuilder &b, Location loc,
diff --git a/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.td b/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.td
index 232c6aa..337b273 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.td
+++ b/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.td
@@ -85,6 +85,10 @@
         /*desc=*/[{
           Generates a tiled version of the operation given the tile
           size for the loops.
+
+          Returns the tiled operation generated. If the operation has
+          tensor semantics then the result of the tiled values are to
+          be inserted into the `outputs` and return in `results`.
         }],
         /*retType=*/"Operation *",
         /*methodName=*/"getTiledImplementation",
@@ -93,8 +97,7 @@
             "ValueRange ":$outputs,
             "ArrayRef<OpFoldResult> ":$offsets,
             "ArrayRef<OpFoldResult> ":$sizes,
-            "SmallVectorImpl<SmallVector<OpFoldResult, 4>> &":$resultOffsets,
-            "SmallVectorImpl<SmallVector<OpFoldResult, 4>> &":$resultSizes),
+            "SmallVectorImpl<Value> &":$results),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return nullptr;
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp b/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp
index 829c306..bf6f1be 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp
@@ -132,36 +132,13 @@
   // If this is the innermost loop, then generated the tiled implementation of
   // the op by invoking the TiledOpInterface methods.
   if (loopDepth == tileSizes.size()) {
-    SmallVector<SmallVector<OpFoldResult, 4>> resultOffsets;
-    SmallVector<SmallVector<OpFoldResult, 4>> resultSizes;
-    Operation *tiledOp = tilableOp.getTiledImplementation(
-        builder, outputs, offsets, tileSizes, resultOffsets, resultSizes);
-    if (!tiledOp) {
+    TiledOp ret;
+    ret.op = tilableOp.getTiledImplementation(builder, outputs, offsets,
+                                              tileSizes, ret.results);
+    if (!ret.op) {
       return static_cast<LogicalResult>(
           tilableOp.emitOpError("failed to get tiled implementation"));
     }
-    assert(tiledOp->getNumResults() == 0 ||
-           (resultOffsets.size() == tiledOp->getNumResults()));
-    TiledOp ret;
-    ret.op = tiledOp;
-
-    // If the operation has results, then the result of the tiled operation is
-    // to be inserted into the `initValues` and returned.
-    if (tiledOp->getNumResults()) {
-      SmallVector<Value> results;
-      auto oneAttr = builder.getI64IntegerAttr(1);
-      results.reserve(tiledOp->getNumResults());
-      for (auto en : llvm::enumerate(tiledOp->getResults())) {
-        Value result = en.value();
-        ArrayRef<OpFoldResult> offsets(resultOffsets[en.index()]);
-        ArrayRef<OpFoldResult> sizes(resultSizes[en.index()]);
-        SmallVector<OpFoldResult> strides(offsets.size(), oneAttr);
-        Value insert = builder.create<tensor::InsertSliceOp>(
-            loc, result, outputs[en.index()], offsets, sizes, strides);
-        results.push_back(insert);
-      }
-      std::swap(ret.results, results);
-    }
     return ret;
   }
 
@@ -318,15 +295,13 @@
     return loopBounds;
   }
 
-  Operation *getTiledImplementation(
-      Operation *op, OpBuilder &b, ValueRange outputs,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
-      SmallVectorImpl<SmallVector<OpFoldResult, 4>> &resultOffsets,
-      SmallVectorImpl<SmallVector<OpFoldResult, 4>> &resultSizes) const {
-    // Compute a subtensor of the source based on the offsets.
+  Operation *getTiledImplementation(Operation *op, OpBuilder &b,
+                                    ValueRange outputs,
+                                    ArrayRef<OpFoldResult> offsets,
+                                    ArrayRef<OpFoldResult> sizes,
+                                    SmallVectorImpl<Value> &results) const {
     auto insertOp = cast<tensor::InsertSliceOp>(op);
-    auto opOffsets = insertOp.getMixedOffsets();
-    auto opSizes = insertOp.getMixedSizes();
+    // Compute a subtensor of the source based on the offsets.
     auto opStrides = insertOp.getMixedStrides();
     if (!llvm::all_of(opStrides, [&](OpFoldResult valueOrAttr) {
           return isValue(valueOrAttr, 1);
@@ -334,23 +309,22 @@
       op->emitOpError("unable to tile operation with non-unit stride");
       return nullptr;
     }
-    // The operation returned is just a tensor.extract_slice of the source with
-    // the given offsets, sizes and strides. Setting the correct result offset
-    // will make the sure the tiling algorithm will insert this slice into the
-    // correct place in the destination.
-    // The result offset is just the offset passed in plus the offset specified
-    // in the op (since all strides are checked to be 1).
+    Location loc = insertOp.getLoc();
+    auto oneAttr = b.getI64IntegerAttr(1);
+    SmallVector<OpFoldResult> strides(offsets.size(), oneAttr);
+    auto extractSliceOp = b.create<tensor::ExtractSliceOp>(
+        loc, insertOp.source(), offsets, sizes, strides);
+
+    // The offsets for the insert is based on the op offsets plus the offsets of
+    // the loops passed in.
+    auto opOffsets = insertOp.getMixedOffsets();
+    auto opSizes = insertOp.getMixedSizes();
     unsigned offsetIndex = 0;
     ArrayRef<int64_t> sourceShape = insertOp.getSourceType().getShape();
     int64_t destRank = insertOp.getType().getRank();
-    resultOffsets.resize(1);
-    resultOffsets[0].resize(destRank);
-    resultSizes.resize(1);
-    resultSizes[0].resize(destRank);
-    Location loc = insertOp.getLoc();
+    SmallVector<OpFoldResult> resultOffsets(destRank);
+    SmallVector<OpFoldResult> resultSizes(destRank);
     auto zeroAttr = b.getI64IntegerAttr(0);
-    auto oneAttr = b.getI64IntegerAttr(1);
-    SmallVector<OpFoldResult> strides(offsets.size(), oneAttr);
     for (auto opOffset : llvm::enumerate(opOffsets)) {
       // Check for rank-reducing by checking that
       // 1) The corresponding opSize value is 1
@@ -358,29 +332,33 @@
       // Then the opOffset is for the rank-reduced dimension. Skip.
       unsigned opOffsetIndex = opOffset.index();
       if (isValue(opSizes[opOffsetIndex], 1) && sourceShape[offsetIndex] != 1) {
-        resultOffsets[0][opOffsetIndex] = zeroAttr;
-        resultSizes[0][opOffsetIndex] = oneAttr;
+        resultOffsets[opOffsetIndex] = zeroAttr;
+        resultSizes[opOffsetIndex] = oneAttr;
         continue;
       }
       OpFoldResult opOffsetVal = opOffset.value();
       OpFoldResult offset = offsets[offsetIndex];
       if (opOffsetVal.is<Attribute>() && offset.is<Attribute>()) {
-        resultOffsets[0][opOffsetIndex] = b.getI64IntegerAttr(
+        resultOffsets[opOffsetIndex] = b.getI64IntegerAttr(
             *getConstantValue(opOffsetVal) + *getConstantValue(offset));
       } else {
         AffineMap map = AffineMap::get(
             1, 1, {b.getAffineDimExpr(0) + b.getAffineSymbolExpr(0)});
-        resultOffsets[0][opOffsetIndex] =
+        resultOffsets[opOffsetIndex] =
             b.create<AffineApplyOp>(loc, map,
                                     ValueRange{getValue(b, loc, offset),
                                                getValue(b, loc, opOffsetVal)})
                 .getResult();
       }
-      resultSizes[0][opOffsetIndex] = sizes[offsetIndex];
+      resultSizes[opOffsetIndex] = sizes[offsetIndex];
       offsetIndex++;
     }
-    return b.create<tensor::ExtractSliceOp>(loc, insertOp.source(), offsets,
-                                            sizes, strides);
+    SmallVector<OpFoldResult> resultStrides(destRank, oneAttr);
+    auto tiledInsertOp = b.create<tensor::InsertSliceOp>(
+        loc, extractSliceOp.result(), outputs[0], resultOffsets, resultSizes,
+        resultStrides);
+    results.push_back(tiledInsertOp.result());
+    return extractSliceOp;
   }
 };
 }  // namespace
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
index 5818c01..41a1cd0 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
@@ -30,15 +30,11 @@
 //  CHECK-SAME:         [%[[USED_TILESIZE]], %[[D1]]]
 //       CHECK:     %[[INDEX_SLICE:.+]] = tensor.extract_slice %[[INDICES]][%[[IV]], 0]
 //  CHECK-SAME:         [%[[USED_TILESIZE]], 1]
-//   CHECK-DAG:     %[[SLICE_D0:.+]] = tensor.dim %[[ORIGINAL]], %[[C0]]
-//   CHECK-DAG:     %[[SLICE_D1:.+]] = tensor.dim %[[ORIGINAL]], %[[C1]]
 //       CHECK:     %[[SCATTER_TILE:.+]] = linalg_ext.scatter
 //  CHECK-SAME:         __internal_linalg_transform__ = "tiling_output"
 //  CHECK-SAME:         ins(%[[UPDATE_SLICE]], %[[INDEX_SLICE]]
 //  CHECK-SAME:         outs(%[[INIT]]
-//       CHECK:     %[[YIELD:.+]] = tensor.insert_slice %[[SCATTER_TILE]] into %[[INIT]][0, 0]
-//  CHECK-SAME:         [%[[SLICE_D0]], %[[SLICE_D1]]]
-//       CHECK:     scf.yield %[[YIELD]]
+//       CHECK:     scf.yield %[[SCATTER_TILE]]
 //       CHECK:   return %[[RESULT]]
 
 // -----
@@ -114,15 +110,11 @@
 //  CHECK-SAME:         [%[[USED_TILESIZE]], %[[D1]]]
 //       CHECK:     %[[INDEX_SLICE:.+]] = tensor.extract_slice %[[INDICES]][%[[IV]], 0]
 //  CHECK-SAME:         [%[[USED_TILESIZE]], 1]
-//   CHECK-DAG:     %[[SLICE_D0:.+]] = tensor.dim %[[ORIGINAL]], %[[C0]]
-//   CHECK-DAG:     %[[SLICE_D1:.+]] = tensor.dim %[[ORIGINAL]], %[[C1]]
 //       CHECK:     %[[SCATTER_TILE:.+]] = linalg_ext.scatter
 //  CHECK-SAME:         __internal_linalg_transform__ = "distribute_output"
 //  CHECK-SAME:         ins(%[[UPDATE_SLICE]], %[[INDEX_SLICE]]
 //  CHECK-SAME:         outs(%[[INIT]]
-//       CHECK:     %[[YIELD:.+]] = tensor.insert_slice %[[SCATTER_TILE]] into %[[INIT]][0, 0]
-//  CHECK-SAME:         [%[[SLICE_D0]], %[[SLICE_D1]]]
-//       CHECK:     scf.yield %[[YIELD]]
+//       CHECK:     scf.yield %[[SCATTER_TILE]]
 //       CHECK:   return %[[RESULT]]
 
 // -----
@@ -435,9 +427,9 @@
 // CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: index
 //      CHECK:   %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] =
 //      CHECK:     %[[YIELD1:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] =
+//      CHECK:       %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], %[[IV1]]]
 //      CHECK:       %[[OFFSET0:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[ARG2]]]
 //      CHECK:       %[[OFFSET1:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[ARG3]]]
-//      CHECK:       %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], %[[IV1]]]
 //      CHECK:       %[[UPDATE:.+]] = tensor.insert_slice %[[SLICE]]
 // CHECK-SAME:         into %{{.+}}[%[[OFFSET0]], %[[OFFSET1]]]
 //      CHECK:       scf.yield %[[UPDATE]]
@@ -466,9 +458,9 @@
 // CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: index
 //      CHECK:   %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] =
 //      CHECK:     %[[YIELD1:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] =
+//      CHECK:       %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], %[[IV1]]]
 //      CHECK:       %[[OFFSET0:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[ARG2]]]
 //      CHECK:       %[[OFFSET1:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[ARG3]]]
-//      CHECK:       %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], %[[IV1]]]
 //      CHECK:       %[[UPDATE:.+]] = tensor.insert_slice %[[SLICE]]
 // CHECK-SAME:         into %{{.+}}[%[[OFFSET0]], 0, %[[OFFSET1]]]
 //      CHECK:       scf.yield %[[UPDATE]]
diff --git a/iree/compiler/InputConversion/MHLO/BUILD b/iree/compiler/InputConversion/MHLO/BUILD
index c1dc464..aaf1475 100644
--- a/iree/compiler/InputConversion/MHLO/BUILD
+++ b/iree/compiler/InputConversion/MHLO/BUILD
@@ -46,10 +46,10 @@
     name = "MHLO",
     srcs = [
         "BroadcastingToLinalgPatterns.cpp",
-        "ConvertAndDistributeMHLOToLinalgExt.cpp",
         "ConvertComplexToReal.cpp",
         "ConvertMHLOToFlow.cpp",
         "ConvertMHLOToFlow.h",
+        "ConvertMHLOToLinalgExt.cpp",
         "LegalizeInputTypes.cpp",
         "MHLOToLinalgOnTensors.cpp",
         "MHLOToMHLOPreprocessing.cpp",
diff --git a/iree/compiler/InputConversion/MHLO/CMakeLists.txt b/iree/compiler/InputConversion/MHLO/CMakeLists.txt
index 36d2a4c..66276fc 100644
--- a/iree/compiler/InputConversion/MHLO/CMakeLists.txt
+++ b/iree/compiler/InputConversion/MHLO/CMakeLists.txt
@@ -41,10 +41,10 @@
     "Passes.h"
   SRCS
     "BroadcastingToLinalgPatterns.cpp"
-    "ConvertAndDistributeMHLOToLinalgExt.cpp"
     "ConvertComplexToReal.cpp"
     "ConvertMHLOToFlow.cpp"
     "ConvertMHLOToFlow.h"
+    "ConvertMHLOToLinalgExt.cpp"
     "LegalizeInputTypes.cpp"
     "MHLOToLinalgOnTensors.cpp"
     "MHLOToMHLOPreprocessing.cpp"
diff --git a/iree/compiler/InputConversion/MHLO/ConvertAndDistributeMHLOToLinalgExt.cpp b/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
similarity index 68%
rename from iree/compiler/InputConversion/MHLO/ConvertAndDistributeMHLOToLinalgExt.cpp
rename to iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
index b24c76d..571b854 100644
--- a/iree/compiler/InputConversion/MHLO/ConvertAndDistributeMHLOToLinalgExt.cpp
+++ b/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
@@ -47,51 +47,6 @@
 namespace {
 
 //===----------------------------------------------------------------------===//
-// Base classes.
-//===----------------------------------------------------------------------===//
-
-template <typename Derived, typename OpTy>
-struct ConvertToLinalgExtPattern : public OpConversionPattern<OpTy> {
-  using OpConversionPattern<OpTy>::OpConversionPattern;
-
-  LogicalResult matchAndRewrite(
-      OpTy op, ArrayRef<Value> args,
-      ConversionPatternRewriter &rewriter) const final {
-    Value one = rewriter.create<ConstantIndexOp>(op.getLoc(), 1);
-    SmallVector<Value> workload(3, one);
-
-    // Gather the dynamic dimensions for all operands.
-    SmallVector<Value> operandDynamicDims;
-    for (Value arg : args) {
-      if (auto rt = arg.getType().dyn_cast<RankedTensorType>()) {
-        for (unsigned i = 0; i < rt.getRank(); ++i) {
-          if (!rt.isDynamicDim(i)) continue;
-          auto dim = rewriter.createOrFold<tensor::DimOp>(op.getLoc(), arg, i);
-          operandDynamicDims.push_back(dim);
-        }
-      }
-    }
-
-    auto dispatchOp = rewriter.create<IREE::Flow::DispatchWorkgroupsOp>(
-        op.getLoc(), workload, op->getResultTypes(),
-        /*result_dims=*/ValueRange{},
-        /*operands=*/args,
-        /*operand_dims=*/operandDynamicDims,
-        /*tied_operands=*/Derived::getTiedResultOperandIndices(args));
-    {
-      OpBuilder::InsertionGuard guard(rewriter);
-      rewriter.setInsertionPointToStart(&dispatchOp.getRegion().front());
-      if (failed(Derived::lowerMHLOOp(dispatchOp, op, args, rewriter))) {
-        return failure();
-      }
-      rewriter.create<IREE::Flow::ReturnOp>(op.getLoc());
-    }
-    rewriter.replaceOp(op, dispatchOp.getResults());
-    return success();
-  }
-};
-
-//===----------------------------------------------------------------------===//
 // Region operations lowering.
 //===----------------------------------------------------------------------===//
 
@@ -131,32 +86,15 @@
 // SortOp
 //===----------------------------------------------------------------------===//
 
-struct SortOpConversion
-    : public ConvertToLinalgExtPattern<SortOpConversion, mhlo::SortOp> {
-  using ConvertToLinalgExtPattern<SortOpConversion,
-                                  mhlo::SortOp>::ConvertToLinalgExtPattern;
+struct SortOpConversion : public OpConversionPattern<mhlo::SortOp> {
+  using OpConversionPattern<mhlo::SortOp>::OpConversionPattern;
 
-  static SmallVector<int64_t> getTiedResultOperandIndices(
-      ArrayRef<Value> args) {
-    return llvm::to_vector<4>(llvm::seq<int64_t>(0, args.size()));
-  }
-
-  static LogicalResult lowerMHLOOp(IREE::Flow::DispatchWorkgroupsOp dispatchOp,
-                                   mhlo::SortOp op, ArrayRef<Value> args,
-                                   ConversionPatternRewriter &rewriter) {
-    auto blockArgs = dispatchOp.getClosureBodyRegion().getArguments();
-    SmallVector<Value> initValues;
-    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    for (auto it : llvm::zip(args, blockArgs)) {
-      auto argTy = std::get<0>(it).getType().cast<RankedTensorType>();
-      auto blockArg = std::get<1>(it);
-      initValues.push_back(
-          b.create<IREE::Flow::DispatchTensorLoadOp>(argTy, blockArg));
-    }
-
-    auto sortOp = b.create<linalg_ext::SortOp>(op.getResultTypes(),
-                                               /*inputs=*/ValueRange{},
-                                               initValues, op.dimensionAttr());
+  LogicalResult matchAndRewrite(
+      mhlo::SortOp op, ArrayRef<Value> args,
+      ConversionPatternRewriter &rewriter) const final {
+    auto sortOp = rewriter.create<linalg_ext::SortOp>(
+        op.getLoc(), op.getResultTypes(),
+        /*inputs=*/ValueRange{}, args, op.dimensionAttr());
     rewriter.inlineRegionBefore(op.comparator(), sortOp.region(),
                                 sortOp.region().begin());
     Region &region = sortOp.region();
@@ -169,12 +107,7 @@
     }
     rewriter.applySignatureConversion(&region, signature_converter);
 
-    for (auto it : llvm::zip(sortOp.getResults(), blockArgs)) {
-      auto value = std::get<0>(it);
-      auto target = std::get<1>(it);
-      b.create<IREE::Flow::DispatchTensorStoreOp>(value, target);
-    }
-
+    rewriter.replaceOp(op, sortOp->getResults());
     return success();
   }
 };
@@ -183,10 +116,8 @@
 // ScatterOp
 //===----------------------------------------------------------------------===//
 
-struct ScatterOpConversion
-    : public ConvertToLinalgExtPattern<ScatterOpConversion, mhlo::ScatterOp> {
-  using ConvertToLinalgExtPattern<ScatterOpConversion,
-                                  mhlo::ScatterOp>::ConvertToLinalgExtPattern;
+struct ScatterOpConversion : public OpConversionPattern<mhlo::ScatterOp> {
+  using OpConversionPattern<mhlo::ScatterOp>::OpConversionPattern;
 
   /// Returns true if the `dimensionNumbers` from the mhlo.scatter op follows a
   /// canonical form:
@@ -275,28 +206,23 @@
     return success();
   }
 
-  static LogicalResult lowerMHLOOp(IREE::Flow::DispatchWorkgroupsOp dispatchOp,
-                                   mhlo::ScatterOp op, ArrayRef<Value> args,
-                                   ConversionPatternRewriter &rewriter) {
+  LogicalResult matchAndRewrite(
+      mhlo::ScatterOp op, ArrayRef<Value> args,
+      ConversionPatternRewriter &rewriter) const final {
     if (!hasCanonicalDimensionNumbers(op)) return failure();
 
     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
     mhlo::ScatterOpAdaptor adaptor(args);
 
-    auto blockArgs = dispatchOp.getClosureBodyRegion().getArguments();
-    Value original = b.create<IREE::Flow::DispatchTensorLoadOp>(
-        adaptor.operand().getType().cast<RankedTensorType>(), blockArgs[0]);
-    Value indices = b.create<IREE::Flow::DispatchTensorLoadOp>(
-        adaptor.scatter_indices().getType().cast<RankedTensorType>(),
-        blockArgs[1]);
-    Value updates = b.create<IREE::Flow::DispatchTensorLoadOp>(
-        adaptor.updates().getType().cast<RankedTensorType>(), blockArgs[2]);
+    Value original = adaptor.operand();
+    Value indices = adaptor.scatter_indices();
+    Value updates = adaptor.updates();
 
     if (failed(collapseBatchDimsIfNeeded(indices, updates, b))) {
       return failure();
     }
-    auto scatterOp = b.create<linalg_ext::ScatterOp>(
-        op->getResultTypes(), ValueRange{updates, indices},
+    auto scatterOp = rewriter.create<linalg_ext::ScatterOp>(
+        op.getLoc(), op->getResultTypes(), ValueRange{updates, indices},
         ValueRange{original});
 
     rewriter.inlineRegionBefore(op.update_computation(), scatterOp.region(),
@@ -311,8 +237,7 @@
     signatureConverter.addInputs(0, argType);
     rewriter.applySignatureConversion(&region, signatureConverter);
 
-    b.create<IREE::Flow::DispatchTensorStoreOp>(scatterOp.getResult(0),
-                                                blockArgs[0]);
+    rewriter.replaceOp(op, scatterOp->getResults());
     return success();
   }
 };
@@ -321,9 +246,8 @@
 // Pass
 //===----------------------------------------------------------------------===//
 
-struct ConvertAndDistributeMHLOToLinalgExtPass
-    : public ConvertAndDistributeMHLOToLinalgExtBase<
-          ConvertAndDistributeMHLOToLinalgExtPass> {
+struct ConvertMHLOToLinalgExtPass
+    : public ConvertMHLOToLinalgExtBase<ConvertMHLOToLinalgExtPass> {
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<linalg_ext::LinalgExtDialect, linalg::LinalgDialect,
                     IREE::Flow::FlowDialect, StandardOpsDialect,
@@ -354,9 +278,8 @@
 };
 }  // namespace
 
-std::unique_ptr<OperationPass<FuncOp>>
-createConvertAndDistributeMHLOToLinalgExtPass() {
-  return std::make_unique<ConvertAndDistributeMHLOToLinalgExtPass>();
+std::unique_ptr<OperationPass<FuncOp>> createConvertMHLOToLinalgExtPass() {
+  return std::make_unique<ConvertMHLOToLinalgExtPass>();
 }
 
 }  // namespace iree_compiler
diff --git a/iree/compiler/InputConversion/MHLO/Passes.cpp b/iree/compiler/InputConversion/MHLO/Passes.cpp
index ebf3147..2b589cb 100644
--- a/iree/compiler/InputConversion/MHLO/Passes.cpp
+++ b/iree/compiler/InputConversion/MHLO/Passes.cpp
@@ -66,7 +66,7 @@
 
   // Convert to Linalg. After this point, MHLO will be eliminated.
   passManager.addNestedPass<FuncOp>(
-      mlir::iree_compiler::createConvertAndDistributeMHLOToLinalgExtPass());
+      mlir::iree_compiler::createConvertMHLOToLinalgExtPass());
   passManager.addNestedPass<FuncOp>(
       mlir::iree_compiler::createMHLOToLinalgOnTensorsPass());
 
diff --git a/iree/compiler/InputConversion/MHLO/Passes.h b/iree/compiler/InputConversion/MHLO/Passes.h
index 5761952..8bc4222 100644
--- a/iree/compiler/InputConversion/MHLO/Passes.h
+++ b/iree/compiler/InputConversion/MHLO/Passes.h
@@ -34,9 +34,8 @@
 /// Creates XLA-HLO to Linalg on tensors transformation pass.
 std::unique_ptr<OperationPass<FuncOp>> createMHLOToLinalgOnTensorsPass();
 
-/// Creates XLA-HLO to LinalgExt and Flow transformation pass.
-std::unique_ptr<OperationPass<FuncOp>>
-createConvertAndDistributeMHLOToLinalgExtPass();
+/// Creates XLA-HLO to LinalgExt pass.
+std::unique_ptr<OperationPass<FuncOp>> createConvertMHLOToLinalgExtPass();
 
 /// Creates XLA-HLO preprocessing transformation pass. In this pass we should
 /// have all mhlo -> mhlo transformations that are shared between all
diff --git a/iree/compiler/InputConversion/MHLO/Passes.td b/iree/compiler/InputConversion/MHLO/Passes.td
index a08fa08..b0082b2 100644
--- a/iree/compiler/InputConversion/MHLO/Passes.td
+++ b/iree/compiler/InputConversion/MHLO/Passes.td
@@ -15,12 +15,12 @@
   let constructor = "mlir::iree_compiler::createMHLOToLinalgOnTensorsPass()";
 }
 
-def ConvertAndDistributeMHLOToLinalgExt
+def ConvertMHLOToLinalgExt
     : Pass<"iree-mhlo-to-linalg-ext", "FuncOp"> {
   let summary =
       "Convert from XLA-HLO ops to LinalgExt ops and distribute to Flow ops";
   let constructor =
-      "mlir::iree_compiler::createConvertAndDistributeMHLOToLinalgExtPass()";
+      "mlir::iree_compiler::createConvertMHLOToLinalgExtPass()";
 }
 
 def LegalizeInputTypes :
diff --git a/iree/compiler/InputConversion/MHLO/test/BUILD b/iree/compiler/InputConversion/MHLO/test/BUILD
index c3f4485..cef5ec5 100644
--- a/iree/compiler/InputConversion/MHLO/test/BUILD
+++ b/iree/compiler/InputConversion/MHLO/test/BUILD
@@ -20,7 +20,7 @@
     srcs = enforce_glob(
         [
             "broadcasting.mlir",
-            "convert_and_distribute_mhlo_to_linalg_ext.mlir",
+            "convert_mhlo_to_linalg_ext.mlir",
             "convert_complex_to_real.mlir",
             "dynamic_shape.mlir",
             "fft.mlir",
diff --git a/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt b/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt
index 9f3a453..d7405b8 100644
--- a/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt
+++ b/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt
@@ -15,8 +15,8 @@
     lit
   SRCS
     "broadcasting.mlir"
-    "convert_and_distribute_mhlo_to_linalg_ext.mlir"
     "convert_complex_to_real.mlir"
+    "convert_mhlo_to_linalg_ext.mlir"
     "dynamic_shape.mlir"
     "fft.mlir"
     "legalize_input_types.mlir"
diff --git a/iree/compiler/InputConversion/MHLO/test/convert_and_distribute_mhlo_to_linalg_ext.mlir b/iree/compiler/InputConversion/MHLO/test/convert_and_distribute_mhlo_to_linalg_ext.mlir
deleted file mode 100644
index 8b80c34..0000000
--- a/iree/compiler/InputConversion/MHLO/test/convert_and_distribute_mhlo_to_linalg_ext.mlir
+++ /dev/null
@@ -1,342 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-mhlo-to-linalg-ext %s | IreeFileCheck %s
-
-func @sort_1d(%arg0: tensor<128xi32>) -> (tensor<128xi32>) {
-  %0 = "mhlo.sort"(%arg0) ( {
-  ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):  // no predecessors
-    %1 = "mhlo.compare"(%arg2, %arg3) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
-    "mhlo.return"(%1) : (tensor<i1>) -> ()
-  }) {dimension = 0 : i64, is_stable = false} : (tensor<128xi32>) -> (tensor<128xi32>)
-  return %0 : tensor<128xi32>
-}
-// CHECK-LABEL: func @sort_1d
-// CHECK:         %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-DAG:     %[[C1:.+]] = constant 1 : index
-// CHECK:         %[[RES:.+]] = flow.dispatch.workgroups
-// CHECK-SAME:      [%[[C1]], %[[C1]], %[[C1]]](%[[ARG0]]) : (tensor<128xi32>) -> %[[ARG0]]
-// CHECK:           %[[ARG1:.+]]: !flow.dispatch.tensor<readwrite:128xi32>
-// CHECK:           %[[IN:.+]] = flow.dispatch.tensor.load %[[ARG1]]
-// CHECK:           %[[SORT:.+]] = linalg_ext.sort
-// CHECK-SAME:        dimension(0)
-// CHECK-SAME:        outs(%[[IN]] : tensor<128xi32>)
-// CHECK:          ^bb0(%[[ARG2:.+]]: i32, %[[ARG3:.+]]: i32)
-// CHECK:            %[[CMP:.+]] = cmpi sgt, %[[ARG2]], %[[ARG3]]
-// CHECK:            linalg_ext.yield %[[CMP]]
-// CHECK:          flow.dispatch.tensor.store %[[SORT]], %[[ARG1]]
-// CHECK:        return %[[RES]]
-
-// -----
-
-func @sort_2d(%arg0: tensor<16x32xi32>) -> (tensor<16x32xi32>) {
-  %0 = "mhlo.sort"(%arg0) ( {
-  ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):  // no predecessors
-    %1 = "mhlo.compare"(%arg2, %arg3) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
-    "mhlo.return"(%1) : (tensor<i1>) -> ()
-  }) {dimension = 0 : i64, is_stable = false} : (tensor<16x32xi32>) -> (tensor<16x32xi32>)
-  return %0 : tensor<16x32xi32>
-}
-// CHECK-LABEL: func @sort_2d
-// CHECK:         %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-DAG:     %[[C1:.+]] = constant 1 : index
-// CHECK:         %[[RES:.+]] = flow.dispatch.workgroups
-// CHECK-SAME:      [%[[C1]], %[[C1]], %[[C1]]](%[[ARG0]]) : (tensor<16x32xi32>) -> %[[ARG0]]
-// CHECK:           %[[ARG1:.+]]: !flow.dispatch.tensor<readwrite:16x32xi32>
-// CHECK:           %[[IN:.+]] = flow.dispatch.tensor.load %[[ARG1]]
-// CHECK:           %[[SORT:.+]] = linalg_ext.sort
-// CHECK-SAME:        dimension(0)
-// CHECK-SAME:        outs(%[[IN]] : tensor<16x32xi32>)
-// CHECK:          ^bb0(%[[ARG2:.+]]: i32, %[[ARG3:.+]]: i32)
-// CHECK:            %[[CMP:.+]] = cmpi sgt, %[[ARG2]], %[[ARG3]]
-// CHECK:            linalg_ext.yield %[[CMP]]
-// CHECK:          flow.dispatch.tensor.store %[[SORT]], %[[ARG1]]
-// CHECK:        return %[[RES]]
-
-// -----
-
-func @topk(%arg0: tensor<128xi32>, %arg1: tensor<128xi32>) -> (tensor<128xi32>) {
-  %0:2 = "mhlo.sort"(%arg0, %arg1) ( {
-  ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>, %arg5: tensor<i32>):  // no predecessors
-    %1 = "mhlo.compare"(%arg2, %arg3) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
-    "mhlo.return"(%1) : (tensor<i1>) -> ()
-  }) {dimension = 0 : i64, is_stable = false} : (tensor<128xi32>, tensor<128xi32>) -> (tensor<128xi32>, tensor<128xi32>)
-  return %0#0 : tensor<128xi32>
-}
-// CHECK-LABEL: func @topk
-// CHECK:         %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK:         %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK-DAG:     %[[C1:.+]] = constant 1 : index
-// CHECK:         %[[RES:.+]]:2 = flow.dispatch.workgroups
-// CHECK-SAME:      [%[[C1]], %[[C1]], %[[C1]]](%[[ARG0]], %[[ARG1]])
-// CHECK-SAME:    : (tensor<128xi32>, tensor<128xi32>) -> (%[[ARG0]], %[[ARG1]])
-// CHECK:           %[[ARG2:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readwrite:128xi32>
-// CHECK:           %[[ARG3:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readwrite:128xi32>
-// CHECK:           %[[IN1:.+]] = flow.dispatch.tensor.load %[[ARG2]]
-// CHECK:           %[[IN2:.+]] = flow.dispatch.tensor.load %[[ARG3]]
-// CHECK:           %[[SORT:.+]]:2 = linalg_ext.sort
-// CHECK-SAME:        dimension(0)
-// CHECK-SAME:        outs(%[[IN1]], %[[IN2]] : tensor<128xi32>, tensor<128xi32>)
-// CHECK:          ^bb0(%[[ARG4:.+]]: i32, %[[ARG5:.+]]: i32, %{{.*}}: i32, %{{.*}}: i32)
-// CHECK:            %[[CMP:.+]] = cmpi sgt, %[[ARG4]], %[[ARG5]]
-// CHECK:            linalg_ext.yield %[[CMP]]
-// CHECK:          flow.dispatch.tensor.store %[[SORT]]#0, %[[ARG2]]
-// CHECK:          flow.dispatch.tensor.store %[[SORT]]#1, %[[ARG3]]
-// CHECK:        return %[[RES]]#0
-
-// -----
-
-func @scatter_update_scalar_1D(%arg0: tensor<8xi32>, %arg1: tensor<4x1xi32>,
-    %arg2: tensor<4xi32>) -> tensor<8xi32> {
-  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
-  ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):  // no predecessors
-    "mhlo.return"(%arg4) : (tensor<i32>) -> ()
-  }) {
-    indices_are_sorted = false,
-    scatter_dimension_numbers = {
-      index_vector_dim = 1 : i64,
-      inserted_window_dims = dense<0> : tensor<1xi64>,
-      scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
-      update_window_dims = dense<> : tensor<0xi64>
-    },
-    unique_indices = false
-  } : (tensor<8xi32>, tensor<4x1xi32>, tensor<4xi32>) -> tensor<8xi32>
-  return %0 : tensor<8xi32>
-}
-// CHECK-LABEL: func @scatter_update_scalar_1D
-// CHECK:         %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK:         %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK:         %[[ARG2:[a-zA-Z0-9]+]]
-// CHECK-DAG:     %[[C1:.+]] = constant 1 : index
-// CHECK:         %[[RES:.+]] = flow.dispatch.workgroups
-// CHECK-SAME:      [%[[C1]], %[[C1]], %[[C1]]](%[[ARG0]], %[[ARG1]], %[[ARG2]])
-// CHECK-SAME:    : (tensor<8xi32>, tensor<4x1xi32>, tensor<4xi32>) -> %[[ARG0]]
-// CHECK:           %[[ARG3:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readwrite:8xi32>
-// CHECK-SAME:      %[[ARG4:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:4x1xi32>
-// CHECK-SAME:      %[[ARG5:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:4xi32>
-// CHECK:           %[[ORIGINAL:.+]] = flow.dispatch.tensor.load %[[ARG3]]
-// CHECK:           %[[INDICES:.+]] = flow.dispatch.tensor.load %[[ARG4]]
-// CHECK:           %[[UPDATES:.+]] = flow.dispatch.tensor.load %[[ARG5]]
-// CHECK:           %[[SCATTER:.+]] = linalg_ext.scatter
-// CHECK-SAME:        ins(%[[UPDATES]], %[[INDICES]] : tensor<4xi32>, tensor<4x1xi32>)
-// CHECK-SAME:       outs(%[[ORIGINAL]] : tensor<8xi32>)
-// CHECK:           ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32):  // no predecessors
-// CEECK:             linalg.yield %[[V1]]
-// CHECK:           flow.dispatch.tensor.store %[[SCATTER]], %[[ARG3]]
-// CHECK:        return %[[RES]]
-
-// -----
-
-func @scatter_update_scalar_2D(%arg0: tensor<4x3xi32>, %arg1: tensor<3x2xi32>,
-    %arg2: tensor<3xi32>) -> tensor<4x3xi32> {
-  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
-  ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):  // no predecessors
-    "mhlo.return"(%arg4) : (tensor<i32>) -> ()
-  }) {indices_are_sorted = false,
-      scatter_dimension_numbers = {
-        index_vector_dim = 1 : i64,
-        inserted_window_dims = dense<[0, 1]> : tensor<2xi64>,
-        scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>,
-        update_window_dims = dense<> : tensor<0xi64>
-      },
-      unique_indices = false
-  } : (tensor<4x3xi32>, tensor<3x2xi32>, tensor<3xi32>) -> tensor<4x3xi32>
-  return %0 : tensor<4x3xi32>
-}
-// CHECK-LABEL: func @scatter_update_scalar_2D
-// CHECK:         %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK:         %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK:         %[[ARG2:[a-zA-Z0-9]+]]
-// CHECK-DAG:     %[[C1:.+]] = constant 1 : index
-// CHECK:         %[[RES:.+]] = flow.dispatch.workgroups
-// CHECK-SAME:      [%[[C1]], %[[C1]], %[[C1]]](%[[ARG0]], %[[ARG1]], %[[ARG2]])
-// CHECK-SAME:    : (tensor<4x3xi32>, tensor<3x2xi32>, tensor<3xi32>) -> %[[ARG0]]
-// CHECK:           %[[ARG3:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readwrite:4x3xi32>
-// CHECK-SAME:      %[[ARG4:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:3x2xi32>
-// CHECK-SAME:      %[[ARG5:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:3xi32>
-// CHECK:           %[[ORIGINAL:.+]] = flow.dispatch.tensor.load %[[ARG3]]
-// CHECK:           %[[INDICES:.+]] = flow.dispatch.tensor.load %[[ARG4]]
-// CHECK:           %[[UPDATES:.+]] = flow.dispatch.tensor.load %[[ARG5]]
-// CHECK:           %[[SCATTER:.+]] = linalg_ext.scatter
-// CHECK-SAME:        ins(%[[UPDATES]], %[[INDICES]] : tensor<3xi32>, tensor<3x2xi32>)
-// CHECK-SAME:       outs(%[[ORIGINAL]] : tensor<4x3xi32>)
-// CHECK:           ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32):  // no predecessors
-// CEECK:             linalg.yield %[[V1]]
-// CHECK:           flow.dispatch.tensor.store %[[SCATTER]], %[[ARG3]]
-// CHECK:        return %[[RES]]
-
-// -----
-
-func @scatter_update_slice_2D(%arg0: tensor<6x3xi32>, %arg1: tensor<2x1xi32>,
-    %arg2: tensor<2x3xi32>) -> tensor<6x3xi32> {
-  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
-  ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):  // no predecessors
-    "mhlo.return"(%arg4) : (tensor<i32>) -> ()
-  }) {
-    indices_are_sorted = false,
-    scatter_dimension_numbers = {
-      index_vector_dim = 1 : i64,
-      inserted_window_dims = dense<0> : tensor<1xi64>,
-      scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
-      update_window_dims = dense<1> : tensor<1xi64>
-    },
-    unique_indices = false
-  } : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) -> tensor<6x3xi32>
-  return %0 : tensor<6x3xi32>
-}
-// CHECK-LABEL: func @scatter_update_slice_2D
-// CHECK:         %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK:         %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK:         %[[ARG2:[a-zA-Z0-9]+]]
-// CHECK-DAG:     %[[C1:.+]] = constant 1 : index
-// CHECK:         %[[RES:.+]] = flow.dispatch.workgroups
-// CHECK-SAME:      [%[[C1]], %[[C1]], %[[C1]]](%[[ARG0]], %[[ARG1]], %[[ARG2]])
-// CHECK-SAME:    : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) -> %[[ARG0]]
-// CHECK:           %[[ARG3:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readwrite:6x3xi32>
-// CHECK-SAME:      %[[ARG4:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:2x1xi32>
-// CHECK-SAME:      %[[ARG5:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:2x3xi32>
-// CHECK:           %[[ORIGINAL:.+]] = flow.dispatch.tensor.load %[[ARG3]]
-// CHECK:           %[[INDICES:.+]] = flow.dispatch.tensor.load %[[ARG4]]
-// CHECK:           %[[UPDATES:.+]] = flow.dispatch.tensor.load %[[ARG5]]
-// CHECK:           %[[SCATTER:.+]] = linalg_ext.scatter
-// CHECK-SAME:        ins(%[[UPDATES]], %[[INDICES]] : tensor<2x3xi32>, tensor<2x1xi32>)
-// CHECK-SAME:       outs(%[[ORIGINAL]] : tensor<6x3xi32>)
-// CHECK:           ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32):  // no predecessors
-// CEECK:             linalg.yield %[[V1]]
-// CHECK:           flow.dispatch.tensor.store %[[SCATTER]], %[[ARG3]]
-// CHECK:        return %[[RES]]
-
-// -----
-
-func @scatter_add_slice_2D(%arg0: tensor<6x3xi32>, %arg1: tensor<2x1xi32>,
-    %arg2: tensor<2x3xi32>) -> tensor<6x3xi32> {
-  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
-  ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):  // no predecessors
-    %1 = mhlo.add %arg3, %arg4 : tensor<i32>
-    "mhlo.return"(%1) : (tensor<i32>) -> ()
-  }) {
-    indices_are_sorted = false,
-    scatter_dimension_numbers = {
-      index_vector_dim = 1 : i64,
-      inserted_window_dims = dense<0> : tensor<1xi64>,
-      scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
-      update_window_dims = dense<1> : tensor<1xi64>
-    },
-    unique_indices = false
-  } : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) -> tensor<6x3xi32>
-  return %0 : tensor<6x3xi32>
-}
-// CHECK-LABEL: func @scatter_add_slice_2D
-// CHECK:         %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK:         %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK:         %[[ARG2:[a-zA-Z0-9]+]]
-// CHECK-DAG:     %[[C1:.+]] = constant 1 : index
-// CHECK:         %[[RES:.+]] = flow.dispatch.workgroups
-// CHECK-SAME:      [%[[C1]], %[[C1]], %[[C1]]](%[[ARG0]], %[[ARG1]], %[[ARG2]])
-// CHECK-SAME:    : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) -> %[[ARG0]]
-// CHECK:           %[[ARG3:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readwrite:6x3xi32>
-// CHECK-SAME:      %[[ARG4:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:2x1xi32>
-// CHECK-SAME:      %[[ARG5:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:2x3xi32>
-// CHECK:           %[[ORIGINAL:.+]] = flow.dispatch.tensor.load %[[ARG3]]
-// CHECK:           %[[INDICES:.+]] = flow.dispatch.tensor.load %[[ARG4]]
-// CHECK:           %[[UPDATES:.+]] = flow.dispatch.tensor.load %[[ARG5]]
-// CHECK:           %[[SCATTER:.+]] = linalg_ext.scatter
-// CHECK-SAME:        ins(%[[UPDATES]], %[[INDICES]] : tensor<2x3xi32>, tensor<2x1xi32>)
-// CHECK-SAME:       outs(%[[ORIGINAL]] : tensor<6x3xi32>)
-// CHECK:           ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32):  // no predecessors
-//
-//                   The order is reverse.
-// CHECK:             %[[V3:.+]] = addi %[[V2]], %[[V1]]
-// CEECK:             linalg.yield %[[V3]]
-// CHECK:           flow.dispatch.tensor.store %[[SCATTER]], %[[ARG3]]
-// CHECK:        return %[[RES]]
-
-// -----
-
-func @scatter_update_batch_scalar_1D(%arg0: tensor<8xi32>,
-    %arg1: tensor<3x4x1xi32>, %arg2: tensor<3x4xi32>) -> tensor<8xi32> {
-  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
-  ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):  // no predecessors
-    "mhlo.return"(%arg4) : (tensor<i32>) -> ()
-  }) {
-    indices_are_sorted = false,
-    scatter_dimension_numbers = {
-      index_vector_dim = 2 : i64,
-      inserted_window_dims = dense<0> : tensor<i64>,
-      scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
-      update_window_dims = dense<> : tensor<0xi64>
-    },
-    unique_indices = false
-  } : (tensor<8xi32>, tensor<3x4x1xi32>, tensor<3x4xi32>) -> tensor<8xi32>
-  return %0 : tensor<8xi32>
-}
-// CHECK-LABEL: func @scatter_update_batch_scalar_1D
-// CHECK:         %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK:         %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK:         %[[ARG2:[a-zA-Z0-9]+]]
-// CHECK-DAG:     %[[C1:.+]] = constant 1 : index
-// CHECK:         %[[RES:.+]] = flow.dispatch.workgroups
-// CHECK-SAME:      [%[[C1]], %[[C1]], %[[C1]]](%[[ARG0]], %[[ARG1]], %[[ARG2]])
-// CHECK-SAME:    : (tensor<8xi32>, tensor<3x4x1xi32>, tensor<3x4xi32>) -> %[[ARG0]]
-// CHECK:           %[[ARG3:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readwrite:8xi32>
-// CHECK-SAME:      %[[ARG4:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:3x4x1xi32>
-// CHECK-SAME:      %[[ARG5:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:3x4xi32>
-// CHECK:           %[[ORIGINAL:.+]] = flow.dispatch.tensor.load %[[ARG3]]
-// CHECK:           %[[INDICES:.+]] = flow.dispatch.tensor.load %[[ARG4]]
-// CHECK:           %[[UPDATES:.+]] = flow.dispatch.tensor.load %[[ARG5]]
-// CHECK:           %[[COLLAPSED_INDICES:.+]] = linalg.tensor_collapse_shape
-// CHECK-SAME:        %[[INDICES]] {{\[}}[0, 1], [2]] : tensor<3x4x1xi32> into tensor<12x1xi32>
-// CHECK:           %[[COLLAPSED_UPDATES:.+]] = linalg.tensor_collapse_shape
-// CHECK-SAME:        %[[UPDATES]] {{\[}}[0, 1]] : tensor<3x4xi32> into tensor<12xi32>
-// CHECK:           %[[SCATTER:.+]] = linalg_ext.scatter
-// CHECK-SAME:        ins(%[[COLLAPSED_UPDATES]], %[[COLLAPSED_INDICES]] : tensor<12xi32>, tensor<12x1xi32>)
-// CHECK-SAME:       outs(%[[ORIGINAL]] : tensor<8xi32>)
-// CHECK:           ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32):  // no predecessors
-// CEECK:             linalg.yield %[[V1]]
-// CHECK:           flow.dispatch.tensor.store %[[SCATTER]], %[[ARG3]]
-// CHECK:        return %[[RES]]
-
-// -----
-
-func @scatter_update_batch_slice_3D_dynamic(%arg0: tensor<1x24x512xi32>,
-    %arg1: tensor<?x3x2xi32>, %arg2: tensor<?x3x512xi32>) -> tensor<1x24x512xi32> {
-  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
-  ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):  // no predecessors
-    "mhlo.return"(%arg4) : (tensor<i32>) -> ()
-  }) {indices_are_sorted = false,
-      scatter_dimension_numbers = {
-        index_vector_dim = 2 : i64,
-        inserted_window_dims = dense<[0, 1]> : tensor<2xi64>,
-        scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>,
-        update_window_dims = dense<2> : tensor<1xi64>
-      },
-      unique_indices = false
-  } : (tensor<1x24x512xi32>, tensor<?x3x2xi32>, tensor<?x3x512xi32>) -> tensor<1x24x512xi32>
-  return %0 : tensor<1x24x512xi32>
-}
-// CHECK-LABEL: func @scatter_update_batch_slice_3D_dynamic
-// CHECK:         %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK:         %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK:         %[[ARG2:[a-zA-Z0-9]+]]
-// CHECK-DAG:     %[[C1:.+]] = constant 1 : index
-// CHECK-DAG:     %[[C0:.+]] = constant 0 : index
-// CHECK-DAG:     %[[DIM1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x3x2xi32>
-// CHECK-DAG:     %[[C0:.+]] = constant 0 : index
-// CHECK-DAG:     %[[DIM2:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x3x512xi32>
-// CHECK:         %[[RES:.+]] = flow.dispatch.workgroups
-// CHECK-SAME:      [%[[C1]], %[[C1]], %[[C1]]](%[[ARG0]], %[[ARG1]], %[[ARG2]])
-// CHECK-SAME:    : (tensor<1x24x512xi32>, tensor<?x3x2xi32>{%[[DIM1]]}, tensor<?x3x512xi32>{%[[DIM2]]}) -> %[[ARG0]]
-// CHECK:           %[[ARG3:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readwrite:1x24x512xi32>
-// CHECK-SAME:      %[[ARG4:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:?x3x2xi32>
-// CHECK-SAME:      %[[ARG5:[a-zA-Z0-9]+]]: !flow.dispatch.tensor<readonly:?x3x512xi32>
-// CHECK:           %[[ORIGINAL:.+]] = flow.dispatch.tensor.load %[[ARG3]]
-// CHECK:           %[[INDICES:.+]] = flow.dispatch.tensor.load %[[ARG4]]
-// CHECK:           %[[UPDATES:.+]] = flow.dispatch.tensor.load %[[ARG5]]
-// CHECK:           %[[COLLAPSED_INDICES:.+]] = linalg.tensor_collapse_shape
-// CHECK-SAME:        %[[INDICES]] {{\[}}[0, 1], [2]] : tensor<?x3x2xi32> into tensor<?x2xi32>
-// CHECK:           %[[COLLAPSED_UPDATES:.+]] = linalg.tensor_collapse_shape
-// CHECK-SAME:        %[[UPDATES]] {{\[}}[0, 1], [2]] : tensor<?x3x512xi32> into tensor<?x512xi32>
-// CHECK:           %[[SCATTER:.+]] = linalg_ext.scatter
-// CHECK-SAME:        ins(%[[COLLAPSED_UPDATES]], %[[COLLAPSED_INDICES]] : tensor<?x512xi32>, tensor<?x2xi32>)
-// CHECK-SAME:       outs(%[[ORIGINAL]] : tensor<1x24x512xi32>)
-// CHECK:           ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32):  // no predecessors
-// CEECK:             linalg.yield %[[V1]]
-// CHECK:           flow.dispatch.tensor.store %[[SCATTER]], %[[ARG3]]
-// CHECK:        return %[[RES]]
diff --git a/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir b/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir
new file mode 100644
index 0000000..e25865b
--- /dev/null
+++ b/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir
@@ -0,0 +1,250 @@
+// RUN: iree-opt -split-input-file -iree-mhlo-to-linalg-ext %s | IreeFileCheck %s
+
+func @sort_1d(%arg0: tensor<128xi32>) -> (tensor<128xi32>) {
+  %0 = "mhlo.sort"(%arg0) ( {
+  ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):  // no predecessors
+    %1 = "mhlo.compare"(%arg2, %arg3) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    "mhlo.return"(%1) : (tensor<i1>) -> ()
+  }) {dimension = 0 : i64, is_stable = false} : (tensor<128xi32>) -> (tensor<128xi32>)
+  return %0 : tensor<128xi32>
+}
+// CHECK-LABEL: func @sort_1d
+// CHECK:         %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[SORT:.+]] = linalg_ext.sort
+// CHECK-SAME:      dimension(0)
+// CHECK-SAME:      outs(%[[ARG0]] : tensor<128xi32>)
+// CHECK:           ^bb0(%[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32)
+// CHECK:             %[[CMP:.+]] = cmpi sgt, %[[ARG1]], %[[ARG2]]
+// CHECK:             linalg_ext.yield %[[CMP]]
+// CHECK:         return %[[SORT]]
+
+// -----
+
+func @sort_2d(%arg0: tensor<16x32xi32>) -> (tensor<16x32xi32>) {
+  %0 = "mhlo.sort"(%arg0) ( {
+  ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):  // no predecessors
+    %1 = "mhlo.compare"(%arg2, %arg3) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    "mhlo.return"(%1) : (tensor<i1>) -> ()
+  }) {dimension = 0 : i64, is_stable = false} : (tensor<16x32xi32>) -> (tensor<16x32xi32>)
+  return %0 : tensor<16x32xi32>
+}
+// CHECK-LABEL: func @sort_2d
+// CHECK:         %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[SORT:.+]] = linalg_ext.sort
+// CHECK-SAME:      dimension(0)
+// CHECK-SAME:      outs(%[[ARG0]] : tensor<16x32xi32>)
+// CHECK:           ^bb0(%[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32)
+// CHECK:             %[[CMP:.+]] = cmpi sgt, %[[ARG1]], %[[ARG2]]
+// CHECK:             linalg_ext.yield %[[CMP]]
+// CHECK:         return %[[SORT]]
+
+// -----
+
+func @topk(%arg0: tensor<128xi32>, %arg1: tensor<128xi32>) -> (tensor<128xi32>) {
+  %0:2 = "mhlo.sort"(%arg0, %arg1) ( {
+  ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>, %arg5: tensor<i32>):  // no predecessors
+    %1 = "mhlo.compare"(%arg2, %arg3) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    "mhlo.return"(%1) : (tensor<i1>) -> ()
+  }) {dimension = 0 : i64, is_stable = false} : (tensor<128xi32>, tensor<128xi32>) -> (tensor<128xi32>, tensor<128xi32>)
+  return %0#0 : tensor<128xi32>
+}
+// CHECK-LABEL: func @topk
+// CHECK:         %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK:         %[[SORT:.+]]:2 = linalg_ext.sort
+// CHECK-SAME:      dimension(0)
+// CHECK-SAME:      outs(%[[ARG0]], %[[ARG1]] : tensor<128xi32>, tensor<128xi32>)
+// CHECK:           ^bb0(%[[ARG2:.+]]: i32, %[[ARG3:.+]]: i32, %{{.*}}: i32, %{{.*}}: i32)
+// CHECK:             %[[CMP:.+]] = cmpi sgt, %[[ARG2]], %[[ARG3]]
+// CHECK:             linalg_ext.yield %[[CMP]]
+// CHECK:        return %[[SORT]]#0
+
+// -----
+
+func @scatter_update_scalar_1D(%arg0: tensor<8xi32>, %arg1: tensor<4x1xi32>,
+    %arg2: tensor<4xi32>) -> tensor<8xi32> {
+  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):  // no predecessors
+    "mhlo.return"(%arg4) : (tensor<i32>) -> ()
+  }) {
+    indices_are_sorted = false,
+    scatter_dimension_numbers = {
+      index_vector_dim = 1 : i64,
+      inserted_window_dims = dense<0> : tensor<1xi64>,
+      scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
+      update_window_dims = dense<> : tensor<0xi64>
+    },
+    unique_indices = false
+  } : (tensor<8xi32>, tensor<4x1xi32>, tensor<4xi32>) -> tensor<8xi32>
+  return %0 : tensor<8xi32>
+}
+// CHECK-LABEL: func @scatter_update_scalar_1D
+// CHECK:         %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK:         %[[SCATTER:.+]] = linalg_ext.scatter
+// CHECK-SAME:      ins(%[[ARG2]], %[[ARG1]] : tensor<4xi32>, tensor<4x1xi32>)
+// CHECK-SAME:      outs(%[[ARG0]] : tensor<8xi32>)
+// CHECK:           ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32):  // no predecessors
+// CEECK:             linalg.yield %[[V1]]
+// CHECK:         return %[[SCATTER]]
+
+// -----
+
+func @scatter_update_scalar_2D(%arg0: tensor<4x3xi32>, %arg1: tensor<3x2xi32>,
+    %arg2: tensor<3xi32>) -> tensor<4x3xi32> {
+  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):  // no predecessors
+    "mhlo.return"(%arg4) : (tensor<i32>) -> ()
+  }) {indices_are_sorted = false,
+      scatter_dimension_numbers = {
+        index_vector_dim = 1 : i64,
+        inserted_window_dims = dense<[0, 1]> : tensor<2xi64>,
+        scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>,
+        update_window_dims = dense<> : tensor<0xi64>
+      },
+      unique_indices = false
+  } : (tensor<4x3xi32>, tensor<3x2xi32>, tensor<3xi32>) -> tensor<4x3xi32>
+  return %0 : tensor<4x3xi32>
+}
+// CHECK-LABEL: func @scatter_update_scalar_2D
+// CHECK:         %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK:         %[[SCATTER:.+]] = linalg_ext.scatter
+// CHECK-SAME:      ins(%[[ARG2]], %[[ARG1]] : tensor<3xi32>, tensor<3x2xi32>)
+// CHECK-SAME:      outs(%[[ARG0]] : tensor<4x3xi32>)
+// CHECK:           ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32):  // no predecessors
+// CEECK:             linalg.yield %[[V1]]
+// CHECK:         return %[[SCATTER]]
+
+// -----
+
+func @scatter_update_slice_2D(%arg0: tensor<6x3xi32>, %arg1: tensor<2x1xi32>,
+    %arg2: tensor<2x3xi32>) -> tensor<6x3xi32> {
+  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):  // no predecessors
+    "mhlo.return"(%arg4) : (tensor<i32>) -> ()
+  }) {
+    indices_are_sorted = false,
+    scatter_dimension_numbers = {
+      index_vector_dim = 1 : i64,
+      inserted_window_dims = dense<0> : tensor<1xi64>,
+      scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
+      update_window_dims = dense<1> : tensor<1xi64>
+    },
+    unique_indices = false
+  } : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) -> tensor<6x3xi32>
+  return %0 : tensor<6x3xi32>
+}
+// CHECK-LABEL: func @scatter_update_slice_2D
+// CHECK:         %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK:         %[[SCATTER:.+]] = linalg_ext.scatter
+// CHECK-SAME:      ins(%[[ARG2]], %[[ARG1]] : tensor<2x3xi32>, tensor<2x1xi32>)
+// CHECK-SAME:      outs(%[[ARG0]] : tensor<6x3xi32>)
+// CHECK:           ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32):  // no predecessors
+// CEECK:             linalg.yield %[[V1]]
+// CHECK:         return %[[SCATTER]]
+
+// -----
+
+func @scatter_add_slice_2D(%arg0: tensor<6x3xi32>, %arg1: tensor<2x1xi32>,
+    %arg2: tensor<2x3xi32>) -> tensor<6x3xi32> {
+  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):  // no predecessors
+    %1 = mhlo.add %arg3, %arg4 : tensor<i32>
+    "mhlo.return"(%1) : (tensor<i32>) -> ()
+  }) {
+    indices_are_sorted = false,
+    scatter_dimension_numbers = {
+      index_vector_dim = 1 : i64,
+      inserted_window_dims = dense<0> : tensor<1xi64>,
+      scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
+      update_window_dims = dense<1> : tensor<1xi64>
+    },
+    unique_indices = false
+  } : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) -> tensor<6x3xi32>
+  return %0 : tensor<6x3xi32>
+}
+// CHECK-LABEL: func @scatter_add_slice_2D
+// CHECK:         %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK:         %[[SCATTER:.+]] = linalg_ext.scatter
+// CHECK-SAME:      ins(%[[ARG2]], %[[ARG1]] : tensor<2x3xi32>, tensor<2x1xi32>)
+// CHECK-SAME:      outs(%[[ARG0]] : tensor<6x3xi32>)
+// CHECK:           ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32):  // no predecessors
+//
+//                   The order is reverse.
+// CHECK:              %[[V3:.+]] = addi %[[V2]], %[[V1]]
+// CEECK:              linalg.yield %[[V3]]
+// CHECK:         return %[[SCATTER]]
+
+// -----
+
+func @scatter_update_batch_scalar_1D(%arg0: tensor<8xi32>,
+    %arg1: tensor<3x4x1xi32>, %arg2: tensor<3x4xi32>) -> tensor<8xi32> {
+  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):  // no predecessors
+    "mhlo.return"(%arg4) : (tensor<i32>) -> ()
+  }) {
+    indices_are_sorted = false,
+    scatter_dimension_numbers = {
+      index_vector_dim = 2 : i64,
+      inserted_window_dims = dense<0> : tensor<i64>,
+      scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
+      update_window_dims = dense<> : tensor<0xi64>
+    },
+    unique_indices = false
+  } : (tensor<8xi32>, tensor<3x4x1xi32>, tensor<3x4xi32>) -> tensor<8xi32>
+  return %0 : tensor<8xi32>
+}
+// CHECK-LABEL: func @scatter_update_batch_scalar_1D
+// CHECK:         %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK:         %[[COLLAPSED_INDICES:.+]] = linalg.tensor_collapse_shape
+// CHECK-SAME:        %[[ARG1]] {{\[}}[0, 1], [2]] : tensor<3x4x1xi32> into tensor<12x1xi32>
+// CHECK:         %[[COLLAPSED_UPDATES:.+]] = linalg.tensor_collapse_shape
+// CHECK-SAME:        %[[ARG2]] {{\[}}[0, 1]] : tensor<3x4xi32> into tensor<12xi32>
+// CHECK:         %[[SCATTER:.+]] = linalg_ext.scatter
+// CHECK-SAME:       ins(%[[COLLAPSED_UPDATES]], %[[COLLAPSED_INDICES]] : tensor<12xi32>, tensor<12x1xi32>)
+// CHECK-SAME:       outs(%[[ARG0]] : tensor<8xi32>)
+// CHECK:            ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32):  // no predecessors
+// CEECK:              linalg.yield %[[V1]]
+// CHECK:         return %[[SCATTER]]
+
+// -----
+
+func @scatter_update_batch_slice_3D_dynamic(%arg0: tensor<1x24x512xi32>,
+    %arg1: tensor<?x3x2xi32>, %arg2: tensor<?x3x512xi32>) -> tensor<1x24x512xi32> {
+  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):  // no predecessors
+    "mhlo.return"(%arg4) : (tensor<i32>) -> ()
+  }) {indices_are_sorted = false,
+      scatter_dimension_numbers = {
+        index_vector_dim = 2 : i64,
+        inserted_window_dims = dense<[0, 1]> : tensor<2xi64>,
+        scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>,
+        update_window_dims = dense<2> : tensor<1xi64>
+      },
+      unique_indices = false
+  } : (tensor<1x24x512xi32>, tensor<?x3x2xi32>, tensor<?x3x512xi32>) -> tensor<1x24x512xi32>
+  return %0 : tensor<1x24x512xi32>
+}
+// CHECK-LABEL: func @scatter_update_batch_slice_3D_dynamic
+// CHECK:         %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK:         %[[COLLAPSED_INDICES:.+]] = linalg.tensor_collapse_shape
+// CHECK-SAME:        %[[ARG1]] {{\[}}[0, 1], [2]] : tensor<?x3x2xi32> into tensor<?x2xi32>
+// CHECK:         %[[COLLAPSED_UPDATES:.+]] = linalg.tensor_collapse_shape
+// CHECK-SAME:        %[[ARG2]] {{\[}}[0, 1], [2]] : tensor<?x3x512xi32> into tensor<?x512xi32>
+// CHECK:         %[[SCATTER:.+]] = linalg_ext.scatter
+// CHECK-SAME:        ins(%[[COLLAPSED_UPDATES]], %[[COLLAPSED_INDICES]] : tensor<?x512xi32>, tensor<?x2xi32>)
+// CHECK-SAME:        outs(%[[ARG0]] : tensor<1x24x512xi32>)
+// CHECK:             ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32):  // no predecessors
+// CEECK:               linalg.yield %[[V1]]
+// CHECK:         return %[[SCATTER]]
diff --git a/iree/hal/allocator_heap.c b/iree/hal/allocator_heap.c
index 7951398..21b7f3e 100644
--- a/iree/hal/allocator_heap.c
+++ b/iree/hal/allocator_heap.c
@@ -28,7 +28,8 @@
   IREE_TRACE_ZONE_BEGIN(z0);
 
   iree_hal_heap_allocator_t* allocator = NULL;
-  iree_host_size_t total_size = sizeof(*allocator) + identifier.size;
+  iree_host_size_t total_size =
+      iree_sizeof_struct(*allocator) + identifier.size;
   iree_status_t status =
       iree_allocator_malloc(host_allocator, total_size, (void**)&allocator);
   if (iree_status_is_ok(status)) {
@@ -37,7 +38,7 @@
     allocator->host_allocator = host_allocator;
     iree_string_view_append_to_buffer(
         identifier, &allocator->identifier,
-        (char*)allocator + total_size - identifier.size);
+        (char*)allocator + iree_sizeof_struct(*allocator));
     *out_allocator = (iree_hal_allocator_t*)allocator;
   }
 
diff --git a/iree/hal/buffer_heap.c b/iree/hal/buffer_heap.c
index 7621667..9d74cc1 100644
--- a/iree/hal/buffer_heap.c
+++ b/iree/hal/buffer_heap.c
@@ -32,8 +32,10 @@
   IREE_ASSERT_ARGUMENT(out_buffer);
   IREE_TRACE_ZONE_BEGIN(z0);
 
+  // NOTE: we want the buffer data to always be 16-byte aligned.
   iree_hal_heap_buffer_t* buffer = NULL;
-  iree_host_size_t header_size = iree_host_align(sizeof(*buffer), 16);
+  iree_host_size_t header_size =
+      iree_host_align(iree_sizeof_struct(*buffer), 16);
   iree_host_size_t total_size = header_size + allocation_size;
   iree_status_t status =
       iree_allocator_malloc(host_allocator, total_size, (void**)&buffer);
diff --git a/iree/hal/cuda/BUILD b/iree/hal/cuda/BUILD
index 78685b7..6ebe4c9 100644
--- a/iree/hal/cuda/BUILD
+++ b/iree/hal/cuda/BUILD
@@ -48,6 +48,8 @@
         "nop_executable_cache.h",
         "status_util.c",
         "status_util.h",
+        "stream_command_buffer.c",
+        "stream_command_buffer.h",
     ],
     hdrs = [
         "api.h",
@@ -59,9 +61,11 @@
         "//iree/base:core_headers",
         "//iree/base:tracing",
         "//iree/base/internal",
+        "//iree/base/internal:arena",
         "//iree/base/internal:flatcc",
         "//iree/base/internal:synchronization",
         "//iree/hal",
+        "//iree/hal/utils:deferred_command_buffer",
         "//iree/schemas:cuda_executable_def_c_fbs",
     ],
 )
diff --git a/iree/hal/cuda/CMakeLists.txt b/iree/hal/cuda/CMakeLists.txt
index 997a146..829c381 100644
--- a/iree/hal/cuda/CMakeLists.txt
+++ b/iree/hal/cuda/CMakeLists.txt
@@ -45,15 +45,19 @@
     "nop_executable_cache.h"
     "status_util.c"
     "status_util.h"
+    "stream_command_buffer.c"
+    "stream_command_buffer.h"
   DEPS
     ::dynamic_symbols
     iree::base
     iree::base::core_headers
     iree::base::internal
+    iree::base::internal::arena
     iree::base::internal::flatcc
     iree::base::internal::synchronization
     iree::base::tracing
     iree::hal
+    iree::hal::utils::deferred_command_buffer
     iree::schemas::cuda_executable_def_c_fbs
   PUBLIC
 )
diff --git a/iree/hal/cuda/api.h b/iree/hal/cuda/api.h
index 0a3cd1d..0a7ec62 100644
--- a/iree/hal/cuda/api.h
+++ b/iree/hal/cuda/api.h
@@ -16,6 +16,27 @@
 extern "C" {
 #endif  // __cplusplus
 
+// Parameters configuring an iree_hal_cuda_device_t.
+// Must be initialized with iree_hal_cuda_device_params_initialize prior to use.
+typedef struct iree_hal_cuda_device_params_t {
+  // Number of queues exposed on the device.
+  // Each queue acts as a separate synchronization scope where all work executes
+  // concurrently unless prohibited by semaphores.
+  iree_host_size_t queue_count;
+
+  // Total size of each block in the device shared block pool.
+  // Larger sizes will lower overhead and ensure the heap isn't hit for
+  // transient allocations while also increasing memory consumption.
+  iree_host_size_t arena_block_size;
+
+  // Switch for using deferred command buffer or default graph command buffer
+  bool use_deferred_submission;
+} iree_hal_cuda_device_params_t;
+
+// Initializes |out_params| to default values.
+void iree_hal_cuda_device_params_initialize(
+    iree_hal_cuda_device_params_t* out_params);
+
 //===----------------------------------------------------------------------===//
 // iree_hal_cuda_driver_t
 //===----------------------------------------------------------------------===//
@@ -35,6 +56,7 @@
 // |out_driver| must be released by the caller (see |iree_hal_driver_release|).
 IREE_API_EXPORT iree_status_t iree_hal_cuda_driver_create(
     iree_string_view_t identifier,
+    const iree_hal_cuda_device_params_t* default_params,
     const iree_hal_cuda_driver_options_t* options,
     iree_allocator_t host_allocator, iree_hal_driver_t** out_driver);
 
diff --git a/iree/hal/cuda/cuda_buffer.c b/iree/hal/cuda/cuda_buffer.c
index a87eb24..63334a5 100644
--- a/iree/hal/cuda/cuda_buffer.c
+++ b/iree/hal/cuda/cuda_buffer.c
@@ -59,7 +59,7 @@
   }
 
   IREE_TRACE_ZONE_END(z0);
-  return iree_ok_status();
+  return status;
 }
 
 static void iree_hal_cuda_buffer_destroy(iree_hal_buffer_t* base_buffer) {
@@ -95,7 +95,7 @@
   // would only work if the entire buffer was discarded.
 #ifndef NDEBUG
   if (iree_any_bit_set(memory_access, IREE_HAL_MEMORY_ACCESS_DISCARD)) {
-    memset(data_ptr + local_byte_offset, 0xCD, local_byte_length);
+    memset(data_ptr, 0xCD, local_byte_length);
   }
 #endif  // !NDEBUG
   *out_data_ptr = data_ptr;
diff --git a/iree/hal/cuda/cuda_device.c b/iree/hal/cuda/cuda_device.c
index d78d21f..0239b96 100644
--- a/iree/hal/cuda/cuda_device.c
+++ b/iree/hal/cuda/cuda_device.c
@@ -10,6 +10,7 @@
 #include <stdint.h>
 #include <string.h>
 
+#include "iree/base/internal/arena.h"
 #include "iree/base/tracing.h"
 #include "iree/hal/cuda/context_wrapper.h"
 #include "iree/hal/cuda/cuda_allocator.h"
@@ -21,6 +22,8 @@
 #include "iree/hal/cuda/graph_command_buffer.h"
 #include "iree/hal/cuda/nop_executable_cache.h"
 #include "iree/hal/cuda/status_util.h"
+#include "iree/hal/cuda/stream_command_buffer.h"
+#include "iree/hal/utils/deferred_command_buffer.h"
 
 //===----------------------------------------------------------------------===//
 // iree_hal_cuda_device_t
@@ -30,6 +33,10 @@
   iree_hal_resource_t resource;
   iree_string_view_t identifier;
 
+  // Block pool used for command buffers with a larger block size (as command
+  // buffers can contain inlined data uploads).
+  iree_arena_block_pool_t block_pool;
+
   // Optional driver that owns the CUDA symbols. We retain it for our lifetime
   // to ensure the symbols remains valid.
   iree_hal_driver_t* driver;
@@ -41,6 +48,8 @@
   iree_hal_cuda_context_wrapper_t context_wrapper;
   iree_hal_allocator_t* device_allocator;
 
+  // Switch for using deferred command buffer or default graph command buffer
+  bool use_deferred_submission;
 } iree_hal_cuda_device_t;
 
 extern const iree_hal_device_vtable_t iree_hal_cuda_device_vtable;
@@ -51,6 +60,26 @@
   return (iree_hal_cuda_device_t*)base_value;
 }
 
+void iree_hal_cuda_device_params_initialize(
+    iree_hal_cuda_device_params_t* out_params) {
+  out_params->arena_block_size = 32 * 1024;
+  out_params->queue_count = 8;
+  out_params->use_deferred_submission = false;
+}
+
+static iree_status_t iree_hal_cuda_device_check_params(
+    const iree_hal_cuda_device_params_t* params) {
+  if (params->arena_block_size < 4096) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "arena block size too small (< 4096 bytes)");
+  }
+  if (params->queue_count == 0) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "at least one queue is required");
+  }
+  return iree_ok_status();
+}
+
 static void iree_hal_cuda_device_destroy(iree_hal_device_t* base_device) {
   iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
   iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device);
@@ -61,6 +90,7 @@
   CUDA_IGNORE_ERROR(device->context_wrapper.syms,
                     cuStreamDestroy(device->stream));
 
+  iree_arena_block_pool_deinitialize(&device->block_pool);
   // Finally, destroy the device.
   iree_hal_driver_release(device->driver);
 
@@ -71,25 +101,28 @@
 
 static iree_status_t iree_hal_cuda_device_create_internal(
     iree_hal_driver_t* driver, iree_string_view_t identifier,
-    CUdevice cu_device, CUstream stream, CUcontext context,
-    iree_hal_cuda_dynamic_symbols_t* syms, iree_allocator_t host_allocator,
-    iree_hal_device_t** out_device) {
+    const iree_hal_cuda_device_params_t* params, CUdevice cu_device,
+    CUstream stream, CUcontext context, iree_hal_cuda_dynamic_symbols_t* syms,
+    iree_allocator_t host_allocator, iree_hal_device_t** out_device) {
   iree_hal_cuda_device_t* device = NULL;
-  iree_host_size_t total_size = sizeof(*device) + identifier.size;
+  iree_host_size_t total_size = iree_sizeof_struct(*device) + identifier.size;
   IREE_RETURN_IF_ERROR(
       iree_allocator_malloc(host_allocator, total_size, (void**)&device));
   memset(device, 0, total_size);
   iree_hal_resource_initialize(&iree_hal_cuda_device_vtable, &device->resource);
   device->driver = driver;
   iree_hal_driver_retain(device->driver);
-  uint8_t* buffer_ptr = (uint8_t*)device + sizeof(*device);
-  buffer_ptr += iree_string_view_append_to_buffer(
-      identifier, &device->identifier, (char*)buffer_ptr);
+  iree_string_view_append_to_buffer(
+      identifier, &device->identifier,
+      (char*)device + iree_sizeof_struct(*device));
   device->device = cu_device;
   device->stream = stream;
   device->context_wrapper.cu_context = context;
   device->context_wrapper.host_allocator = host_allocator;
+  iree_arena_block_pool_initialize(params->arena_block_size, host_allocator,
+                                   &device->block_pool);
   device->context_wrapper.syms = syms;
+  device->use_deferred_submission = params->use_deferred_submission;
   iree_status_t status = iree_hal_cuda_allocator_create(
       &device->context_wrapper, &device->device_allocator);
   if (iree_status_is_ok(status)) {
@@ -100,13 +133,15 @@
   return status;
 }
 
-iree_status_t iree_hal_cuda_device_create(iree_hal_driver_t* driver,
-                                          iree_string_view_t identifier,
-                                          iree_hal_cuda_dynamic_symbols_t* syms,
-                                          CUdevice device,
-                                          iree_allocator_t host_allocator,
-                                          iree_hal_device_t** out_device) {
+iree_status_t iree_hal_cuda_device_create(
+    iree_hal_driver_t* driver, iree_string_view_t identifier,
+    const iree_hal_cuda_device_params_t* params,
+    iree_hal_cuda_dynamic_symbols_t* syms, CUdevice device,
+    iree_allocator_t host_allocator, iree_hal_device_t** out_device) {
+  IREE_ASSERT_ARGUMENT(params);
   IREE_TRACE_ZONE_BEGIN(z0);
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(z0,
+                                    iree_hal_cuda_device_check_params(params));
   CUcontext context;
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
       z0, CU_RESULT_TO_STATUS(syms, cuCtxCreate(&context, 0, device)));
@@ -115,8 +150,8 @@
       syms, cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING));
 
   if (iree_status_is_ok(status)) {
-    status = iree_hal_cuda_device_create_internal(driver, identifier, device,
-                                                  stream, context, syms,
+    status = iree_hal_cuda_device_create_internal(driver, identifier, params,
+                                                  device, stream, context, syms,
                                                   host_allocator, out_device);
   }
   if (!iree_status_is_ok(status)) {
@@ -174,7 +209,12 @@
     iree_hal_queue_affinity_t queue_affinity,
     iree_hal_command_buffer_t** out_command_buffer) {
   iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
-  return iree_hal_cuda_graph_command_buffer_allocate(
+  if (device->use_deferred_submission) {
+    return iree_hal_deferred_command_buffer_create(
+        mode, command_categories, &device->block_pool,
+        iree_hal_device_host_allocator(base_device), out_command_buffer);
+  }
+  return iree_hal_cuda_graph_command_buffer_create(
       &device->context_wrapper, mode, command_categories, queue_affinity,
       out_command_buffer);
 }
@@ -240,13 +280,30 @@
     iree_hal_queue_affinity_t queue_affinity, iree_host_size_t batch_count,
     const iree_hal_submission_batch_t* batches) {
   iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
-  for (int i = 0; i < batch_count; i++) {
-    for (int j = 0; j < batches[i].command_buffer_count; j++) {
-      CUgraphExec exec = iree_hal_cuda_graph_command_buffer_exec(
-          batches[i].command_buffers[j]);
-      CUDA_RETURN_IF_ERROR(device->context_wrapper.syms,
-                           cuGraphLaunch(exec, device->stream),
-                           "cuGraphLaunch");
+  if (device->use_deferred_submission) {
+    iree_hal_command_buffer_t* stream_command_buffer;
+    iree_status_t status = iree_hal_cuda_stream_command_buffer_create(
+        &device->context_wrapper,
+        IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION, command_categories,
+        device->stream, &stream_command_buffer);
+    if (iree_status_is_ok(status)) {
+      for (int i = 0; i < batch_count; i++) {
+        for (int j = 0; j < batches[i].command_buffer_count; j++) {
+          iree_hal_deferred_command_buffer_apply(batches[i].command_buffers[j],
+                                                 stream_command_buffer);
+        }
+      }
+    }
+    iree_hal_command_buffer_release(stream_command_buffer);
+  } else {
+    for (int i = 0; i < batch_count; i++) {
+      for (int j = 0; j < batches[i].command_buffer_count; j++) {
+        CUgraphExec exec = iree_hal_cuda_graph_command_buffer_exec(
+            batches[i].command_buffers[j]);
+        CUDA_RETURN_IF_ERROR(device->context_wrapper.syms,
+                             cuGraphLaunch(exec, device->stream),
+                             "cuGraphLaunch");
+      }
     }
   }
   // TODO(thomasraoux): Conservatively syncronize after every submit until we
diff --git a/iree/hal/cuda/cuda_device.h b/iree/hal/cuda/cuda_device.h
index b6b47af..d7b5790 100644
--- a/iree/hal/cuda/cuda_device.h
+++ b/iree/hal/cuda/cuda_device.h
@@ -17,12 +17,11 @@
 #endif  // __cplusplus
 
 // Creates a device that owns and manages its own CUcontext.
-iree_status_t iree_hal_cuda_device_create(iree_hal_driver_t* driver,
-                                          iree_string_view_t identifier,
-                                          iree_hal_cuda_dynamic_symbols_t* syms,
-                                          CUdevice device,
-                                          iree_allocator_t host_allocator,
-                                          iree_hal_device_t** out_device);
+iree_status_t iree_hal_cuda_device_create(
+    iree_hal_driver_t* driver, iree_string_view_t identifier,
+    const iree_hal_cuda_device_params_t* params,
+    iree_hal_cuda_dynamic_symbols_t* syms, CUdevice device,
+    iree_allocator_t host_allocator, iree_hal_device_t** out_device);
 
 #ifdef __cplusplus
 }  // extern "C"
diff --git a/iree/hal/cuda/cuda_driver.c b/iree/hal/cuda/cuda_driver.c
index 6e5617d..8903775 100644
--- a/iree/hal/cuda/cuda_driver.c
+++ b/iree/hal/cuda/cuda_driver.c
@@ -22,6 +22,7 @@
   // We allow overriding so that multiple CUDA versions can be exposed in the
   // same process.
   iree_string_view_t identifier;
+  iree_hal_cuda_device_params_t default_params;
   int default_device_index;
   // CUDA symbols.
   iree_hal_cuda_dynamic_symbols_t syms;
@@ -46,18 +47,23 @@
 
 static iree_status_t iree_hal_cuda_driver_create_internal(
     iree_string_view_t identifier,
+    const iree_hal_cuda_device_params_t* default_params,
     const iree_hal_cuda_driver_options_t* options,
     iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) {
   iree_hal_cuda_driver_t* driver = NULL;
-  iree_host_size_t total_size = sizeof(*driver) + identifier.size;
+  iree_host_size_t total_size = iree_sizeof_struct(*driver) + identifier.size;
   IREE_RETURN_IF_ERROR(
       iree_allocator_malloc(host_allocator, total_size, (void**)&driver));
+
   iree_hal_resource_initialize(&iree_hal_cuda_driver_vtable, &driver->resource);
   driver->host_allocator = host_allocator;
   iree_string_view_append_to_buffer(
       identifier, &driver->identifier,
-      (char*)driver + total_size - identifier.size);
+      (char*)driver + iree_sizeof_struct(*driver));
+  memcpy(&driver->default_params, default_params,
+         sizeof(driver->default_params));
   driver->default_device_index = options->default_device_index;
+
   iree_status_t status =
       iree_hal_cuda_dynamic_symbols_initialize(host_allocator, &driver->syms);
   if (iree_status_is_ok(status)) {
@@ -81,14 +87,16 @@
 
 IREE_API_EXPORT iree_status_t iree_hal_cuda_driver_create(
     iree_string_view_t identifier,
+    const iree_hal_cuda_device_params_t* default_params,
     const iree_hal_cuda_driver_options_t* options,
     iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) {
+  IREE_ASSERT_ARGUMENT(default_params);
   IREE_ASSERT_ARGUMENT(options);
   IREE_ASSERT_ARGUMENT(out_driver);
   IREE_TRACE_ZONE_BEGIN(z0);
 
   iree_status_t status = iree_hal_cuda_driver_create_internal(
-      identifier, options, host_allocator, out_driver);
+      identifier, default_params, options, host_allocator, out_driver);
 
   IREE_TRACE_ZONE_END(z0);
   return status;
@@ -193,9 +201,9 @@
   iree_string_view_t device_name = iree_make_cstring_view("cuda");
 
   // Attempt to create the device.
-  iree_status_t status =
-      iree_hal_cuda_device_create(base_driver, device_name, &driver->syms,
-                                  device, host_allocator, out_device);
+  iree_status_t status = iree_hal_cuda_device_create(
+      base_driver, device_name, &driver->default_params, &driver->syms, device,
+      host_allocator, out_device);
 
   IREE_TRACE_ZONE_END(z0);
   return status;
diff --git a/iree/hal/cuda/dynamic_symbol_tables.h b/iree/hal/cuda/dynamic_symbol_tables.h
index f131e14..6222219 100644
--- a/iree/hal/cuda/dynamic_symbol_tables.h
+++ b/iree/hal/cuda/dynamic_symbol_tables.h
@@ -39,3 +39,8 @@
 CU_PFN_DECL(cuStreamDestroy, CUstream)
 CU_PFN_DECL(cuStreamSynchronize, CUstream)
 CU_PFN_DECL(cuStreamWaitEvent, CUstream, CUevent, unsigned int)
+CU_PFN_DECL(cuMemsetD32Async, unsigned long long, int, size_t, CUstream)
+CU_PFN_DECL(cuMemcpyAsync, CUdeviceptr, CUdeviceptr, size_t, CUstream)
+CU_PFN_DECL(cuLaunchKernel, CUfunction, unsigned int, unsigned int,
+            unsigned int, unsigned int, unsigned int, unsigned int,
+            unsigned int, CUstream, void **, void **)
diff --git a/iree/hal/cuda/graph_command_buffer.c b/iree/hal/cuda/graph_command_buffer.c
index 90126e8..f5667fa 100644
--- a/iree/hal/cuda/graph_command_buffer.c
+++ b/iree/hal/cuda/graph_command_buffer.c
@@ -47,7 +47,7 @@
   return (iree_hal_cuda_graph_command_buffer_t*)base_value;
 }
 
-iree_status_t iree_hal_cuda_graph_command_buffer_allocate(
+iree_status_t iree_hal_cuda_graph_command_buffer_create(
     iree_hal_cuda_context_wrapper_t* context,
     iree_hal_command_buffer_mode_t mode,
     iree_hal_command_category_t command_categories,
@@ -136,7 +136,7 @@
 
 static iree_status_t iree_hal_cuda_graph_command_buffer_begin(
     iree_hal_command_buffer_t* base_command_buffer) {
-  // nothing to do.
+  // Nothing to do.
   return iree_ok_status();
 }
 
diff --git a/iree/hal/cuda/graph_command_buffer.h b/iree/hal/cuda/graph_command_buffer.h
index b38b251..c50ccf9 100644
--- a/iree/hal/cuda/graph_command_buffer.h
+++ b/iree/hal/cuda/graph_command_buffer.h
@@ -18,7 +18,7 @@
 #endif  // __cplusplus
 
 // Creates a cuda graph.
-iree_status_t iree_hal_cuda_graph_command_buffer_allocate(
+iree_status_t iree_hal_cuda_graph_command_buffer_create(
     iree_hal_cuda_context_wrapper_t* context,
     iree_hal_command_buffer_mode_t mode,
     iree_hal_command_category_t command_categories,
diff --git a/iree/hal/cuda/registration/driver_module.c b/iree/hal/cuda/registration/driver_module.c
index f8fb736..0a51ed9 100644
--- a/iree/hal/cuda/registration/driver_module.c
+++ b/iree/hal/cuda/registration/driver_module.c
@@ -41,6 +41,9 @@
                             driver_id);
   }
   IREE_TRACE_ZONE_BEGIN(z0);
+  iree_hal_cuda_device_params_t default_params;
+  iree_hal_cuda_device_params_initialize(&default_params);
+  // TODO(jinchen62): set up default_params.use_deferred_submission by flag
   // When we expose more than one driver (different cuda versions, etc) we
   // can name them here:
   iree_string_view_t identifier = iree_make_cstring_view("cuda");
@@ -48,7 +51,7 @@
   iree_hal_cuda_driver_options_t driver_options;
   iree_hal_cuda_driver_options_initialize(&driver_options);
   iree_status_t status = iree_hal_cuda_driver_create(
-      identifier, &driver_options, allocator, out_driver);
+      identifier, &default_params, &driver_options, allocator, out_driver);
   IREE_TRACE_ZONE_END(z0);
   return status;
 }
diff --git a/iree/hal/cuda/stream_command_buffer.c b/iree/hal/cuda/stream_command_buffer.c
new file mode 100644
index 0000000..9ab153d
--- /dev/null
+++ b/iree/hal/cuda/stream_command_buffer.c
@@ -0,0 +1,353 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/cuda/stream_command_buffer.h"
+
+#include "iree/base/tracing.h"
+#include "iree/hal/cuda/cuda_buffer.h"
+#include "iree/hal/cuda/cuda_event.h"
+#include "iree/hal/cuda/executable_layout.h"
+#include "iree/hal/cuda/native_executable.h"
+#include "iree/hal/cuda/status_util.h"
+
+#define IREE_HAL_CUDA_MAX_BINDING_COUNT 64
+
+// This records the commands on the calling thread without additional threading
+// indirection.
+
+typedef struct {
+  iree_hal_resource_t resource;
+  iree_hal_cuda_context_wrapper_t* context;
+  iree_hal_command_buffer_mode_t mode;
+  iree_hal_command_category_t allowed_categories;
+  CUstream stream;
+  // Keep track of the current set of kernel arguments.
+  void* current_descriptor[IREE_HAL_CUDA_MAX_BINDING_COUNT];
+  CUdeviceptr* device_ptrs[IREE_HAL_CUDA_MAX_BINDING_COUNT];
+} iree_hal_cuda_stream_command_buffer_t;
+
+extern const iree_hal_command_buffer_vtable_t
+    iree_hal_cuda_stream_command_buffer_vtable;
+
+static iree_hal_cuda_stream_command_buffer_t*
+iree_hal_cuda_stream_command_buffer_cast(
+    iree_hal_command_buffer_t* base_value) {
+  IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda_stream_command_buffer_vtable);
+  return (iree_hal_cuda_stream_command_buffer_t*)base_value;
+}
+
+iree_status_t iree_hal_cuda_stream_command_buffer_create(
+    iree_hal_cuda_context_wrapper_t* context,
+    iree_hal_command_buffer_mode_t mode,
+    iree_hal_command_category_t command_categories, CUstream stream,
+    iree_hal_command_buffer_t** out_command_buffer) {
+  IREE_ASSERT_ARGUMENT(context);
+  IREE_ASSERT_ARGUMENT(out_command_buffer);
+  *out_command_buffer = NULL;
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_hal_cuda_stream_command_buffer_t* command_buffer = NULL;
+  size_t total_size = sizeof(*command_buffer) +
+                      IREE_HAL_CUDA_MAX_BINDING_COUNT * sizeof(void*) +
+                      IREE_HAL_CUDA_MAX_BINDING_COUNT * sizeof(CUdeviceptr);
+  iree_status_t status = iree_allocator_malloc(
+      context->host_allocator, total_size, (void**)&command_buffer);
+  if (iree_status_is_ok(status)) {
+    iree_hal_resource_initialize(&iree_hal_cuda_stream_command_buffer_vtable,
+                                 &command_buffer->resource);
+    command_buffer->context = context;
+    command_buffer->mode = mode;
+    command_buffer->allowed_categories = command_categories;
+    command_buffer->stream = stream;
+    for (size_t i = 0; i < IREE_HAL_CUDA_MAX_BINDING_COUNT; i++) {
+      command_buffer->current_descriptor[i] = &command_buffer->device_ptrs[i];
+    }
+  }
+
+  *out_command_buffer = (iree_hal_command_buffer_t*)command_buffer;
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+static void iree_hal_cuda_stream_command_buffer_destroy(
+    iree_hal_command_buffer_t* base_command_buffer) {
+  iree_hal_cuda_stream_command_buffer_t* command_buffer =
+      iree_hal_cuda_stream_command_buffer_cast(base_command_buffer);
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  iree_allocator_free(command_buffer->context->host_allocator, command_buffer);
+
+  IREE_TRACE_ZONE_END(z0);
+}
+
+static iree_hal_command_buffer_mode_t iree_hal_cuda_stream_command_buffer_mode(
+    const iree_hal_command_buffer_t* base_command_buffer) {
+  const iree_hal_cuda_stream_command_buffer_t* command_buffer =
+      (const iree_hal_cuda_stream_command_buffer_t*)(base_command_buffer);
+  return command_buffer->mode;
+}
+
+static iree_hal_command_category_t
+iree_hal_cuda_stream_command_buffer_allowed_categories(
+    const iree_hal_command_buffer_t* base_command_buffer) {
+  const iree_hal_cuda_stream_command_buffer_t* command_buffer =
+      (const iree_hal_cuda_stream_command_buffer_t*)(base_command_buffer);
+  return command_buffer->allowed_categories;
+}
+
+static iree_status_t iree_hal_cuda_stream_command_buffer_begin(
+    iree_hal_command_buffer_t* base_command_buffer) {
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_cuda_stream_command_buffer_end(
+    iree_hal_command_buffer_t* base_command_buffer) {
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_cuda_stream_command_buffer_execution_barrier(
+    iree_hal_command_buffer_t* base_command_buffer,
+    iree_hal_execution_stage_t source_stage_mask,
+    iree_hal_execution_stage_t target_stage_mask,
+    iree_hal_execution_barrier_flags_t flags,
+    iree_host_size_t memory_barrier_count,
+    const iree_hal_memory_barrier_t* memory_barriers,
+    iree_host_size_t buffer_barrier_count,
+    const iree_hal_buffer_barrier_t* buffer_barriers) {
+  // TODO(jinchen62): implement CUDA barrier
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_cuda_stream_command_buffer_signal_event(
+    iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event,
+    iree_hal_execution_stage_t source_stage_mask) {
+  // TODO(jinchen62): implement CUDA barrier
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_cuda_stream_command_buffer_reset_event(
+    iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event,
+    iree_hal_execution_stage_t source_stage_mask) {
+  // TODO(jinchen62): implement CUDA barrier
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_cuda_stream_command_buffer_wait_events(
+    iree_hal_command_buffer_t* base_command_buffer,
+    iree_host_size_t event_count, const iree_hal_event_t** events,
+    iree_hal_execution_stage_t source_stage_mask,
+    iree_hal_execution_stage_t target_stage_mask,
+    iree_host_size_t memory_barrier_count,
+    const iree_hal_memory_barrier_t* memory_barriers,
+    iree_host_size_t buffer_barrier_count,
+    const iree_hal_buffer_barrier_t* buffer_barriers) {
+  // TODO(jinchen62): implement CUDA barrier
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_cuda_stream_command_buffer_discard_buffer(
+    iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* buffer) {
+  // nothing to do.
+  return iree_ok_status();
+}
+
+// Splats a pattern value of 1, 2, or 4 bytes out to a 4 byte value.
+static uint32_t iree_hal_cuda_splat_pattern(const void* pattern,
+                                            size_t pattern_length) {
+  switch (pattern_length) {
+    case 1: {
+      uint32_t pattern_value = *(const uint8_t*)(pattern);
+      return (pattern_value << 24) | (pattern_value << 16) |
+             (pattern_value << 8) | pattern_value;
+    }
+    case 2: {
+      uint32_t pattern_value = *(const uint16_t*)(pattern);
+      return (pattern_value << 16) | pattern_value;
+    }
+    case 4: {
+      uint32_t pattern_value = *(const uint32_t*)(pattern);
+      return pattern_value;
+    }
+    default:
+      return 0;  // Already verified that this should not be possible.
+  }
+}
+
+static iree_status_t iree_hal_cuda_stream_command_buffer_fill_buffer(
+    iree_hal_command_buffer_t* base_command_buffer,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length, const void* pattern,
+    iree_host_size_t pattern_length) {
+  iree_hal_cuda_stream_command_buffer_t* command_buffer =
+      iree_hal_cuda_stream_command_buffer_cast(base_command_buffer);
+
+  CUdeviceptr target_device_buffer = iree_hal_cuda_buffer_device_pointer(
+      iree_hal_buffer_allocated_buffer(target_buffer));
+  target_offset += iree_hal_buffer_byte_offset(target_buffer);
+  uint32_t dword_pattern = iree_hal_cuda_splat_pattern(pattern, pattern_length);
+  CUdeviceptr dst = target_device_buffer + target_offset;
+  int value = dword_pattern;
+  size_t sizeBytes = length;
+  CUDA_RETURN_IF_ERROR(
+      command_buffer->context->syms,
+      cuMemsetD32Async(dst, value, sizeBytes, command_buffer->stream),
+      "cuMemsetD32Async");
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_cuda_stream_command_buffer_update_buffer(
+    iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer,
+    iree_host_size_t source_offset, iree_hal_buffer_t* target_buffer,
+    iree_device_size_t target_offset, iree_device_size_t length) {
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                          "need cuda implementation of update buffer");
+}
+
+static iree_status_t iree_hal_cuda_stream_command_buffer_copy_buffer(
+    iree_hal_command_buffer_t* base_command_buffer,
+    iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
+    iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+    iree_device_size_t length) {
+  iree_hal_cuda_stream_command_buffer_t* command_buffer =
+      iree_hal_cuda_stream_command_buffer_cast(base_command_buffer);
+
+  CUdeviceptr target_device_buffer = iree_hal_cuda_buffer_device_pointer(
+      iree_hal_buffer_allocated_buffer(target_buffer));
+  target_offset += iree_hal_buffer_byte_offset(target_buffer);
+  CUdeviceptr source_device_buffer = iree_hal_cuda_buffer_device_pointer(
+      iree_hal_buffer_allocated_buffer(source_buffer));
+  source_offset += iree_hal_buffer_byte_offset(source_buffer);
+  CUDA_RETURN_IF_ERROR(command_buffer->context->syms,
+                       cuMemcpyAsync(target_device_buffer, source_device_buffer,
+                                     length, command_buffer->stream),
+                       "cuMemcpyAsync");
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_cuda_stream_command_buffer_push_constants(
+    iree_hal_command_buffer_t* base_command_buffer,
+    iree_hal_executable_layout_t* executable_layout, iree_host_size_t offset,
+    const void* values, iree_host_size_t values_length) {
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                          "need cuda implementation of push constants");
+}
+
+// Tie together the binding index and its index in |bindings| array.
+typedef struct {
+  uint32_t index;
+  uint32_t binding;
+} iree_hal_cuda_binding_mapping_t;
+
+// Helper to sort the binding based on their binding index.
+static int compare_binding_index(const void* a, const void* b) {
+  const iree_hal_cuda_binding_mapping_t buffer_a =
+      *(const iree_hal_cuda_binding_mapping_t*)a;
+  const iree_hal_cuda_binding_mapping_t buffer_b =
+      *(const iree_hal_cuda_binding_mapping_t*)b;
+  return buffer_a.binding < buffer_b.binding ? -1 : 1;
+}
+
+static iree_status_t iree_hal_cuda_stream_command_buffer_push_descriptor_set(
+    iree_hal_command_buffer_t* base_command_buffer,
+    iree_hal_executable_layout_t* executable_layout, uint32_t set,
+    iree_host_size_t binding_count,
+    const iree_hal_descriptor_set_binding_t* bindings) {
+  iree_hal_cuda_stream_command_buffer_t* command_buffer =
+      iree_hal_cuda_stream_command_buffer_cast(base_command_buffer);
+  iree_host_size_t base_binding =
+      iree_hal_cuda_base_binding_index(executable_layout, set);
+  // Convention with the compiler side. We map bindings to kernel argument.
+  // We compact the bindings to get a dense set of arguments and keep them order
+  // based on the binding index.
+  // Sort the binding based on the binding index and map the array index to the
+  // argument index.
+  iree_hal_cuda_binding_mapping_t binding_used[IREE_HAL_CUDA_MAX_BINDING_COUNT];
+  for (iree_host_size_t i = 0; i < binding_count; i++) {
+    iree_hal_cuda_binding_mapping_t buffer = {i, bindings[i].binding};
+    binding_used[i] = buffer;
+  }
+  qsort(binding_used, binding_count, sizeof(iree_hal_cuda_binding_mapping_t),
+        compare_binding_index);
+  assert(binding_count < IREE_HAL_CUDA_MAX_BINDING_COUNT &&
+         "binding count larger than the max expected.");
+  for (iree_host_size_t i = 0; i < binding_count; i++) {
+    iree_hal_descriptor_set_binding_t binding = bindings[binding_used[i].index];
+    CUdeviceptr device_ptr =
+        iree_hal_cuda_buffer_device_pointer(
+            iree_hal_buffer_allocated_buffer(binding.buffer)) +
+        iree_hal_buffer_byte_offset(binding.buffer) + binding.offset;
+    *((CUdeviceptr*)command_buffer->current_descriptor[i + base_binding]) =
+        device_ptr;
+  }
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_cuda_stream_command_buffer_bind_descriptor_set(
+    iree_hal_command_buffer_t* base_command_buffer,
+    iree_hal_executable_layout_t* executable_layout, uint32_t set,
+    iree_hal_descriptor_set_t* descriptor_set,
+    iree_host_size_t dynamic_offset_count,
+    const iree_device_size_t* dynamic_offsets) {
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                          "need cuda implementation of bind descriptor set");
+}
+
+static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch(
+    iree_hal_command_buffer_t* base_command_buffer,
+    iree_hal_executable_t* executable, int32_t entry_point,
+    uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+  iree_hal_cuda_stream_command_buffer_t* command_buffer =
+      iree_hal_cuda_stream_command_buffer_cast(base_command_buffer);
+
+  int32_t block_size_x, block_size_y, block_size_z;
+  IREE_RETURN_IF_ERROR(iree_hal_cuda_native_executable_block_size(
+      executable, entry_point, &block_size_x, &block_size_y, &block_size_z));
+  CUfunction func =
+      iree_hal_cuda_native_executable_for_entry_point(executable, entry_point);
+  CUDA_RETURN_IF_ERROR(
+      command_buffer->context->syms,
+      cuLaunchKernel(func, workgroup_x, workgroup_y, workgroup_z, block_size_x,
+                     block_size_y, block_size_z, 0, command_buffer->stream,
+                     command_buffer->current_descriptor, NULL),
+      "cuLaunchKernel");
+  return iree_ok_status();
+}
+
+static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch_indirect(
+    iree_hal_command_buffer_t* base_command_buffer,
+    iree_hal_executable_t* executable, int32_t entry_point,
+    iree_hal_buffer_t* workgroups_buffer,
+    iree_device_size_t workgroups_offset) {
+  return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                          "need cuda implementation of dispatch indirect");
+}
+
+const iree_hal_command_buffer_vtable_t
+    iree_hal_cuda_stream_command_buffer_vtable = {
+        .destroy = iree_hal_cuda_stream_command_buffer_destroy,
+        .mode = iree_hal_cuda_stream_command_buffer_mode,
+        .allowed_categories =
+            iree_hal_cuda_stream_command_buffer_allowed_categories,
+        .begin = iree_hal_cuda_stream_command_buffer_begin,
+        .end = iree_hal_cuda_stream_command_buffer_end,
+        .execution_barrier =
+            iree_hal_cuda_stream_command_buffer_execution_barrier,
+        .signal_event = iree_hal_cuda_stream_command_buffer_signal_event,
+        .reset_event = iree_hal_cuda_stream_command_buffer_reset_event,
+        .wait_events = iree_hal_cuda_stream_command_buffer_wait_events,
+        .discard_buffer = iree_hal_cuda_stream_command_buffer_discard_buffer,
+        .fill_buffer = iree_hal_cuda_stream_command_buffer_fill_buffer,
+        .update_buffer = iree_hal_cuda_stream_command_buffer_update_buffer,
+        .copy_buffer = iree_hal_cuda_stream_command_buffer_copy_buffer,
+        .push_constants = iree_hal_cuda_stream_command_buffer_push_constants,
+        .push_descriptor_set =
+            iree_hal_cuda_stream_command_buffer_push_descriptor_set,
+        .bind_descriptor_set =
+            iree_hal_cuda_stream_command_buffer_bind_descriptor_set,
+        .dispatch = iree_hal_cuda_stream_command_buffer_dispatch,
+        .dispatch_indirect =
+            iree_hal_cuda_stream_command_buffer_dispatch_indirect,
+};
diff --git a/iree/hal/cuda/stream_command_buffer.h b/iree/hal/cuda/stream_command_buffer.h
new file mode 100644
index 0000000..b4b901a
--- /dev/null
+++ b/iree/hal/cuda/stream_command_buffer.h
@@ -0,0 +1,35 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_CUDA_STREAM_COMMAND_BUFFER_H_
+#define IREE_HAL_CUDA_STREAM_COMMAND_BUFFER_H_
+
+#include "iree/base/internal/arena.h"
+#include "iree/hal/api.h"
+#include "iree/hal/cuda/context_wrapper.h"
+#include "iree/hal/cuda/cuda_headers.h"
+#include "iree/hal/cuda/dynamic_symbols.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Creates a cuda stream command buffer that immediately
+// issues commands against the given |stream|.
+// Access to |stream| must be synchronized by the user.
+// Used for replaying commands in special situations and
+// never returned to a user from the device_create_command_buffer
+iree_status_t iree_hal_cuda_stream_command_buffer_create(
+    iree_hal_cuda_context_wrapper_t *context,
+    iree_hal_command_buffer_mode_t mode,
+    iree_hal_command_category_t command_categories, CUstream stream,
+    iree_hal_command_buffer_t **out_command_buffer);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_HAL_CUDA_STREAM_COMMAND_BUFFER_H_
diff --git a/iree/hal/driver_registry.c b/iree/hal/driver_registry.c
index 972bff9..8594023 100644
--- a/iree/hal/driver_registry.c
+++ b/iree/hal/driver_registry.c
@@ -188,8 +188,8 @@
 
   // Allocate the required memory for both the driver infos and the string
   // storage in a single block.
-  iree_host_size_t total_driver_infos_size = iree_host_align(
-      total_driver_info_count * sizeof(iree_hal_driver_info_t), 8);
+  iree_host_size_t total_driver_infos_size =
+      total_driver_info_count * sizeof(iree_hal_driver_info_t);
   if (iree_status_is_ok(status)) {
     status = iree_allocator_malloc(allocator,
                                    total_driver_infos_size + total_storage_size,
diff --git a/iree/hal/local/elf/arch/x86_32.c b/iree/hal/local/elf/arch/x86_32.c
index 6a7c103..05d08d7 100644
--- a/iree/hal/local/elf/arch/x86_32.c
+++ b/iree/hal/local/elf/arch/x86_32.c
@@ -120,10 +120,10 @@
 // Cross-ABI function calls
 //==============================================================================
 
-// System V 386 ABI (used in IREE):
+// System V i386 ABI (used in IREE):
 // https://uclibc.org/docs/psABI-i386.pdf
 // Arguments:
-//
+//   (reverse order on the stack; last arg furthest from stack pointer)
 //
 // Results:
 //   EAX
@@ -136,7 +136,7 @@
 
 #if defined(IREE_PLATFORM_WINDOWS)
 
-#error "TODO"
+#error "TODO(#6554): need cdecl -> sysv ABI shims in x86_32_msvc.asm"
 
 #else
 
diff --git a/iree/hal/local/loaders/embedded_library_loader.c b/iree/hal/local/loaders/embedded_library_loader.c
index ddaafe2..74d48fe 100644
--- a/iree/hal/local/loaders/embedded_library_loader.c
+++ b/iree/hal/local/loaders/embedded_library_loader.c
@@ -35,6 +35,8 @@
     const iree_hal_executable_library_header_t** header;
     const iree_hal_executable_library_v0_t* v0;
   } library;
+
+  iree_hal_local_executable_layout_t* layouts[];
 } iree_hal_elf_executable_t;
 
 extern const iree_hal_local_executable_vtable_t iree_hal_elf_executable_vtable;
@@ -140,16 +142,13 @@
   iree_hal_elf_executable_t* executable = NULL;
   iree_host_size_t total_size =
       sizeof(*executable) +
-      executable_layout_count * sizeof(iree_hal_local_executable_layout_t);
+      executable_layout_count * sizeof(*executable->layouts);
   iree_status_t status =
       iree_allocator_malloc(host_allocator, total_size, (void**)&executable);
   if (iree_status_is_ok(status)) {
-    iree_hal_local_executable_layout_t** executable_layouts_ptr =
-        (iree_hal_local_executable_layout_t**)(((uint8_t*)executable) +
-                                               sizeof(*executable));
     iree_hal_local_executable_initialize(
         &iree_hal_elf_executable_vtable, executable_layout_count,
-        executable_layouts, executable_layouts_ptr, host_allocator,
+        executable_layouts, &executable->layouts[0], host_allocator,
         &executable->base);
   }
   if (iree_status_is_ok(status)) {
diff --git a/iree/hal/local/loaders/legacy_library_loader.c b/iree/hal/local/loaders/legacy_library_loader.c
index f370b94..fa9fa1f 100644
--- a/iree/hal/local/loaders/legacy_library_loader.c
+++ b/iree/hal/local/loaders/legacy_library_loader.c
@@ -81,6 +81,8 @@
     const iree_hal_executable_library_header_t** header;
     const iree_hal_executable_library_v0_t* v0;
   } library;
+
+  iree_hal_local_executable_layout_t* layouts[];
 } iree_hal_legacy_executable_t;
 
 extern const iree_hal_local_executable_vtable_t
@@ -224,16 +226,13 @@
   iree_hal_legacy_executable_t* executable = NULL;
   iree_host_size_t total_size =
       sizeof(*executable) +
-      executable_layout_count * sizeof(iree_hal_local_executable_layout_t);
+      executable_layout_count * sizeof(*executable->layouts);
   iree_status_t status =
       iree_allocator_malloc(host_allocator, total_size, (void**)&executable);
   if (iree_status_is_ok(status)) {
-    iree_hal_local_executable_layout_t** executable_layouts_ptr =
-        (iree_hal_local_executable_layout_t**)(((uint8_t*)executable) +
-                                               sizeof(*executable));
     iree_hal_local_executable_initialize(
         &iree_hal_legacy_executable_vtable, executable_layout_count,
-        executable_layouts, executable_layouts_ptr, host_allocator,
+        executable_layouts, &executable->layouts[0], host_allocator,
         &executable->base);
     executable->def = executable_def;
   }
diff --git a/iree/hal/local/loaders/static_library_loader.c b/iree/hal/local/loaders/static_library_loader.c
index a7ff63b..1e7819c 100644
--- a/iree/hal/local/loaders/static_library_loader.c
+++ b/iree/hal/local/loaders/static_library_loader.c
@@ -30,6 +30,8 @@
     const iree_hal_executable_library_header_t** header;
     const iree_hal_executable_library_v0_t* v0;
   } library;
+
+  iree_hal_local_executable_layout_t* layouts[];
 } iree_hal_static_executable_t;
 
 static const iree_hal_local_executable_vtable_t
@@ -50,16 +52,13 @@
   iree_hal_static_executable_t* executable = NULL;
   iree_host_size_t total_size =
       sizeof(*executable) +
-      executable_layout_count * sizeof(iree_hal_local_executable_layout_t);
+      executable_layout_count * sizeof(*executable->layouts);
   iree_status_t status =
       iree_allocator_malloc(host_allocator, total_size, (void**)&executable);
   if (iree_status_is_ok(status)) {
-    iree_hal_local_executable_layout_t** executable_layouts_ptr =
-        (iree_hal_local_executable_layout_t**)(((uint8_t*)executable) +
-                                               sizeof(*executable));
     iree_hal_local_executable_initialize(
         &iree_hal_static_executable_vtable, executable_layout_count,
-        executable_layouts, executable_layouts_ptr, host_allocator,
+        executable_layouts, &executable->layouts[0], host_allocator,
         &executable->base);
     executable->library.header = library_header;
     executable->identifier = iree_make_cstring_view((*library_header)->name);
diff --git a/iree/hal/local/sync_device.c b/iree/hal/local/sync_device.c
index b6083ab..924da83 100644
--- a/iree/hal/local/sync_device.c
+++ b/iree/hal/local/sync_device.c
@@ -23,13 +23,13 @@
   iree_hal_resource_t resource;
   iree_string_view_t identifier;
 
-  iree_host_size_t loader_count;
-  iree_hal_executable_loader_t** loaders;
-
   iree_allocator_t host_allocator;
   iree_hal_allocator_t* device_allocator;
 
   iree_hal_sync_semaphore_state_t semaphore_state;
+
+  iree_host_size_t loader_count;
+  iree_hal_executable_loader_t* loaders[];
 } iree_hal_sync_device_t;
 
 static const iree_hal_device_vtable_t iree_hal_sync_device_vtable;
@@ -64,8 +64,9 @@
                                     iree_hal_sync_device_check_params(params));
 
   iree_hal_sync_device_t* device = NULL;
-  iree_host_size_t total_size = sizeof(*device) + identifier.size +
-                                loader_count * sizeof(*device->loaders);
+  iree_host_size_t struct_size =
+      sizeof(*device) + loader_count * sizeof(*device->loaders);
+  iree_host_size_t total_size = struct_size + identifier.size;
   iree_status_t status =
       iree_allocator_malloc(host_allocator, total_size, (void**)&device);
   if (iree_status_is_ok(status)) {
@@ -73,13 +74,10 @@
     iree_hal_resource_initialize(&iree_hal_sync_device_vtable,
                                  &device->resource);
     iree_string_view_append_to_buffer(identifier, &device->identifier,
-                                      (char*)device + sizeof(*device));
+                                      (char*)device + struct_size);
     device->host_allocator = host_allocator;
 
     device->loader_count = loader_count;
-    device->loaders =
-        (iree_hal_executable_loader_t**)((uint8_t*)device->identifier.data +
-                                         identifier.size);
     for (iree_host_size_t i = 0; i < device->loader_count; ++i) {
       device->loaders[i] = loaders[i];
       iree_hal_executable_loader_retain(device->loaders[i]);
diff --git a/iree/hal/local/task_device.c b/iree/hal/local/task_device.c
index 158b9e4..236c83e 100644
--- a/iree/hal/local/task_device.c
+++ b/iree/hal/local/task_device.c
@@ -92,19 +92,18 @@
                                     iree_hal_task_device_check_params(params));
 
   iree_hal_task_device_t* device = NULL;
-  iree_host_size_t total_size =
-      sizeof(*device) + params->queue_count * sizeof(*device->queues) +
-      identifier.size + loader_count * sizeof(*device->loaders);
+  iree_host_size_t struct_size = sizeof(*device) +
+                                 params->queue_count * sizeof(*device->queues) +
+                                 loader_count * sizeof(*device->loaders);
+  iree_host_size_t total_size = struct_size + identifier.size;
   iree_status_t status =
       iree_allocator_malloc(host_allocator, total_size, (void**)&device);
   if (iree_status_is_ok(status)) {
     memset(device, 0, total_size);
     iree_hal_resource_initialize(&iree_hal_task_device_vtable,
                                  &device->resource);
-    iree_string_view_append_to_buffer(
-        identifier, &device->identifier,
-        (char*)device + sizeof(*device) +
-            params->queue_count * sizeof(*device->queues));
+    iree_string_view_append_to_buffer(identifier, &device->identifier,
+                                      (char*)device + struct_size);
     device->host_allocator = host_allocator;
     iree_arena_block_pool_initialize(4096, host_allocator,
                                      &device->small_block_pool);
@@ -117,8 +116,9 @@
 
     device->loader_count = loader_count;
     device->loaders =
-        (iree_hal_executable_loader_t**)((uint8_t*)device->identifier.data +
-                                         identifier.size);
+        (iree_hal_executable_loader_t**)((uint8_t*)device + sizeof(*device) +
+                                         params->queue_count *
+                                             sizeof(*device->queues));
     for (iree_host_size_t i = 0; i < device->loader_count; ++i) {
       device->loaders[i] = loaders[i];
       iree_hal_executable_loader_retain(device->loaders[i]);
diff --git a/iree/hal/local/task_driver.c b/iree/hal/local/task_driver.c
index e3f10ed..7711520 100644
--- a/iree/hal/local/task_driver.c
+++ b/iree/hal/local/task_driver.c
@@ -47,9 +47,9 @@
   IREE_TRACE_ZONE_BEGIN(z0);
 
   iree_hal_task_driver_t* driver = NULL;
-  iree_host_size_t total_size = sizeof(*driver) +
-                                loader_count * sizeof(*driver->loaders) +
-                                identifier.size;
+  iree_host_size_t struct_size =
+      sizeof(*driver) + loader_count * sizeof(*driver->loaders);
+  iree_host_size_t total_size = struct_size + identifier.size;
   iree_status_t status =
       iree_allocator_malloc(host_allocator, total_size, (void**)&driver);
   if (iree_status_is_ok(status)) {
@@ -57,9 +57,8 @@
                                  &driver->resource);
     driver->host_allocator = host_allocator;
 
-    iree_string_view_append_to_buffer(
-        identifier, &driver->identifier,
-        (char*)driver + total_size - identifier.size);
+    iree_string_view_append_to_buffer(identifier, &driver->identifier,
+                                      (char*)driver + struct_size);
     memcpy(&driver->default_params, default_params,
            sizeof(driver->default_params));
 
diff --git a/iree/hal/vulkan/native_descriptor_set_layout.cc b/iree/hal/vulkan/native_descriptor_set_layout.cc
index 1d3c4a3..dff1661 100644
--- a/iree/hal/vulkan/native_descriptor_set_layout.cc
+++ b/iree/hal/vulkan/native_descriptor_set_layout.cc
@@ -56,6 +56,7 @@
 
   VkDescriptorSetLayoutBinding* native_bindings = NULL;
   if (binding_count > 0) {
+    // TODO(benvanik): avoid this allocation if possible (inline_array).
     IREE_RETURN_IF_ERROR(iree_allocator_malloc(
         logical_device->host_allocator(),
         binding_count * sizeof(VkDescriptorSetLayoutBinding),
diff --git a/iree/hal/vulkan/vma_buffer.cc b/iree/hal/vulkan/vma_buffer.cc
index 9030775..ffdb8ff 100644
--- a/iree/hal/vulkan/vma_buffer.cc
+++ b/iree/hal/vulkan/vma_buffer.cc
@@ -123,7 +123,7 @@
   // would only work if the entire buffer was discarded.
 #ifndef NDEBUG
   if (iree_any_bit_set(memory_access, IREE_HAL_MEMORY_ACCESS_DISCARD)) {
-    memset(data_ptr + local_byte_offset, 0xCD, local_byte_length);
+    memset(*out_data_ptr, 0xCD, local_byte_length);
   }
 #endif  // !NDEBUG
 
diff --git a/iree/samples/custom_modules/module.cc b/iree/samples/custom_modules/module.cc
index 0477ef5..1759d82 100644
--- a/iree/samples/custom_modules/module.cc
+++ b/iree/samples/custom_modules/module.cc
@@ -56,7 +56,7 @@
   // Note that we allocate the message and the string value together.
   iree_custom_message_t* message = NULL;
   IREE_RETURN_IF_ERROR(iree_allocator_malloc(
-      allocator, sizeof(iree_custom_message_t) + value.size, (void**)&message));
+      allocator, sizeof(*message) + value.size, (void**)&message));
   message->ref_object.counter = IREE_ATOMIC_VAR_INIT(1);
   message->allocator = allocator;
   message->value.data = ((const char*)message) + sizeof(iree_custom_message_t);
@@ -71,8 +71,8 @@
                                        iree_custom_message_t** out_message) {
   IREE_ASSERT_ARGUMENT(out_message);
   iree_custom_message_t* message = NULL;
-  IREE_RETURN_IF_ERROR(iree_allocator_malloc(
-      allocator, sizeof(iree_custom_message_t), (void**)&message));
+  IREE_RETURN_IF_ERROR(
+      iree_allocator_malloc(allocator, sizeof(*message), (void**)&message));
   message->ref_object.counter = IREE_ATOMIC_VAR_INIT(1);
   message->allocator = allocator;
   message->value = value;  // Unowned.
diff --git a/iree/samples/simple_embedding/BUILD b/iree/samples/simple_embedding/BUILD
index 94c29c2..d2fd399 100644
--- a/iree/samples/simple_embedding/BUILD
+++ b/iree/samples/simple_embedding/BUILD
@@ -265,6 +265,53 @@
 iree_cmake_extra_content(
     content = """
 endif()
+
+if(${IREE_HAL_DRIVER_CUDA} AND (${IREE_TARGET_BACKEND_CUDA} OR DEFINED IREE_HOST_BINARY_ROOT))
+""",
+    inline = True,
+)
+
+cc_binary(
+    name = "simple_embedding_cuda",
+    srcs = [
+        "device_cuda.c",
+        "simple_embedding.c",
+    ],
+    deps = [
+        ":simple_embedding_test_bytecode_module_cuda_c",
+        "//iree/base",
+        "//iree/hal",
+        "//iree/hal/cuda/registration",
+        "//iree/modules/hal",
+        "//iree/vm",
+        "//iree/vm:bytecode_module",
+    ],
+)
+
+iree_bytecode_module(
+    name = "simple_embedding_test_bytecode_module_cuda",
+    src = "simple_embedding_test.mlir",
+    c_identifier = "iree_samples_simple_embedding_test_module_cuda",
+    flags = [
+        "-iree-input-type=mhlo",
+        "-iree-mlir-to-vm-bytecode-module",
+        "-iree-hal-target-backends=cuda",
+        "-iree-llvm-debug-symbols=false",
+    ],
+)
+
+# Simple embedding is failing in the CI.
+# run_binary_test(
+#     name = "simple_embedding_cuda_test",
+#     tags = [
+#         "driver=cuda",
+#     ],
+#     test_binary = ":simple_embedding_cuda",
+# )
+
+iree_cmake_extra_content(
+    content = """
+endif()
 """,
     inline = True,
 )
diff --git a/iree/samples/simple_embedding/CMakeLists.txt b/iree/samples/simple_embedding/CMakeLists.txt
index 4623869..aaee677 100644
--- a/iree/samples/simple_embedding/CMakeLists.txt
+++ b/iree/samples/simple_embedding/CMakeLists.txt
@@ -245,4 +245,48 @@
 
 endif()
 
+if(${IREE_HAL_DRIVER_CUDA} AND (${IREE_TARGET_BACKEND_CUDA} OR DEFINED IREE_HOST_BINARY_ROOT))
+
+iree_cc_binary(
+  NAME
+    simple_embedding_cuda
+  SRCS
+    "device_cuda.c"
+    "simple_embedding.c"
+  DEPS
+    ::simple_embedding_test_bytecode_module_cuda_c
+    iree::base
+    iree::hal
+    iree::hal::cuda::registration
+    iree::modules::hal
+    iree::vm
+    iree::vm::bytecode_module
+)
+
+iree_bytecode_module(
+  NAME
+    simple_embedding_test_bytecode_module_cuda
+  SRC
+    "simple_embedding_test.mlir"
+  C_IDENTIFIER
+    "iree_samples_simple_embedding_test_module_cuda"
+  FLAGS
+    "-iree-input-type=mhlo"
+    "-iree-mlir-to-vm-bytecode-module"
+    "-iree-hal-target-backends=cuda"
+    "-iree-llvm-debug-symbols=false"
+  PUBLIC
+)
+
+iree_run_binary_test(
+  NAME
+    "simple_embedding_cuda_test"
+  LABELS
+    "driver=cuda"
+  TEST_BINARY
+    ::simple_embedding_cuda
+)
+
+endif()
+
 ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/samples/simple_embedding/device_cuda.c b/iree/samples/simple_embedding/device_cuda.c
new file mode 100644
index 0000000..e8cf2f4
--- /dev/null
+++ b/iree/samples/simple_embedding/device_cuda.c
@@ -0,0 +1,39 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+// A example of setting up the the cuda driver.
+
+#include <stddef.h>
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+#include "iree/hal/cuda/registration/driver_module.h"
+
+// Compiled module embedded here to avoid file IO:
+#include "iree/samples/simple_embedding/simple_embedding_test_bytecode_module_cuda_c.h"
+
+iree_status_t create_sample_device(iree_hal_device_t** device) {
+  // Only register the cuda HAL driver.
+  IREE_RETURN_IF_ERROR(
+      iree_hal_cuda_driver_module_register(iree_hal_driver_registry_default()));
+  // Create the hal driver from the name.
+  iree_hal_driver_t* driver = NULL;
+  iree_string_view_t identifier = iree_make_cstring_view("cuda");
+  IREE_RETURN_IF_ERROR(iree_hal_driver_registry_try_create_by_name(
+      iree_hal_driver_registry_default(), identifier, iree_allocator_system(),
+      &driver));
+  IREE_RETURN_IF_ERROR(iree_hal_driver_create_default_device(
+      driver, iree_allocator_system(), device));
+  iree_hal_driver_release(driver);
+  return iree_ok_status();
+}
+
+const iree_const_byte_span_t load_bytecode_module_data() {
+  const struct iree_file_toc_t* module_file_toc =
+      iree_samples_simple_embedding_test_module_cuda_create();
+  return iree_make_const_byte_span(module_file_toc->data,
+                                   module_file_toc->size);
+}
diff --git a/iree/tools/run_lit.ps1 b/iree/tools/run_lit.ps1
index 7f19601..e88dcb4 100644
--- a/iree/tools/run_lit.ps1
+++ b/iree/tools/run_lit.ps1
@@ -63,6 +63,7 @@
   }
   $test_line = $test_line.Substring("// RUN: ".Length)
   $test_line = $test_line -replace "%s", $test_file
+  $test_line = $test_line -replace "`"", "\`""
   Write-Host -ForegroundColor Blue "Running test command:"
   Write-Host -ForegroundColor Yellow "$test_line"
   & $bashExe -c $test_line | Out-Default
diff --git a/iree/vm/buffer.c b/iree/vm/buffer.c
index b6b2117..d433a9f 100644
--- a/iree/vm/buffer.c
+++ b/iree/vm/buffer.c
@@ -95,8 +95,7 @@
 
   // The actual buffer payload is prefixed with the buffer type so we need only
   // a single allocation.
-  iree_host_size_t prefix_size =
-      iree_host_align(sizeof(iree_vm_buffer_t), iree_max_align_t);
+  iree_host_size_t prefix_size = iree_sizeof_struct(**out_buffer);
   iree_host_size_t total_size = prefix_size + length;
 
   // Allocate combined [prefix | buffer] memory.
diff --git a/iree/vm/bytecode_dispatch_util.h b/iree/vm/bytecode_dispatch_util.h
index 2e692be..596968e 100644
--- a/iree/vm/bytecode_dispatch_util.h
+++ b/iree/vm/bytecode_dispatch_util.h
@@ -153,34 +153,12 @@
 
 // Bytecode data access macros for reading values of a given type from a byte
 // offset within the current function.
-#if defined(IREE_ENDIANNESS_LITTLE)
-#define OP_I8(i) bytecode_data[pc + (i)]
-#define OP_I16(i) *((uint16_t*)&bytecode_data[pc + (i)])
-#define OP_I32(i) *((uint32_t*)&bytecode_data[pc + (i)])
-#define OP_I64(i) *((uint64_t*)&bytecode_data[pc + (i)])
-#define OP_F32(i) *((float*)&bytecode_data[pc + (i)])
-#define OP_F64(i) *((double*)&bytecode_data[pc + (i)])
-#else
-#define OP_I8(i) bytecode_data[pc + (i)]
-#define OP_I16(i)                           \
-  ((uint16_t)bytecode_data[pc + 0 + (i)]) | \
-      ((uint16_t)bytecode_data[pc + 1 + (i)] << 8)
-#define OP_I32(i)                                     \
-  ((uint32_t)bytecode_data[pc + 0 + (i)]) |           \
-      ((uint32_t)bytecode_data[pc + 1 + (i)] << 8) |  \
-      ((uint32_t)bytecode_data[pc + 2 + (i)] << 16) | \
-      ((uint32_t)bytecode_data[pc + 3 + (i)] << 24)
-#define OP_I64(i)                                     \
-  ((uint64_t)bytecode_data[pc + 0 + (i)]) |           \
-      ((uint64_t)bytecode_data[pc + 1 + (i)] << 8) |  \
-      ((uint64_t)bytecode_data[pc + 2 + (i)] << 16) | \
-      ((uint64_t)bytecode_data[pc + 3 + (i)] << 24) | \
-      ((uint64_t)bytecode_data[pc + 4 + (i)] << 32) | \
-      ((uint64_t)bytecode_data[pc + 5 + (i)] << 40) | \
-      ((uint64_t)bytecode_data[pc + 6 + (i)] << 48) | \
-      ((uint64_t)bytecode_data[pc + 7 + (i)] << 56)
-#error "TODO: OP_F32 and OP_F64 for big endian systems"
-#endif  // IREE_ENDIANNESS_LITTLE
+#define OP_I8(i) iree_unaligned_load_le((uint8_t*)&bytecode_data[pc + (i)])
+#define OP_I16(i) iree_unaligned_load_le((uint16_t*)&bytecode_data[pc + (i)])
+#define OP_I32(i) iree_unaligned_load_le((uint32_t*)&bytecode_data[pc + (i)])
+#define OP_I64(i) iree_unaligned_load_le((uint64_t*)&bytecode_data[pc + (i)])
+#define OP_F32(i) iree_unaligned_load_le((float*)&bytecode_data[pc + (i)])
+#define OP_F64(i) iree_unaligned_load_le((double*)&bytecode_data[pc + (i)])
 
 //===----------------------------------------------------------------------===//
 // Utilities matching the tablegen op encoding scheme
diff --git a/iree/vm/bytecode_module.c b/iree/vm/bytecode_module.c
index 0feedd8..e88972d 100644
--- a/iree/vm/bytecode_module.c
+++ b/iree/vm/bytecode_module.c
@@ -761,9 +761,8 @@
 
   iree_vm_bytecode_module_t* module = NULL;
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_allocator_malloc(
-              allocator, sizeof(iree_vm_bytecode_module_t) + type_table_size,
-              (void**)&module));
+      z0, iree_allocator_malloc(allocator, sizeof(*module) + type_table_size,
+                                (void**)&module));
   module->allocator = allocator;
 
   iree_vm_FunctionDescriptor_vec_t function_descriptors =
@@ -782,8 +781,6 @@
   module->def = module_def;
 
   module->type_count = iree_vm_TypeDef_vec_len(type_defs);
-  module->type_table = (iree_vm_type_def_t*)((uint8_t*)module +
-                                             sizeof(iree_vm_bytecode_module_t));
   iree_status_t resolve_status =
       iree_vm_bytecode_module_resolve_types(type_defs, module->type_table);
   if (!iree_status_is_ok(resolve_status)) {
diff --git a/iree/vm/bytecode_module_impl.h b/iree/vm/bytecode_module_impl.h
index d602194..a63cbf8 100644
--- a/iree/vm/bytecode_module_impl.h
+++ b/iree/vm/bytecode_module_impl.h
@@ -67,7 +67,7 @@
 
   // Type table mapping module type IDs to registered VM types.
   iree_host_size_t type_count;
-  iree_vm_type_def_t* type_table;
+  iree_vm_type_def_t type_table[];
 } iree_vm_bytecode_module_t;
 
 // A resolved and split import in the module state table.
diff --git a/iree/vm/context.c b/iree/vm/context.c
index fa56039..2652fbf 100644
--- a/iree/vm/context.c
+++ b/iree/vm/context.c
@@ -321,12 +321,12 @@
       // TODO(benvanik): tune list growth for module count >> 4.
       new_capacity = context->list.capacity * 2;
     }
-    iree_vm_module_t** new_module_list;
+    iree_vm_module_t** new_module_list = NULL;
     IREE_RETURN_AND_END_ZONE_IF_ERROR(
         z0, iree_allocator_malloc(context->allocator,
                                   sizeof(iree_vm_module_t*) * new_capacity,
                                   (void**)&new_module_list));
-    iree_vm_module_state_t** new_module_state_list;
+    iree_vm_module_state_t** new_module_state_list = NULL;
     IREE_RETURN_AND_END_ZONE_IF_ERROR(
         z0,
         iree_allocator_malloc(context->allocator,
diff --git a/iree/vm/instance.c b/iree/vm/instance.c
index 34fee0c..9d1f4c6 100644
--- a/iree/vm/instance.c
+++ b/iree/vm/instance.c
@@ -27,8 +27,8 @@
 
   iree_vm_instance_t* instance = NULL;
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_allocator_malloc(allocator, sizeof(iree_vm_instance_t),
-                                (void**)&instance));
+      z0,
+      iree_allocator_malloc(allocator, sizeof(*instance), (void**)&instance));
   instance->allocator = allocator;
   iree_atomic_ref_count_init(&instance->ref_count);
 
diff --git a/iree/vm/list.c b/iree/vm/list.c
index e6e784d..d668969 100644
--- a/iree/vm/list.c
+++ b/iree/vm/list.c
@@ -166,7 +166,7 @@
     iree_allocator_t allocator, iree_vm_list_t** out_list) {
   iree_vm_list_t* list = NULL;
   IREE_RETURN_IF_ERROR(
-      iree_allocator_malloc(allocator, sizeof(iree_vm_list_t), (void**)&list));
+      iree_allocator_malloc(allocator, sizeof(*list), (void**)&list));
   memset(list, 0, sizeof(*list));
   iree_atomic_ref_count_init(&list->ref_object.counter);
   list->allocator = allocator;
diff --git a/iree/vm/native_module.c b/iree/vm/native_module.c
index 1926ab2..7e90813 100644
--- a/iree/vm/native_module.c
+++ b/iree/vm/native_module.c
@@ -360,8 +360,8 @@
   // to expose this via a query_size function so that we could adjust the size
   // of our storage independent of the definition of the user module.
   iree_vm_native_module_t* module = NULL;
-  IREE_RETURN_IF_ERROR(iree_allocator_malloc(
-      allocator, sizeof(iree_vm_native_module_t), (void**)&module));
+  IREE_RETURN_IF_ERROR(
+      iree_allocator_malloc(allocator, sizeof(*module), (void**)&module));
 
   iree_status_t status = iree_vm_native_module_initialize(
       interface, module_descriptor, allocator, (iree_vm_module_t*)module);
diff --git a/scripts/git/check_submodule_init.py b/scripts/git/check_submodule_init.py
new file mode 100644
index 0000000..0523c9d
--- /dev/null
+++ b/scripts/git/check_submodule_init.py
@@ -0,0 +1,33 @@
+# Copyright 2021 The IREE Authors

+#

+# Licensed under the Apache License v2.0 with LLVM Exceptions.

+# See https://llvm.org/LICENSE.txt for license information.

+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

+

+import os

+import subprocess

+import sys

+

+

+def run():

+  # No-op if we're not in a git repository.

+  try:

+    subprocess.check_call(['git', 'rev-parse', '--is-inside-work-tree'],

+                          stdout=subprocess.DEVNULL,

+                          stderr=subprocess.DEVNULL)

+  except:

+    return

+

+  output = os.popen("git submodule status")

+  submodules = output.readlines()

+

+  for submodule in submodules:

+    if (submodule.strip()[0] == "-"):

+      print(

+          "The git submodule '%s' is not initialized. Please run `git submodule update --init`"

+          % (submodule.split()[1]))

+      sys.exit(1)

+

+

+if __name__ == "__main__":

+  run()