Merge pull request #3467 from google/benvanik-vm-fold-arithmetic

diff --git a/SUBMODULE_VERSIONS b/SUBMODULE_VERSIONS
index b61cdfe..48180ae 100644
--- a/SUBMODULE_VERSIONS
+++ b/SUBMODULE_VERSIONS
@@ -5,7 +5,7 @@
 a5d9d0f7d368054fd1691aedf1db4116efcc233e third_party/flatbuffers
 4fb0ff7069bd88ee85902f4d0bb62794e5f6d021 third_party/flatcc
 f2fb48c3b3d79a75a88a99fba6576b25d42ec528 third_party/googletest
-93377888ae89560ba6d3976e2762d3d4724c4dfd third_party/llvm-project
+9b3c2a72e4cb3b0ae27f87064c11f728452b2af9 third_party/llvm-project
 17b12a4481daa150e2d1ea3ada086b551b856707 third_party/marl
 d2cdb70e038370b5e28f353fe98ccd70af1cbc25 third_party/mlir-emitc
 d8c7ee00a687ac369e62e2032514a93a9b413502 third_party/pybind11
@@ -14,7 +14,7 @@
 685f86471e9d26b3eb7676695a2e2cefb4551ae9 third_party/spirv_cross
 f8bf11a0253a32375c32cad92c841237b96696c0 third_party/spirv_headers
 57eb48aed36160c4876bc8310d9ca84d42ee9e2a third_party/swiftshader
-2e56481abcef1dd1625fba465a5d02ee6b347842 third_party/tensorflow
+090b691fbf7b7823c41345004d12eddaa6c86118 third_party/tensorflow
 a9a09ab0940408898fccfdcfe2bb8dc19b50f13c third_party/tracy
 9bd3f561bcee3f01d22912de10bb07ce4e23d378 third_party/vulkan_headers
 909f36b714c9239ee0b112a321220213a474ba53 third_party/vulkan_memory_allocator
diff --git a/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/BUILD.bazel b/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/BUILD.bazel
index 31a839d..4aeb2b8 100644
--- a/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/BUILD.bazel
+++ b/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/BUILD.bazel
@@ -3918,6 +3918,7 @@
         "include/mlir/Dialect/Linalg/EDSC/Builders.h",
         "include/mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h",
         "include/mlir/Dialect/Linalg/Passes.h",
+        "include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h",
         "include/mlir/Dialect/Linalg/Transforms/Hoisting.h",
         "include/mlir/Dialect/Linalg/Transforms/Transforms.h",
         "include/mlir/Dialect/Linalg/Utils/Utils.h",
@@ -3945,6 +3946,7 @@
         ":Transforms",
         ":TransformsPassIncGen",
         ":VectorOps",
+        ":VectorToSCF",
         "@llvm-project//llvm:Core",
         "@llvm-project//llvm:Support",
     ],
@@ -4033,7 +4035,6 @@
         ":EDSC",
         ":IR",
         ":LLVMDialect",
-        ":LinalgTransforms",
         ":Pass",
         ":SCFDialect",
         ":StandardOps",
diff --git a/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/test/BUILD.bazel b/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/test/BUILD.bazel
index d88190c..c467880 100644
--- a/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/test/BUILD.bazel
+++ b/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/test/BUILD.bazel
@@ -17,6 +17,21 @@
     includes = ["."],
 )
 
+filegroup(
+    name = "TestOpTdFiles",
+    srcs = [
+        "lib/Dialect/Test/TestOps.td",
+        "@llvm-project//mlir:OpBaseTdFiles",
+        "@llvm-project//mlir:include/mlir/IR/OpAsmInterface.td",
+        "@llvm-project//mlir:include/mlir/IR/RegionKindInterface.td",
+        "@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td",
+        "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
+        "@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td",
+        "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
+        "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
+    ],
+)
+
 gentbl(
     name = "TestOpsIncGen",
     strip_include_prefix = "lib/Dialect/Test",
@@ -57,14 +72,7 @@
     tblgen = "@llvm-project//mlir:mlir-tblgen",
     td_file = "lib/Dialect/Test/TestOps.td",
     td_srcs = [
-        "@llvm-project//mlir:OpBaseTdFiles",
-        "@llvm-project//mlir:include/mlir/IR/OpAsmInterface.td",
-        "@llvm-project//mlir:include/mlir/IR/RegionKindInterface.td",
-        "@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td",
-        "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
-        "@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td",
-        "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
-        "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
+        ":TestOpTdFiles",
     ],
     test = True,
 )
@@ -90,11 +98,34 @@
     test = True,
 )
 
+gentbl(
+    name = "TestTypeDefsIncGen",
+    strip_include_prefix = "lib/Dialect/Test",
+    tbl_outs = [
+        (
+            "-gen-typedef-decls",
+            "lib/Dialect/Test/TestTypeDefs.h.inc",
+        ),
+        (
+            "-gen-typedef-defs",
+            "lib/Dialect/Test/TestTypeDefs.cpp.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "lib/Dialect/Test/TestTypeDefs.td",
+    td_srcs = [
+        ":TestOpTdFiles",
+    ],
+    test = True,
+)
+
 cc_library(
     name = "TestDialect",
     srcs = [
         "lib/Dialect/Test/TestDialect.cpp",
         "lib/Dialect/Test/TestPatterns.cpp",
+        "lib/Dialect/Test/TestTraits.cpp",
+        "lib/Dialect/Test/TestTypes.cpp",
     ],
     hdrs = [
         "lib/Dialect/Test/TestDialect.h",
@@ -106,6 +137,7 @@
     deps = [
         ":TestInterfacesIncGen",
         ":TestOpsIncGen",
+        ":TestTypeDefsIncGen",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:ControlFlowInterfaces",
         "@llvm-project//mlir:DerivedAttributeOpInterface",
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 00ea2c2..dcd2341 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -313,8 +313,10 @@
       yield call
 
   @staticmethod
-  def compare_traces(ref_trace: "Trace", tar_trace: "Trace") -> bool:
+  def compare_traces(ref_trace: "Trace",
+                     tar_trace: "Trace") -> Tuple[bool, Sequence[str]]:
     traces_match = True
+    error_messages = []
 
     # Check that all method invocations match.
     ref_methods = [(call.method, call.rtol, call.atol) for call in ref_trace]
@@ -330,13 +332,17 @@
       logging.info("Comparing calls to '%s'", ref_call.method)
       rtol, atol = ref_call.get_tolerances()
 
-      inputs_match = Trace._check_same(ref_call.inputs, tar_call.inputs, rtol,
-                                       atol)
+      inputs_match, error_message = Trace._check_same(ref_call.inputs,
+                                                      tar_call.inputs, rtol,
+                                                      atol)
       if not inputs_match:
+        error_messages.append(error_message)
         logging.error("Inputs did not match.")
-      outputs_match = Trace._check_same(ref_call.outputs, tar_call.outputs,
-                                        rtol, atol)
+      outputs_match, error_message = Trace._check_same(ref_call.outputs,
+                                                       tar_call.outputs, rtol,
+                                                       atol)
       if not outputs_match:
+        error_messages.append(error_message)
         logging.error("Outputs did not match.")
       calls_match = inputs_match and outputs_match
 
@@ -349,83 +355,96 @@
         logging.error("Target call '%s':\n%s", tar_trace.backend_id, tar_call)
 
       traces_match = traces_match and calls_match
-    return traces_match
+    return traces_match, error_messages
 
   @staticmethod
-  def _check_same(ref: Any, tar: Any, rtol: float, atol: float) -> bool:
+  def _check_same(ref: Any, tar: Any, rtol: float,
+                  atol: float) -> Tuple[bool, Union[str, None]]:
     """Checks that ref and tar have identical datastructures and values."""
     # Check for matching types.
     if not isinstance(tar, type(ref)):
-      logging.error(
-          "Expected ref and tar to have the same type but got '%s' and '%s'",
-          type(ref), type(tar))
-      return False
+      error = ("Expected ref and tar to have the same type but got "
+               f"'{type(ref)}' and '{type(tar)}'")
+      logging.error(error)
+      return False, error
 
     if ref is None:
       # Nothing to compare (e.g. the called method had no outputs).
-      return True
+      return True, None
 
     # Recursive check for dicts.
     if isinstance(ref, dict):
       if ref.keys() != tar.keys():
-        logging.error(
-            "Expected ref and tar to have the same keys, but got '%s' and '%s'",
-            ref.keys(), tar.keys())
-        return False
+        error = ("Expected ref and tar to have the same keys, but got "
+                 f"'{ref.keys()}' and '{tar.keys()}'")
+        logging.error(error)
+        return False, error
       # Check that all of the dictionaries' values are the same.
       for key in ref:
-        if not Trace._check_same(ref[key], tar[key], rtol, atol):
-          return False
+        same, error = Trace._check_same(ref[key], tar[key], rtol, atol)
+        if not same:
+          return same, error
 
     # Recursive check for iterables.
     elif isinstance(ref, list) or isinstance(ref, tuple):
       if len(ref) != len(tar):
-        logging.error(
-            "Expected ref and tar to have the same length, but got %s and %s",
-            len(ref), len(tar))
-        return False
+        error = ("Expected ref and tar to have the same length, but got "
+                 f"{len(ref)} and {len(tar)}")
+        logging.error(error)
+        return False, error
       # Check that all of the iterables' values are the same.
       for i in range(len(ref)):
-        if not Trace._check_same(ref[i], tar[i], rtol, atol):
-          return False
+        same, error = Trace._check_same(ref[i], tar[i], rtol, atol)
+        if not same:
+          return same, error
 
     # Base check for numpy arrays.
     elif isinstance(ref, np.ndarray):
       if ref.dtype != tar.dtype:
-        logging.error(
-            "Expected ref and tar to have the same dtype, but got %s  and %s",
-            ref.dtype, tar.dtype)
-        return False
+        error = ("Expected ref and tar to have the same dtype, but got "
+                 f"'{ref.dtype}' and '{tar.dtype}'")
+        logging.error(error)
+        return False, error
       if ref.size == tar.size == 0:
-        return True
+        return True, None
 
       if np.issubdtype(ref.dtype, np.floating):
         same = np.allclose(ref, tar, rtol=rtol, atol=atol)
         abs_diff = np.max(np.abs(ref - tar))
         rel_diff = np.max(np.abs(ref - tar) / np.max(np.abs(tar)))
+        diff_string = (f"Max abs diff: {abs_diff:.2e}, atol: {atol:.2e}, "
+                       f"max relative diff: {rel_diff:.2e}, rtol: {rtol:.2e}")
         if not same:
-          logging.error(
-              "Floating point difference between ref and tar was too large. "
-              "Max abs diff: %s, atol: %s, max relative diff: %s, rtol: %s",
-              abs_diff, atol, rel_diff, rtol)
+          error = ("Floating point difference between ref and tar was too "
+                   f"large. {diff_string}")
+          logging.error(error)
         else:
+          error = None
           logging.info(
               "Floating point difference between ref and tar was within "
-              "tolerance. "
-              "Max abs diff: %s, atol: %s, max relative diff: %s, rtol: %s",
-              abs_diff, atol, rel_diff, rtol)
-        return same
+              "tolerance. %s", diff_string)
+        return same, error
+      elif np.issubdtype(ref.dtype, np.integer):
+        same = np.array_equal(ref, tar)
+        if not same:
+          abs_diff = np.max(np.abs(ref - tar))
+          error = ("Expected array equality between ref and tar, but got "
+                   f"a max elementwise difference of {abs_diff}")
+          logging.error(error)
+        else:
+          error = None
+        return same, error
       else:
-        return np.array_equal(ref, tar)
+        return np.array_equal(ref, tar), None
 
     # Base check for native number types.
     elif isinstance(ref, (int, float)):
-      return ref == tar
+      return ref == tar, None
 
     # If outputs end up here then an extra branch for that type should be added.
     else:
       raise TypeError(f"Encountered results with unexpected type {type(ref)}")
-    return True
+    return True, None
 
   def save_plaintext(self, trace_dir: str, summarize: bool = True) -> None:
     """Saves a human-readable string representation of this trace to disk.
@@ -718,12 +737,14 @@
 
     # Compare each target trace of trace_function with the reference trace.
     failed_backend_indices = []
+    error_messages = []
     for i, tar_trace in enumerate(tar_traces):
       logging.info("Comparing the reference backend '%s' with '%s'",
                    ref_trace.backend_id, tar_trace.backend_id)
-      traces_match = Trace.compare_traces(ref_trace, tar_trace)
+      traces_match, errors = Trace.compare_traces(ref_trace, tar_trace)
       if not traces_match:
         failed_backend_indices.append(i)
+        error_messages.extend(errors)
 
     # Save the results to disk before validating.
     ref_trace_dir = _get_trace_dir(modules.artifacts_dir, ref_trace)
@@ -740,10 +761,11 @@
       failed_backends = [
           tar_traces[i].backend_id for i in failed_backend_indices
       ]
+      error_list = ''.join([f'\n  - {message}' for message in error_messages])
       self.fail(
-          "Comparision between the reference backend and the following targets "
-          f"failed: {failed_backends}. The errors above show the inputs and "
-          "outputs of the non-matching calls.")
+          "Comparison between the reference backend and the following targets "
+          f"failed: {failed_backends}. Errors: {error_list}\n"
+          "See the logs above for more details about the non-matching calls.")
 
   @classmethod
   def tearDownClass(cls) -> None:
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
index e1fb6e3..efd09ff 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
@@ -104,7 +104,7 @@
         ],
     }
     # yapf: enable
-    same = tf_test_utils.Trace._check_same(ref, tar, rtol=1e-6, atol=1e-6)
+    same, _ = tf_test_utils.Trace._check_same(ref, tar, rtol=1e-6, atol=1e-6)
     self.assertEqual(tar_same, same)
 
   def test_trace_inputs_and_outputs(self):
@@ -171,7 +171,9 @@
     vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
     vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))
 
-    self.assertFalse(tf_test_utils.Trace.compare_traces(tf_trace, vmla_trace))
+    same, error_messages = tf_test_utils.Trace.compare_traces(
+        tf_trace, vmla_trace)
+    self.assertFalse(same)
 
   def test_trace_serialize_and_load(self):
 
diff --git a/integrations/tensorflow/compiler/dialect/tf_strings/conversion/convert_tf_to_tf_strings.cc b/integrations/tensorflow/compiler/dialect/tf_strings/conversion/convert_tf_to_tf_strings.cc
index e5c1a3e..d5c3961 100644
--- a/integrations/tensorflow/compiler/dialect/tf_strings/conversion/convert_tf_to_tf_strings.cc
+++ b/integrations/tensorflow/compiler/dialect/tf_strings/conversion/convert_tf_to_tf_strings.cc
@@ -156,7 +156,7 @@
 
 void populateTFToTFStringsPatterns(MLIRContext *ctx,
                                    OwningRewritePatternList &patterns) {
-  populateWithGenerated(ctx, &patterns);
+  populateWithGenerated(ctx, patterns);
   patterns.insert<StringFormatOpLowering>(ctx);
 }
 
diff --git a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc
index 472b890..17d2230 100644
--- a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc
+++ b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/convert_tf_to_tf_tensorlist.cc
@@ -83,7 +83,7 @@
   // It only knows how to handle blindly convert one type to another type.
 
   OwningRewritePatternList patterns;
-  populateWithGenerated(&getContext(), &patterns);
+  populateWithGenerated(&getContext(), patterns);
   patterns.insert<ConvertTfTensorlistConcatV2>(&getContext());
 
   ConversionTarget target(getContext());
diff --git a/integrations/tensorflow/e2e/keras/BUILD b/integrations/tensorflow/e2e/keras/BUILD
index 29e0acc..7753475 100644
--- a/integrations/tensorflow/e2e/keras/BUILD
+++ b/integrations/tensorflow/e2e/keras/BUILD
@@ -200,7 +200,7 @@
     datasets = ["cifar10"],
     failing_configurations = [
         {
-            # Failing all but tf and vmla:
+            # Failing on llvm and vulkan:
             "models": [
                 "NASNetLarge",
                 "NASNetMobile",
@@ -210,7 +210,6 @@
             ],
             "datasets": ["cifar10"],
             "backends": [
-                "tflite",
                 "iree_llvmjit",
                 "iree_vulkan",
             ],
@@ -284,10 +283,19 @@
     datasets = ["imagenet"],
     failing_configurations = [
         {
-            # Failing all but tf and vmla:
+            # Failing vulkan:
             "models": [
                 "InceptionResNetV2",
                 "InceptionV3",
+            ],
+            "datasets": ["imagenet"],
+            "backends": [
+                "iree_vulkan",
+            ],
+        },
+        {
+            # Failing llvm and vulkan:
+            "models": [
                 "NASNetLarge",
                 "NASNetMobile",
                 "ResNet50V2",
@@ -297,7 +305,6 @@
             ],
             "datasets": ["imagenet"],
             "backends": [
-                "tflite",
                 "iree_llvmjit",
                 "iree_vulkan",
             ],
diff --git a/integrations/tensorflow/e2e/keras/train/model_train_test.py b/integrations/tensorflow/e2e/keras/train/model_train_test.py
index 49a3949..a77ff95 100644
--- a/integrations/tensorflow/e2e/keras/train/model_train_test.py
+++ b/integrations/tensorflow/e2e/keras/train/model_train_test.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 """Test keras Model training."""
 
+from absl import app
 from absl import flags
 import numpy as np
 from pyiree.tf.support import tf_test_utils
@@ -77,6 +78,11 @@
 
 class ModelTrainTest(tf_test_utils.TracedModuleTestCase):
 
+  def __init__(self, *args, **kwargs):
+    super(ModelTrainTest, self).__init__(*args, **kwargs)
+    self._modules = tf_test_utils.compile_tf_module(
+        ModelTrain.CreateModule, exported_names=["train_step"])
+
   def generate_regression_data(self, size=8):
     x = np.arange(size) - size // 2
     y = 1.0 * x**3 + 1.0 * x**2 + 1.0 * x + np.random.randn(size) * size
@@ -105,9 +111,13 @@
     self.compare_backends(train_step, self._modules)
 
 
-if __name__ == "__main__":
+def main(argv):
+  del argv  # Unused
   if hasattr(tf, "enable_v2_behavior"):
     tf.enable_v2_behavior()
-  tf_test_utils.compile_tf_module(ModelTrain.CreateModule,
-                                  exported_names=["train_step"])
+
   tf.test.main()
+
+
+if __name__ == "__main__":
+  app.run(main)
diff --git a/integrations/tensorflow/e2e/keras/vision_model_test.py b/integrations/tensorflow/e2e/keras/vision_model_test.py
index 0e7a263..f533915 100644
--- a/integrations/tensorflow/e2e/keras/vision_model_test.py
+++ b/integrations/tensorflow/e2e/keras/vision_model_test.py
@@ -150,10 +150,10 @@
     self._modules = tf_test_utils.compile_tf_module(VisionModule,
                                                     exported_names=['predict'])
 
-  def test_application(self):
+  def test_predict(self):
 
     def predict(module):
-      module.predict(tf_utils.uniform(get_input_shape()))
+      module.predict(tf_utils.uniform(get_input_shape()), atol=1e-5, rtol=1e-5)
 
     self.compare_backends(predict, self._modules)
 
diff --git a/integrations/tensorflow/e2e/slim_vision_models/BUILD b/integrations/tensorflow/e2e/slim_vision_models/BUILD
index 42ce76d..fe87fb7 100644
--- a/integrations/tensorflow/e2e/slim_vision_models/BUILD
+++ b/integrations/tensorflow/e2e/slim_vision_models/BUILD
@@ -72,56 +72,36 @@
             ],
         },
         {
-            # Failing all but tf and vmla:
+            # Failing llvmjit and vulkan:
             "models": [
-                "inception_resnet_v2",
-                # tflite: RuntimeError: tensorflow/lite/kernels/reshape.cc:66 num_input_elements != num_output_elements (38400 != -1571481807)Node number 333 (RESHAPE) failed to prepare.
-                # llvmjit: *** Received signal 6 ***
-                # vulkan: Floating point difference was too large. Max abs diff: 1.3961234, atol: 5e-05, max relative diff: 0.27304956, rtol: 1e-06
+                "nasnet_mobile",
+                "nasnet_large",
+                "pnasnet_large",
+                "resnet_v2_50",
                 "resnet_v2_101",
-                # tflite: RuntimeError: tensorflow/lite/core/subgraph.cc BytesRequired number of elements overflowed.
-                # llvmjit: Floating point difference was too large. Max abs diff: 11.668068, atol: 5e-05, max relative diff: 0.93950737, rtol: 1e-06
-                # vulkan: Floating point difference was too large. Max abs diff: 11.668067, atol: 5e-05, max relative diff: 0.93950737, rtol: 1e-06
                 "resnet_v2_152",
-                # tflite: RuntimeError: tensorflow/lite/core/subgraph.cc BytesRequired number of elements overflowed.
-                # llvmjit: Floating point difference was too large. Max abs diff: 7.080696, atol: 5e-05, max relative diff: 0.97750616, rtol: 1e-06
-                # vulkan: Floating point difference was too large. Max abs diff: 7.08069, atol: 5e-05, max relative diff: 0.97750485, rtol: 1e-06
             ],
             "backends": [
-                "tflite",
                 "iree_llvmjit",
                 "iree_vulkan",
             ],
         },
         {
-            # Failing llvmjit and vulkan:
+            # Failing vulkan:
             "models": [
+                # [ERROR]: cannot separate Linalg/Parallel ops into multiple kernels
+                "inception_v1",
                 "inception_v2",
-                # llvmjit: double free or corruption (!prev); *** Received signal 6 ***
-                # vulkan: Floating point difference was too large. Max abs diff: 1.0769763, atol: 5e-05, max relative diff: 0.19576924, rtol: 1e-06
                 "inception_v3",
-                # llvmjit: double free or corruption (!prev); *** Received signal 6 ***
-                # vulkan: Floating point difference was too large. Max abs diff: 2.5201874, atol: 5e-05, max relative diff: 0.53700095, rtol: 1e-06
-                "nasnet_mobile",
-                # llvmjit: corrupted size vs. prev_size; *** Received signal 6 ***
-                # vulkan: *** Received signal 11 ***
-                "nasnet_large",
-                # llvmjit: *** Received signal 6 ***
-                # vulkan: *** Received signal 11 ***
-                "pnasnet_large",
-                # llvmjit: Floating point difference was too large. Max abs diff: 1.0411791, atol: 5e-05, max relative diff: 0.20533353, rtol: 1e-06
-                # vulkan: *** Received signal 11 ***
-                "resnet_v2_50",
-                # llvmjit: Floating point difference was too large. Max abs diff: 5.8187943, atol: 5e-05, max relative diff: 0.7946711, rtol: 1e-06
-                # vulkan: Floating point difference was too large. Max abs diff: 5.8187933, atol: 5e-05, max relative diff: 0.79467094, rtol: 1e-06
+                "inception_resnet_v2",
             ],
             "backends": [
-                "iree_llvmjit",
                 "iree_vulkan",
             ],
         },
     ],
     models = [
+        "amoebanet_a_n18_f448",
         "inception_resnet_v2",
         "inception_v1",
         "inception_v2",
diff --git a/iree/base/atomics.h b/iree/base/atomics.h
index a2d0e71..b97d803 100644
--- a/iree/base/atomics.h
+++ b/iree/base/atomics.h
@@ -79,9 +79,9 @@
 typedef struct {
   int64_t __val;
 } iree_atomic_int64_t;
-typedef __declspec(align(16)) struct {
-  uint64_t __val[2];
-} iree_atomic_int128_t;
+// typedef __declspec(align(16)) struct {
+//   uint64_t __val[2];
+// } iree_atomic_int128_t;
 
 #define iree_atomic_load_int32(object, order) \
   InterlockedExchangeAdd((volatile LONG*)object, 0)
@@ -153,7 +153,8 @@
 
 typedef _Atomic int32_t iree_atomic_int32_t;
 typedef _Atomic int64_t iree_atomic_int64_t;
-typedef _Atomic __int128 iree_atomic_int128_t;
+// TODO(#3453): check for __int128 support before using
+// typedef _Atomic __int128 iree_atomic_int128_t;
 
 #define iree_atomic_load_auto(object, order) \
   __c11_atomic_load((object), (order))
@@ -192,7 +193,7 @@
 
 typedef int32_t iree_atomic_int32_t;
 typedef int64_t iree_atomic_int64_t;
-typedef __int128 iree_atomic_int128_t;
+// typedef __int128 iree_atomic_int128_t;
 
 #ifdef __cplusplus
 // Equiv to C++ auto keyword in C++ mode.
diff --git a/iree/base/ref_ptr_test.cc b/iree/base/ref_ptr_test.cc
index 5087791..efe0bd6 100644
--- a/iree/base/ref_ptr_test.cc
+++ b/iree/base/ref_ptr_test.cc
@@ -218,6 +218,8 @@
   struct MyBaseType : public RefObject<MyBaseType> {
     int x = 5;
     using RefObject<MyBaseType>::counter_;  // Expose for testing.
+
+    virtual ~MyBaseType() = default;
   };
   struct MyTypeA : public MyBaseType {
     int a = 6;
diff --git a/iree/base/threading_pthreads.c b/iree/base/threading_pthreads.c
index a5f99b6..5482414 100644
--- a/iree/base/threading_pthreads.c
+++ b/iree/base/threading_pthreads.c
@@ -241,11 +241,15 @@
     iree_thread_t* thread, iree_thread_priority_class_t priority_class) {
   IREE_TRACE_ZONE_BEGIN(z0);
 
+#if defined(IREE_PLATFORM_ANDROID)
+  // TODO(benvanik): Some sort of solution on Android, if possible (see above)
+#else
   int policy = 0;
   struct sched_param param;
   pthread_getschedparam(thread->handle, &policy, &param);
   param = iree_thread_sched_param_for_priority_class(policy, priority_class);
   pthread_setschedparam(thread->handle, policy, &param);
+#endif  // IREE_PLATFORM_ANDROID
 
   IREE_TRACE_ZONE_END(z0);
 }
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
index 0b20c43..5421cc9 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
@@ -605,8 +605,9 @@
     // If the `linalgOp` writes to workgroup memory insert barrier after the
     // op.
     if (llvm::any_of(linalgOp.getOperands(), [](Value output) {
-          return output.getType().cast<MemRefType>().getMemorySpace() ==
-                 getWorkgroupMemorySpace();
+          MemRefType outputType = output.getType().dyn_cast<MemRefType>();
+          return outputType &&
+                 outputType.getMemorySpace() == getWorkgroupMemorySpace();
         })) {
       rewriter.create<spirv::ControlBarrierOp>(
           linalgOp.getLoc(), spirv::Scope::Workgroup, spirv::Scope::Workgroup,
@@ -751,6 +752,7 @@
                   MapLinalgOpToGlobalInvocationId<linalg::IndexedGenericOp>,
                   MapLinalgOpToLocalInvocationId<linalg::ConvOp>,
                   MapLinalgOpToLocalInvocationId<linalg::CopyOp>,
+                  MapLinalgOpToLocalInvocationId<linalg::FillOp>,
                   MapLinalgOpToLocalInvocationId<linalg::MatmulOp>,
                   MapLinalgOpToLocalInvocationId<linalg::BatchMatmulOp>,
                   MapLinalgOpToLocalInvocationId<linalg::PoolingMaxOp>,
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
index fdae78b..0021c77 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
@@ -63,23 +63,22 @@
     entryPointFn.emitError("unable to find num workgroups fn ") << attr;
     return nullptr;
   }
-  if (!numWorkgroupsFn.empty()) {
-    entryPointFn.emitError("num workgroups fn expected to be empty");
-    return nullptr;
-  }
   return numWorkgroupsFn;
 }
 
 /// Computes the bounds of the parallel loops partitioned across workgroups.
 static Optional<SmallVector<Value, 2>> getParallelLoopRange(
-    PatternRewriter &rewriter, Location loc, linalg::LinalgOp linalgOp) {
-  FuncOp numWorkgroupsFn =
-      getNumWorkgroupsFn(linalgOp.getParentOfType<FuncOp>());
-  if (!numWorkgroupsFn) return {};
+    PatternRewriter &rewriter, FuncOp numWorkgroupsFn, Location loc,
+    linalg::LinalgOp linalgOp) {
+  if (!numWorkgroupsFn.empty()) {
+    numWorkgroupsFn.emitError("num workgroups fn expected to be empty");
+    return {};
+  }
   LLVM_DEBUG({
     llvm::dbgs() << "Found num workgroups function : "
                  << numWorkgroupsFn.getName();
   });
+
   rewriter.setInsertionPointToEnd(numWorkgroupsFn.addEntryBlock());
   llvm::SetVector<Operation *> slice;
   getBackwardSlice(linalgOp, &slice);
@@ -127,10 +126,14 @@
                                                  linalg::LinalgOp linalgOp,
                                                  FuncOp entryPointFn,
                                                  ArrayRef<int64_t> tileSizes) {
+  FuncOp numWorkgroupsFn =
+      getNumWorkgroupsFn(linalgOp.getParentOfType<FuncOp>());
+  if (!numWorkgroupsFn) return failure();
+
   Location loc = linalgOp.getLoc();
   OpBuilder::InsertionGuard gaurd(rewriter);
   Optional<SmallVector<Value, 2>> parallelLoopRange =
-      getParallelLoopRange(rewriter, loc, linalgOp);
+      getParallelLoopRange(rewriter, numWorkgroupsFn, loc, linalgOp);
   if (!parallelLoopRange) return failure();
   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
   SmallVector<Value, 3> returnValues(3, one);
@@ -148,10 +151,23 @@
 LogicalResult createNumWorkgroupsFromLinearizedResultShape(
     PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
     int64_t workgroupSizeX) {
+  FuncOp numWorkgroupsFn =
+      getNumWorkgroupsFn(linalgOp.getParentOfType<FuncOp>());
+  if (!numWorkgroupsFn) return failure();
+  if (!numWorkgroupsFn.empty()) {
+    // TODO(ravishankarm): We can end up with multiple linalg operations
+    // (typically linalg.generic operations) that have the same workload in a
+    // dispatch region. In that case, the first linalg.generic creates the body
+    // of number of workgroups. For now, just returning if the body is not empty
+    // assuming that it is correct for all the ops in the dispatch region. This
+    // needs to be enforced somehow.
+    return success();
+  }
+
   Location loc = linalgOp.getLoc();
   OpBuilder::InsertionGuard gaurd(rewriter);
   Optional<SmallVector<Value, 2>> parallelLoopRange =
-      getParallelLoopRange(rewriter, loc, linalgOp);
+      getParallelLoopRange(rewriter, numWorkgroupsFn, loc, linalgOp);
   if (!parallelLoopRange) return failure();
   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
   SmallVector<Value, 3> returnValues(3, one);
@@ -168,6 +184,10 @@
 // Launch config calculation.
 //===----------------------------------------------------------------------===//
 
+/// Name of the StrAttr that can be used to get the key to access the tile size
+/// information.
+static const char kLaunchInfoKey[] = "launch_info_key";
+
 /// Given `nprocs` try to distribute it evenly across 2 logical x and y.
 static std::tuple<int64_t, int64_t> distributeProcs2D(int64_t nprocs) {
   int64_t nprocs_x = std::max<int64_t>(
@@ -387,12 +407,27 @@
 
 #undef DEFINE_POOLINGOP_CONFIG
 
-LogicalResult LaunchConfig::init(const SPIRVCodegenOptions &options,
-                                 ArrayRef<linalg::LinalgOp> linalgOps) {
+Optional<StringRef> LaunchConfig::getKey(Operation *op) const {
+  StringAttr attr = op->getAttrOfType<StringAttr>(kLaunchInfoKey);
+  if (!attr) return {};
+  return attr.getValue();
+}
+
+LogicalResult LaunchConfig::init(MLIRContext *context,
+                                 const SPIRVCodegenOptions &options,
+                                 ArrayRef<Operation *> linalgOps) {
+  unsigned numTiledOps = 0;
+  auto setKey = [&](Operation *op) -> std::string {
+    std::string key = llvm::formatv("__op_num_{0}__", numTiledOps++).str();
+    op->setAttr(Identifier::get(kLaunchInfoKey, context),
+                StringAttr::get(key, context));
+    return key;
+  };
+
   if (!options.workgroupSize.empty()) {
-    for (linalg::LinalgOp op : linalgOps)
-      tileSizes[op.getOperation()->getName().getStringRef()].emplace_back(
-          options.tileSizes.begin(), options.tileSizes.end());
+    for (Operation *linalgOp : linalgOps)
+      tileSizes[setKey(linalgOp)].emplace_back(options.tileSizes.begin(),
+                                               options.tileSizes.end());
     workgroupSize = {1, 1, 1};
     for (unsigned i = 0,
                   e = std::min<unsigned>(3, options.workgroupSize.size());
@@ -406,17 +441,18 @@
   spirv::ResourceLimitsAttr resourceLimits =
       spirv::lookupTargetEnv(*linalgOps.begin()).getResourceLimits();
 
-  for (linalg::LinalgOp op : linalgOps) {
-    StringRef key = op.getOperation()->getName().getStringRef();
-    if (tileSizes.count(key)) {
-      return op.emitError("unexpected multiple ")
-             << key << " operations within dispatch region";
-    }
+  Optional<linalg::LinalgOp> rootOperation = {};
 
-    TileSizesListType &tileSizesInfo = tileSizes[key];
-
+  for (Operation *op : linalgOps) {
+    linalg::LinalgOp linalgOp = cast<linalg::LinalgOp>(op);
 #define DISPATCH(opName)                                                      \
-  if (auto lOp = dyn_cast<opName>(op.getOperation())) {                       \
+  if (auto lOp = dyn_cast<opName>(linalgOp.getOperation())) {                 \
+    if (rootOperation) {                                                      \
+      return lOp.emitError(                                                   \
+          "unhandled multiple root operations in dispatch region");           \
+    }                                                                         \
+    rootOperation = cast<linalg::LinalgOp>(lOp.getOperation());               \
+    TileSizesListType &tileSizesInfo = tileSizes[setKey(*rootOperation)];     \
     if (failed(getOpLaunchConfig(lOp, options, resourceLimits, tileSizesInfo, \
                                  workgroupSize, numSubgroups))) {             \
       return failure();                                                       \
@@ -439,5 +475,12 @@
   return success();
 }
 
+void LaunchConfig::finalize(FuncOp funcOp) {
+  funcOp.walk([&](linalg::LinalgOp linalgOp) {
+    linalgOp.removeAttr(Identifier::get(kLaunchInfoKey, funcOp.getContext()));
+    ;
+  });
+}
+
 }  // namespace iree_compiler
 }  // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
index d690a0a..919464e 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
@@ -26,6 +26,7 @@
 
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringMap.h"
+#include "llvm/Support/FormatVariadic.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/Support/LLVM.h"
 
@@ -39,7 +40,7 @@
 
 namespace linalg {
 class LinalgOp;
-}
+}  // namespace linalg
 namespace iree_compiler {
 struct SPIRVCodegenOptions;
 }
@@ -94,17 +95,26 @@
   /// - tile sizes for each level,
   /// - the workgroup size, and
   /// - number of subgroups to use.
-  LogicalResult init(const SPIRVCodegenOptions &options,
-                     ArrayRef<linalg::LinalgOp> linalgOps);
+  LogicalResult init(MLIRContext *context, const SPIRVCodegenOptions &options,
+                     ArrayRef<Operation *> linalgOps);
+
+  /// Remove attributed added to operations for retrieving tile size
+  /// information.
+  void finalize(FuncOp funcOp);
 
   /// Gets the tile size computed for an operation at all levels.
   TileSizesListType getTileSizes(Operation *op) const {
-    return tileSizes.lookup(op->getName().getStringRef());
+    auto key = getKey(op);
+    if (!key) return {};
+    auto it = tileSizes.find(*key);
+    return it->second;
   }
 
   /// Gets the tile size computed for an operation for an level.
   ArrayRef<int64_t> getTileSizes(Operation *op, size_t level) const {
-    auto it = tileSizes.find(op->getName().getStringRef());
+    auto key = getKey(op);
+    if (!key) return {};
+    auto it = tileSizes.find(*key);
     if (it == tileSizes.end() || level >= it->second.size()) return {};
     return it->second[level];
   }
@@ -115,14 +125,19 @@
   /// Returns the number of subgroups to use.
   ArrayRef<int64_t> getNumSubgroups() const { return numSubgroups; }
 
- protected:
-  /// Current tile size configuration per operation.
+  /// Returns true if tile sizes have been computed for the operation. If tile
+  /// sizes arent set, it implies operation is not to be tiled.
+  bool hasTileSizes(Operation *op, size_t level = 0) const {
+    return !getTileSizes(op, level).empty();
+  }
 
-  // TODO: For now just use the operation name for the mapping. The tile sizes
-  // will be selected only for operations like matmul, conv, pool, etc. and
-  // assume that there is only one such operation per dispatch
-  // region. Eventually this might need to be relaxed, and some name-marker
-  // based mechanism might be needed.
+ protected:
+  /// Current tile size configuration per operation. They key used here to
+  /// retrieve the tile size information per operation is the value of a StrAttr
+  /// added to operations during `init`. When tiled this attribute is copied
+  /// over to the tiled operation, thereby the same key can be used to retrieve
+  /// the tile sizes for the next level of tiling. The `finalize` method removes
+  /// these attributes.
   llvm::StringMap<TileSizesListType> tileSizes;
 
   /// Workgroup size to use.
@@ -130,6 +145,11 @@
 
   /// Number of subgroups that are logically distributed along x, y & z.
   std::array<int64_t, 3> numSubgroups;
+
+ private:
+  /// Retrieves the key to use to get the `tileSizes` for a given
+  /// `operation`. Returns llvm::None on failure.
+  Optional<StringRef> getKey(Operation *op) const;
 };
 
 }  // namespace iree_compiler
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index 5c5abbd..c2ec755 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -27,6 +27,7 @@
 #include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
 #include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
@@ -123,16 +124,19 @@
 
 namespace {
 /// Pattern for tiling operations. Updates the workgroup size in the surrounding
-/// function operation if tiling succeeds.
+/// function operation if tiling succeeds, and generates the function that
+/// computes the number of workgroups for the launch.
 template <typename LinalgOpTy>
 struct TileToWorkgroupsPattern : public linalg::LinalgBaseTilingPattern {
   using Base = linalg::LinalgBaseTilingPattern;
   TileToWorkgroupsPattern(MLIRContext *context,
+                          const linalg::LinalgDependenceGraph &dependenceGraph,
                           linalg::LinalgTilingOptions options,
                           linalg::LinalgMarker marker,
                           const LaunchConfig &launchConfig,
                           PatternBenefit benefit = 1)
       : Base(LinalgOpTy::getOperationName(), context, options, marker, benefit),
+        dependenceGraph(dependenceGraph),
         launchConfig(launchConfig) {}
 
   LogicalResult matchAndRewrite(Operation *op,
@@ -151,18 +155,59 @@
              launchConfig.getTileSizes(op, 0))))) {
       return failure();
     }
-    rewriter.eraseOp(op);
+    setMarker(op, getDeleteMarker());
     return success();
   }
 
+  const linalg::LinalgDependenceGraph &dependenceGraph;
+  const LaunchConfig &launchConfig;
+};
+
+/// Pattern for tile + fuse of operations. Updates the workgroup size in the
+/// surrounding function operation if tiling succeeds, and generates the
+/// function that computes the number of workgroups for the launch..
+template <typename LinalgOpTy>
+struct TileAndFuseToWorkgroupsPattern
+    : public linalg::LinalgTileAndFusePattern<LinalgOpTy> {
+  using Base = linalg::LinalgTileAndFusePattern<LinalgOpTy>;
+  TileAndFuseToWorkgroupsPattern(
+      MLIRContext *context,
+      const linalg::LinalgDependenceGraph &dependenceGraph,
+      linalg::LinalgTilingOptions tilingOptions, linalg::LinalgMarker marker,
+      const LaunchConfig &launchConfig, PatternBenefit benefit = 1)
+      : Base(context, dependenceGraph, tilingOptions,
+             linalg::LinalgFusionOptions().setIndicesToFuse({2}), marker,
+             marker,
+             linalg::LinalgMarker(ArrayRef<Identifier>(),
+                                  Identifier::get(getDeleteMarker(), context)),
+             benefit),
+        dependenceGraph(dependenceGraph),
+        launchConfig(launchConfig) {}
+
+  virtual LogicalResult matchAndRewrite(Operation *op,
+                                        PatternRewriter &rewriter) const {
+    FuncOp funcOp = op->getParentOfType<FuncOp>();
+    linalg::LinalgOp linalgOp = cast<linalg::LinalgOp>(op);
+    if (!funcOp || !dependenceGraph.hasDependentOperations(linalgOp) ||
+        failed(Base::matchAndRewrite(op, rewriter)) ||
+        failed(updateWorkGroupSize(funcOp, launchConfig.getWorkgroupSize())) ||
+        (funcOp.getAttr(getNumWorkgroupsFnAttrName()) &&
+         failed(createNumWorkgroupsFromResultShape(
+             rewriter, linalgOp, funcOp, launchConfig.getTileSizes(op, 0))))) {
+      return failure();
+    }
+    return success();
+  }
+
+  const linalg::LinalgDependenceGraph &dependenceGraph;
   const LaunchConfig &launchConfig;
 };
 }  // namespace
 
 /// Populate patterns for first-level tiling.
 static void populateTilingToWorkgroupPatterns(
-    MLIRContext *context, const LaunchConfig &launchConfig,
-    OwningRewritePatternList &patterns) {
+    MLIRContext *context, const linalg::LinalgDependenceGraph &dependenceGraph,
+    const LaunchConfig &launchConfig, OwningRewritePatternList &patterns) {
   // Function to compute first level tiling values.
   std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>
       getOuterTileSizeFn =
@@ -178,13 +223,19 @@
     }
     return tileSizesVal;
   };
-  patterns.insert<TileToWorkgroupsPattern<linalg::ConvOp>,
-                  TileToWorkgroupsPattern<linalg::MatmulOp>,
+  patterns.insert<TileAndFuseToWorkgroupsPattern<linalg::BatchMatmulOp>,
+                  TileAndFuseToWorkgroupsPattern<linalg::ConvOp>,
+                  TileAndFuseToWorkgroupsPattern<linalg::MatmulOp>,
+                  TileAndFuseToWorkgroupsPattern<linalg::PoolingMaxOp>,
+                  TileAndFuseToWorkgroupsPattern<linalg::PoolingMinOp>,
+                  TileAndFuseToWorkgroupsPattern<linalg::PoolingSumOp>,
                   TileToWorkgroupsPattern<linalg::BatchMatmulOp>,
+                  TileToWorkgroupsPattern<linalg::ConvOp>,
+                  TileToWorkgroupsPattern<linalg::MatmulOp>,
                   TileToWorkgroupsPattern<linalg::PoolingMaxOp>,
                   TileToWorkgroupsPattern<linalg::PoolingMinOp>,
                   TileToWorkgroupsPattern<linalg::PoolingSumOp>>(
-      context,
+      context, dependenceGraph,
       linalg::LinalgTilingOptions()
           .setDistributionOptions(workgroupDistributionOptions)
           .setTileSizeComputationFunction(getOuterTileSizeFn)
@@ -380,9 +431,9 @@
     if (linalgOps.empty()) continue;
 
     LaunchConfig launchConfig;
-    SmallVector<linalg::LinalgOp, 4> linalgOpsVec(linalgOps.begin(),
-                                                  linalgOps.end());
-    if (failed(launchConfig.init(options, linalgOpsVec))) {
+    SmallVector<Operation *, 4> linalgOpsVec(linalgOps.begin(),
+                                             linalgOps.end());
+    if (failed(launchConfig.init(context, options, linalgOpsVec))) {
       funcOp.emitError("unable to find launch configuration");
       return signalPassFailure();
     }
@@ -406,11 +457,24 @@
       }
     });
 
-    OwningRewritePatternList firstLevelTilingPatterns;
-    populateTilingToWorkgroupPatterns(context, launchConfig,
-                                      firstLevelTilingPatterns);
-    applyPatternsAndFoldGreedily(funcOp, firstLevelTilingPatterns);
-    applyCanonicalizationPatterns(context, funcOp);
+    {
+      // Compute the Linalg Dependence Graph.
+      linalg::Aliases aliases;
+      linalg::LinalgDependenceGraph dependenceGraph =
+          linalg::LinalgDependenceGraph::buildDependenceGraph(aliases, funcOp);
+
+      OwningRewritePatternList firstLevelTilingPatterns;
+      populateTilingToWorkgroupPatterns(context, dependenceGraph, launchConfig,
+                                        firstLevelTilingPatterns);
+      applyPatternsAndFoldGreedily(funcOp, firstLevelTilingPatterns);
+      applyCanonicalizationPatterns(context, funcOp);
+
+      // Delete the ops that are marked for deletion.
+      funcOp.walk([](linalg::LinalgOp linalgOp) {
+        if (hasMarker(linalgOp.getOperation(), getDeleteMarker()))
+          linalgOp.getOperation()->erase();
+      });
+    }
 
     if (options.useWorkgroupMemory) {
       // The promotion patterns are put separate from the tiling patterns to
@@ -434,6 +498,13 @@
                                     vectorizationPatterns);
       applyPatternsAndFoldGreedily(funcOp, vectorizationPatterns);
     }
+
+    launchConfig.finalize(funcOp);
+    SmallVector<linalg::LinalgOp, 1> toDelete;
+    funcOp.walk([&](linalg::LinalgOp linalgOp) {
+      if (hasMarker(linalgOp.getOperation(), getDeleteMarker()))
+        linalgOp.erase();
+    });
   }
 }
 
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
index 51d3585..dff5292 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
@@ -27,6 +27,8 @@
 const StringLiteral VectorTransforms::kVectorTransformMarker =
     "__internal_vector_transform__";
 
+StringRef getFusedMarker() { return "fused_numprocs_ge_numiters"; }
+
 StringRef getWorkgroupMarker() { return "workgroup"; }
 
 StringRef getWorkgroupMemoryMarker() { return "workgroup_memory"; }
@@ -45,6 +47,8 @@
 
 StringRef getVectorizeMarker() { return "vectorize"; }
 
+StringRef getDeleteMarker() { return "delete"; }
+
 bool hasMarker(Operation *op, ArrayRef<StringRef> marker) {
   StringAttr attr = op->getAttrOfType<StringAttr>(
       linalg::LinalgTransforms::kLinalgTransformMarker);
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
index 72d5385..78d4304 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
@@ -44,6 +44,12 @@
 /// Marker for operations that are going to be vectorized.
 StringRef getVectorizeMarker();
 
+/// Marker for tagging an operation for deletion. Tile and fuse pattern does not
+/// delete the original operation to not invalidate the
+/// `linalg::LinalgDependenceGraph` data structure. Instead it is marked with a
+/// marker that can be used later to delete these operations.
+StringRef getDeleteMarker();
+
 /// Returns true if an operation has the specified `marker`. When `marker` is
 /// empty, returns true if the operation has any marker.
 bool hasMarker(Operation *, ArrayRef<StringRef> markers = {});
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
index 04044a7..d751065 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
@@ -85,21 +85,17 @@
   //   - All Linalg ops have buffer semantics.
   //
   // Post-conditions:
+  //   - The operations that cannot be fused at buffer levels are split into
+  //     separate entry points.
   //   - If the input Linalg ops are tilable:
   //     - loop.parallel ops are generated for mapping to workgroups.
   //     - Linalg ops are nested inside loop.parallel ops and ready for mapping
   //       to workitems.
+  //     - If multiple linalg operations are present they get tiled and fused to
+  //       get outer loop.parallel ops which can be mapped to workitems.
   //   - Otherwise:
   //     - The Linalg op is kept untouched.
-  //   - Dispatch functions might be split into multiple ones.
   //
-  // Note:
-  //   We first try to tile and fuse the dispatch function as a whole. If there
-  //   are multiple Linalg ops inside, they may not share any number of common
-  //   outer parallel iterators. Then the first tile and fuse pass will do
-  //   nothing. We split all the Linalg ops into their own dispatch functions
-  //   afterwards. This gives each Linalg op a second chance to be tiled,
-  //   with the second tile and fuse pass.
   //===--------------------------------------------------------------------===//
   pm.addPass(createSplitDispatchFunctionPass());
   pm.addPass(createLinalgTileAndFusePass(options));
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
index 94436c6..7e8d1aa 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
@@ -35,6 +35,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/FormatVariadic.h"
+#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -55,37 +56,94 @@
 // Utility functions
 //===----------------------------------------------------------------------===//
 
-namespace {
+/// Returns true if an op can be fused with the list of ops that are to be put
+/// in the same entry point function. This should be consistent with whatthe
+/// downstream passes can handle.
+static bool isFusableWithCurrentOpsList(
+    Operation *nextOp, ArrayRef<Operation *> currOpsList,
+    const linalg::LinalgDependenceGraph &dependenceGraph) {
+  if (currOpsList.empty()) return true;
 
-/// Returns true if the Linalg ops can be separated to multiple kernels.
-bool canSeparateOps(ArrayRef<Operation *> ops) {
-  if (llvm::any_of(ops, [](Operation *op) {
-        if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op))
-          return !linalgOp.hasBufferSemantics();
-        return false;
-      }))
-    return false;
+  linalg::LinalgOp dstOp = dyn_cast<linalg::LinalgOp>(nextOp);
+  linalg::LinalgOp srcOp = dyn_cast<linalg::LinalgOp>(currOpsList.back());
+  if (dstOp && srcOp) {
+    // TODO(#2963): This splits independent linalg opreations into its own
+    // dispatch, but in reality if the iteration domain of the ops are the same,
+    // and they have all iterator types parallel, they could be put in the same
+    // dispatch region.
+    if (!dependenceGraph.hasDependenceFrom(srcOp, dstOp)) return false;
 
-  // Require no other non-metadata ops interleave with Linalg structured ops for
-  // now. This is the common case and it simplifies further analysis.
+#define ADD_FUSABLE_PAIR(SrcOpTy, DstOpTy, DependenceTy)             \
+  if (isa<SrcOpTy>(srcOp.getOperation()) &&                          \
+      isa<DstOpTy>(dstOp.getOperation()) &&                          \
+      dependenceGraph.hasDependenceFrom(srcOp, dstOp, DependenceTy)) \
+    return true;
+
+    ADD_FUSABLE_PAIR(linalg::FillOp, linalg::ConvOp,
+                     linalg::LinalgDependenceGraph::DependenceType::WAW)
+    ADD_FUSABLE_PAIR(linalg::FillOp, linalg::MatmulOp,
+                     linalg::LinalgDependenceGraph::DependenceType::WAW)
+    ADD_FUSABLE_PAIR(linalg::FillOp, linalg::PoolingMaxOp,
+                     linalg::LinalgDependenceGraph::DependenceType::WAW)
+    ADD_FUSABLE_PAIR(linalg::FillOp, linalg::PoolingMinOp,
+                     linalg::LinalgDependenceGraph::DependenceType::WAW)
+    ADD_FUSABLE_PAIR(linalg::FillOp, linalg::PoolingSumOp,
+                     linalg::LinalgDependenceGraph::DependenceType::WAW)
+
+#undef ADD_FUSABLE_PAIR
+  }
+  return false;
+}
+
+/// For the list of operations in `ops` returns a list of lists where each list
+/// contains the operations that need to be put in a separate dispatch function.
+static LogicalResult separateOps(
+    ArrayRef<Operation *> ops,
+    const linalg::LinalgDependenceGraph &dependenceGraph,
+    SmallVectorImpl<SmallVector<Operation *, 1>> &fusedOpList) {
+  assert(!ops.empty() &&
+         "expected at least one separable op for splitting dispatch function");
+  SmallVector<Operation *, 1> currList;
   for (auto currOp = ops.begin(), nextOp = std::next(ops.begin());
        nextOp != ops.end(); ++currOp, ++nextOp) {
+    // Check that the operation has buffer semantics.
+    if (auto linalgOp = dyn_cast<linalg::LinalgOp>(*currOp)) {
+      if (!linalgOp.hasBufferSemantics()) return failure();
+    }
+
+    // Require no other non-metadata ops interleave with Linalg structured ops
+    // for now. This is the common case and it simplifies further analysis.
     Operation *iter = (*currOp)->getNextNode();
     while (iter != *nextOp && (MemoryEffectOpInterface::hasNoEffect(iter) ||
                                isa<IREE::PlaceholderOp>(iter)))
       iter = iter->getNextNode();
-    if (iter != *nextOp) return false;
-  }
+    if (iter != *nextOp) return failure();
 
-  return true;
+    currList.push_back(*currOp);
+
+    // If the nextOp is not fusible with the currOp, then record the list of ops
+    // so far, and start a new list.
+    if (isFusableWithCurrentOpsList(*nextOp, currList, dependenceGraph)) {
+      continue;
+    }
+
+    // Push the current list of ops into the list of lists `currList` and
+    // start a new list.
+    fusedOpList.emplace_back();
+    std::swap(fusedOpList.back(), currList);
+  }
+  currList.push_back(ops.back());
+  fusedOpList.emplace_back(std::move(currList));
+  return success();
 }
 
 /// Recursively collects all the operations that are referenced by given
 /// `rootOp` into `closure`.
-void collectAllReferencedOps(Operation *rootOp,
-                             llvm::SmallPtrSetImpl<Operation *> &closure) {
+static void collectAllReferencedOps(
+    ArrayRef<Operation *> rootOps,
+    llvm::SmallPtrSetImpl<Operation *> &closure) {
   llvm::SmallVector<Operation *, 8> workList;
-  workList.push_back(rootOp);
+  workList.assign(rootOps.begin(), rootOps.end());
 
   while (!workList.empty()) {
     Operation *curOp = workList.pop_back_val();
@@ -104,8 +162,6 @@
   }
 }
 
-}  // namespace
-
 //===----------------------------------------------------------------------===//
 // Pass and patterns
 //===----------------------------------------------------------------------===//
@@ -161,10 +217,16 @@
       separableOps.push_back(&op);
 
   if (separableOps.size() <= 1) return success();
-  if (!canSeparateOps(separableOps)) {
+
+  linalg::Aliases aliases;
+  linalg::LinalgDependenceGraph dependenceGraph =
+      linalg::LinalgDependenceGraph::buildDependenceGraph(aliases, oldFn);
+  SmallVector<SmallVector<Operation *, 1>, 1> fusedOpsList;
+  if (failed(separateOps(separableOps, dependenceGraph, fusedOpsList))) {
     return oldFn.emitError(
         "cannot separate Linalg/Parallel ops into multiple kernels");
   }
+  if (fusedOpsList.size() <= 1) return success();
 
   ModuleOp moduleOp = cast<ModuleOp>(oldFn.getParentOp());
   Block &oldFnBlock = oldFn.getBlocks().front();
@@ -174,10 +236,11 @@
   splitKernels.reserve(separableOps.size());
   llvm::SmallPtrSet<Operation *, 16> closure;
 
-  for (const auto &separableOp : llvm::enumerate(separableOps)) {
+  for (const auto &fusedOps : llvm::enumerate(fusedOpsList)) {
+    if (fusedOps.value().empty()) continue;
     // Create a new function for hosting this op.
-    splitKernels.emplace_back(llvm::formatv("{0}_dispatch_{1}", oldFn.getName(),
-                                            separableOp.index()));
+    splitKernels.emplace_back(
+        llvm::formatv("{0}_dispatch_{1}", oldFn.getName(), fusedOps.index()));
     StringRef newFnName = splitKernels.back();
     builder.setInsertionPointToStart(moduleOp.getBody());
     auto newFn = builder.create<FuncOp>(loc, newFnName, oldFn.getType());
@@ -210,7 +273,7 @@
 
     // Collect the closure for the current Linalg op.
     closure.clear();
-    collectAllReferencedOps(separableOp.value(), closure);
+    collectAllReferencedOps(fusedOps.value(), closure);
 
     // Clone all ops in the closure to the new function.
     Block *newFnBlock = newFn.addEntryBlock();
@@ -219,7 +282,7 @@
     for (Operation &op : oldFnBlock) {
       if (closure.count(&op) == 0) continue;
       builder.insert(op.clone(remapper));
-      if (&op == separableOp.value()) break;
+      if (&op == fusedOps.value().back()) break;
     }
     builder.insert(oldFnBlock.getTerminator()->clone(remapper));
   }
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
index 9737bb6..52aeff7 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
@@ -1,13 +1,13 @@
 // RUN: iree-opt -split-input-file -iree-codegen-convert-to-spirv %s | IreeFileCheck %s
 
 module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
-  // CHECK: spv.globalVariable @__push_constant_var__ : !spv.ptr<!spv.struct<!spv.array<5 x i32, stride=4> [0]>, PushConstant>
+  // CHECK: spv.globalVariable @__push_constant_var__ : !spv.ptr<!spv.struct<(!spv.array<5 x i32, stride=4> [0])>, PushConstant>
   // CHECK: spv.func @push_constant()
   func @push_constant() {
     // CHECK: %[[INDEX_0:.+]] = spv.constant 0 : i32
     // CHECK: %[[INDEX_1:.+]] = spv.constant 2 : i32
-    // CHECK: %[[ADDR:.+]] = spv._address_of @__push_constant_var__ : !spv.ptr<!spv.struct<!spv.array<5 x i32, stride=4> [0]>, PushConstant>
-    // CHECK: %[[AC:.+]] = spv.AccessChain %[[ADDR]][%[[INDEX_0]], %[[INDEX_1]]] : !spv.ptr<!spv.struct<!spv.array<5 x i32, stride=4> [0]>, PushConstant>
+    // CHECK: %[[ADDR:.+]] = spv._address_of @__push_constant_var__ : !spv.ptr<!spv.struct<(!spv.array<5 x i32, stride=4> [0])>, PushConstant>
+    // CHECK: %[[AC:.+]] = spv.AccessChain %[[ADDR]][%[[INDEX_0]], %[[INDEX_1]]] : !spv.ptr<!spv.struct<(!spv.array<5 x i32, stride=4> [0])>, PushConstant>
     // CHECK: spv.Load "PushConstant" %[[AC]] : i32
     %0 = hal.interface.load.constant offset = 2 : index
     return
@@ -23,12 +23,12 @@
 
 module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
 
-  // CHECK: spv.globalVariable @__resource_var_3_4__ bind(3, 4) : !spv.ptr<!spv.struct<!spv.array<16 x f32, stride=4> [0]>, StorageBuffer>
-  // CHECK: spv.globalVariable @__resource_var_1_2__ bind(1, 2) : !spv.ptr<!spv.struct<!spv.array<16 x f32, stride=4> [0]>, StorageBuffer>
+  // CHECK: spv.globalVariable @__resource_var_3_4__ bind(3, 4) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+  // CHECK: spv.globalVariable @__resource_var_1_2__ bind(1, 2) : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
   // CHECK: spv.func @resource_variable()
   func @resource_variable() {
-    // CHECK: spv._address_of @__resource_var_1_2__ : !spv.ptr<!spv.struct<!spv.array<16 x f32, stride=4> [0]>, StorageBuffer>
-    // CHECK: spv._address_of @__resource_var_3_4__ : !spv.ptr<!spv.struct<!spv.array<16 x f32, stride=4> [0]>, StorageBuffer>
+    // CHECK: spv._address_of @__resource_var_1_2__ : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+    // CHECK: spv._address_of @__resource_var_3_4__ : !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, StorageBuffer>
     %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4x4xf32>
     %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4x4xf32>
     return
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
index 0d5adc6..580d32a 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
@@ -302,3 +302,114 @@
 //       CHECK:   %[[T1:.+]] = addi %[[DIM0]], %[[C3]]
 //   CHECK-DAG:   %[[NBY:.+]] = divi_signed %[[T1]], %[[C4]]
 //       CHECK:   return %[[NBX]], %[[NBY]], %[[C1]]
+
+// -----
+
+module attributes {
+  spv.target_env =
+    #spv.target_env<#spv.vce<v1.3,
+    [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+  func @matmul_fusion() attributes {vkspv.num_workgroups_fn = @matmul_fusion__num_workgroups__} {
+    %0 = iree.placeholder for "interace buffer"
+      {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<?x?xf32>
+    %1 = iree.placeholder for "interace buffer"
+      {binding = @legacy_io::@arg1, operand_result_index = 1 : i32} : memref<?x?xf32>
+    %2 = iree.placeholder for "interace buffer"
+      {binding = @legacy_io::@ret0, operand_result_index = 2 : i32} : memref<?x?xf32>
+    %cst = constant 0.000000e+00 : f32
+    linalg.fill(%2, %cst) : memref<?x?xf32>, f32
+    linalg.matmul ins(%0, %1 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%2 : memref<?x?xf32>)
+    return
+  }
+  func @matmul_fusion__num_workgroups__
+    (!shapex.ranked_shape<[?,?]>, !shapex.ranked_shape<[?,?]>,
+     !shapex.ranked_shape<[?,?]>) -> (index, index, index)
+    attributes {sym_visibility = "private"}
+  hal.interface @legacy_io attributes {sym_visibility = "private"} {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+  }
+}
+
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 8)>
+//   CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 16)>
+//       CHECK: func @matmul_fusion()
+//  CHECK-SAME:   local_size = dense<[16, 8, 1]>
+//   CHECK-DAG:   %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
+//   CHECK-DAG:   %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
+//   CHECK-DAG:   %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
+//   CHECK-DAG:   %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
+//   CHECK-DAG:   %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
+//   CHECK-NOT:   scf.parallel
+//   CHECK-NOT:   scf.for
+//       CHECK:   %[[LBY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+//       CHECK:   %[[VIEW0:.+]] = subview %[[ARG0]][%[[LBY]], 0]
+//       CHECK:   %[[LBX:.+]] = affine.apply #[[MAP3]]()[%[[BIDX]]]
+//       CHECK:   %[[VIEW1:.+]] = subview %[[ARG1]][0, %[[LBX]]]
+//       CHECK:   %[[LBY_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+//       CHECK:   %[[LBX_2:.+]] = affine.apply #[[MAP3]]()[%[[BIDX]]]
+//       CHECK:   %[[VIEW2:.+]] = subview %[[RET0]][%[[LBY_2]], %[[LBX_2]]]
+//       CHECK:   %[[VIEW3:.+]] = subview %[[RET0]][%[[LBY_2]], %[[LBX_2]]]
+//       CHECK:   linalg.fill(%[[VIEW3]], %{{.+}})
+//  CHECK-SAME:     "workgroup"
+//       CHECK:   linalg.matmul
+//  CHECK-SAME:     "workgroup"
+//  CHECK-SAME:     ins(%[[VIEW0]], %[[VIEW1]]
+//  CHECK-SAME:     outs(%[[VIEW2]]
+
+// -----
+
+module attributes {
+  spv.target_env =
+    #spv.target_env<#spv.vce<v1.3,
+    [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+  func @conv_no_padding_fusion() {
+    %0 = iree.placeholder for "interace buffer"
+      {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<?x?x?x?xf32>
+    %1 = iree.placeholder for "interace buffer"
+      {binding = @legacy_io::@arg1, operand_result_index = 1 : i32} : memref<?x?x?x?xf32>
+    %2 = iree.placeholder for "interace buffer"
+      {binding = @legacy_io::@ret0, operand_result_index = 2 : i32} : memref<?x?x?x?xf32>
+    %cst = constant 0.000000e+00 : f32
+    linalg.fill(%2, %cst) : memref<?x?x?x?xf32>, f32
+    linalg.conv(%0, %1, %2) {dilations = [1, 1], strides = [1, 1]} :
+      memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
+    return
+  }
+  hal.interface @legacy_io attributes {sym_visibility = "private"} {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+  }
+}
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 4)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 32)>
+//       CHECK: func @conv_no_padding_fusion()
+//  CHECK-SAME:   local_size = dense<[32, 4, 1]>
+//   CHECK-DAG:   %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
+//   CHECK-DAG:   %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
+//   CHECK-DAG:   %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
+//   CHECK-DAG:   %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
+//   CHECK-DAG:   %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
+//   CHECK-DAG:   %[[BIDZ:.+]] = "gpu.block_id"() {dimension = "z"}
+//       CHECK:   %[[LBY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+//       CHECK:   %[[LBX:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]]
+//       CHECK:   %[[VIEW1:.+]] = subview %[[ARG1]]
+//  CHECK-SAME:     [%[[BIDZ]], %[[LBY]], %[[LBX]], 0]
+//       CHECK:   %[[LBY_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+//       CHECK:   %[[LBX_2:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]]
+//       CHECK:   %[[VIEW2:.+]] = subview %[[RET0]]
+//  CHECK-SAME:     [%[[BIDZ]], %[[LBY_2]], %[[LBX_2]], 0]
+//       CHECK:   %[[VIEW3:.+]] = subview %[[RET0]]
+//  CHECK-SAME:     [%[[BIDZ]], %[[LBY_2]], %[[LBX_2]], 0]
+//       CHECK:   linalg.fill(%[[VIEW3]], %{{.*}})
+//  CHECK-SAME:     "workgroup"
+//       CHECK:   linalg.conv
+//  CHECK-SAME:     %[[ARG0]], %[[VIEW1]], %[[VIEW2]]
+//  CHECK-SAME:     "workgroup"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
index 5d3ff91..e2e8d7e 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
@@ -1,8 +1,117 @@
 // RUN: iree-opt -split-input-file -iree-codegen-split-dispatch-function -verify-diagnostics %s | IreeFileCheck %s
 
+module {
+  // CHECK: func @kernel_fusable_fill_conv_ops
+  // CHECK:   linalg.fill
+  // CHECK:   linalg.conv
+
+  func @kernel_fusable_fill_conv_ops()
+  attributes {vkspv.num_workgroups_fn = @kernel_fusable_fill_conv_ops_num_workgroups__} {
+    %cst = constant 0.000000e+00 : f32
+    %dim = hal.interface.load.constant offset = 0 : index
+    %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,2,2,512]>
+    %shape2 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,1,1,512]>
+    %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x2x2x512xf32>
+    %ts1 = shapex.tie_shape %0, %shape1 : memref<?x2x2x512xf32>, !shapex.ranked_shape<[?,2,2,512]>
+    %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
+    %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
+    %ts2 = shapex.tie_shape %2, %shape2 : memref<?x1x1x512xf32>, !shapex.ranked_shape<[?,1,1,512]>
+    linalg.fill(%ts2, %cst) : memref<?x1x1x512xf32>, f32
+    linalg.conv(%1, %ts1, %ts2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<?x2x2x512xf32>, memref<?x1x1x512xf32>
+    return
+  }
+  func @kernel_fill_conv_ops_num_workgroups__(!shapex.ranked_shape<[?,2,2,512]>,
+                                              !shapex.ranked_shape<[3,3,512,1]>,
+                                              !shapex.ranked_shape<[?,1,1,512]>)
+                                             -> (index, index, index)
+  attributes {sym_visibility = "private"}
+  hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+  }
+}
+
+// -----
+
+module {
+  // CHECK: func @kernel_fusable_fill_matmul_ops
+  // CHECK:   linalg.fill
+  // CHECK:   linalg.matmul
+
+  func @kernel_fusable_fill_matmul_ops()
+  attributes {vkspv.num_workgroups_fn = @kernel_fusable_fill_matmul_ops_num_workgroups__} {
+    %cst = constant 0.000000e+00 : f32
+    %dimM = hal.interface.load.constant offset = 0 : index
+    %dimN = hal.interface.load.constant offset = 1 : index
+    %shape1 = shapex.make_ranked_shape %dimM : (index) -> !shapex.ranked_shape<[?,512]>
+    %shape2 = shapex.make_ranked_shape %dimN : (index) -> !shapex.ranked_shape<[512,?]>
+    %shape3 = shapex.make_ranked_shape %dimM, %dimN : (index, index) -> !shapex.ranked_shape<[?,?]>
+    %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x512xf32>
+    %ts1 = shapex.tie_shape %0, %shape1 : memref<?x512xf32>, !shapex.ranked_shape<[?,512]>
+    %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<512x?xf32>
+    %ts2 = shapex.tie_shape %1, %shape2 : memref<512x?xf32>, !shapex.ranked_shape<[512, ?]>
+    %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x?xf32>
+    %ts3 = shapex.tie_shape %2, %shape3 : memref<?x?xf32>, !shapex.ranked_shape<[?,?]>
+    linalg.fill(%ts3, %cst) : memref<?x?xf32>, f32
+    linalg.matmul ins(%ts1, %ts2 : memref<?x512xf32>, memref<512x?xf32>)
+                  outs(%ts3 : memref<?x?xf32>)
+    return
+  }
+  func @kernel_fusable_matmul_ops_num_workgroups__(!shapex.ranked_shape<[?,512]>,
+                                                   !shapex.ranked_shape<[512,?]>,
+                                                   !shapex.ranked_shape<[?,?]>)
+                                                  -> (index, index, index)
+  attributes {sym_visibility = "private"}
+  hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+  }
+}
+
+// -----
+
+module {
+  // CHECK: func @kernel_fusable_pooling()
+  // CHECK:   linalg.fill
+  // CHECK:   linalg.pooling
+  func @kernel_fusable_pooling() attributes {vkspv.num_workgroups_fn = @kernel_fusable_pooling__num_workgroups__} {
+    %cst = constant 0.000000e+00 : f32
+    %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x?xf32>
+    %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<?x?xf32>
+    %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x?xf32>
+    linalg.fill(%2, %cst) : memref<?x?xf32>, f32
+    linalg.pooling_sum(%1, %0, %2) {dilations = [1, 1], strides = [1, 1]} :
+      memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
+    return
+  }
+  func @kernel_fusable_pooling__num_workgroups__(!shapex.ranked_shape<[?,?]>,
+                                                 !shapex.ranked_shape<[?,?]>,
+                                                 !shapex.ranked_shape<[?,?]>)
+                                                -> (index, index, index)
+  attributes {sym_visibility = "private"}
+  hal.interface @legacy_io attributes {sym_visibility = "private"} {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+  }
+}
+
+// -----
+
 // CHECK: module attributes {vkspv.entry_point_schedule = ["kernel_dispatch_0", "kernel_dispatch_1"]}
 module {
   // CHECK: func @kernel_dispatch_1()
+  // CHECK:   %[[ZERO:.+]] = constant
+  // CHECK:   %[[DIM:.+]] = hal.interface.load.constant
+  // CHECK:   %[[SHAPE:.+]] = shapex.make_ranked_shape %[[DIM]]
+  // CHECK:   %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
+  // CHECK:   %[[TS:.+]] = shapex.tie_shape %[[OUT]], %[[SHAPE]]
+  // CHECK:   linalg.fill(%[[TS]], %[[ZERO]])
+  // CHECK:   return
+
+  // CHECK: func @kernel_dispatch_0()
   // CHECK:   %[[DIM:.+]] = hal.interface.load.constant
   // CHECK:   %[[SHAPE1:.+]] = shapex.make_ranked_shape %[[DIM]]
   // CHECK:   %[[SHAPE2:.+]] = shapex.make_ranked_shape %[[DIM]]
@@ -14,15 +123,6 @@
   // CHECK:   linalg.conv(%[[IN2]], %[[TS1]], %[[TS2]])
   // CHECK:   return
 
-  // CHECK: func @kernel_dispatch_0()
-  // CHECK:   %[[ZERO:.+]] = constant
-  // CHECK:   %[[DIM:.+]] = hal.interface.load.constant
-  // CHECK:   %[[SHAPE:.+]] = shapex.make_ranked_shape %[[DIM]]
-  // CHECK:   %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
-  // CHECK:   %[[TS:.+]] = shapex.tie_shape %[[OUT]], %[[SHAPE]]
-  // CHECK:   linalg.fill(%[[TS]], %[[ZERO]])
-  // CHECK:   return
-
   func @kernel() attributes {vkspv.num_workgroups_fn = @kernel__num_workgroups__} {
     %cst = constant 0.000000e+00 : f32
     %dim = hal.interface.load.constant offset = 0 : index
@@ -33,8 +133,8 @@
     %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
     %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
     %ts2 = shapex.tie_shape %2, %shape2 : memref<?x1x1x512xf32>, !shapex.ranked_shape<[?,1,1,512]>
-    linalg.fill(%ts2, %cst) : memref<?x1x1x512xf32>, f32
     linalg.conv(%1, %ts1, %ts2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<?x2x2x512xf32>, memref<?x1x1x512xf32>
+    linalg.fill(%ts2, %cst) : memref<?x1x1x512xf32>, f32
     return
   }
   func @kernel__num_workgroups__(!shapex.ranked_shape<[?,2,2,512]>,
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp
index 62d9433..f2b0b27 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp
@@ -349,8 +349,9 @@
           dispatchOp, dispatchOp.executable()));
 
   // TODO(benvanik): support multiple interfaces. We'd probably want to
-  // store each executable+interface as a variable.
-  auto interfaceOp = executableOp.getInterfaceOp();
+  // store each executable+interface as a variable, or follow interface
+  // references stored on entry points.
+  auto interfaceOp = executableOp.getFirstInterfaceOp();
   auto executableLayout =
       rewriter.createOrFold<IREE::HAL::ExecutableLayoutLookupOp>(
           dispatchOp.getLoc(),
@@ -415,9 +416,9 @@
   dispatchState.results = resultAdaptors;
 
   // Ask each target backend to record their dispatch logic.
-  IREE::HAL::DeviceSwitchBuilder switchBuilder(dispatchOp.getLoc(),
-                                               /*resultTypes=*/TypeRange{},
-                                               device, rewriter);
+  IREE::HAL::DeviceSwitchRewriter switchRewriter(dispatchOp.getLoc(),
+                                                 /*resultTypes=*/TypeRange{},
+                                                 device, rewriter);
   for (auto targetOp :
        executableOp.getBlock().getOps<IREE::HAL::ExecutableTargetOp>()) {
     for (auto &targetBackend : IREE::HAL::matchTargetBackends(
@@ -432,15 +433,15 @@
       // sequence them together during the call to |recordDispatch| below.
       dispatchState.entryPointOp = *entryPointOps.begin();
 
-      if (failed(targetBackend->recordDispatch(dispatchOp.getLoc(),
-                                               dispatchState, switchBuilder))) {
+      if (failed(targetBackend->recordDispatch(
+              dispatchOp.getLoc(), dispatchState, switchRewriter))) {
         return dispatchOp.emitError()
                << "unable to record dispatch for target backend "
                << targetBackend->name();
       }
     }
   }
-  switchBuilder.build();
+  switchRewriter.build();
 
   // Full barriers for now as we aren't scheduling things.
   // TODO(benvanik): don't add at the end of the command buffer (we could
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index 7550964..bc5e21f 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -1134,9 +1134,10 @@
 // hal.executable
 //===----------------------------------------------------------------------===//
 
-InterfaceOp ExecutableOp::getInterfaceOp() {
+InterfaceOp ExecutableOp::getFirstInterfaceOp() {
   auto interfaceOps = llvm::to_vector<1>(getBlock().getOps<InterfaceOp>());
-  assert(interfaceOps.size() == 1 && "executable must have one interface");
+  assert(!interfaceOps.empty() &&
+         "executable must have at least one interface");
   return interfaceOps.front();
 }
 
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.td b/iree/compiler/Dialect/HAL/IR/HALOps.td
index 6234e49..1541a46 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -1601,7 +1601,7 @@
 
   let extraClassDeclaration = [{
     Block& getBlock() { return body().front(); }
-    IREE::HAL::InterfaceOp getInterfaceOp();
+    IREE::HAL::InterfaceOp getFirstInterfaceOp();
   }];
 
   let verifier = [{ return verifyExecutableOp(*this); }];
diff --git a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp
index 2b18457..25cea52 100644
--- a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp
@@ -125,7 +125,7 @@
 
 LogicalResult SPIRVTargetBackend::recordDispatch(
     Location loc, DispatchState dispatchState,
-    DeviceSwitchBuilder &switchBuilder) {
+    DeviceSwitchRewriter &switchRewriter) {
   // Multiple entry points might be generated for a single dispatch function.
   // Under such circumstances, we will have a special attribute indicating the
   // schedule of the split entry points. Try to see if we can find such
@@ -177,7 +177,7 @@
     }
   }
 
-  auto *region = switchBuilder.addConditionRegion(
+  auto *region = switchRewriter.addConditionRegion(
       IREE::HAL::DeviceMatchIDAttr::get(filter_pattern(), loc.getContext()),
       {
           dispatchState.workload,
@@ -185,7 +185,7 @@
       });
 
   auto &entryBlock = region->front();
-  ConversionPatternRewriter &rewriter = switchBuilder.getRewriter();
+  ConversionPatternRewriter &rewriter = switchRewriter.getRewriter();
   OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPointToEnd(&entryBlock);
   auto commandBuffer = entryBlock.getArgument(1);
@@ -214,7 +214,7 @@
              << " that computes the number of workgroups to use";
     }
     workgroupCount = calculateWorkgroupCountFromNumWorkgroupsFn(
-        loc, numWorkgroupsFn, dispatchState.executableOp.getInterfaceOp(),
+        loc, numWorkgroupsFn, dispatchState.executableOp.getFirstInterfaceOp(),
         dispatchState.operands, dispatchState.results, rewriter);
 
     if (llvm::any_of(workgroupCount,
diff --git a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h
index d24409b..83950b0 100644
--- a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h
+++ b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h
@@ -40,7 +40,7 @@
                                     OpPassManager &passManager) override;
 
   LogicalResult recordDispatch(Location loc, DispatchState dispatchState,
-                               DeviceSwitchBuilder &switchBuilder) override;
+                               DeviceSwitchRewriter &switchRewriter) override;
 
   // Finds the spv.ExecutionMode operation to get the workgroup size from.
   std::array<Value, 3> calculateDispatchWorkgroupSize(
diff --git a/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp b/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp
index b126991..882c7aa 100644
--- a/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp
+++ b/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp
@@ -133,8 +133,8 @@
 
 LogicalResult TargetBackend::recordDispatch(
     Location loc, DispatchState dispatchState,
-    DeviceSwitchBuilder &switchBuilder) {
-  auto *region = switchBuilder.addConditionRegion(
+    DeviceSwitchRewriter &switchRewriter) {
+  auto *region = switchRewriter.addConditionRegion(
       IREE::HAL::DeviceMatchIDAttr::get(filter_pattern(), loc.getContext()),
       {
           dispatchState.workload,
diff --git a/iree/compiler/Dialect/HAL/Target/TargetBackend.h b/iree/compiler/Dialect/HAL/Target/TargetBackend.h
index 2f0a8c5..1cd4b4f 100644
--- a/iree/compiler/Dialect/HAL/Target/TargetBackend.h
+++ b/iree/compiler/Dialect/HAL/Target/TargetBackend.h
@@ -237,7 +237,7 @@
   // such as by inserting an  `hal.command_buffer.execution_barrier`.
   virtual LogicalResult recordDispatch(Location loc,
                                        DispatchState dispatchState,
-                                       DeviceSwitchBuilder &switchBuilder);
+                                       DeviceSwitchRewriter &switchRewriter);
 
   // Inserts passes used to translate the `hal.executable.target` op contents.
   // The pass manager will be nested on `hal.executable` such that the pipeline
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
index ae429f5..752a6fd 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
@@ -26,9 +26,11 @@
 #include "iree/schemas/vmla_executable_def_generated.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/raw_ostream.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Module.h"
+#include "mlir/IR/OperationSupport.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Support/LogicalResult.h"
 
@@ -37,6 +39,29 @@
 namespace IREE {
 namespace HAL {
 
+namespace {
+
+bool AreInterfacesEquivalent(IREE::HAL::InterfaceOp lhs,
+                             IREE::HAL::InterfaceOp rhs) {
+  auto lhsBindings = lhs.getBlock().getOps<IREE::HAL::InterfaceBindingOp>();
+  auto rhsBindings = rhs.getBlock().getOps<IREE::HAL::InterfaceBindingOp>();
+  auto lhsIt = lhsBindings.begin(), lhsEnd = lhsBindings.end();
+  auto rhsIt = rhsBindings.begin(), rhsEnd = rhsBindings.end();
+  for (; lhsIt != lhsEnd && rhsIt != rhsEnd; ++lhsIt, ++rhsIt) {
+    // Assume bindings are in order, check equivalence of each pairing.
+    if (!OperationEquivalence::isEquivalentTo(*lhsIt, *rhsIt)) return false;
+  }
+
+  if (lhsIt != lhsEnd || rhsIt != rhsEnd) {
+    // Not finished iterating through one, number of interface bindings differ.
+    return false;
+  }
+
+  return true;
+}
+
+}  // namespace
+
 VMLATargetOptions getVMLATargetOptionsFromFlags() {
   VMLATargetOptions targetOptions;
   // TODO(benvanik): flags.
@@ -66,6 +91,169 @@
         passManager, IREE::VM::getTargetOptionsFromFlags());
   }
 
+  LogicalResult linkExecutables(mlir::ModuleOp moduleOp) override {
+    // --- Linking overview ---
+    //
+    // We start with a `module` containing multiple `hal.executable`s, each with
+    // potentially multiple `hal.executable.target`s. We want to move all
+    // compatible VMLA functions into a new "linked" executable, de-duping
+    // symbols, and updating references as we go.
+    //
+    // Sample IR after:
+    //   hal.executable @linked_vmla {
+    //     hal.interface @legacy_io_0 { ... }
+    //     hal.interface @legacy_io_1 { ... }
+    //     hal.executable.target @vmla, filter="vmla" {
+    //       hal.executable.entry_point @main_dispatch_0 attributes { ... }
+    //       hal.executable.entry_point @main_dispatch_1 attributes { ... }
+    //       hal.executable.entry_point @main_dispatch_2 attributes { ... }
+    //       module {
+    //         vm.module @module {
+    //           vm.func @main_0(...) { ... }
+    //           vm.func @main_1(...) { ... }
+    //           vm.func @main_2(...) { ... }
+    //         }
+    //       }
+    //     }
+    //   }
+    //   hal.executable @main_dispatch_0 {
+    //     hal.interface @legacy_io { ... }
+    //     hal.executable.target @other, filter="other" {
+    //       hal.executable.entry_point @main_dispatch_0 attributes { ... }
+    //       module { ... }
+    //     }
+    //   }
+
+    OpBuilder builder = OpBuilder::atBlockBegin(moduleOp.getBody());
+    auto executableOps = moduleOp.getOps<IREE::HAL::ExecutableOp>();
+
+    // Create our new "linked" hal.executable.
+    auto linkedExecutableOp = builder.create<IREE::HAL::ExecutableOp>(
+        moduleOp.getLoc(), "linked_vmla");
+    SymbolTable::setSymbolVisibility(linkedExecutableOp,
+                                     SymbolTable::Visibility::Private);
+    // Add our VMLA hal.executable.target with an empty module.
+    builder.setInsertionPointToStart(linkedExecutableOp.getBody());
+    auto linkedTargetOp = builder.create<IREE::HAL::ExecutableTargetOp>(
+        moduleOp.getLoc(), name(), filter_pattern());
+    builder.setInsertionPoint(&linkedTargetOp.getBlock().back());
+    auto linkedModuleOp = builder.create<ModuleOp>(moduleOp.getLoc());
+    // Add an empty vm.module to that module.
+    builder.setInsertionPointToStart(linkedModuleOp.getBody());
+    auto linkedVmModuleOp =
+        builder.create<IREE::VM::ModuleOp>(moduleOp.getLoc(), "linked_module");
+
+    int executablesLinked = 0;
+    llvm::SmallVector<IREE::HAL::InterfaceOp, 4> interfaceOps;
+    int nextEntryPointOrdinal = 0;
+    for (auto executableOp : executableOps) {
+      auto targetOps = llvm::to_vector<4>(
+          executableOp.getOps<IREE::HAL::ExecutableTargetOp>());
+      for (auto targetOp : targetOps) {
+        // Only process targets matching our pattern.
+        if (!matchPattern(targetOp.target_backend_filter(), filter_pattern())) {
+          continue;
+        }
+
+        IREE::HAL::InterfaceOp interfaceOpForExecutable;
+        for (auto interfaceOp : interfaceOps) {
+          if (AreInterfacesEquivalent(interfaceOp,
+                                      executableOp.getFirstInterfaceOp())) {
+            interfaceOpForExecutable = interfaceOp;
+          }
+        }
+        if (!interfaceOpForExecutable) {
+          builder.setInsertionPoint(linkedTargetOp);
+          interfaceOpForExecutable = dyn_cast<IREE::HAL::InterfaceOp>(
+              builder.clone(*executableOp.getFirstInterfaceOp()));
+          interfaceOpForExecutable.setName(
+              llvm::formatv("legacy_io_{0}", interfaceOps.size()).str());
+          interfaceOps.push_back(interfaceOpForExecutable);
+        }
+
+        // Clone entry point ops, remapping ordinals and updating symbol refs.
+        builder.setInsertionPoint(linkedModuleOp);
+        for (auto entryPointOp :
+             targetOp.getOps<IREE::HAL::ExecutableEntryPointOp>()) {
+          auto newEntryPointOp =
+              builder.create<IREE::HAL::ExecutableEntryPointOp>(
+                  entryPointOp.getLoc(), entryPointOp.sym_nameAttr(),
+                  builder.getI32IntegerAttr(nextEntryPointOrdinal++),
+                  builder.getSymbolRefAttr(interfaceOpForExecutable.getName()),
+                  entryPointOp.signatureAttr());
+
+          // Update references to @executable::@target::@entry symbols.
+          // SymbolTable::replaceAllSymbolUses only looks at root symbols,
+          // which we can't blindly replace (other targets will map to other
+          // linked executables).
+          auto executableUses =
+              SymbolTable::getSymbolUses(executableOp, moduleOp);
+          if (!executableUses.hasValue()) continue;
+          for (auto executableUse : executableUses.getValue()) {
+            auto executableUser = executableUse.getUser();
+            // Only process symbols for this @target::@entry.
+            auto nestedRefs =
+                executableUse.getSymbolRef().getNestedReferences();
+            if (nestedRefs.size() != 2 ||
+                nestedRefs[0].getValue() != targetOp.sym_name() ||
+                nestedRefs[1].getValue() != entryPointOp.sym_name()) {
+              continue;
+            }
+            if (auto dispatchOp =
+                    dyn_cast<IREE::HAL::CommandBufferDispatchSymbolOp>(
+                        executableUser)) {
+              // New nested reference to the linked exe/target/entry.
+              StringRef newExecutableOpSymName =
+                  linkedExecutableOp
+                      .getAttrOfType<StringAttr>(
+                          SymbolTable::getSymbolAttrName())
+                      .getValue();
+              auto newSymbolRefAttr = builder.getSymbolRefAttr(
+                  newExecutableOpSymName,
+                  {builder.getSymbolRefAttr(linkedTargetOp),
+                   builder.getSymbolRefAttr(newEntryPointOp)});
+              dispatchOp.setAttr("entry_point", newSymbolRefAttr);
+            }
+          }
+        }
+
+        // Clone vm.module ops, including their contents.
+        auto vmModuleOps =
+            targetOp.getInnerModule().getOps<IREE::VM::ModuleOp>();
+        if (vmModuleOps.empty()) {
+          return targetOp.getInnerModule().emitError()
+                 << "target's outer module does not contain a vm.module op";
+        }
+        auto vmModuleOp = *vmModuleOps.begin();
+        builder.setInsertionPoint(&linkedVmModuleOp.getBlock().back());
+        // Use a SymbolTable to guard against inserting duplicate symbols.
+        SymbolTable symbolTable(linkedVmModuleOp.getOperation());
+
+        for (auto &op : vmModuleOp.getBody()->getOperations()) {
+          if (auto terminatorOp = dyn_cast<IREE::VM::ModuleTerminatorOp>(op)) {
+            continue;
+          }
+          if (op.hasTrait<SymbolOpInterface::Trait>() &&
+              symbolTable.lookup(dyn_cast<SymbolOpInterface>(op).getName())) {
+            continue;
+          }
+          builder.clone(op);
+        }
+
+        // Now that we're done cloning its ops, delete the original target op.
+        targetOp.erase();
+
+        executablesLinked++;
+      }
+    }
+
+    if (executablesLinked == 0) {
+      linkedExecutableOp.erase();
+    }
+
+    return success();
+  }
+
   LogicalResult serializeExecutable(IREE::HAL::ExecutableTargetOp targetOp,
                                     OpBuilder &executableBuilder) override {
     // Serialize the VM module to bytes.
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir b/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir
index 4d5c661..f9be39e 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='iree-hal-transformation-pipeline{serialize-executables=false},canonicalize' -iree-hal-target-backends=vmla %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline='iree-hal-transformation-pipeline{serialize-executables=false link-executables=false},canonicalize' -iree-hal-target-backends=vmla %s | IreeFileCheck %s
 
 // CHECK-LABEL: @i1_op_usage(%arg0: !hal.buffer) -> !hal.buffer
 func @i1_op_usage(%arg0: tensor<4xi1>) -> tensor<4xi1> {
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/linking.mlir b/iree/compiler/Dialect/HAL/Target/VMLA/test/linking.mlir
new file mode 100644
index 0000000..052bd7e
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/test/linking.mlir
@@ -0,0 +1,204 @@
+// RUN: iree-opt -split-input-file -iree-hal-link-executables -iree-hal-target-backends=vmla %s | IreeFileCheck %s
+
+module {
+  hal.executable @dispatch_0 attributes {sym_visibility = "private"} {
+    hal.interface @legacy_io {
+      hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+      hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+      hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+    }
+    hal.executable.target @vmla, filter="vmla" {
+      hal.executable.entry_point @dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
+      module {
+        vm.module @module {
+          vm.func @dispatch_0(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
+            vm.return
+          }
+          vm.export @dispatch_0
+        }
+      }
+    }
+  }
+  hal.executable @dispatch_1 attributes {sym_visibility = "private"} {
+    hal.interface @legacy_io {
+      hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+      hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+      hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+    }
+    hal.executable.target @vmla, filter="vmla" {
+      hal.executable.entry_point @dispatch_1 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
+      module {
+        vm.module @module {
+          vm.func @dispatch_1(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
+            vm.return
+          }
+          vm.export @dispatch_1
+        }
+      }
+    }
+  }
+  hal.executable @dispatch_2 attributes {sym_visibility = "private"} {
+    hal.interface @legacy_io {
+      hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+      hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+      hal.interface.binding @arg2, set=0, binding=1, type="StorageBuffer", access="Read"
+      hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+    }
+    hal.executable.target @vmla, filter="vmla" {
+      hal.executable.entry_point @dispatch_2 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
+      module {
+        vm.module @module {
+          vm.func @dispatch_2(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32) {
+            vm.return
+          }
+          vm.export @dispatch_2
+        }
+      }
+    }
+  }
+  func @main() -> () {
+    %dev = hal.ex.shared_device : !hal.device
+    %cmd = hal.command_buffer.create %dev, "OneShot", "Transfer|Dispatch" : !hal.command_buffer
+    %c1 = constant 1 : index
+    hal.command_buffer.dispatch.symbol %cmd, @dispatch_0::@vmla::@dispatch_0, workgroup_xyz = [%c1, %c1, %c1]
+    hal.command_buffer.dispatch.symbol %cmd, @dispatch_1::@vmla::@dispatch_1, workgroup_xyz = [%c1, %c1, %c1]
+    hal.command_buffer.dispatch.symbol %cmd, @dispatch_2::@vmla::@dispatch_2, workgroup_xyz = [%c1, %c1, %c1]
+    return
+  }
+}
+
+// All executables (including their interfaces and entry points) should be linked together into @linked_vmla
+// CHECK-NOT: hal.executable @dispatch_0
+// CHECK-NOT: hal.executable @dispatch_1
+// CHECK-NOT: hal.executable @dispatch_2
+// CHECK:       hal.executable @linked_vmla attributes {sym_visibility = "private"} {
+// CHECK-NEXT:    hal.interface @legacy_io_0 {
+// CHECK-NEXT:      hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+// CHECK-NEXT:      hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+// CHECK-NEXT:      hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+// CHECK-NEXT:    }
+// CHECK-NEXT:    hal.interface @legacy_io_1 {
+// CHECK-NEXT:      hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+// CHECK-NEXT:      hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+// CHECK-NEXT:      hal.interface.binding @arg2, set=0, binding=1, type="StorageBuffer", access="Read"
+// CHECK-NEXT:      hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+// CHECK-NEXT:    }
+// CHECK-NEXT:    hal.executable.target @vmla, filter="vmla" {
+// CHECK-NEXT:      hal.executable.entry_point @dispatch_0 attributes {interface = @legacy_io_0, ordinal = 0 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
+// CHECK-NEXT:      hal.executable.entry_point @dispatch_1 attributes {interface = @legacy_io_0, ordinal = 1 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
+// CHECK-NEXT:      hal.executable.entry_point @dispatch_2 attributes {interface = @legacy_io_1, ordinal = 2 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
+// CHECK-NEXT:      module {
+// CHECK-NEXT:        vm.module @linked_module {
+// CHECK-NEXT:          vm.func @dispatch_0(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
+// CHECK-NEXT:            vm.return
+// CHECK-NEXT:          }
+// CHECK-NEXT:          vm.export @dispatch_0
+// CHECK-NEXT:          vm.func @dispatch_1(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
+// CHECK-NEXT:            vm.return
+// CHECK-NEXT:          }
+// CHECK-NEXT:          vm.export @dispatch_1
+// CHECK-NEXT:          vm.func @dispatch_2(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32) {
+// CHECK-NEXT:            vm.return
+// CHECK-NEXT:          }
+// CHECK-NEXT:          vm.export @dispatch_2
+// CHECK-NEXT:        }
+// CHECK-NEXT:      }
+// CHECK-NEXT:    }
+// CHECK-NEXT:  }
+//
+// CHECK:       func @main() {
+// CHECK:         hal.command_buffer.dispatch.symbol %cmd, @linked_vmla::@vmla::@dispatch_0, workgroup_xyz = [%c1, %c1, %c1]
+// CHECK-NEXT:    hal.command_buffer.dispatch.symbol %cmd, @linked_vmla::@vmla::@dispatch_1, workgroup_xyz = [%c1, %c1, %c1]
+// CHECK-NEXT:    hal.command_buffer.dispatch.symbol %cmd, @linked_vmla::@vmla::@dispatch_2, workgroup_xyz = [%c1, %c1, %c1]
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+// -----
+
+module {
+  hal.executable @dispatch_0 attributes {sym_visibility = "private"} {
+    hal.interface @legacy_io {
+      hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+      hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+      hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+    }
+    hal.executable.target @vmla, filter="vmla" {
+      hal.executable.entry_point @dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
+      module {
+        vm.module @module {
+          vm.func @dispatch_0(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
+            vm.return
+          }
+          vm.export @dispatch_0
+        }
+      }
+    }
+    hal.executable.target @othertarget, filter="othertarget" {
+      module {
+      }
+    }
+  }
+  func @main() -> () {
+    %dev = hal.ex.shared_device : !hal.device
+    %cmd = hal.command_buffer.create %dev, "OneShot", "Transfer|Dispatch" : !hal.command_buffer
+    hal.device.switch(%dev : !hal.device)
+    #hal.device.match.id<"vmla">(%arg1 = %cmd : !hal.command_buffer) {
+      %c1 = constant 1 : index
+      hal.command_buffer.dispatch.symbol %arg1, @linked_vmla::@vmla::@main_ex_dispatch_0, workgroup_xyz = [%c1, %c1, %c1]
+      hal.return
+    },
+    #hal.device.match.id<"othertarget">(%arg1 = %cmd : !hal.command_buffer) {
+      %c1 = constant 1 : index
+      hal.command_buffer.dispatch.symbol %arg1, @dispatch_0::@otherdispatch::@dispatch_0, workgroup_xyz = [%c1, %c1, %c1]
+      hal.return
+    }
+    return
+  }
+}
+
+// VMLA target should be pulled out from @dispatch_0
+// CHECK:       hal.executable @linked_vmla attributes {sym_visibility = "private"} {
+// CHECK-NEXT:    hal.interface @legacy_io_0 {
+// CHECK-NEXT:      hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+// CHECK-NEXT:      hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+// CHECK-NEXT:      hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+// CHECK-NEXT:    }
+// CHECK-NEXT:    hal.executable.target @vmla, filter="vmla" {
+// CHECK-NEXT:      hal.executable.entry_point @dispatch_0 attributes {interface = @legacy_io_0, ordinal = 0 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
+// CHECK-NEXT:      module {
+// CHECK-NEXT:        vm.module @linked_module {
+// CHECK-NEXT:          vm.func @dispatch_0(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
+// CHECK-NEXT:            vm.return
+// CHECK-NEXT:          }
+// CHECK-NEXT:          vm.export @dispatch_0
+// CHECK-NEXT:        }
+// CHECK-NEXT:      }
+// CHECK-NEXT:    }
+// CHECK-NEXT:  }
+//
+// @dispatch_0 should remain, with just @othertarget
+// CHECK:       hal.executable @dispatch_0 attributes {sym_visibility = "private"} {
+// CHECK-NEXT:    hal.interface @legacy_io {
+// CHECK-NEXT:      hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+// CHECK-NEXT:      hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+// CHECK-NEXT:      hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+// CHECK-NEXT:    }
+// CHECK-NEXT:    hal.executable.target @othertarget, filter="othertarget" {
+// CHECK-NEXT:      module {
+// CHECK-NEXT:      }
+// CHECK-NEXT:    }
+//
+// CHECK:       func @main() {
+// CHECK:         hal.device.switch(%dev : !hal.device)
+// CHECK-NEXT:    #hal.device.match.id<"vmla">(%arg0 = %cmd : !hal.command_buffer) {
+// CHECK-NEXT:      %c1 = constant 1 : index
+// CHECK-NEXT:      hal.command_buffer.dispatch.symbol %arg0, @linked_vmla::@vmla::@main_ex_dispatch_0, workgroup_xyz = [%c1, %c1, %c1]
+// CHECK-NEXT:      hal.return
+// CHECK-NEXT:    },
+// CHECK-NEXT:    #hal.device.match.id<"othertarget">(%arg0 = %cmd : !hal.command_buffer) {
+// CHECK-NEXT:      %c1 = constant 1 : index
+// CHECK-NEXT:      hal.command_buffer.dispatch.symbol %arg0, @dispatch_0::@otherdispatch::@dispatch_0, workgroup_xyz = [%c1, %c1, %c1]
+// CHECK-NEXT:      hal.return
+// CHECK-NEXT:    }
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir b/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir
index 48c889a..939f06a 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir
@@ -12,15 +12,15 @@
   }
 }
 
-// CHECK-LABEL: hal.executable @simpleMath_ex_dispatch_0
-//  CHECK-NEXT:   hal.interface @legacy_io {
+// CHECK-LABEL: hal.executable @linked_vmla
+//  CHECK-NEXT:   hal.interface @legacy_io_0 {
 //  CHECK-NEXT:     hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
 //  CHECK-NEXT:     hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
 //  CHECK-NEXT:   }
 //  CHECK-NEXT:   hal.executable.target @vmla, filter="vmla" {
-//  CHECK-NEXT:     hal.executable.entry_point @simpleMath_rgn_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<4xf32>) -> tensor<4xf32>}
+//  CHECK-NEXT:     hal.executable.entry_point @simpleMath_rgn_dispatch_0 attributes {interface = @legacy_io_0, ordinal = 0 : i32, signature = (tensor<4xf32>) -> tensor<4xf32>}
 //  CHECK-NEXT:     module {
-//  CHECK-NEXT:       vm.module @module {
+//  CHECK-NEXT:       vm.module @linked_module {
 //  CHECK-NEXT:         vm.func @simpleMath_rgn_dispatch_0(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
 //   CHECK-DAG:           %zero = vm.const.i32.zero : i32
 //   CHECK-DAG:           %c16 = vm.const.i32 16 : i32
@@ -55,15 +55,15 @@
   }
 }
 
-// CHECK-LABEL: hal.executable @shaped_dispatch
-//  CHECK-NEXT:   hal.interface @legacy_io attributes {push_constants = 1 : i32} {
+// CHECK-LABEL: hal.executable @linked_vmla
+//  CHECK-NEXT:   hal.interface @legacy_io_0 attributes {push_constants = 1 : i32} {
 //  CHECK-NEXT:     hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
 //  CHECK-NEXT:     hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
 //  CHECK-NEXT:   }
 //  CHECK-NEXT:   hal.executable.target @vmla, filter="vmla" {
-//  CHECK-NEXT:     hal.executable.entry_point @entry attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<4x?xf32>, index) -> tensor<4x?xf32>}
+//  CHECK-NEXT:     hal.executable.entry_point @entry attributes {interface = @legacy_io_0, ordinal = 0 : i32, signature = (tensor<4x?xf32>, index) -> tensor<4x?xf32>}
 //  CHECK-NEXT:     module {
-//  CHECK-NEXT:       vm.module @module {
+//  CHECK-NEXT:       vm.module @linked_module {
 //  CHECK-NEXT:         vm.func @entry(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
 //   CHECK-DAG:           %zero = vm.const.i32.zero : i32
 //   CHECK-DAG:           %c16 = vm.const.i32 16 : i32
@@ -97,15 +97,15 @@
   }
 }
 
-// CHECK-LABEL: hal.executable @reduction_ex_dispatch_0
-//  CHECK-NEXT:   hal.interface @legacy_io {
+// CHECK-LABEL: hal.executable @linked_vmla
+//  CHECK-NEXT:   hal.interface @legacy_io_0 {
 //  CHECK-NEXT:     hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
 //  CHECK-NEXT:     hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
 //  CHECK-NEXT:   }
 //  CHECK-NEXT:   hal.executable.target @vmla, filter="vmla" {
-//  CHECK-NEXT:     hal.executable.entry_point @reduction_ex_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<4x8xf32>) -> tensor<4xf32>}
+//  CHECK-NEXT:     hal.executable.entry_point @reduction_ex_dispatch_0 attributes {interface = @legacy_io_0, ordinal = 0 : i32, signature = (tensor<4x8xf32>) -> tensor<4xf32>}
 //  CHECK-NEXT:     module {
-//  CHECK-NEXT:       vm.module @module {
+//  CHECK-NEXT:       vm.module @linked_module {
 //  CHECK-NEXT:         vm.rodata @reduction_ex_dispatch_0_const_0 dense<0.000000e+00> : tensor<f32>
 //  CHECK-NEXT:         vm.func @reduction_ex_dispatch_0(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
 //  CHECK-NEXT:           %zero = vm.const.i32.zero : i32
diff --git a/iree/compiler/Dialect/HAL/Transforms/LinkExecutables.cpp b/iree/compiler/Dialect/HAL/Transforms/LinkExecutables.cpp
index 3f2891e..e2e5a64 100644
--- a/iree/compiler/Dialect/HAL/Transforms/LinkExecutables.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/LinkExecutables.cpp
@@ -47,6 +47,17 @@
         return signalPassFailure();
       }
     }
+
+    // Backends may move target ops from executables into linked executables.
+    // If an executable ends up with no targets, remove it.
+    auto executableOps =
+        llvm::to_vector<4>(moduleOp.getOps<IREE::HAL::ExecutableOp>());
+    for (auto executableOp : executableOps) {
+      auto targetOps = executableOp.getOps<IREE::HAL::ExecutableTargetOp>();
+      if (targetOps.empty()) {
+        executableOp.erase();
+      }
+    }
   }
 
  private:
diff --git a/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp b/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
index 94b567d..d40113e 100644
--- a/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
@@ -16,7 +16,9 @@
 
 #include "iree/compiler/Dialect/HAL/IR/HALOps.h"
 #include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
+#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
 #include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
+#include "iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
@@ -55,10 +57,12 @@
     // other nice thing is that we get ordering similar to the executable
     // variables above.
     for (auto executableOp : executableOps) {
-      auto interfaceOp = executableOp.getInterfaceOp();
-      defineExecutableLayoutOp(interfaceOp.getLoc(),
-                               interfaceOp.getExecutableSetLayoutsAttr(),
-                               interfaceOp.push_constantsAttr());
+      for (auto interfaceOp :
+           executableOp.getBlock().getOps<IREE::HAL::InterfaceOp>()) {
+        defineExecutableLayoutOp(interfaceOp.getLoc(),
+                                 interfaceOp.getExecutableSetLayoutsAttr(),
+                                 interfaceOp.push_constantsAttr());
+      }
     }
 
     // Generate cached resource singletons and replace lookup ops with direct
@@ -221,33 +225,71 @@
             loc, executableCacheType, deviceValue,
             blockBuilder.getStringAttr("default"));
 
-    // TODO(benvanik): use targetOptions_ to determine these flags.
-    auto cachingMode = ExecutableCachingModeBitfield::AliasProvidedData |
-                       ExecutableCachingModeBitfield::AllowPersistentCaching |
-                       ExecutableCachingModeBitfield::AllowOptimization;
-    for (auto executableOp : executableOps) {
-      auto executableIt = executableCache_.find(executableOp.sym_name());
-      assert(executableIt != executableCache_.end() &&
-             "executable must have been cached");
-      auto executableVariableOp = executableIt->second;
+    // Create a switch statement with a case for each backend.
+    // Each case should then cache only executables which contain a matching
+    // ExecutableTargetOp.
+    // Afterwards, we could inline and de-dup across switch cases.
+    DeviceSwitchBuilder switchBuilder(loc, /*resultTypes=*/TypeRange{},
+                                      deviceValue, blockBuilder);
+    auto targetBackends = matchTargetBackends(targetOptions_.targets);
+    for (auto &targetBackend : targetBackends) {
+      auto *region = switchBuilder.addConditionRegion(
+          IREE::HAL::DeviceMatchIDAttr::get(targetBackend->filter_pattern(),
+                                            blockBuilder.getContext()),
+          {
+              executableCacheValue,
+          });
+      auto &entryBlock = region->front();
+      auto executableCache = entryBlock.getArgument(0);
+      auto caseBuilder = OpBuilder::atBlockBegin(&entryBlock);
 
-      // TODO(benvanik): support multiple interfaces. We'd probably want to
-      // store each executable+interface as a variable.
-      auto interfaceOp = executableOp.getInterfaceOp();
+      // TODO(benvanik): use targetOptions_ to determine these flags.
+      auto cachingMode = ExecutableCachingModeBitfield::AliasProvidedData |
+                         ExecutableCachingModeBitfield::AllowPersistentCaching |
+                         ExecutableCachingModeBitfield::AllowOptimization;
+      for (auto executableOp : executableOps) {
+        // Skip executables with no matching target ops.
+        auto executableTargetOps =
+            executableOp.getOps<IREE::HAL::ExecutableTargetOp>();
+        bool hasMatchingTarget = false;
+        for (auto executableTargetOp : executableTargetOps) {
+          if (TargetBackend::matchPattern(
+                  executableTargetOp.target_backend_filter(),
+                  targetBackend->filter_pattern())) {
+            hasMatchingTarget = true;
+          }
+        }
+        if (!hasMatchingTarget) continue;
 
-      auto executableLayoutVariableOp = defineExecutableLayoutOp(
-          executableOp.getLoc(), interfaceOp.getExecutableSetLayoutsAttr(),
-          interfaceOp.push_constantsAttr());
-      auto executableLayoutValue = blockBuilder.createOrFold<VariableLoadOp>(
-          loc, ExecutableLayoutType::get(loc.getContext()),
-          executableLayoutVariableOp.sym_name());
-      auto executableValue =
-          blockBuilder.createOrFold<ExecutableCachePrepareOp>(
-              loc, ExecutableType::get(loc.getContext()), executableCacheValue,
-              executableLayoutValue, cachingMode, executableOp.sym_name());
-      blockBuilder.create<VariableStoreOp>(loc, executableValue,
-                                           executableVariableOp.sym_name());
+        auto executableIt = executableCache_.find(executableOp.sym_name());
+        assert(executableIt != executableCache_.end() &&
+               "executable must have been cached");
+        auto executableVariableOp = executableIt->second;
+
+        // TODO(benvanik): support multiple interfaces. We'd probably want to
+        // store each executable+interface as a variable.
+        //
+        // This is *only* safe now because any backends that support multiple
+        // interfaces during compilation do *not* use layouts during executable
+        // cache preparation.
+        auto interfaceOp = executableOp.getFirstInterfaceOp();
+
+        auto executableLayoutVariableOp = defineExecutableLayoutOp(
+            executableOp.getLoc(), interfaceOp.getExecutableSetLayoutsAttr(),
+            interfaceOp.push_constantsAttr());
+        auto executableLayoutValue = caseBuilder.createOrFold<VariableLoadOp>(
+            loc, ExecutableLayoutType::get(loc.getContext()),
+            executableLayoutVariableOp.sym_name());
+        auto executableValue =
+            caseBuilder.createOrFold<ExecutableCachePrepareOp>(
+                loc, ExecutableType::get(loc.getContext()), executableCache,
+                executableLayoutValue, cachingMode, executableOp.sym_name());
+        caseBuilder.create<VariableStoreOp>(loc, executableValue,
+                                            executableVariableOp.sym_name());
+      }
+      caseBuilder.create<IREE::HAL::ReturnOp>(loc);
     }
+    switchBuilder.build();
 
     blockBuilder.create<mlir::ReturnOp>(loc, executableCacheValue);
 
diff --git a/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
index 1a9440d..fc05535 100644
--- a/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
@@ -35,6 +35,10 @@
       llvm::cl::desc("Whether to serialize hal.executable.target ops to "
                      "hal.executable.binary ops."),
       llvm::cl::init(true)};
+  Option<bool> linkExecutables{
+      *this, "link-executables",
+      llvm::cl::desc("Whether to link hal.executable ops together."),
+      llvm::cl::init(true)};
 };
 
 }  // namespace
@@ -52,12 +56,6 @@
   // this pass.
   passManager.addPass(createTranslateExecutablesPass(targetOptions));
 
-  // After all executables are translated we allow the backends to link them
-  // together. For example, the LLVM AOT backend may combine all executable
-  // targets for the same architecture into a single executable and link it as
-  // a shared library.
-  passManager.addPass(createLinkExecutablesPass(targetOptions));
-
   passManager.addPass(createConvertFlowToHALPass());
 
   // Phase ordering note: Before this pass, functions signatures will be based
@@ -75,6 +73,17 @@
   // been expanded to primitives.
   passManager.addPass(createPublicABIGenerationPass());
 
+  // After all executables are translated and before resolving entry point
+  // ordinals, we allow the backends to link executables together. For example,
+  // the LLVM AOT backend may combine all executable targets for the same
+  // architecture into a single executable and link it as a shared library.
+  // TODO(scotttodd): Move after createTranslateExecutablesPass
+  //   * ConvertStreamOps under ConvertFlowToHALPass assumes one entry point.
+  //     Adjust it to handle multiple entry points then this can move up.
+  if (transformOptions.linkExecutables) {
+    passManager.addPass(createLinkExecutablesPass(targetOptions));
+  }
+
   // Resolve entry point ordinals from nested symbol references prior to
   // serialization. As this pass creates lookup ops it should run before
   // MaterializeResourceCachesPass.
diff --git a/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir b/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir
index d88af3e..590df15 100644
--- a/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir
+++ b/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -iree-hal-materialize-resource-caches %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -iree-hal-materialize-resource-caches %s -iree-hal-target-backends=vmla | IreeFileCheck %s
 
 //      CHECK: hal.variable @_descriptor_set_layout_0 init(@_descriptor_set_layout_0_initializer) : !hal.descriptor_set_layout
 // CHECK-NEXT: func @_descriptor_set_layout_0_initializer() -> !hal.descriptor_set_layout attributes {sym_visibility = "private"} {
@@ -87,12 +87,14 @@
 
 // -----
 
+// TODO(scotttodd): Test without depending on a specific HAL target? Or move to HAL/Target/*/test/?
+//   - If there is no matching hal.executable.target then the executable will not be cached
 hal.executable @exe {
   hal.interface @interface {
     hal.interface.binding @s0b0, set=0, binding=0, type="StorageBuffer", access="Read"
     hal.interface.binding @s0b1, set=0, binding=1, type="StorageBuffer", access="Read|Write"
   }
-  hal.executable.target @target, filter="target" {
+  hal.executable.target @vmla, filter="vmla" {
     hal.executable.entry_point @entry attributes {
       interface = @interface,
       ordinal = 0 : i32,
@@ -119,8 +121,12 @@
 //      CHECK: hal.variable @_executable_cache init(@_executable_cache_initializer) : !hal.executable_cache
 // CHECK-NEXT: func @_executable_cache_initializer
 //      CHECK: %[[CACHE:.+]] = hal.executable_cache.create %dev, identifier = "default" : !hal.executable_cache
-// CHECK-NEXT: %[[LAYOUT:.+]] = hal.variable.load @_executable_layout_0 : !hal.executable_layout
-// CHECK-NEXT: %[[EXE:.+]] = hal.executable_cache.prepare %[[CACHE]], layout = %[[LAYOUT]], caching_mode = "AliasProvidedData|AllowPersistentCaching|AllowOptimization", @exe : !hal.executable
+// CHECK-NEXT: hal.device.switch(%dev : !hal.device)
+// CHECK-NEXT: #hal.device.match.id<"vmla">(%[[CACHE_CAPTURE:.+]] = %executable_cache_default : !hal.executable_cache) {
+// CHECK-NEXT:   %[[LAYOUT:.+]] = hal.variable.load @_executable_layout_0 : !hal.executable_layout
+// CHECK-NEXT:   %[[EXE:.+]] = hal.executable_cache.prepare %[[CACHE_CAPTURE]], layout = %[[LAYOUT]], caching_mode = "AliasProvidedData|AllowPersistentCaching|AllowOptimization", @exe : !hal.executable
+// CHECK-NEXT:   hal.variable.store %[[EXE]], @_executable_exe : !hal.executable
+// CHECK-NEXT:   hal.return
 
 // CHECK-LABEL: @exeLookup
 func @exeLookup(%arg0 : !hal.device) -> !hal.executable {
diff --git a/iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h b/iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h
index b1a2d23..15fc7b0 100644
--- a/iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h
+++ b/iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h
@@ -94,8 +94,9 @@
 // Builder for hal.device.switch ops that allows for nesting of conditions.
 //
 // Example:
-//   DeviceSwitchBuilder b0(.../*initialCondition=*/Z);
-//   b0.addRegion();   // condition: Z
+//   DeviceSwitchBuilder builder();
+//   auto b0 = builder.nest(Z);
+//   b0.addRegion();            // condition: Z
 //   b0.addConditionRegion(A);  // condition: Z && A
 //   auto b1 = b0.nest(B);
 //   b1.addConditionRegion(C);  // condition: Z && B && C
@@ -111,7 +112,61 @@
 class DeviceSwitchBuilder {
  public:
   DeviceSwitchBuilder(Location loc, TypeRange resultTypes, Value device,
-                      ConversionPatternRewriter &rewriter)
+                      OpBuilder builder)
+      : loc_(loc),
+        resultTypes_(resultTypes),
+        device_(device),
+        builder_(builder) {
+    // FIXME: Keep the same listener as the provided builder.
+    builder.setListener(nullptr);
+  }
+
+  // Pushes a new condition onto the stack and returns a builder that must have
+  // all previously nested conditions met in order to execute any conditions.
+  DeviceSwitchCaseBuilder nest(Attribute conditionAttr) {
+    return DeviceSwitchCaseBuilder(loc_, resultTypes_, device_, conditionAttr,
+                                   caseOps_, builder_);
+  }
+
+  // Adds a new condition region that must satisfy |conditionAttr| and all
+  // parent conditions. The region will have a single entry block with the
+  // given |args|.
+  Region *addConditionRegion(Attribute conditionAttr,
+                             const SmallVector<Value, 4> &args) {
+    return nest(conditionAttr).addRegion(args);
+  }
+
+  // Constructs a single hal.device.switch from all added regions.
+  IREE::HAL::DeviceSwitchOp build() {
+    SmallVector<Attribute, 4> conditionAttrs;
+    SmallVector<SmallVector<Value, 4>, 4> conditionArgs;
+    llvm::SetVector<Value> capturedFromAbove;
+    for (auto caseOp : caseOps_) {
+      conditionAttrs.push_back(caseOp.conditions().getValue()[0]);
+      conditionArgs.push_back(caseOp.args());
+    }
+    auto switchOp = builder_.create<IREE::HAL::DeviceSwitchOp>(
+        loc_, resultTypes_, device_, conditionAttrs, conditionArgs);
+    for (int i = 0; i < caseOps_.size(); ++i) {
+      switchOp.getRegion(i).takeBody(caseOps_[i].getRegion(0));
+      caseOps_[i].erase();
+    }
+    return switchOp;
+  }
+
+ private:
+  Location loc_;
+  SmallVector<Type, 4> resultTypes_;
+  Value device_;
+  SmallVector<IREE::HAL::DeviceSwitchOp, 4> caseOps_;
+  OpBuilder builder_;
+};
+
+// Rewriter-compatible version of DeviceSwitchBuilder.
+class DeviceSwitchRewriter {
+ public:
+  DeviceSwitchRewriter(Location loc, TypeRange resultTypes, Value device,
+                       ConversionPatternRewriter &rewriter)
       : loc_(loc),
         resultTypes_(resultTypes),
         device_(device),
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 9337788..9b3c2a7 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 93377888ae89560ba6d3976e2762d3d4724c4dfd
+Subproject commit 9b3c2a72e4cb3b0ae27f87064c11f728452b2af9
diff --git a/third_party/tensorflow b/third_party/tensorflow
index 2e56481..090b691 160000
--- a/third_party/tensorflow
+++ b/third_party/tensorflow
@@ -1 +1 @@
-Subproject commit 2e56481abcef1dd1625fba465a5d02ee6b347842
+Subproject commit 090b691fbf7b7823c41345004d12eddaa6c86118