Merge pull request #2742 from google/benvanik-vm64-flags

diff --git a/SUBMODULE_VERSIONS b/SUBMODULE_VERSIONS
index 879a1cf..90e5f67 100644
--- a/SUBMODULE_VERSIONS
+++ b/SUBMODULE_VERSIONS
@@ -4,15 +4,15 @@
 a5d9d0f7d368054fd1691aedf1db4116efcc233e third_party/flatbuffers
 4fb0ff7069bd88ee85902f4d0bb62794e5f6d021 third_party/flatcc
 f2fb48c3b3d79a75a88a99fba6576b25d42ec528 third_party/googletest
-950f1bf976b332eca60267b25bf759e2ad564e0c third_party/llvm-project
+30c1633386e7cfb01c0a54b31ccf4c3a3873e71b third_party/llvm-project
 17b12a4481daa150e2d1ea3ada086b551b856707 third_party/marl
-80885f899e12d55a45561ef758eea47bb340dbf1 third_party/mlir-emitc
+a3479bbf9161df8c8cac55a08205864e6f371491 third_party/mlir-emitc
 d8c7ee00a687ac369e62e2032514a93a9b413502 third_party/pybind11
 9f53ba413e6fc879236dcaa3e008915973d67a4f third_party/ruy
 a1390ed39ec77ecfb574bc6fcd5bfc5e3adbdea9 third_party/sdl2
 f8bf11a0253a32375c32cad92c841237b96696c0 third_party/spirv_headers
 57eb48aed36160c4876bc8310d9ca84d42ee9e2a third_party/swiftshader
-8a4ffe2e1ae722cff5306778df0cfca8b7f503fe third_party/tensorflow
+86efb18ca5812c76dd52c8536f336e6962b7f8ca third_party/tensorflow
 864d86e8b6d21449474db5e9313dbff90aa9c24f third_party/tracy
 9bd3f561bcee3f01d22912de10bb07ce4e23d378 third_party/vulkan_headers
 909f36b714c9239ee0b112a321220213a474ba53 third_party/vulkan_memory_allocator
diff --git a/bindings/python/build_tools/python/generate_build.py b/bindings/python/build_tools/python/generate_build.py
index 6705cfd..a5c0bda 100644
--- a/bindings/python/build_tools/python/generate_build.py
+++ b/bindings/python/build_tools/python/generate_build.py
@@ -18,10 +18,6 @@
 # Debugging hint: Just runt his with python to see what it prints.
 """Generates a bazel BUILD file for the repo."""
 
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
 import json
 import os
 import sys
diff --git a/bindings/python/pyiree/compiler/compiler.cc b/bindings/python/pyiree/compiler/compiler.cc
index 726cabc..78a2bd4 100644
--- a/bindings/python/pyiree/compiler/compiler.cc
+++ b/bindings/python/pyiree/compiler/compiler.cc
@@ -31,6 +31,7 @@
 #include "iree/tools/init_mlir_dialects.h"
 #include "iree/tools/init_mlir_passes.h"
 #include "iree/tools/init_targets.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/PrettyStackTrace.h"
 #include "llvm/Support/Signals.h"
@@ -185,71 +186,61 @@
 // doesn't do any path shortening, which seems to make long Python stack traces
 // a bit easier to scan.
 void PrintLocation(Location loc, raw_ostream& out) {
-  switch (loc->getKind()) {
-    case StandardAttributes::OpaqueLocation:
-      PrintLocation(loc.cast<OpaqueLoc>().getFallbackLocation(), out);
-      break;
-    case StandardAttributes::UnknownLocation:
-      out << "  [unknown location]\n";
-      break;
-    case StandardAttributes::FileLineColLocation: {
-      auto line_col_loc = loc.cast<FileLineColLoc>();
-      StringRef this_filename = line_col_loc.getFilename();
-      auto slash_pos = this_filename.find_last_of("/\\");
-      // We print both the basename and extended names with a structure like
-      // `foo.py:35:4`. Even though technically the line/col
-      // information is redundant to include in both names, having it on both
-      // makes it easier to paste the paths into an editor and jump to the exact
-      // location.
-      std::string line_col_suffix =
-          ":" + std::to_string(line_col_loc.getLine()) + ":" +
-          std::to_string(line_col_loc.getColumn());
-      bool has_basename = false;
-      StringRef basename = this_filename;
-      if (slash_pos != StringRef::npos) {
-        has_basename = true;
-        basename = this_filename.substr(slash_pos + 1);
-      }
-      out << "  at: " << basename << line_col_suffix;
-      if (has_basename) {
-        // When running through bazel, such as in our e2e test suite,
-        // the paths involved can be quite large, and will have a very long
-        // prefix before the sandboxed "runfiles" directory that the program
-        // runs in. Trim off that long prefix. By convention, the path names
-        // with this prefix dropped will correspond to the path in the source
-        // directory, which is probably what we want anyway.
-        StringRef kRunfiles(".runfiles/");
-        StringRef extended_name = this_filename;
-        auto runfiles_pos = extended_name.rfind(kRunfiles);
-        if (runfiles_pos != StringRef::npos) {
-          extended_name =
-              extended_name.drop_front(runfiles_pos + kRunfiles.size());
+  TypeSwitch<Location>(loc)
+      .Case<OpaqueLoc>(
+          [&](OpaqueLoc loc) { PrintLocation(loc.getFallbackLocation(), out); })
+      .Case<UnknownLoc>([&](UnknownLoc) { out << "  [unknown location]\n"; })
+      .Case<FileLineColLoc>([&](FileLineColLoc line_col_loc) {
+        StringRef this_filename = line_col_loc.getFilename();
+        auto slash_pos = this_filename.find_last_of("/\\");
+        // We print both the basename and extended names with a structure like
+        // `foo.py:35:4`. Even though technically the line/col
+        // information is redundant to include in both names, having it on both
+        // makes it easier to paste the paths into an editor and jump to the
+        // exact location.
+        std::string line_col_suffix =
+            ":" + std::to_string(line_col_loc.getLine()) + ":" +
+            std::to_string(line_col_loc.getColumn());
+        bool has_basename = false;
+        StringRef basename = this_filename;
+        if (slash_pos != StringRef::npos) {
+          has_basename = true;
+          basename = this_filename.substr(slash_pos + 1);
         }
-        // Print out two tabs, as basenames usually vary in length by more than
-        // one tab width.
-        out << "\t\t( " << extended_name << line_col_suffix << " )";
-      }
-      out << "\n";
-      break;
-    }
-    case StandardAttributes::NameLocation: {
-      auto nameLoc = loc.cast<NameLoc>();
-      out << "  @'" << nameLoc.getName() << "':\n";
-      auto childLoc = nameLoc.getChildLoc();
-      if (!childLoc.isa<UnknownLoc>()) {
-        out << "(...\n";
-        PrintLocation(childLoc, out);
-        out << ")\n";
-      }
-      break;
-    }
-    case StandardAttributes::CallSiteLocation: {
-      auto call_site = loc.cast<CallSiteLoc>();
-      PrintLocation(call_site.getCaller(), out);
-      PrintLocation(call_site.getCallee(), out);
-      break;
-    }
-  }
+        out << "  at: " << basename << line_col_suffix;
+        if (has_basename) {
+          // When running through bazel, such as in our e2e test suite,
+          // the paths involved can be quite large, and will have a very long
+          // prefix before the sandboxed "runfiles" directory that the program
+          // runs in. Trim off that long prefix. By convention, the path names
+          // with this prefix dropped will correspond to the path in the source
+          // directory, which is probably what we want anyway.
+          StringRef kRunfiles(".runfiles/");
+          StringRef extended_name = this_filename;
+          auto runfiles_pos = extended_name.rfind(kRunfiles);
+          if (runfiles_pos != StringRef::npos) {
+            extended_name =
+                extended_name.drop_front(runfiles_pos + kRunfiles.size());
+          }
+          // Print out two tabs, as basenames usually vary in length by more
+          // than one tab width.
+          out << "\t\t( " << extended_name << line_col_suffix << " )";
+        }
+        out << "\n";
+      })
+      .Case<NameLoc>([&](NameLoc name_loc) {
+        out << "  @'" << name_loc.getName() << "':\n";
+        auto child_loc = name_loc.getChildLoc();
+        if (!child_loc.isa<UnknownLoc>()) {
+          out << "(...\n";
+          PrintLocation(child_loc, out);
+          out << ")\n";
+        }
+      })
+      .Case<CallSiteLoc>([&](CallSiteLoc call_site) {
+        PrintLocation(call_site.getCaller(), out);
+        PrintLocation(call_site.getCallee(), out);
+      });
 }
 
 std::string DiagnosticCapture::ConsumeDiagnosticsAsString(
diff --git a/bindings/python/pyiree/rt/BUILD b/bindings/python/pyiree/rt/BUILD
index 6d77a7f..a89363a 100644
--- a/bindings/python/pyiree/rt/BUILD
+++ b/bindings/python/pyiree/rt/BUILD
@@ -117,10 +117,6 @@
     name = "function_abi_test",
     srcs = ["function_abi_test.py"],
     python_version = "PY3",
-    # TODO(laurenzo): Enable once test does not depend on a real vulkan device.
-    tags = [
-        "nokokoro",
-    ],
     deps = NUMPY_DEPS + [
         "//bindings/python:pathsetup",  # build_cleaner: keep
         "@absl_py//absl/testing:absltest",
diff --git a/bindings/python/pyiree/rt/function_abi.cc b/bindings/python/pyiree/rt/function_abi.cc
index 58a0295..7434653 100644
--- a/bindings/python/pyiree/rt/function_abi.cc
+++ b/bindings/python/pyiree/rt/function_abi.cc
@@ -356,7 +356,7 @@
                                   VmVariantList& f_args,
                                   VmVariantList& f_results) {
   if (f_args.size() != raw_config().inputs.size()) {
-    throw RaiseValueError("Mismatched AllocatResults() input arity");
+    throw RaiseValueError("Mismatched AllocateResults() input arity");
   }
 
   for (size_t i = 0, e = descs.size(); i < e; ++i) {
diff --git a/bindings/python/pyiree/rt/function_abi_test.py b/bindings/python/pyiree/rt/function_abi_test.py
index cb8c804..6e23f10 100644
--- a/bindings/python/pyiree/rt/function_abi_test.py
+++ b/bindings/python/pyiree/rt/function_abi_test.py
@@ -16,6 +16,7 @@
 
 import re
 
+from absl import logging
 from absl.testing import absltest
 
 import numpy as np
@@ -40,7 +41,7 @@
 
   def test_baseclass(self):
     htf = rt.HostTypeFactory()
-    print(htf)
+    logging.info("HostTypeFactory: %s", htf)
 
 
 class FunctionAbiTest(absltest.TestCase):
@@ -50,12 +51,12 @@
     super().setUpClass()
     driver_names = rt.HalDriver.query()
     for driver_name in driver_names:
-      print("Try create driver:", driver_name)
+      logging.info("Try to create driver: %s", driver_name)
       try:
         cls.driver = rt.HalDriver.create(driver_name)
         cls.device = cls.driver.create_default_device()
       except Exception:
-        print("Could not create driver:", driver_name)
+        logging.error("Could not create driver: %s", driver_name)
       else:
         break
 
@@ -66,7 +67,7 @@
   def test_static_arg_success(self):
     fabi = rt.FunctionAbi(self.device, self.htf,
                           ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
-    print(fabi)
+    logging.info("fabi: %s", fabi)
     self.assertEqual(
         "<FunctionAbi (Buffer<float32[10x128x64]>) -> "
         "(Buffer<sint32[32x8x64]>)>", repr(fabi))
@@ -75,7 +76,7 @@
 
     arg = np.zeros((10, 128, 64), dtype=np.float32)
     packed = fabi.raw_pack_inputs([arg])
-    print(packed)
+    logging.info("packed: %s", packed)
     self.assertEqual("<VmVariantList(1): [HalBufferView(10x128x64:0x3000020)]>",
                      repr(packed))
 
@@ -85,7 +86,7 @@
     arg = np.zeros((10, 128, 64), dtype=np.float32)
     f_args = fabi.raw_pack_inputs([arg])
     f_results = fabi.allocate_results(f_args)
-    print(f_results)
+    logging.info("f_results: %s", f_results)
     self.assertEqual("<VmVariantList(1): [HalBufferView(32x8x64:0x1000020)]>",
                      repr(f_results))
     py_result, = fabi.raw_unpack_results(f_results)
@@ -98,13 +99,13 @@
     arg = np.zeros((10, 128, 64), dtype=np.float32)
     f_args = fabi.raw_pack_inputs([arg])
     f_results = fabi.allocate_results(f_args, static_alloc=False)
-    print(f_results)
+    logging.info("f_results: %s", f_results)
     self.assertEqual("<VmVariantList(0): []>", repr(f_results))
 
   def test_dynamic_arg_success(self):
     fabi = rt.FunctionAbi(self.device, self.htf,
                           ATTRS_1ARG_FLOAT32_DYNX128X64_TO_SINT32_DYNX8X64_V1)
-    print(fabi)
+    logging.info("fabi: %s", fabi)
     self.assertEqual(
         "<FunctionAbi (Buffer<float32[?x128x64]>) -> "
         "(Buffer<sint32[?x8x64]>)>", repr(fabi))
@@ -113,14 +114,14 @@
 
     arg = np.zeros((10, 128, 64), dtype=np.float32)
     packed = fabi.raw_pack_inputs([arg])
-    print(packed)
+    logging.info("packed: %s", packed)
     self.assertEqual("<VmVariantList(1): [HalBufferView(10x128x64:0x3000020)]>",
                      repr(packed))
 
   def test_static_arg_rank_mismatch(self):
     fabi = rt.FunctionAbi(self.device, self.htf,
                           ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
-    print(fabi)
+    logging.info("fabi: %s", fabi)
     arg = np.zeros((10,), dtype=np.float32)
     with self.assertRaisesRegex(
         ValueError,
@@ -130,7 +131,7 @@
   def test_static_arg_eltsize_mismatch(self):
     fabi = rt.FunctionAbi(self.device, self.htf,
                           ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
-    print(fabi)
+    logging.info("fabi: %s", fabi)
     arg = np.zeros((10, 128, 64), dtype=np.float64)
     with self.assertRaisesRegex(
         ValueError,
@@ -140,7 +141,7 @@
   def test_static_arg_dtype_mismatch(self):
     fabi = rt.FunctionAbi(self.device, self.htf,
                           ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
-    print(fabi)
+    logging.info("fabi: %s", fabi)
     arg = np.zeros((10, 128, 64), dtype=np.int32)
     with self.assertRaisesRegex(
         ValueError,
@@ -150,7 +151,7 @@
   def test_static_arg_static_dim_mismatch(self):
     fabi = rt.FunctionAbi(self.device, self.htf,
                           ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
-    print(fabi)
+    logging.info("fabi: %s", fabi)
     arg = np.zeros((10, 32, 64), dtype=np.float32)
     with self.assertRaisesRegex(
         ValueError,
diff --git a/bindings/python/pyiree/rt/hal_test.py b/bindings/python/pyiree/rt/hal_test.py
index b7ab59b..a7f1a12 100644
--- a/bindings/python/pyiree/rt/hal_test.py
+++ b/bindings/python/pyiree/rt/hal_test.py
@@ -12,10 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
+from absl import logging
 from absl.testing import absltest
 
 import numpy as np
@@ -25,8 +22,8 @@
 class HalTest(absltest.TestCase):
 
   def testEnums(self):
-    print("MemoryType =", rt.MemoryType)
-    print("HOST_VISIBLE =", int(rt.MemoryType.HOST_VISIBLE))
+    logging.info("MemoryType: %s", rt.MemoryType)
+    logging.info("HOST_VISIBLE: %s", int(rt.MemoryType.HOST_VISIBLE))
 
   def testAllocateHeap(self):
     b = rt.HalBuffer.allocate_heap(
diff --git a/bindings/python/pyiree/rt/system_api.py b/bindings/python/pyiree/rt/system_api.py
index aaea01f..327883bc 100644
--- a/bindings/python/pyiree/rt/system_api.py
+++ b/bindings/python/pyiree/rt/system_api.py
@@ -138,7 +138,7 @@
   def __call__(self, *args):
     # NOTE: This is just doing sync dispatch right now. In the future,
     # this should default to async and potentially have some kind of policy
-    # flag that can allow it to be overriden.
+    # flag that can allow it to be overridden.
     inputs = self._abi.raw_pack_inputs(args)
     results = self._abi.allocate_results(inputs, static_alloc=False)
     self._context._vm_context.invoke(self._vm_function, inputs, results)
@@ -269,8 +269,7 @@
 def load_modules(*modules, config: Optional[Config] = None):
   """Loads modules into a new or shared context and returns them."""
   context = SystemContext(modules=modules, config=config)
-  context_modules = context.modules
-  bound_modules = [context_modules[m.name] for m in modules]
+  bound_modules = [context.modules[m.name] for m in modules]
   return bound_modules
 
 
diff --git a/bindings/python/pyiree/rt/system_api_test.py b/bindings/python/pyiree/rt/system_api_test.py
index 9d670ce..ca47439 100644
--- a/bindings/python/pyiree/rt/system_api_test.py
+++ b/bindings/python/pyiree/rt/system_api_test.py
@@ -17,6 +17,7 @@
 
 import re
 
+from absl import logging
 from absl.testing import absltest
 import numpy as np
 from pyiree import compiler
@@ -68,7 +69,7 @@
     self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
     f = ctx.modules.arithmetic["simple_mul"]
     f_repr = repr(f)
-    print(f_repr)
+    logging.info("f_repr: %s", f_repr)
     self.assertRegex(
         f_repr,
         re.escape(
diff --git a/bindings/python/pyiree/rt/vm_test.py b/bindings/python/pyiree/rt/vm_test.py
index 5a3c1ec..da05a6c 100644
--- a/bindings/python/pyiree/rt/vm_test.py
+++ b/bindings/python/pyiree/rt/vm_test.py
@@ -15,6 +15,7 @@
 
 # pylint: disable=unused-variable
 
+from absl import logging
 from absl.testing import absltest
 import numpy as np
 from pyiree import compiler
@@ -70,7 +71,7 @@
   def setUpClass(cls):
     super().setUpClass()
     driver_names = rt.HalDriver.query()
-    print("DRIVER_NAMES =", driver_names)
+    logging.info("driver_names: %s", driver_names)
     cls.driver = rt.HalDriver.create("vmla")
     cls.device = cls.driver.create_default_device()
     cls.hal_module = rt.create_hal_module(cls.device)
@@ -78,7 +79,7 @@
 
   def test_variant_list(self):
     l = rt.VmVariantList(5)
-    print(l)
+    logging.info("variant_list: %s", l)
     self.assertEqual(l.size, 0)
 
   def test_context_id(self):
@@ -102,19 +103,19 @@
 
   def test_static_module_context(self):
     m = create_simple_static_mul_module()
-    print(m)
+    logging.info("module: %s", m)
     instance = rt.VmInstance()
-    print(instance)
+    logging.info("instance: %s", instance)
     context = rt.VmContext(instance, modules=[self.hal_module, m])
-    print(context)
+    logging.info("context: %s", context)
 
   def test_dynamic_shape_compile(self):
     m = create_simple_dynamic_abs_module()
-    print(m)
+    logging.info("module: %s", m)
     instance = rt.VmInstance()
-    print(instance)
+    logging.info("instance: %s", instance)
     context = rt.VmContext(instance, modules=[self.hal_module, m])
-    print(context)
+    logging.info("context: %s", context)
 
   def test_add_scalar(self):
     m = create_add_scalar_module()
@@ -122,18 +123,19 @@
     context = rt.VmContext(instance, modules=[self.hal_module, m])
     f = m.lookup_function("add_scalar")
     abi = context.create_function_abi(self.device, self.htf, f)
-    print("INVOKING:", abi)
-    arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
-    arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
+    logging.info("abi: %s", abi)
+
     inputs = abi.raw_pack_inputs((5, 6))
-    print("INPUTS:", inputs)
+    logging.info("inputs: %s", inputs)
+
     allocated_results = abi.allocate_results(inputs, static_alloc=False)
-    print("ALLOCATED RESULTS:", allocated_results)
-    print("--- INVOKE:")
+    logging.info("allocated_results: %s", allocated_results)
+    logging.info("Invoking...")
     context.invoke(f, inputs, allocated_results)
-    print("--- DONE.")
+    logging.info("...done")
+
     results = abi.raw_unpack_results(allocated_results)
-    print("RESULTS:", results)
+    logging.info("results: %s", results)
     self.assertEqual(results[0], 11)
 
   def test_synchronous_dynamic_shape_invoke_function(self):
@@ -142,17 +144,20 @@
     context = rt.VmContext(instance, modules=[self.hal_module, m])
     f = m.lookup_function("simple_mul")
     abi = context.create_function_abi(self.device, self.htf, f)
-    print("INVOKING:", abi)
+    logging.info("abi: %s", abi)
+
     arg0 = np.array([[-1., 2.], [3., -4.]], dtype=np.float32)
     inputs = abi.raw_pack_inputs((arg0,))
-    print("INPUTS:", inputs)
+    logging.info("inputs: %s", inputs)
+
     allocated_results = abi.allocate_results(inputs, static_alloc=False)
-    print("ALLOCATED RESULTS:", allocated_results)
-    print("--- INVOKE:")
+    logging.info("allocated_results: %s", allocated_results)
+    logging.info("Invoking...")
     context.invoke(f, inputs, allocated_results)
-    print("--- DONE.")
+    logging.info("...done")
+
     results = abi.raw_unpack_results(allocated_results)
-    print("RESULTS:", results)
+    logging.info("results: %s", results)
     np.testing.assert_allclose(results[0], [[1., 2.], [3., 4.]])
 
   def test_synchronous_invoke_function(self):
@@ -161,18 +166,21 @@
     context = rt.VmContext(instance, modules=[self.hal_module, m])
     f = m.lookup_function("simple_mul")
     abi = context.create_function_abi(self.device, self.htf, f)
-    print("INVOKING:", abi)
+    logging.info("abi: %s", abi)
+
     arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
     arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
     inputs = abi.raw_pack_inputs((arg0, arg1))
-    print("INPUTS:", inputs)
+    logging.info("inputs: %s", inputs)
+
     allocated_results = abi.allocate_results(inputs, static_alloc=False)
-    print("ALLOCATED RESULTS:", allocated_results)
-    print("--- INVOKE:")
+    logging.info("allocated_results: %s", allocated_results)
+    logging.info("Invoking...")
     context.invoke(f, inputs, allocated_results)
-    print("--- DONE.")
+    logging.info("...done")
+
     results = abi.raw_unpack_results(allocated_results)
-    print("RESULTS:", results)
+    logging.info("results: %s", results)
     np.testing.assert_allclose(results[0], [4., 10., 18., 28.])
 
 
diff --git a/build_tools/cmake/iree_copts.cmake b/build_tools/cmake/iree_copts.cmake
index 7563c1f..8ac8c2d 100644
--- a/build_tools/cmake/iree_copts.cmake
+++ b/build_tools/cmake/iree_copts.cmake
@@ -249,6 +249,8 @@
   ${PROJECT_SOURCE_DIR}/third_party/tensorflow/tensorflow/compiler/mlir/hlo/include/
   ${PROJECT_BINARY_DIR}/build_tools/third_party/tensorflow
   ${PROJECT_BINARY_DIR}/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/hlo/include/
+  ${PROJECT_BINARY_DIR}/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/
+  ${PROJECT_BINARY_DIR}/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms
 )
 
 #-------------------------------------------------------------------------------
diff --git a/experimental/ModelBuilder/test/TestMatMulVulkan.cpp b/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
index b31c663..2a2fb6b 100644
--- a/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
+++ b/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
@@ -60,6 +60,10 @@
     "workgroup-size", llvm::cl::desc("Workgroup size to use"),
     llvm::cl::CommaSeparated);
 
+static llvm::cl::list<int> tileSizes("tile-sizes",
+                                     llvm::cl::desc("Tile sizes to use"),
+                                     llvm::cl::CommaSeparated);
+
 using namespace mlir;                    // NOLINT
 using namespace mlir::edsc;              // NOLINT
 using namespace mlir::edsc::intrinsics;  // NOLINT
@@ -96,9 +100,10 @@
   SmallVector<Type, 3> args = {typeA, typeB, typeC};
   SmallVector<int64_t, 4> vWorkgroupSizes(workgroupSize.begin(),
                                           workgroupSize.end());
+  SmallVector<int64_t, 4> vTileSizes(tileSizes.begin(), tileSizes.end());
   auto lowering = [&](mlir::PassManager &pm) {
     pm.addPass(mlir::iree_compiler::createLinalgTileAndFusePass(
-        vWorkgroupSizes, useWorkgroupMemory));
+        vWorkgroupSizes, vTileSizes, useWorkgroupMemory));
     pm.addPass(mlir::iree_compiler::createConvertToGPUPass());
     pm.addPass(mlir::createLowerAffinePass());
     pm.addPass(mlir::createLegalizeStdOpsForSPIRVLoweringPass());
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 87a087b..54b79b9 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -58,14 +58,7 @@
     parent_dir = os.path.join(tempfile.gettempdir(), "iree", "modules")
   artifacts_dir = os.path.join(parent_dir, module_name)
   logging.info("Saving compilation artifacts and traces to '%s'", artifacts_dir)
-
-  # If the artifacts already exist then we overwrite/update them.
-  try:
-    # Use try/except instead of os.path.exists to address a race condition
-    # between multiple tests targets.
-    os.makedirs(artifacts_dir)
-  except IOError:
-    pass
+  tf_utils._makedirs(artifacts_dir)
   return artifacts_dir
 
 
@@ -314,9 +307,9 @@
     return True
 
   def _get_trace_dir(self, artifacts_dir):
-    trace_dir = os.path.join(artifacts_dir, "traces")
-    if not os.path.exists(trace_dir):
-      os.makedirs(trace_dir)
+    trace_dir = os.path.join(artifacts_dir, self.backend, "traces",
+                             self.function_name)
+    tf_utils._makedirs(trace_dir)
     return trace_dir
 
   def save_plaintext(self, artifacts_dir, summarize=True):
@@ -335,7 +328,7 @@
         edgeitems=10)  # Can show more items since they won't clutter the logs.
 
     trace_dir = self._get_trace_dir(artifacts_dir)
-    path = os.path.join(trace_dir, f"{self.function_name}__{self.backend}.txt")
+    path = os.path.join(trace_dir, "log.txt")
     with open(path, "w") as f:
       f.write(str(self))
       f.write("\n")
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index d1cd14c..b93541a 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -48,16 +48,6 @@
   return np.arange(np.prod(shape), dtype=dtype).reshape(shape)
 
 
-def backends_to_str(backend_infos):
-  """Creates a normalized string representing the provided backends."""
-  normalized_names = []
-  for backend_info in backend_infos:
-    # Remove unusual characters and ensure names don't end or start in "_".
-    name = re.sub("[^0-9a-zA-Z_]+", "_", backend_info.name)
-    normalized_names.append(name.strip("_"))
-  return "__".join(normalized_names)
-
-
 def to_mlir_type(dtype):
   """Returns a string that denotes the type `dtype` in MLIR style."""
   bits = dtype.itemsize * 8
@@ -97,6 +87,37 @@
   return result
 
 
+def backends_to_str(backend_infos):
+  """Creates a normalized string representing the provided backends."""
+  normalized_names = []
+  for backend_info in backend_infos:
+    # Remove unusual characters and ensure names don't end or start in "_".
+    name = re.sub("[^0-9a-zA-Z_]+", "_", backend_info.name)
+    normalized_names.append(name.strip("_"))
+  return "__".join(normalized_names)
+
+
+def _get_backends_path(artifact_name, backend_infos, artifacts_dir):
+  backends_string = backends_to_str(backend_infos)
+  # Put the artifact in a directory if there's only one backend.
+  if len(backend_infos) == 1:
+    backend_dir = os.path.join(artifacts_dir, backends_string)
+    _makedirs(backend_dir)
+    return os.path.join(artifacts_dir, backends_string, artifact_name)
+  else:
+    return os.path.join(artifacts_dir, f"{artifact_name}__{backends_string}")
+
+
+def _makedirs(path):
+  # If the artifacts already exist then we overwrite/update them.
+  try:
+    # Use try/except instead of os.path.exists to address any race conditions
+    # that might arise between multiple tests targets.
+    os.makedirs(path)
+  except IOError:
+    pass
+
+
 def compile_tf_module(tf_module,
                       backend_infos=(),
                       exported_names=(),
@@ -107,16 +128,21 @@
   that returns a module that can be called without any further steps.
 
   If artifacts_dir is provided then the following artifacts will be saved:
-    saved_model:
+    backend_name/saved_model:
       A TF SavedModel directory containing the files used translate the
-      tf.Module into an IREE module.
+      tf.Module into an IREE module. Only saved if '--keep_saved_model=True'.
     tf_input.mlir:
       MLIR for the module in TF's input dialect.
     iree_input.mlir:
       The MLIR above translated to IREE via compiler.TF_IMPORT_PASS_PIPELINE.
-    compiled__backends.vmfb:
+    backend_name/compiled.vmfb:
       A VM FlatBuffer compiled to the target backends from the IREE MLIR above.
-  Here 'backends' is a '__' delimited list of iree backends (e.g. vmla__llvm_ir)
+
+  If multiple backends are specified, then instead of saving the SavedModel and
+  compiled 'vmfb' under 'backend_name/', they will be saved as follows:
+    - 'saved_model__{backends}'
+    - 'compiled__{backends}.vmfb'
+  where 'backends' is a '__' delimited list (e.g. iree_vmla__iree_llvmjit).
 
   Args:
     tf_module: A tf.Module.
@@ -168,8 +194,9 @@
       compiled_module = compiler_module.compile(target_backends=target_backends)
 
       if artifacts_dir is not None:
-        compiled_name = f"compiled__{backends_string}.vmfb"
-        compiled_path = os.path.join(artifacts_dir, compiled_name)
+        compiled_path = _get_backends_path("compiled", backend_infos,
+                                           artifacts_dir)
+        compiled_path = f"{compiled_path}.vmfb"
         logging.info("Saving compiled IREE module to: %s", compiled_path)
         with open(compiled_path, "wb") as f:
           f.write(compiled_module)
@@ -187,7 +214,7 @@
     # Create a saved model for these target backends to avoid a race condition
     # when running a test suite.
     # TODO(meadowlark): Remove this once we have a TfLiteCompiledModule.
-    sm_path = os.path.join(artifacts_dir, f"saved_model__{backends_string}")
+    sm_path = _get_backends_path("saved_model", backend_infos, artifacts_dir)
     tf.saved_model.save(tf_module, sm_path, options=options)
     return _compile_from_path(sm_path)
   else:
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
index aa1df8e..4f4084e 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
@@ -19,9 +19,9 @@
 
 from absl import logging
 from absl.testing import parameterized
+import numpy as np
 from pyiree.tf.support import tf_utils
 import tensorflow as tf
-import numpy as np
 
 
 class ConstantModule(tf.Module):
@@ -67,10 +67,13 @@
       iree_compiled_module = tf_utils.compile_tf_module(
           tf_module, backend_infos=backend_infos, artifacts_dir=artifacts_dir)
 
+      compiled_path = tf_utils._get_backends_path('compiled', backend_infos,
+                                                  artifacts_dir)
+      compiled_path = f'{compiled_path}.vmfb'
       artifacts_to_check = [
           'tf_input.mlir',
           'iree_input.mlir',
-          f'compiled__{tf_utils.backends_to_str(backend_infos)}.vmfb',
+          compiled_path,
       ]
       for artifact in artifacts_to_check:
         artifact_path = os.path.join(artifacts_dir, artifact)
diff --git a/integrations/tensorflow/compiler/dialect/tf_strings/ir/dialect.cpp b/integrations/tensorflow/compiler/dialect/tf_strings/ir/dialect.cpp
index 0f2a336..031bc0c 100644
--- a/integrations/tensorflow/compiler/dialect/tf_strings/ir/dialect.cpp
+++ b/integrations/tensorflow/compiler/dialect/tf_strings/ir/dialect.cpp
@@ -54,13 +54,14 @@
 }
 
 void TFStringsDialect::printType(Type type, DialectAsmPrinter& os) const {
-  switch (type.getKind()) {
-    case TFStringsTypes::String:
-      os << "string";
-      break;
-    default:
-      llvm_unreachable("unhandled string type");
-  }
+  if (type.isa<tf_strings::StringType>())
+    os << "string";
+  else
+    llvm_unreachable("unhandled string type");
+}
+
+bool TFStringsType::classof(Type type) {
+  return llvm::isa<TFStringsDialect>(type.getDialect());
 }
 
 }  // namespace tf_strings
diff --git a/integrations/tensorflow/compiler/dialect/tf_strings/ir/types.h b/integrations/tensorflow/compiler/dialect/tf_strings/ir/types.h
index a7f4fe3..c61e0b8 100644
--- a/integrations/tensorflow/compiler/dialect/tf_strings/ir/types.h
+++ b/integrations/tensorflow/compiler/dialect/tf_strings/ir/types.h
@@ -41,10 +41,7 @@
  public:
   using Type::Type;
 
-  static bool classof(Type type) {
-    return type.getKind() >= TFStringsTypes::FIRST_USED_STRINGS_TYPE &&
-           type.getKind() <= TFStringsTypes::LAST_USED_STRINGS_TYPE;
-  }
+  static bool classof(Type type);
 };
 
 class StringType
@@ -54,8 +51,6 @@
   static StringType get(MLIRContext* context) {
     return Base::get(context, TFStringsTypes::String);
   }
-
-  static bool kindof(unsigned kind) { return kind == TFStringsTypes::String; }
 };
 
 }  // namespace tf_strings
diff --git a/integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/tf_tensorlist_types.h b/integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/tf_tensorlist_types.h
index f9602cb..e31bc95 100644
--- a/integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/tf_tensorlist_types.h
+++ b/integrations/tensorflow/compiler/dialect/tf_tensorlist/ir/tf_tensorlist_types.h
@@ -30,7 +30,6 @@
     : public Type::TypeBase<TensorListType, Type, TypeStorage> {
  public:
   using Base::Base;
-  static bool kindof(unsigned kind) { return kind == TypeKind::kTensorList; }
   static TensorListType get(MLIRContext *context) {
     return Base::get(context, TypeKind::kTensorList);
   }
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index 300cadf..89b3bb1 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -13,7 +13,7 @@
 See [Install TensorFlow with pip](https://www.tensorflow.org/install/pip) for
 instructions.
 
-## Vulkan setup
+## Vulkan Setup
 
 If you do not have your environment setup to use IREE with Vulkan (see
 [the doc](../../../docs/vulkan_and_spirv.md)), then you can run the manual test
@@ -48,7 +48,7 @@
 By default the TensorFlow SavedModels will not be kept. This can be overridden
 via the `--keep_saved_model` flag.
 
-## Running tests
+## Running Tests
 
 For locally running tests and iterating on backend development, `bazel run` is
 preferred.
@@ -150,7 +150,33 @@
 bazel test :e2e_tests_failing_broadcasting_test__tf__iree_vulkan
 ```
 
-## Debugging tests
+## Generated Artifacts
+
+By default, running an E2E test generates a number of compilation, debugging and
+benchmarking artifacts in `/tmp/iree/modules/`. The location of these artifacts
+can be changed via the `--artifacts_dir` flag. The generated directory structure
+for each module is as follows:
+
+```
+/tmp/iree/modules/ModuleName
+├── tf_input.mlir          # MLIR for ModuleName in TF's input dialect
+├── iree_input.mlir        # tf_input.mlir translated to IREE MLIR
+├── backend_name           # e.g. iree_vmla, tf or tf_ref
+│   ├── compiled.vmfb      # flatbuffer of ModuleName compiled to this backend
+│   ├── saved_model
+│   └── traces
+│       ├── trace_function
+│       │   └── log.txt    # A more detailed version of the test logs
+│       └── trace_function
+│           └── log.txt
+└── backend_name
+    └── ...
+```
+
+The `saved_model` directory is only created if `--keep_saved_model` is
+specified.
+
+## Debugging Tests
 
 If the compiler fails to compile the program, then it will create a crash
 reproducer (see [MLIR documentation](https://mlir.llvm.org/docs/WritingAPass/)),
@@ -159,7 +185,7 @@
 
 TODO(silvasean): debugging miscompiles
 
-## Test harnesses
+## Test Harnesses
 
 ### Simple function tests
 
diff --git a/iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.cpp b/iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.cpp
index 31c7d29..1372f3e 100644
--- a/iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.cpp
+++ b/iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.cpp
@@ -228,10 +228,16 @@
       linalg::getLinalgTilingCanonicalizationPatterns(context);
   stage2Patterns.insert<AffineMinCanonicalizationPattern>(context);
 
-  auto stage3Transforms = [this](Operation *op) {
-    // Some of these may be too aggressive as a stage 3 that is applied on each
-    // stage 1 application and may have to be split out to post staged patterns
-    // application (in which case they could just be passes, TBD).
+  auto stage3Transforms = [](Operation *op) {
+    promoteSingleIterationLoops(cast<FuncOp>(op));
+    return success();
+  };
+  linalg::applyStagedPatterns(func, stage1Patterns, stage2Patterns,
+                              stage3Transforms);
+
+  auto postStageTransforms = [this](Operation *op) {
+    // Run LICM and hoisting patterns after all the stages as we want to
+    // unrolling before moving transfer ops out of the loop.
     if (hoistInvariantCode) {
       PassManager pm(op->getContext());
       pm.addPass(createLoopInvariantCodeMotionPass());
@@ -241,11 +247,8 @@
       hoistRedundantVectorTransfers(cast<FuncOp>(op));
       hoistRedundantCopies(cast<FuncOp>(op));
     }
-    promoteSingleIterationLoops(cast<FuncOp>(op));
-    return success();
   };
-  linalg::applyStagedPatterns(func, stage1Patterns, stage2Patterns,
-                              stage3Transforms);
+  postStageTransforms(func);
   if (lowering != nullptr) lowering(func);
 }
 
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
index 3d754a7..35f8bdb 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
@@ -66,10 +66,11 @@
   if (!matchPattern(init, m_Constant(&attr))) return {};
   auto type = attr.getType().dyn_cast<ShapedType>();
   if (!type || type.getRank() != 0) return {};
-  if (auto intType = type.getElementType().dyn_cast<IntegerType>())
+  if (auto intType = type.getElementType().dyn_cast<IntegerType>()) {
     return IntegerAttr::get(intType, attr.getValue<APInt>({}));
-  else if (auto floatType = type.getElementType().dyn_cast<FloatType>())
+  } else if (auto floatType = type.getElementType().dyn_cast<FloatType>()) {
     return FloatAttr::get(floatType, attr.getValue<APFloat>({}));
+  }
   return {};
 }
 
@@ -254,14 +255,17 @@
            a == b;
   };
   if (lhsShape.size() == 1 && rhsShape.size() == 1 &&
-      shapeMatches(lhsShape[0], rhsShape[0]))
+      shapeMatches(lhsShape[0], rhsShape[0])) {
     return DotOperationType::VectorDot;
+  }
   if (lhsShape.size() == 2 && rhsShape.size() == 1 &&
-      shapeMatches(lhsShape[1], rhsShape[0]))
+      shapeMatches(lhsShape[1], rhsShape[0])) {
     return DotOperationType::MatrixVector;
+  }
   if (rhsShape.size() == 2 && rhsShape.size() == 2 &&
-      shapeMatches(lhsShape[1], rhsShape[0]))
+      shapeMatches(lhsShape[1], rhsShape[0])) {
     return DotOperationType::MatrixMatrix;
+  }
   return DotOperationType::Unsupported;
 }
 
@@ -317,8 +321,9 @@
     // batch_count, spatial_dims..., input_feature_count.
     if (dimensionNumbers.input_batch_dimension().getInt() != 0 ||
         dimensionNumbers.input_feature_dimension().getInt() !=
-            (inputSpatialRank + 1))
+            (inputSpatialRank + 1)) {
       return failure();
+    }
 
     const int kernelSpatialRank =
         llvm::size(dimensionNumbers.kernel_spatial_dimensions());
@@ -327,8 +332,9 @@
     if (dimensionNumbers.kernel_input_feature_dimension().getInt() !=
             kernelSpatialRank ||
         dimensionNumbers.kernel_output_feature_dimension().getInt() !=
-            (kernelSpatialRank + 1))
+            (kernelSpatialRank + 1)) {
       return failure();
+    }
 
     const int outputSpatialRank =
         llvm::size(dimensionNumbers.output_spatial_dimensions());
@@ -336,12 +342,14 @@
     // batch_count, spatial_dims.., output_feature_count.
     if (dimensionNumbers.output_batch_dimension().getInt() != 0 ||
         dimensionNumbers.output_feature_dimension().getInt() !=
-            (outputSpatialRank + 1))
+            (outputSpatialRank + 1)) {
       return failure();
+    }
 
     if (inputSpatialRank != outputSpatialRank ||
-        inputSpatialRank != kernelSpatialRank)
+        inputSpatialRank != kernelSpatialRank) {
       return failure();
+    }
 
     auto inputSpatialDim = dimensionNumbers.input_spatial_dimensions().begin();
     auto kernelSpatialDim =
@@ -353,8 +361,9 @@
       const int dim = i + 1;
       if ((*inputSpatialDim++).getZExtValue() != dim ||
           (*outputSpatialDim++).getZExtValue() != dim ||
-          (*kernelSpatialDim++).getZExtValue() != i)
+          (*kernelSpatialDim++).getZExtValue() != i) {
         return failure();
+      }
     }
   }
 
@@ -669,8 +678,9 @@
     ArrayRef<Value> resultBuffers, ConversionPatternRewriter &rewriter) const {
   auto loc = op.getLoc();
   auto argType = inputBuffers[0].getType().template dyn_cast<ShapedType>();
-  if (!argType || !argType.hasRank())
+  if (!argType || !argType.hasRank()) {
     return op.emitError("expected known-rank args");
+  }
 
   SmallVector<Value, 3> offsets, sizes, strides;
   for (int i = 0, e = argType.getRank(); i < e; ++i) {
@@ -731,9 +741,12 @@
   int rank = output.getType().cast<ShapedType>().getRank();
   SmallVector<Attribute, 2> indexingMaps;
   SmallVector<AffineExpr, 4> exprs;
-  for (int i = 0; i < batch; ++i) exprs.push_back(rewriter.getAffineDimExpr(i));
-  for (int i = 0, e = nIndices - batch; i < e; ++i)
+  for (int i = 0; i < batch; ++i) {
+    exprs.push_back(rewriter.getAffineDimExpr(i));
+  }
+  for (int i = 0, e = nIndices - batch; i < e; ++i) {
     exprs.push_back(rewriter.getAffineDimExpr(axis + i));
+  }
   indexingMaps.emplace_back(AffineMapAttr::get(
       AffineMap::get(rank, /*symbolCount=*/0, exprs, rewriter.getContext())));
   indexingMaps.emplace_back(
@@ -763,10 +776,13 @@
   SmallVector<Value, 4> indices;
   Value castedValue = rewriter.create<IndexCastOp>(
       loc, block->getArgument(rank), rewriter.getIndexType());
-  for (int i = 0; i < axis; ++i) indices.push_back(block->getArgument(i));
-  indices.push_back(castedValue);
-  for (int i = axis + nIndices - batch; i < rank; ++i)
+  for (int i = 0; i < axis; ++i) {
     indices.push_back(block->getArgument(i));
+  }
+  indices.push_back(castedValue);
+  for (int i = axis + nIndices - batch; i < rank; ++i) {
+    indices.push_back(block->getArgument(i));
+  }
 
   Value res = rewriter.create<LoadOp>(loc, adaptor.input(), indices);
   rewriter.create<linalg::YieldOp>(loc, res);
@@ -822,8 +838,9 @@
 
   // Create a fake window dimension.
   SmallVector<int64_t, 4> shapes;
-  for (auto dim : op.window_dimensions().getValues<int64_t>())
+  for (auto dim : op.window_dimensions().getValues<int64_t>()) {
     shapes.push_back(dim);
+  }
   Type type = rewriter.getIntegerType(32);
   auto memrefType = MemRefType::get(shapes, type);
   auto fakeWindowDims = rewriter.create<AllocOp>(loc, memrefType);
@@ -889,8 +906,9 @@
   for (auto dim : reductionDims) s.insert(dim);
 
   SmallVector<unsigned, 4> permutation;
-  for (int i = 0; i < rank; ++i)
+  for (int i = 0; i < rank; ++i) {
     if (!s.count(i)) permutation.push_back(i);
+  }
   for (auto dim : reductionDims) permutation.push_back(dim);
 
   auto map = AffineMap::getPermutationMap(permutation, context);
@@ -1002,8 +1020,9 @@
   auto loc = reduceOp.getLoc();
   DenseIntElementsAttr dimensionsAttr = reduceOp.dimensions();
   SmallVector<int, 4> reductionDims;
-  for (const auto &dim : dimensionsAttr.getIntValues())
+  for (const auto &dim : dimensionsAttr.getIntValues()) {
     reductionDims.push_back(dim.getSExtValue());
+  }
 
   // Check if initVal is constant. If so, inline the value into the region.
   Attribute initConstVal = getInitValueAsConst(initVal);
@@ -1020,16 +1039,18 @@
   SmallVector<Attribute, 3> indexingMaps;
   indexingMaps.emplace_back(AffineMapAttr::get(getTransposeMapForReduction(
       rewriter.getContext(), nInputRank, reductionDims)));
-  if (!initConstVal)
+  if (!initConstVal) {
     indexingMaps.emplace_back(AffineMapAttr::get(
         AffineMap::get(nInputRank, /*symbolCount=*/0, rewriter.getContext())));
+  }
   // The indexing map of `dst` should drop the reduction loops. Since the
   // reduction loops now are all in the innermost, drops `reductionDims.size()`
   // dimensions. We don't need an inverse permutation here because they are the
   // same.
   SmallVector<AffineExpr, 4> exprs;
-  for (int i = 0, e = nInputRank - reductionDims.size(); i < e; ++i)
+  for (int i = 0, e = nInputRank - reductionDims.size(); i < e; ++i) {
     exprs.push_back(rewriter.getAffineDimExpr(i));
+  }
   indexingMaps.emplace_back(AffineMapAttr::get(
       exprs.empty()
           ? AffineMap::get(nInputRank, /*symbolCount=*/0, rewriter.getContext())
@@ -1038,7 +1059,9 @@
 
   SmallVector<Type, 2> resultTypes = {};
   SmallVector<Value, 2> linalgOpArgs = {inputBuffers[0]};
-  if (!initConstVal) linalgOpArgs.push_back(inputBuffers[1]);
+  if (!initConstVal) {
+    linalgOpArgs.push_back(inputBuffers[1]);
+  }
   linalgOpArgs.push_back(resultBuffers[0]);
   if (failed(zeroFillBuffer(loc, resultBuffers[0], rewriter))) {
     rewriter.notifyMatchFailure(reduceOp, "failed to zero fill result buffer");
@@ -1141,8 +1164,9 @@
     // type.
     TypeConverter::SignatureConversion signatureConverter(numIndices +
                                                           numTensorOperands);
-    for (int i = 0; i < numIndices; ++i)
+    for (int i = 0; i < numIndices; ++i) {
       signatureConverter.addInputs(i, rewriter.getIndexType());
+    }
     for (auto arg : llvm::enumerate(opArgs)) {
       if (arg.index() < numTensorOperands) {
         signatureConverter.addInputs(
@@ -1224,13 +1248,15 @@
       Shape::TieShapeOp shapeOp, ArrayRef<Value> operands,
       ConversionPatternRewriter &rewriter) const override {
     Shape::TieShapeOp::Adaptor adaptor(operands);
-    if (Value buffer = resolveResult(shapeOp.operand(), adaptor.operand(),
-                                     shapeOp.result(), resultTensorToBufferMap))
+    if (Value buffer =
+            resolveResult(shapeOp.operand(), adaptor.operand(),
+                          shapeOp.result(), resultTensorToBufferMap)) {
       rewriter.replaceOp(shapeOp, buffer);
-    else
+    } else {
       rewriter.replaceOpWithNewOp<Shape::TieShapeOp>(
           shapeOp, getMemrefTypeForTensor(shapeOp.result()), adaptor.operand(),
           adaptor.shape());
+    }
     return success();
   }
 
@@ -1250,8 +1276,9 @@
   LogicalResult matchAndRewrite(IREE::HAL::InterfaceLoadTensorOp loadOp,
                                 ArrayRef<Value> operands,
                                 ConversionPatternRewriter &rewriter) const {
-    if (!matchPattern(loadOp.offset(), m_Zero()))
+    if (!matchPattern(loadOp.offset(), m_Zero())) {
       return loadOp.emitError("unhandled non-zero offset");
+    }
 
     // Get the corresponding memref type from the tensor type.
     auto tensorType = loadOp.result().getType().cast<RankedTensorType>();
@@ -1363,8 +1390,9 @@
 static LogicalResult createAndPropagateBufferUsedForResultTensor(
     IREE::HAL::InterfaceStoreTensorOp op, OutputBufferMap &outputBufferMap,
     TensorToBufferMap &resultTensorToBufferMap, OpBuilder &builder) {
-  if (!matchPattern(op.offset(), m_Zero()))
+  if (!matchPattern(op.offset(), m_Zero())) {
     return op.emitError("unhandled non-zero offset");
+  }
 
   // Get the corresponding memref type from the tensor type.
   Value tensor = op.operand();
@@ -1460,8 +1488,9 @@
   OutputBufferMap outputBufferMap;
   TensorToBufferMap resultTensorToBufferMap;
   if (failed(createAndPropagateBufferUsedForResultTensors(
-          funcOp, outputBufferMap, resultTensorToBufferMap)))
+          funcOp, outputBufferMap, resultTensorToBufferMap))) {
     return signalPassFailure();
+  }
 
   OwningRewritePatternList patterns;
   populateHLOToLinalgOnBuffersConversionPatterns(context, patterns,
@@ -1486,8 +1515,9 @@
   target.addDynamicallyLegalDialect<linalg::LinalgDialect>(
       Optional<ConversionTarget::DynamicLegalityCallbackFn>([](Operation *op) {
         // The generated structured Linalg ops should have buffer semantics.
-        if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op))
+        if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
           return linalgOp.hasBufferSemantics();
+        }
         // The other Linalg ops (like linalg.yield) are okay.
         return true;
       }));
diff --git a/iree/compiler/Conversion/HLOToLinalg/Passes.h b/iree/compiler/Conversion/HLOToLinalg/Passes.h
index 445bd51..b50d243 100644
--- a/iree/compiler/Conversion/HLOToLinalg/Passes.h
+++ b/iree/compiler/Conversion/HLOToLinalg/Passes.h
@@ -48,7 +48,7 @@
 using TensorToBufferMap = DenseMap<Value, Value>;
 void populateHLOToLinalgOnBuffersConversionPatterns(
     MLIRContext *context, OwningRewritePatternList &patterns,
-    TensorToBufferMap const &outputTensorToBuffer);
+    TensorToBufferMap const &resultTensorToBufferMap);
 
 /// Populates passes to convert from XLA-HLO to Linalg on buffers as well as
 /// handling some IREE specific conversions (like iree.interface.* and
diff --git a/iree/compiler/Conversion/HLOToLinalg/ResolveShapeOps.cpp b/iree/compiler/Conversion/HLOToLinalg/ResolveShapeOps.cpp
index cd86250..1ddeaf4 100644
--- a/iree/compiler/Conversion/HLOToLinalg/ResolveShapeOps.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/ResolveShapeOps.cpp
@@ -99,8 +99,9 @@
   ConversionTarget target(*context);
   target.addIllegalOp<DimOp>();
   target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
-  if (failed(applyFullConversion(getFunction(), target, dimPatterns)))
+  if (failed(applyFullConversion(getFunction(), target, dimPatterns))) {
     return signalPassFailure();
+  }
 
   OwningRewritePatternList shapePatterns;
   shapePatterns.insert<TieShapeElider>(context);
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
index 7a80343..7877459 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/BUILD
+++ b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
@@ -26,6 +26,7 @@
         "CooperativeMatrixAnalysis.cpp",
         "LinalgTileAndFusePass.cpp",
         "MarkerUtils.cpp",
+        "MatMulVectorizationTest.cpp",
         "Passes.cpp",
         "SplitDispatchFunctionPass.cpp",
         "Utils.cpp",
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
index 7bca4c8..af6fbe9 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
@@ -30,6 +30,7 @@
     "CooperativeMatrixAnalysis.cpp"
     "LinalgTileAndFusePass.cpp"
     "MarkerUtils.cpp"
+    "MatMulVectorizationTest.cpp"
     "Passes.cpp"
     "SplitDispatchFunctionPass.cpp"
     "Utils.cpp"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
index 6ec315c..d623eae 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
@@ -19,9 +19,11 @@
 //===----------------------------------------------------------------------===//
 
 #include <array>
+#include <numeric>
 
 #include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
 #include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/MemorySpace.h"
 #include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
 #include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
@@ -30,6 +32,7 @@
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/TargetAndABI.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/AffineMap.h"
@@ -41,22 +44,21 @@
 namespace mlir {
 namespace iree_compiler {
 
+static constexpr int kNumDims = 3;
+
+/// In some cases the iterations of the loops when partitioned to workgroups
+/// need to be distributed in a cyclic manner. The main use case here is when
+/// the number of workgroups is constrained such that the number of iterations
+/// is greater than or equal to number of processors (along any dimension). In
+/// those cases, distribute the iterations in a cyclic manner. This adds
+/// additional control flow, but isn't too detrimental to performance since they
+/// are convergent for the most part.
 static llvm::cl::opt<bool> isWorkgroupCountConstrained(
     "iree-codegen-constrained-workgroup-count",
     llvm::cl::desc("Specify whether the number of workgroups can be assumed to "
                    "be large enough to cover the entire workload"),
     llvm::cl::init(false));
 
-// TODO(#2134): Remove this flag/set it to false always when issue with
-// convolution is resolved (see bug for more details).
-// TODO(#2346): Make this a pass specific option.
-llvm::cl::opt<bool> useLegacyConvLowering{
-    "iree-codegen-use-legacy-conv-lowering",
-    llvm::cl::desc("Use conv lowering that does not assume 1:1 mapping "
-                   "between threads within a block and iterations of "
-                   "parallel loops distributed to the block"),
-    llvm::cl::init(true)};
-
 //===----------------------------------------------------------------------===//
 // Loop utilities
 //===----------------------------------------------------------------------===//
@@ -420,7 +422,7 @@
                                         unsigned numDims,
                                         MutableArrayRef<Value> id,
                                         MutableArrayRef<Value> count) {
-  std::array<StringRef, 3> dims{"x", "y", "z"};
+  std::array<StringRef, kNumDims> dims{"x", "y", "z"};
   assert(id.size() == numDims);
   assert(count.size() == numDims);
   for (unsigned i = 0; i < numDims; ++i) {
@@ -431,6 +433,24 @@
   }
 }
 
+template <typename GPUIdOp, typename GPUCountOp>
+static ProcessorIdAndCount getLinearizedGPUProcessorIdAndCount(
+    Location loc, ConversionPatternRewriter &rewriter) {
+  std::array<Value, kNumDims> ids, counts;
+  getGPUProcessorIdsAndCounts<GPUIdOp, GPUCountOp>(loc, rewriter, kNumDims, ids,
+                                                   counts);
+  ProcessorIdAndCount linearized;
+  linearized.id = ids[0];
+  linearized.count = counts[0];
+  for (unsigned i = 0; i < kNumDims - 1; ++i) {
+    linearized.id = rewriter.create<MulIOp>(loc, linearized.id, counts[i + 1]);
+    linearized.id = rewriter.create<AddIOp>(loc, linearized.id, ids[i + 1]);
+    linearized.count =
+        rewriter.create<MulIOp>(loc, linearized.count, counts[i + 1]);
+  }
+  return linearized;
+}
+
 /// Distributes scf.parallel to processors where `IdOp` is used to get the
 /// processor ID and `DimOp` is used to get the number of processors along a
 /// dimension.
@@ -474,16 +494,22 @@
 static LogicalResult mapToWorkgroups(ConversionPatternRewriter &rewriter,
                                      scf::ParallelOp pLoopOp,
                                      bool useCyclicDistribution = false) {
-  if (useCyclicDistribution)
+  if (useCyclicDistribution) {
     return distributeCyclicallyToProcessors<gpu::BlockIdOp, gpu::GridDimOp>(
         rewriter, pLoopOp);
+  }
   return distributeSingleIterationPerProcessor<gpu::BlockIdOp, gpu::GridDimOp>(
       rewriter, pLoopOp, false);
 }
 
 /// Distributes scf.parallel to workitems using local invocation ID.
-static LogicalResult mapToLocalInvocationId(ConversionPatternRewriter &rewriter,
-                                            scf::ParallelOp pLoopOp) {
+static LogicalResult mapToLocalInvocationId(
+    ConversionPatternRewriter &rewriter, scf::ParallelOp pLoopOp,
+    bool useCyclicDistribution = false) {
+  if (useCyclicDistribution) {
+    return distributeCyclicallyToProcessors<gpu::ThreadIdOp, gpu::BlockDimOp>(
+        rewriter, pLoopOp);
+  }
   return distributeSingleIterationPerProcessor<gpu::ThreadIdOp,
                                                gpu::BlockDimOp>(rewriter,
                                                                 pLoopOp);
@@ -498,22 +524,47 @@
       rewriter, pLoopOp);
 }
 
+/// Returns the number of bytes copied when loading to/storing from workgorup
+/// memory. It is approximated to be the size of the underlying allocation being
+/// copied into/from.
+static Optional<int64_t> getLinearizedCopySize(linalg::CopyOp copyOp) {
+  Value src = copyOp.input();
+  Value dst = copyOp.output();
+  MemRefType srcType = src.getType().cast<MemRefType>();
+  MemRefType dstType = dst.getType().cast<MemRefType>();
+
+  Value workgroupMemoryView;
+  MemRefType workgroupMemoryType;
+  if (srcType.getMemorySpace() == getWorkgroupMemorySpace()) {
+    workgroupMemoryView = src;
+    workgroupMemoryType = srcType;
+  } else if (dstType.getMemorySpace() == getWorkgroupMemorySpace()) {
+    workgroupMemoryView = dst;
+    workgroupMemoryType = dstType;
+  } else {
+    return {};
+  }
+
+  SubViewOp workgroupMemorySubviewOp =
+      dyn_cast_or_null<SubViewOp>(workgroupMemoryView.getDefiningOp());
+  if (!workgroupMemorySubviewOp) return {};
+  AllocOp allocOp = dyn_cast_or_null<AllocOp>(
+      workgroupMemorySubviewOp.source().getDefiningOp());
+  if (!allocOp) return {};
+
+  MemRefType allocOpType = allocOp.getType();
+  if (!allocOpType.hasStaticShape()) return {};
+  return allocOpType.getNumElements();
+}
+
 //===----------------------------------------------------------------------===//
 // Pass and patterns.
 //===----------------------------------------------------------------------===//
 
-/// In some cases the iterations of the loops when partitioned to workgroups
-/// need to be distributed in a cyclic manner. The main use cases here is when
-/// the number of workgroups is constrained such that the number of iterations
-/// is greater than equal to number of processors (along any dimension). In
-/// those cases, distribute the iterations in a cyclic manner. This adds
-/// additional control flow, but isn't too detrimental to performance since they
-/// are convergent for the most part.
 // TODO(#2134): Mapping iterations to processors directly by assuming number of
 // iterations <= number of processors again seems to have an issue with
 // convolution/pooling. Needs further investigation.
 static bool useCyclicLoopDistribution(scf::ParallelOp pLoopOp) {
-  if (!useLegacyConvLowering) return false;
   auto walkResult = pLoopOp.walk([](Operation *op) -> WalkResult {
     if (isa<linalg::ConvOp>(op) || isa<linalg::PoolingMaxOp>(op) ||
         isa<linalg::PoolingMinOp>(op) || isa<linalg::PoolingSumOp>(op))
@@ -545,6 +596,70 @@
   }
 };
 
+/// Implementation of the mapping of tiled linalg op to workitems within a
+/// workgroup.
+template <typename LinalgOpTy>
+static LogicalResult mapLinalgOpToLocalInvocationIdImpl(
+    LinalgOpTy linalgOp, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) {
+  // Check for marker that specifies that the linalg op is to be partitioned
+  // across threads within a workgroup.
+  if (!hasMarker(linalgOp)) return failure();
+  Optional<linalg::LinalgLoops> loops =
+      linalg::linalgLowerOpToLoops<scf::ParallelOp>(rewriter, linalgOp);
+  if (!loops) return failure();
+  if (loops.getValue().empty()) return success();
+
+  scf::ParallelOp pLoopOp = dyn_cast<scf::ParallelOp>(loops.getValue()[0]);
+  if (!pLoopOp) return success();
+
+  return mapToLocalInvocationId(
+      rewriter, pLoopOp,
+      hasMarker(linalgOp, {getWorkgroupMarker(), getWorkgroupMemoryMarker()}));
+}
+
+/// CopyOp that are loading to/storing from workgroup memory are special cased
+/// to use all workitems to do a copy. This is done by linearizing the copy
+/// operation.
+// TODO(ravishankarm): This linearization is achieved through collapsing the
+// generated parallel loops from a multi-dimensional copy. Such lowering results
+// in mods/divs in the collapsed loop body. This can be removed by reshaping the
+// copy to be a 1D copy. This seems to be hitting an error in reshape
+// canonicalization. Investigate this further.
+template <>
+LogicalResult mapLinalgOpToLocalInvocationIdImpl<linalg::CopyOp>(
+    linalg::CopyOp copyOp, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) {
+  if (!hasMarker(copyOp, getCopyToWorkgroupMemoryMarker())) return failure();
+  Optional<linalg::LinalgLoops> loops =
+      linalg::linalgLowerOpToLoops<scf::ParallelOp>(rewriter, copyOp);
+  if (!loops) return failure();
+  if (loops.getValue().empty()) return success();
+
+  scf::ParallelOp pLoopOp = dyn_cast<scf::ParallelOp>(loops.getValue()[0]);
+  if (!pLoopOp) return success();
+  pLoopOp = collapseParallelLoops(rewriter, pLoopOp);
+  if (!pLoopOp) return failure();
+
+  Optional<int64_t> copyLength = getLinearizedCopySize(copyOp);
+  ProcessorIdAndCount idAndCount =
+      getLinearizedGPUProcessorIdAndCount<gpu::ThreadIdOp, gpu::BlockDimOp>(
+          copyOp.getLoc(), rewriter);
+  auto workgroupSize =
+      spirv::lookupLocalWorkGroupSize(copyOp).getValues<APInt>();
+  int64_t linearizedWorkgroupSize = std::accumulate(
+      workgroupSize.begin(), workgroupSize.end(), 1,
+      [](int64_t total, APInt value) { return total * value.getSExtValue(); });
+
+  if (copyLength.hasValue() && !workgroupSize.empty() &&
+      copyLength.getValue() <= linearizedWorkgroupSize) {
+    return distributeSingleIterationPerProcessor(
+        rewriter, pLoopOp, idAndCount.id, /*generateGuard=*/true);
+  }
+  return distributeCyclicallyToProcessors(rewriter, pLoopOp, idAndCount.id,
+                                          idAndCount.count);
+}
+
 /// Map tiled linalg op to workitems by lowering it to scf.parallel and
 /// partitioning it to workitems.
 template <typename LinalgOpTy>
@@ -553,41 +668,20 @@
   LogicalResult matchAndRewrite(
       LinalgOpTy linalgOp, ArrayRef<Value> operands,
       ConversionPatternRewriter &rewriter) const override {
-    // Check for marker that specifies that the linalg op is to be partitioned
-    // across threads within a workgroup.
-    if (!hasWorkGroupMarker(linalgOp)) return failure();
-    Optional<linalg::LinalgLoops> loops =
-        linalg::linalgLowerOpToLoops<scf::ParallelOp>(rewriter, linalgOp);
-    if (!loops) return failure();
-    if (!loops.getValue().empty()) {
-      scf::ParallelOp pLoopOp = dyn_cast<scf::ParallelOp>(loops.getValue()[0]);
-      if (!pLoopOp || failed(mapToLocalInvocationId(rewriter, pLoopOp)))
-        return failure();
-    }
-    rewriter.eraseOp(linalgOp);
-    return success();
-  }
-};
-
-/// Legacy path for lowering tiled conv/pooling op to loops.
-// TODO(#2134): Remove this pattern. The default path of using
-// `MapLinalgOpToLocalInvocationId` seems to have a bug. It only shows up
-// currently on Resnet50. Remove this pattern after the bug is triaged/fixed.
-template <typename LinalgOpTy>
-struct MapConvPoolToLocalInvocationId : public OpConversionPattern<LinalgOpTy> {
-  using OpConversionPattern<LinalgOpTy>::OpConversionPattern;
-  LogicalResult matchAndRewrite(
-      LinalgOpTy linalgOp, ArrayRef<Value> operands,
-      ConversionPatternRewriter &rewriter) const override {
-    if (!hasWorkGroupMarker(linalgOp)) return failure();
-    Optional<linalg::LinalgLoops> loops =
-        linalg::linalgLowerOpToLoops<scf::ParallelOp>(rewriter, linalgOp);
-    if (!loops) return failure();
-    scf::ParallelOp pLoopOp = cast<scf::ParallelOp>(loops.getValue()[0]);
     if (failed(
-            distributeCyclicallyToProcessors<gpu::ThreadIdOp, gpu::BlockDimOp>(
-                rewriter, pLoopOp)))
+            mapLinalgOpToLocalInvocationIdImpl(linalgOp, operands, rewriter)))
       return failure();
+
+    // If the `linalgOp` writes to workgroup memory insert barrier after the
+    // op.
+    if (llvm::any_of(linalgOp.getOperands(), [](Value output) {
+          return output.getType().cast<MemRefType>().getMemorySpace() ==
+                 getWorkgroupMemorySpace();
+        })) {
+      rewriter.create<spirv::ControlBarrierOp>(
+          linalgOp.getLoc(), spirv::Scope::Workgroup, spirv::Scope::Workgroup,
+          spirv::MemorySemantics::AcquireRelease);
+    }
     rewriter.eraseOp(linalgOp);
     return success();
   }
@@ -674,39 +768,17 @@
 
   OwningRewritePatternList patterns;
 
-  // clang-format off
-  patterns.insert<
-
-#define ADD_ALL_LINALG_PATTERNS(OP_NAME)                                       \
-    MapLinalgOpToGlobalInvocationId<OP_NAME>,                                  \
-    MapLinalgOpToLocalInvocationId<OP_NAME>
-
-    ADD_ALL_LINALG_PATTERNS(linalg::CopyOp),
-    ADD_ALL_LINALG_PATTERNS(linalg::FillOp),
-    ADD_ALL_LINALG_PATTERNS(linalg::GenericOp),
-    ADD_ALL_LINALG_PATTERNS(linalg::IndexedGenericOp),
-
-#undef ADD_ALL_LINALG_PATTERNS
-
-#define ADD_ALL_CONV_POOL_PATTERNS(OP_NAME)                                    \
-    MapConvPoolToLocalInvocationId<OP_NAME>,                                   \
-    MapLinalgOpToGlobalInvocationId<OP_NAME>
-
-    ADD_ALL_CONV_POOL_PATTERNS(linalg::PoolingMaxOp),
-    ADD_ALL_CONV_POOL_PATTERNS(linalg::PoolingMinOp),
-    ADD_ALL_CONV_POOL_PATTERNS(linalg::PoolingSumOp),
-
-#undef ADD_ALL_CONV_POOL_PATTERNS
-
-    MapLinalgOpToLocalInvocationId<linalg::MatmulOp>,
-    PartitionPLoopToWorkgroups, RemoveLinalgRange>(context);
-  // clang-format on
-
-  patterns.insert<MapLinalgOpToGlobalInvocationId<linalg::ConvOp>>(context);
-  if (useLegacyConvLowering)
-    patterns.insert<MapConvPoolToLocalInvocationId<linalg::ConvOp>>(context);
-  else
-    patterns.insert<MapLinalgOpToLocalInvocationId<linalg::ConvOp>>(context);
+  patterns.insert<MapLinalgOpToGlobalInvocationId<linalg::CopyOp>,
+                  MapLinalgOpToGlobalInvocationId<linalg::FillOp>,
+                  MapLinalgOpToGlobalInvocationId<linalg::GenericOp>,
+                  MapLinalgOpToGlobalInvocationId<linalg::IndexedGenericOp>,
+                  MapLinalgOpToLocalInvocationId<linalg::ConvOp>,
+                  MapLinalgOpToLocalInvocationId<linalg::CopyOp>,
+                  MapLinalgOpToLocalInvocationId<linalg::MatmulOp>,
+                  MapLinalgOpToLocalInvocationId<linalg::PoolingMaxOp>,
+                  MapLinalgOpToLocalInvocationId<linalg::PoolingMinOp>,
+                  MapLinalgOpToLocalInvocationId<linalg::PoolingSumOp>,
+                  PartitionPLoopToWorkgroups, RemoveLinalgRange>(context);
 
   if (failed(applyFullConversion(funcOp, target, patterns)))
     return signalPassFailure();
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index 934e5ae..f83d444 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -25,18 +25,16 @@
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
-#include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/TargetAndABI.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/Identifier.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/FoldUtils.h"
 
 #define DEBUG_TYPE "iree-linalg-tile-and-fuse-buffer"
 
-static std::string PromotionMarker = "promotion";
-
 namespace mlir {
 namespace iree_compiler {
 
@@ -81,26 +79,24 @@
     workgroupSize.resize(3, 1);
   }
 
-  /// Compute the tile sizes based on workgroup size specified.
-  LogicalResult setTileSizesBasedOnWorkgroupSize(
-      ArrayRef<int64_t> vWorkGroupSize) {
-    if (!vWorkGroupSize.empty()) {
-      vWorkGroupSize = dropTrailingOnes(vWorkGroupSize);
-      workgroupSize.assign(vWorkGroupSize.begin(), vWorkGroupSize.end());
-      auto rev = reverse(workgroupSize);
-      tileSizes.assign(rev.begin(), rev.end());
-    }
-    return success();
+  /// Set tile sizes to use.
+  void setTileSizes(ArrayRef<int64_t> sizes) {
+    tileSizes.assign(sizes.begin(), sizes.end());
+  }
+
+  /// Set workgroup size to use.
+  void setWorkgroupSize(ArrayRef<int64_t> sizes) {
+    workgroupSize.assign(sizes.begin(), sizes.end());
   }
 
   /// Compute the tile sizes based on the Linalg Ops within the dispatch region.
-  LogicalResult setTileSizesBasedOnOps(ArrayRef<linalg::LinalgOp> linalgOps);
+  LogicalResult inferTileAndWorkgroupSize(ArrayRef<linalg::LinalgOp> linalgOps);
 
   /// Get the current tile size computed.
   ArrayRef<int64_t> getTileSizes() const { return tileSizes; }
 
   /// Returns the workgroup size to use based on the tile sizes.
-  ArrayRef<int64_t> getWorkGroupSize() const { return workgroupSize; }
+  ArrayRef<int64_t> getWorkgroupSize() const { return workgroupSize; }
 
  private:
   /// Current tile size configuration.
@@ -114,7 +110,7 @@
 };
 }  // namespace
 
-LogicalResult TileSizeCalculator::setTileSizesBasedOnOps(
+LogicalResult TileSizeCalculator::inferTileAndWorkgroupSize(
     ArrayRef<linalg::LinalgOp> linalgOps) {
   tileSizes.clear();
   if (linalgOps.empty()) {
@@ -134,11 +130,16 @@
   uint32_t opInfo = OpInfo::None;
   for (linalg::LinalgOp linalgOp : linalgOps) {
     Operation *op = linalgOp.getOperation();
-    if (isa<linalg::ConvOp>(op)) opInfo |= OpInfo::Convolution;
-    if (isa<linalg::MatmulOp>(op)) opInfo |= OpInfo::Matmul;
-    if (isa<linalg::PoolingMaxOp>(op)) opInfo |= OpInfo::Pooling;
-    if (isa<linalg::PoolingMinOp>(op)) opInfo |= OpInfo::Pooling;
-    if (isa<linalg::PoolingSumOp>(op)) opInfo |= OpInfo::Pooling;
+    if (isa<linalg::ConvOp>(op))
+      opInfo |= OpInfo::Convolution;
+    else if (isa<linalg::MatmulOp>(op))
+      opInfo |= OpInfo::Matmul;
+    else if (isa<linalg::PoolingMaxOp>(op))
+      opInfo |= OpInfo::Pooling;
+    else if (isa<linalg::PoolingMinOp>(op))
+      opInfo |= OpInfo::Pooling;
+    else if (isa<linalg::PoolingSumOp>(op))
+      opInfo |= OpInfo::Pooling;
   }
   // If there are no tilable ops, there is nothing to do here.
   if (!opInfo) return success();
@@ -155,10 +156,6 @@
   unsigned maxWorkgroupSize =
       resourceLimits.max_compute_workgroup_invocations().getInt();
   if (opInfo & OpInfo::Convolution) {
-    // TODO(ravishankarm): This tiling is meant to enable promotion to workgroup
-    // memory, but doesnt actually get us to a state where we can do this. The
-    // promotion is possible only when the subviews created are constant
-    // size. For now this doesnt really matter. Revisit this later.
     int64_t tileSizeX = 32;
     int64_t tileSizeY = maxWorkgroupSize / 32;
     tileSizes = {1, tileSizeY, tileSizeX};
@@ -187,18 +184,22 @@
 //===----------------------------------------------------------------------===//
 
 /// Allocation callback for allocation workgroup local memory.
-static Value allocateWorkgroupMemory(OpBuilder &b, SubViewOp subview,
-                                     ArrayRef<Value> boundingSubViewSize,
-                                     OperationFolder *folder) {
+static Optional<Value> allocateWorkgroupMemory(
+    OpBuilder &b, SubViewOp subview, ArrayRef<Value> boundingSubViewSize,
+    OperationFolder *folder) {
   // The bounding subview size is expected to be constant. This specified the
   // shape of the allocation.
-  SmallVector<int64_t, 2> shape(boundingSubViewSize.size(),
-                                ShapedType::kDynamicSize);
-  return b.create<AllocOp>(
-      subview.getLoc(),
-      MemRefType::get(shape, subview.getType().getElementType(), {},
-                      getWorkgroupMemorySpace()),
-      boundingSubViewSize);
+  SmallVector<int64_t, 2> shape = llvm::to_vector<2>(
+      llvm::map_range(boundingSubViewSize, [](Value v) -> int64_t {
+        APInt value;
+        if (matchPattern(v, m_ConstantInt(&value))) return value.getSExtValue();
+        return -1;
+      }));
+  if (llvm::any_of(shape, [](int64_t v) { return v == -1; })) return {};
+  MemRefType allocType = MemRefType::get(
+      shape, subview.getType().getElementType(), {}, getWorkgroupMemorySpace());
+  Value buffer = b.create<AllocOp>(subview.getLoc(), allocType);
+  return buffer;
 }
 
 /// Deallocation callback for allocation workgroup local memory.
@@ -208,56 +209,50 @@
   return success();
 }
 
-/// Insert barrier after `op`.
-static void insertBarrierAfter(OpBuilder &b, Location loc, Operation *op) {
-  OpBuilder::InsertionGuard guard(b);
-  b.setInsertionPointAfter(op);
-  b.create<spirv::ControlBarrierOp>(loc, spirv::Scope::Workgroup,
-                                    spirv::Scope::Workgroup,
-                                    spirv::MemorySemantics::AcquireRelease);
-}
-
-/// Function used as callback for copyin/copyout in promotion pattern used to
-/// promote subviews to workgroup memory.
-static LogicalResult copyToFromWorkgroupMemory(
-    OpBuilder &b, Value src, Value dst, StringRef marker = PromotionMarker) {
-  auto copyOp = b.create<linalg::CopyOp>(src.getLoc(), src, dst);
-  setMarker(copyOp, marker);
-  return success();
-}
-
 namespace {
 /// Function pass that implements tiling and fusion in Linalg on buffers.
 struct LinalgTileAndFusePass
     : public PassWrapper<LinalgTileAndFusePass, FunctionPass> {
-  LinalgTileAndFusePass(ArrayRef<int64_t> workGroupSize = {},
-                        bool useWorkgroupMem = false)
-      : workGroupSize(workGroupSize.begin(), workGroupSize.end()) {
+  LinalgTileAndFusePass(ArrayRef<int64_t> workgroupSize = {},
+                        ArrayRef<int64_t> tileSizes = {},
+                        bool useWorkgroupMem = false) {
+    this->workgroupSize = workgroupSize;
+    this->tileSizes = tileSizes;
     this->useWorkgroupMemory = useWorkgroupMem;
   }
   LinalgTileAndFusePass(const LinalgTileAndFusePass &pass) {}
 
   void runOnFunction() override;
 
+ private:
   Option<bool> useWorkgroupMemory{
       *this, "use-workgroup-memory",
       llvm::cl::desc("Promote subviews to use workgroup memory"),
       llvm::cl::init(false)};
 
- private:
-  SmallVector<int64_t, 3> workGroupSize;
+  ListOption<int64_t> workgroupSize{
+      *this, "workgroup-size",
+      llvm::cl::desc("Override the default workgroup size"),
+      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+
+  ListOption<int64_t> tileSizes{
+      *this, "tile-sizes", llvm::cl::desc("Set tile sizes to use"),
+      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
 };
 
 /// Pattern for tiling operations. Updates the workgroup size in the surrounding
 /// function operation if tiling succeeds.
-template <typename OpTy>
-struct TilingPattern : public linalg::LinalgTilingPattern<OpTy> {
-  using Base = linalg::LinalgTilingPattern<OpTy>;
-  TilingPattern(MLIRContext *context, linalg::LinalgTilingOptions options,
-                ArrayRef<int64_t> workgroupSize,
-                linalg::LinalgMarker marker = linalg::LinalgMarker(),
-                PatternBenefit benefit = 1)
-      : Base(context, options, marker, benefit),
+struct TileMatmulPattern
+    : public linalg::LinalgTilingPattern<linalg::MatmulOp> {
+  using Base = linalg::LinalgTilingPattern<linalg::MatmulOp>;
+  TileMatmulPattern(MLIRContext *context, linalg::LinalgTilingOptions options,
+                    ArrayRef<int64_t> workgroupSize, PatternBenefit benefit = 1)
+      : Base(context, options,
+             linalg::LinalgMarker(
+                 ArrayRef<Identifier>(),
+                 Identifier::get(getWorkgroupNumItemsGENumItersMarker(),
+                                 context)),
+             benefit),
         workgroupSize(workgroupSize.begin(), workgroupSize.end()) {}
 
   virtual LogicalResult matchAndRewrite(Operation *op,
@@ -280,9 +275,17 @@
 /// Pattern for tiling convolution and pooling operations. Currently is just a
 /// way to not tile when the operation has padding.
 template <typename OpTy>
-struct TileConvPoolPattern : public TilingPattern<OpTy> {
-  using Base = TilingPattern<OpTy>;
-  using Base::TilingPattern;
+struct TileConvPoolPattern : public linalg::LinalgTilingPattern<OpTy> {
+  using Base = linalg::LinalgTilingPattern<OpTy>;
+  TileConvPoolPattern(MLIRContext *context, linalg::LinalgTilingOptions options,
+                      ArrayRef<int64_t> workgroupSize,
+                      PatternBenefit benefit = 1)
+      : Base(context, options,
+             linalg::LinalgMarker(
+                 ArrayRef<Identifier>(),
+                 Identifier::get(getWorkgroupMarker(), context)),
+             benefit),
+        workgroupSize(workgroupSize.begin(), workgroupSize.end()) {}
 
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
@@ -296,28 +299,69 @@
                        WorkgroupCountMethodology::Default)));
     return success();
   }
+
+  SmallVector<int64_t, 3> workgroupSize;
 };
 
-/// Pattern to promote subviews to memory.
-// TODO(ravishankarm): Generalize this for other operations.
-struct PromoteSubviewsPattern
+//===----------------------------------------------------------------------===//
+// Patterns to promote subviews to workgroup memory
+//===----------------------------------------------------------------------===//
+
+/// Function used as callback for copyin/copyout in promotion pattern used to
+/// promote subviews to workgroup memory when the number of threads is known to
+/// be greater than equal to the number of iteration of loops the copy is
+/// lowered to.
+static LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst) {
+  auto copyOp = b.create<linalg::CopyOp>(src.getLoc(), src, dst);
+  setMarker(copyOp, getCopyToWorkgroupMemoryMarker());
+  return success();
+}
+
+/// Pattern to promote matmul operands to workgroup memory.
+struct PromoteMatmulSubviewsPattern
     : public linalg::LinalgPromotionPattern<linalg::MatmulOp> {
-  PromoteSubviewsPattern(MLIRContext *context,
-                         linalg::LinalgPromotionOptions options,
-                         linalg::LinalgMarker marker = linalg::LinalgMarker(),
-                         PatternBenefit benefit = 1)
+  PromoteMatmulSubviewsPattern(
+      MLIRContext *context, linalg::LinalgPromotionOptions options,
+      linalg::LinalgMarker marker = linalg::LinalgMarker(),
+      PatternBenefit benefit = 1)
       : linalg::LinalgPromotionPattern<linalg::MatmulOp>(
             context,
             options.setOperandsToPromote({0, 1}).setUseFullTileBuffers(
                 {false, false}),
-            marker, benefit) {}
+            linalg::LinalgMarker(
+                Identifier::get(getWorkgroupNumItemsGENumItersMarker(),
+                                context),
+                Identifier::get(getWorkgroupMemoryNumItemsGENumItersMarker(),
+                                context)),
+            benefit) {}
+};
 
-  LogicalResult matchAndRewrite(Operation *op,
-                                PatternRewriter &rewriter) const override {
-    if (!hasWorkGroupMarker(op)) return failure();
-    return linalg::LinalgPromotionPattern<linalg::MatmulOp>::matchAndRewrite(
-        op, rewriter);
-  }
+/// Patterns to promote convolution operands to workgroup memory.
+// TODO(ravishankarm): This pattern is only promoting the image subview to
+// workgroup memory. In reality we should also be able to promote the filter
+// subview to workgroup memory as well. Since none of the loops used to access
+// the filter are tiled, this would mean the entire filter is moved to workgroup
+// memory. Two reasons this is not done right now:
+// 1) Linalg when tiling doesnt create a subview for the filter (since none of
+//    its dimensions are tiled. This needs to be relaxed (maybe by using an
+//    option.
+// 2) Maybe there are better alternatives for handling filter (using different
+//    StorageClasses, since for inference workloads these are model
+//    constants. This is TBD.
+struct PromoteConvolutionSubviewsPattern
+    : public linalg::LinalgPromotionPattern<linalg::ConvOp> {
+  PromoteConvolutionSubviewsPattern(
+      MLIRContext *context, linalg::LinalgPromotionOptions options,
+      linalg::LinalgMarker marker = linalg::LinalgMarker(),
+      PatternBenefit benefit = 1)
+      : linalg::LinalgPromotionPattern<linalg::ConvOp>(
+            context,
+            options.setOperandsToPromote({1}).setUseFullTileBuffers(
+                {false, false}),
+            linalg::LinalgMarker(
+                Identifier::get(getWorkgroupMarker(), context),
+                Identifier::get(getWorkgroupMemoryMarker(), context)),
+            benefit) {}
 };
 }  // namespace
 
@@ -334,28 +378,29 @@
   if (linalgOps.empty()) return;
 
   TileSizeCalculator tileSizeCalculator(funcOp);
-  if (workGroupSize.empty()) {
+  if (tileSizes.empty()) {
     // Get the tile sizes to use for the lowering.
     SmallVector<int64_t, 3> tileSizes;
     SmallVector<linalg::LinalgOp, 1> opsVec(linalgOps.begin(), linalgOps.end());
-    if (failed(tileSizeCalculator.setTileSizesBasedOnOps(opsVec)))
+    if (failed(tileSizeCalculator.inferTileAndWorkgroupSize(opsVec)))
       return signalPassFailure();
   } else {
-    tileSizeCalculator.setTileSizesBasedOnWorkgroupSize(workGroupSize);
+    tileSizeCalculator.setTileSizes(tileSizes);
+    if (!workgroupSize.empty())
+      tileSizeCalculator.setWorkgroupSize(workgroupSize);
   }
 
   LLVM_DEBUG({
     llvm::dbgs() << "--- IREE Linalg tile and fuse configuration ---\n";
-    llvm::dbgs() << "# workgroup sizes at start: [";
-    interleaveComma(workGroupSize, llvm::dbgs());
+    llvm::dbgs() << "# workgroup sizes: [";
+    interleaveComma(tileSizeCalculator.getWorkgroupSize(), llvm::dbgs());
     llvm::dbgs() << "]\ntile sizes: [";
     interleaveComma(tileSizeCalculator.getTileSizes(), llvm::dbgs());
     llvm::dbgs() << "]\n";
   });
 
   OwningRewritePatternList tilingPatterns;
-  tilingPatterns.insert<TileConvPoolPattern<linalg::ConvOp>,
-                        TilingPattern<linalg::MatmulOp>,
+  tilingPatterns.insert<TileConvPoolPattern<linalg::ConvOp>, TileMatmulPattern,
                         TileConvPoolPattern<linalg::PoolingMaxOp>,
                         TileConvPoolPattern<linalg::PoolingMinOp>,
                         TileConvPoolPattern<linalg::PoolingSumOp>>(
@@ -363,9 +408,7 @@
       linalg::LinalgTilingOptions()
           .setTileSizes(tileSizeCalculator.getTileSizes())
           .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops),
-      tileSizeCalculator.getWorkGroupSize(),
-      linalg::LinalgMarker(ArrayRef<Identifier>(),
-                           Identifier::get(getWorkGroupMarker(), context)));
+      tileSizeCalculator.getWorkgroupSize());
   applyPatternsAndFoldGreedily(getOperation(), tilingPatterns);
 
   if (useWorkgroupMemory) {
@@ -373,31 +416,15 @@
     // sure that the allocated scratchspace memory is constant sizes which
     // requires some folding to trigger.
     OwningRewritePatternList promotionPatterns;
-    promotionPatterns.insert<PromoteSubviewsPattern>(
+    promotionPatterns.insert<PromoteMatmulSubviewsPattern,
+                             PromoteConvolutionSubviewsPattern>(
         context,
         linalg::LinalgPromotionOptions()
             .setAllocationDeallocationFns(allocateWorkgroupMemory,
                                           deallocateWorkgroupMemory)
-            .setCopyInOutFns(
-                [&](OpBuilder &b, Value src, Value dst) -> LogicalResult {
-                  return copyToFromWorkgroupMemory(b, src, dst);
-                },
-                [&](OpBuilder &b, Value src, Value dst) -> LogicalResult {
-                  return copyToFromWorkgroupMemory(b, src, dst);
-                }),
-        linalg::LinalgMarker(Identifier::get(getWorkGroupMarker(), context),
-                             Identifier::get(PromotionMarker, context)));
+            .setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory));
     applyPatternsAndFoldGreedily(getOperation(), promotionPatterns);
   }
-
-  // Add barrier after all linalg operations marked with workitem marker.
-  OpBuilder builder(context);
-  funcOp.walk([&builder](linalg::LinalgOp linalgOp) {
-    if (hasMarker(linalgOp, PromotionMarker)) {
-      setWorkGroupMarker(linalgOp);
-      insertBarrierAfter(builder, linalgOp.getLoc(), linalgOp);
-    }
-  });
 }
 
 //===----------------------------------------------------------------------===//
@@ -405,8 +432,9 @@
 //===----------------------------------------------------------------------===//
 
 std::unique_ptr<OperationPass<FuncOp>> createLinalgTileAndFusePass(
-    ArrayRef<int64_t> workGroupSize, bool useWorkgroupMemory) {
-  return std::make_unique<LinalgTileAndFusePass>(workGroupSize,
+    ArrayRef<int64_t> workgroupSize, ArrayRef<int64_t> tileSizes,
+    bool useWorkgroupMemory) {
+  return std::make_unique<LinalgTileAndFusePass>(workgroupSize, tileSizes,
                                                  useWorkgroupMemory);
 }
 
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
index 47747de..9c45419 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
@@ -26,22 +26,30 @@
 };
 const StringLiteral VectorTransforms::kVectorTransformMarker =
     "__internal_vector_transform__";
-/// Checks if the operation has the `marker` If `marker` is null string, checks
-/// if any marker is set.
-static bool checkMarkerValue(Operation *op, StringRef marker = "") {
+
+StringRef getWorkgroupMarker() { return "workgroup"; }
+
+StringRef getWorkgroupMemoryMarker() { return "workgroup_memory"; }
+
+StringRef getWorkgroupNumItemsGENumItersMarker() {
+  return "workgroup_numprocs_ge_numiters";
+}
+
+StringRef getWorkgroupMemoryNumItemsGENumItersMarker() {
+  return "workgroup_memory_numprocs_ge_numiters";
+}
+
+StringRef getCopyToWorkgroupMemoryMarker() {
+  return "copy_to_workgroup_memory";
+}
+
+bool hasMarker(Operation *op, ArrayRef<StringRef> marker) {
   StringAttr attr = op->getAttrOfType<StringAttr>(
       linalg::LinalgTransforms::kLinalgTransformMarker);
-  return attr && (marker.empty() || attr.getValue() == marker);
-}
-
-StringRef getWorkGroupMarker() { return "workgroup"; }
-
-bool hasMarker(Operation *op, StringRef marker) {
-  return checkMarkerValue(op, marker);
-}
-
-bool hasWorkGroupMarker(Operation *op) {
-  return checkMarkerValue(op, getWorkGroupMarker());
+  return attr && (marker.empty() ||
+                  llvm::any_of(marker, [&attr](StringRef markerValue) {
+                    return attr.getValue() == markerValue;
+                  }));
 }
 
 void setMarker(Operation *op, StringRef marker) {
@@ -49,7 +57,5 @@
               StringAttr::get(marker, op->getContext()));
 }
 
-void setWorkGroupMarker(Operation *op) { setMarker(op, getWorkGroupMarker()); }
-
 }  // namespace iree_compiler
 }  // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
index e512ead..222291d 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
@@ -22,7 +22,7 @@
 #ifndef IREE_COMPILER_CONVERSION_LINALGTOSPIRV_MARKERUTILS_H_
 #define IREE_COMPILER_CONVERSION_LINALGTOSPIRV_MARKERUTILS_H_
 
-#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "mlir/Support/LLVM.h"
 
 namespace mlir {
@@ -30,24 +30,29 @@
 class Operation;
 namespace iree_compiler {
 
-/// Marker to denote that a linalg operation is to be partitioned to workitems.
-StringRef getWorkGroupMarker();
+/// Marker to denote that a linalg operation is to be partitioned to
+/// workitems. No assumption can be made about the number of woritems in the
+/// workgroup and number of iterations, i.e. a cyclic distribution is required.
+StringRef getWorkgroupMarker();
+StringRef getWorkgroupMemoryMarker();
+
+/// Marker to denote that a linalg operation is to be partitioned to workitems
+/// with the assumption that the number of workitems in the workgroup is greater
+/// than equal to the number of iterations.
+StringRef getWorkgroupNumItemsGENumItersMarker();
+StringRef getWorkgroupMemoryNumItemsGENumItersMarker();
+
+/// Marker for copy operations that are moving data from StorageClass to
+/// Workgroup memory.
+StringRef getCopyToWorkgroupMemoryMarker();
 
 /// Returns true if an operation has the specified `marker`. When `marker` is
 /// empty, returns true if the operation has any marker.
-bool hasMarker(Operation *, StringRef marker = "");
-
-/// Returns true if an operation has marker to denote that it is to be
-/// partitioned to workitems.
-bool hasWorkGroupMarker(Operation *);
+bool hasMarker(Operation *, ArrayRef<StringRef> markers = {});
 
 /// Sets a given marker on an operation.
 void setMarker(Operation *, StringRef);
 
-/// Sets marker to denote that a linalg operation is to be partitioned to
-/// workitems.
-void setWorkGroupMarker(Operation *);
-
 }  // namespace iree_compiler
 }  // namespace mlir
 
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/MatMulVectorizationTest.cpp b/iree/compiler/Conversion/LinalgToSPIRV/MatMulVectorizationTest.cpp
new file mode 100644
index 0000000..9183933
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/MatMulVectorizationTest.cpp
@@ -0,0 +1,74 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+
+static llvm::cl::opt<int> wgTileSize(
+    "iree-codegen-linalg-to-gpu-wg-tile-size",
+    llvm::cl::desc(
+        "Specify the size of workgroup tile for matmul vector lowering"),
+    llvm::cl::init(32));
+
+static llvm::cl::list<uint32_t> unrollSize(
+    "iree-codegen-linalg-to-gpu-unroll-size",
+    llvm::cl::desc("Specify the size of the "), llvm::cl::CommaSeparated);
+
+static llvm::cl::opt<bool> enableLICM(
+    "iree-codegen-linalg-to-gpu-matmul-licm",
+    llvm::cl::desc(
+        "If true run LICM and hoisting passes after the staged transforms"),
+    llvm::cl::init(true));
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+struct MatMulTileAndVectorizeGPUPass
+    : PassWrapper<MatMulTileAndVectorizeGPUPass, FunctionPass> {
+  void runOnFunction() override;
+};
+}  // namespace
+
+void MatMulTileAndVectorizeGPUPass::runOnFunction() {
+  FuncOp fn = getFunction();
+  SmallVector<uint32_t, 3> vUnrollSize(unrollSize.begin(), unrollSize.end());
+  if (vUnrollSize.size() != 3) signalPassFailure();
+  MatmulCodegenStrategy strategy;
+  strategy
+      .tile<linalg::MatmulOp>(
+          linalg::LinalgTilingOptions()
+              // TODO(thomasraoux): Enable parallel loops once affine.min
+              // canonicalize supports it.
+              //.setLoopType(linalg::LinalgTilingLoopType::ParallelLoops)
+              .setTileSizes({wgTileSize, wgTileSize, wgTileSize}))
+      .setHoistInvariantCode(enableLICM)
+      .vectorize<linalg::MatmulOp>()
+      .unrollVector<vector::ContractionOp>(
+          {vUnrollSize[0], vUnrollSize[1], vUnrollSize[2]});
+  strategy.transform(fn);
+}
+
+std::unique_ptr<FunctionPass> createMatMulTileAndVectorizeGPUPass() {
+  return std::make_unique<MatMulTileAndVectorizeGPUPass>();
+}
+
+static PassRegistration<MatMulTileAndVectorizeGPUPass> pass(
+    "iree-codegen-linalg-to-gpu-matmul-vectorization-pass",
+    "Tile and vectorize linalg.matmul operation",
+    [] { return std::make_unique<MatMulTileAndVectorizeGPUPass>(); });
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
index fd7b316..5c1b672 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
@@ -59,6 +59,9 @@
           "three integers standarding for the x, y, and z dimension; "
           "additional arguments will be ignored (used only for testing)"),
       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+  ListOption<int64_t> tileSizes{
+      *this, "tile-sizes", llvm::cl::desc("Set tile sizes to use"),
+      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
   Option<bool> useWorkgroupMemory{
       *this, "use-workgroup-memory",
       llvm::cl::desc(
@@ -97,8 +100,8 @@
   //   afterwards. This gives each Linalg op a second chance to be tiled,
   //   with the second tile and fuse pass.
   //===--------------------------------------------------------------------===//
-  pm.addPass(createLinalgTileAndFusePass(options.workgroupSize,
-                                         options.useWorkgroupMemory));
+  pm.addPass(createLinalgTileAndFusePass(
+      options.workgroupSize, options.tileSizes, options.useWorkgroupMemory));
   pm.addPass(createSplitDispatchFunctionPass());
   pm.addPass(createLinalgTileAndFusePass(options.workgroupSize,
                                          options.useWorkgroupMemory));
@@ -221,6 +224,7 @@
   SPIRVCodegenOptions options;
   options.workgroupSize.assign(clOpts.workgroupSize.begin(),
                                clOpts.workgroupSize.end());
+  options.tileSizes.assign(clOpts.tileSizes.begin(), clOpts.tileSizes.end());
   options.useWorkgroupMemory = clOpts.useWorkgroupMemory;
   return options;
 }
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
index 8267780..35de230 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
@@ -25,6 +25,7 @@
 // Options that can be used to configure SPIR-V codegeneration.
 struct SPIRVCodegenOptions {
   SmallVector<int64_t, 3> workgroupSize = {};
+  SmallVector<int64_t, 3> tileSizes = {};
   bool useWorkgroupMemory = false;
 };
 
@@ -35,7 +36,8 @@
 /// it exists) and along "z" for the next loop (if it exists). The workgroup
 /// size is expected to be of size at-most 3.
 std::unique_ptr<OperationPass<FuncOp>> createLinalgTileAndFusePass(
-    ArrayRef<int64_t> workGroupSize = {}, bool useWorkgroupMem = false);
+    ArrayRef<int64_t> workGroupSize = {}, ArrayRef<int64_t> tileSizes = {},
+    bool useWorkgroupMem = false);
 
 /// Pass to add the synchronizations and attributes needed to lower from PLoops
 /// to GPU dialect.
@@ -60,6 +62,9 @@
 /// vector size equal to subgroup size are distributed across the subgroup.
 std::unique_ptr<OperationPass<FuncOp>> createVectorToGPUPass();
 
+/// Pass to apply tiling and vectorization transformations on linagl::MatMulOp.
+std::unique_ptr<FunctionPass> createMatMulTileAndVectorizeGPUPass();
+
 /// Populates passes needed to lower a XLA HLO op to SPIR-V dialect via the
 /// structured ops path. The pass manager `pm` in here operate on the module
 /// within the IREE::HAL::ExecutableOp. The `workGroupSize` can be used to
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
index 64621f3..599af08 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
@@ -162,7 +162,7 @@
         %12 = dim %arg2, %c1 : memref<?x?xf32>
         %13 = affine.min #map0(%arg4)[%12]
         %14 = subview %arg2[%arg3, %arg4] [%11, %13] [1, 1]  : memref<?x?xf32> to memref<?x?xf32, #map2>
-        linalg.matmul %5, %9, %14 {__internal_linalg_transform__ = "workgroup"} : (memref<?x?xf32, #map2>, memref<?x?xf32, #map2>, memref<?x?xf32, #map2>)
+        linalg.matmul %5, %9, %14 {__internal_linalg_transform__ = "workgroup_numprocs_ge_numiters"} : (memref<?x?xf32, #map2>, memref<?x?xf32, #map2>, memref<?x?xf32, #map2>)
       }
       scf.yield
     }
@@ -288,57 +288,6 @@
 
 // -----
 
-#map0 = affine_map<(d0, d1, d2) -> (32, d1 - d2)>
-#map1 = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)>
-
-
-module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
-  func @conv_padding(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} {
-    linalg.conv(%arg0, %arg1, %arg2) {dilations = [1, 1], padding = dense<[[1, 1], [0, 1]]> : tensor<2x2xi64>, strides = [1, 1]} : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
-    return
-  }
-}
-
-// CHECK-LABEL: func @conv_padding
-//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-//  CHECK-SAME:   local_size = dense<[32, 1, 1]>
-//  CHECK-SAME:   vkspv.workgroup_count_from_result_shape = 1
-//   CHECK-DAG:   %[[C2:.+]] = constant 2 : index
-//   CHECK-DAG:   %[[C3:.+]] = constant 3 : index
-//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
-//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
-//   CHECK-DAG:   %[[UB0:.+]] = dim %[[ARG0]], %[[C0]]
-//   CHECK-DAG:   %[[UB1:.+]] = dim %[[ARG0]], %[[C1]]
-//   CHECK-DAG:   %[[UB2:.+]] = dim %[[ARG0]], %[[C2]]
-//   CHECK-DAG:   %[[UB3:.+]] = dim %[[ARG0]], %[[C3]]
-//   CHECK-DAG:   %[[UB4:.+]] = dim %[[ARG1]], %[[C0]]
-//   CHECK-DAG:   %[[UB5:.+]] = dim %[[ARG2]], %[[C1]]
-//   CHECK-DAG:   %[[UB6:.+]] = dim %[[ARG2]], %[[C2]]
-//       CHECK:   %[[T7:.+]] = muli %[[UB3]], %[[UB6]]
-//       CHECK:   %[[T8:.+]] = muli %[[T7]], %[[UB5]]
-//       CHECK:   %[[UB:.+]] = muli %[[T8]], %[[UB4]]
-//   CHECK-DAG:   %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
-//   CHECK-DAG:   %[[NTHREADSX:.+]] = "gpu.block_dim"() {dimension = "x"}
-//   CHECK-DAG:   %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
-//       CHECK:   %[[T13:.+]] = muli %[[BIDX]], %[[NTHREADSX]]
-//       CHECK:   %[[PROCID:.+]] = addi %[[T13]], %[[TIDX]]
-//       CHECK:   %[[COND:.+]] = cmpi "slt", %[[PROCID]], %[[UB]]
-//       CHECK:   scf.if %[[COND]]
-//       CHECK:     %[[IV0:.+]] = divi_signed %[[PROCID]], %[[T8]]
-//       CHECK:     %[[T17:.+]] = remi_signed %[[PROCID]], %[[T8]]
-//       CHECK:     %[[IV1:.+]] = divi_signed %[[T17]], %[[T7]]
-//       CHECK:     %[[T19:.+]] = remi_signed %[[T17]], %[[T7]]
-//       CHECK:     %[[IV2:.+]] = divi_signed %[[T19]], %[[UB3]]
-//       CHECK:     %[[T21:.+]] = remi_signed %[[T19]], %[[UB3]]
-//       CHECK:     scf.for %[[IV3:.+]] = %[[C0]] to %[[UB2]] step %[[C1]]
-//       CHECK:       scf.for %[[IV4:.+]] = %[[C0]] to %[[UB0]] step %[[C1]]
-//       CHECK:         scf.for %[[IV5:.+]]= %[[C0]] to %[[UB1]] step %[[C1]]
-//   CHECK-NOT:           linalg.conv
-
-// -----
-
 #map0 = affine_map<(d0)[s0, s1] -> (s0 + 4, -d0 + s1)>
 #map1 = affine_map<(d0)[s0, s1] -> (s0 + 32, -d0 + s1)>
 #map2 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu_option.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu_option.mlir
deleted file mode 100644
index 63f8aa5..0000000
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu_option.mlir
+++ /dev/null
@@ -1,76 +0,0 @@
-// RUN: iree-opt -iree-codegen-convert-to-gpu -iree-codegen-use-legacy-conv-lowering=false -canonicalize -cse -split-input-file %s | IreeFileCheck %s
-
-#map0 = affine_map<(d0)[s0] -> (1, -d0 + s0)>
-#map1 = affine_map<(d0)[s0, s1] -> (s0 + 4, -d0 + s1)>
-#map2 = affine_map<(d0)[s0, s1] -> (s0 + 32, -d0 + s1)>
-#map3 = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>
-#map4 = affine_map<(d0)[s0] -> (4, -d0 + s0)>
-#map5 = affine_map<(d0)[s0] -> (32, -d0 + s0)>
-
-module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
-  func @conv_no_padding(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}} {
-    %c4 = constant 4 : index
-    %c32 = constant 32 : index
-    %c2 = constant 2 : index
-    %c0 = constant 0 : index
-    %c3 = constant 3 : index
-    %c1 = constant 1 : index
-    %0 = dim %arg1, %c0 : memref<?x?x?x?xf32>
-    %1 = dim %arg1, %c1 : memref<?x?x?x?xf32>
-    %2 = dim %arg1, %c2 : memref<?x?x?x?xf32>
-    %3 = dim %arg2, %c1 : memref<?x?x?x?xf32>
-    %4 = dim %arg2, %c2 : memref<?x?x?x?xf32>
-    scf.parallel (%arg3, %arg4, %arg5) = (%c0, %c0, %c0) to (%0, %3, %4) step (%c1, %c4, %c32) {
-      %5 = affine.min #map0(%arg3)[%0]
-      %6 = affine.min #map1(%arg4)[%1, %1]
-      %7 = affine.min #map2(%arg5)[%2, %2]
-      %8 = dim %arg1, %c3 : memref<?x?x?x?xf32>
-      %9 = subview %arg1[%arg3, %arg4, %arg5, 0] [%5, %6, %7, %8] [1, 1, 1, 1]  : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map3>
-      %10 = dim %arg2, %c0 : memref<?x?x?x?xf32>
-      %11 = affine.min #map0(%arg3)[%10]
-      %12 = affine.min #map4(%arg4)[%3]
-      %13 = affine.min #map5(%arg5)[%4]
-      %14 = dim %arg2, %c3 : memref<?x?x?x?xf32>
-      %15 = subview %arg2[%arg3, %arg4, %arg5, 0] [%11, %12, %13, %14] [1, 1, 1, 1]  : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map3>
-      linalg.conv(%arg0, %9, %15) {__internal_linalg_transform__ = "workgroup", dilations = [1, 1], strides = [1, 1]} : memref<?x?x?x?xf32>, memref<?x?x?x?xf32, #map3>, memref<?x?x?x?xf32, #map3>
-      scf.yield
-    }
-    return
-  }
-}
-
-// CHECK-LABEL: func @conv_no_padding
-//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-//   CHECK-DAG:   %[[C4:.+]] = constant 4 : index
-//   CHECK-DAG:   %[[C32:.+]] = constant 32 : index
-//   CHECK-DAG:   %[[C2:.+]] = constant 2 : index
-//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
-//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
-//   CHECK-DAG:   %[[UB0:.+]] = dim %[[ARG1]], %[[C0]]
-//   CHECK-DAG:   %[[UB1:.+]] = dim %[[ARG1]], %[[C1]]
-//   CHECK-DAG:   %[[UB2:.+]] = dim %[[ARG1]], %[[C2]]
-//   CHECK-DAG:   %[[UB3:.+]] = dim %[[ARG2]], %[[C1]]
-//   CHECK-DAG:   %[[UB4:.+]] = dim %[[ARG2]], %[[C2]]
-//   CHECK-DAG:   %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
-//   CHECK-DAG:   %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
-//   CHECK-DAG:   %[[BIDZ:.+]] = "gpu.block_id"() {dimension = "z"}
-//   CHECK-DAG:   %[[BOFFSETY:.+]] = muli %[[BIDY]], %[[C4]]
-//   CHECK-DAG:   %[[BOFFSETX:.+]] = muli %[[BIDX]], %[[C32]]
-//       CHECK:   %[[SV1:.+]] = subview %[[ARG1]][%[[BIDZ]], %[[BOFFSETY]], %[[BOFFSETX]], 0]
-//       CHECK:   %[[SV2:.+]] = subview %[[ARG2]][%[[BIDZ]], %[[BOFFSETY]], %[[BOFFSETX]], 0]
-//   CHECK-DAG:   %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
-//   CHECK-DAG:   %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
-//   CHECK-DAG:   %[[TIDZ:.+]] = "gpu.thread_id"() {dimension = "z"}
-//       CHECK:   %[[INBOUNDSZ:.+]] = cmpi "slt", %[[TIDZ]], %{{.+}}
-//       CHECK:   %[[INBOUNDSY:.+]] = cmpi "slt", %[[TIDY]], %{{.+}}
-//       CHECK:   %[[T35:.+]] = and %[[INBOUNDSZ]], %[[INBOUNDSY]]
-//       CHECK:   %[[INBOUNDSX:.+]] = cmpi "slt", %[[TIDX]], %{{.+}}
-//       CHECK:   %[[INBOUNDS:.+]] = and %[[T35]], %[[INBOUNDSX]]
-//       CHECK:   scf.if %[[INBOUNDS]]
-//       CHECK:     scf.for
-//       CHECK:       scf.for
-//       CHECK:         scf.for
-//       CHECK:           scf.for
-//   CHECK-NOT:             linalg.conv
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/cyclic_to_workgroup.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/cyclic_to_workgroup.mlir
deleted file mode 100644
index cac18ab..0000000
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/cyclic_to_workgroup.mlir
+++ /dev/null
@@ -1,69 +0,0 @@
-// RUN: iree-opt -iree-codegen-convert-to-gpu -canonicalize -cse -split-input-file  -iree-codegen-constrained-workgroup-count %s | IreeFileCheck %s
-
-#map0 = affine_map<(d0)[s0] -> (8, -d0 + s0)>
-#map1 = affine_map<(d0)[s0] -> (4, -d0 + s0)>
-#map2 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-
-module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
-  func @matmul(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) attributes {spv.entry_point_abi = {local_size = dense<[8, 8, 1]> : vector<3xi32>}} {
-    %c0 = constant 0 : index
-    %c1 = constant 1 : index
-    %c4 = constant 4 : index
-    %c8 = constant 8 : index
-    %0 = dim %arg0, %c0 : memref<?x?xf32>
-    %1 = dim %arg0, %c1 : memref<?x?xf32>
-    %2 = dim %arg1, %c1 : memref<?x?xf32>
-    scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%0, %2) step (%c8, %c8) {
-      scf.for %arg5 = %c0 to %1 step %c4 {
-        %3 = affine.min #map0(%arg3)[%0]
-        %4 = affine.min #map1(%arg5)[%1]
-        %5 = subview %arg0[%arg3, %arg5] [%3, %4] [1, 1]  : memref<?x?xf32> to memref<?x?xf32, #map2>
-        %6 = dim %arg1, %c0 : memref<?x?xf32>
-        %7 = affine.min #map1(%arg5)[%6]
-        %8 = affine.min #map0(%arg4)[%2]
-        %9 = subview %arg1[%arg5, %arg4] [%7, %8] [1, 1]  : memref<?x?xf32> to memref<?x?xf32, #map2>
-        %10 = dim %arg2, %c0 : memref<?x?xf32>
-        %11 = affine.min #map0(%arg3)[%10]
-        %12 = dim %arg2, %c1 : memref<?x?xf32>
-        %13 = affine.min #map0(%arg4)[%12]
-        %14 = subview %arg2[%arg3, %arg4] [%11, %13] [1, 1]  : memref<?x?xf32> to memref<?x?xf32, #map2>
-        linalg.matmul %5, %9, %14 {__internal_linalg_transform__ = "workgroup"} : (memref<?x?xf32, #map2>, memref<?x?xf32, #map2>, memref<?x?xf32, #map2>)
-      }
-      scf.yield
-    }
-    return
-  }
-}
-
-// CHECK-LABEL: func @matmul
-//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
-//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
-//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
-//   CHECK-DAG:   %[[C4:.+]] = constant 4 : index
-//   CHECK-DAG:   %[[C8:.+]] = constant 8 : index
-//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
-//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
-//   CHECK-DAG:   %[[UB0:.+]] = dim %[[ARG0]], %[[C0]]
-//   CHECK-DAG:   %[[UB1:.+]] = dim %[[ARG1]], %[[C1]]
-//   CHECK-DAG:   %[[UB2:.+]] = dim %[[ARG0]], %[[C1]]
-//   CHECK-DAG:   %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
-//   CHECK-DAG:   %[[GDIMX:.+]] = "gpu.grid_dim"() {dimension = "x"}
-//   CHECK-DAG:   %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
-//   CHECK-DAG:   %[[GDIMY:.+]] = "gpu.grid_dim"() {dimension = "y"}
-//       CHECK:   %[[BOFFSETY:.+]] = muli %[[BIDY]], %[[C8]]
-//       CHECK:   %[[BSTEPY:.+]] = muli %[[GDIMY]], %[[C8]]
-//       CHECK:   %[[BOFFSETX:.+]] = muli %[[BIDX]], %[[C8]]
-//       CHECK:   %[[BSTEPX:.+]] = muli %[[GDIMX]], %[[C8]]
-//       CHECK:   scf.for %[[BIV0:.+]] = %[[BOFFSETY]] to %[[UB0]] step %[[BSTEPY]]
-//       CHECK:     scf.for %[[BIV1:.+]] = %[[BOFFSETX]] to %[[UB1]] step %[[BSTEPX]]
-//       CHECK:       scf.for %[[BIV2:.+]] = %[[C0]] to %[[UB2]] step %[[C4]]
-//   CHECK-DAG:         %[[VIEWUB0:.+]] = affine.min #{{.*}}(%[[BIV0]])[%[[UB0]]]
-//   CHECK-DAG:         %[[VIEWUB1:.+]] = affine.min #{{.*}}(%[[BIV1]])[%[[UB1]]]
-//   CHECK-DAG:         %[[VIEWUB2:.+]] = affine.min #{{.*}}(%[[BIV2]])[%[[UB2]]]
-//   CHECK-DAG:         %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
-//   CHECK-DAG:         %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
-//       CHECK:         %[[INBOUNDY:.+]] = cmpi "slt", %[[TIDY]], %[[VIEWUB0]]
-//       CHECK:         %[[INBOUNDX:.+]] = cmpi "slt", %[[TIDX]], %[[VIEWUB1]]
-//       CHECK:         %[[COND:.+]] = and %[[INBOUNDY]], %[[INBOUNDX]]
-//       CHECK:         scf.if %[[COND]]
-//       CHECK:           scf.for %{{.*}} = %[[C0]] to %[[VIEWUB2]] step %[[C1]]
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
index 1728d35..7e7433a 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
@@ -81,7 +81,7 @@
 //       CHECK:     %[[VIEW1:.+]] = subview %[[ARG1]]
 //       CHECK:     %[[VIEW2:.+]] = subview %[[ARG2]]
 //       CHECK:     linalg.matmul
-//  CHECK-SAME:       "workgroup"
+//  CHECK-SAME:       "workgroup_numprocs_ge_numiters"
 //  CHECK-SAME:       %[[VIEW0]], %[[VIEW1]], %[[VIEW2]]
 
 // -----
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir
new file mode 100644
index 0000000..63ba65f
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir
@@ -0,0 +1,21 @@
+// RUN: iree-opt --iree-codegen-linalg-to-gpu-matmul-vectorization-pass
+// RUN: -split-input-file %s --iree-codegen-linalg-to-gpu-unroll-size=8,8,32 \
+// RUN: -iree-codegen-linalg-to-gpu-matmul-licm | IreeFileCheck %s
+
+// CHECK-LABEL: func @matmul_128x128x128
+// CHECK-SAME: (%[[ARG0:.+]]: memref<128x128xf32>, %[[ARG1:.+]]: memref<128x128xf32>, %[[ARG2:.+]]: memref<128x128xf32>)
+func @matmul_128x128x128(%arg0 : memref<128x128xf32>, %arg1: memref<128x128xf32>, %arg2: memref<128x128xf32>) {
+    linalg.matmul %arg0, %arg1, %arg2 : (memref<128x128xf32>, memref<128x128xf32>, memref<128x128xf32>)
+    return
+}
+
+// CHECK: %[[TILESIZE:.+]] = constant 32 : index
+// CHECK: %[[MATSIZE:.+]] = constant 128 : index
+// CHECK: %[[START:.+]] = constant 0 : index
+// CHECK: scf.for %[[IL:.+]] = %[[START]] to %[[MATSIZE]] step %[[TILESIZE]]
+// CHECK: scf.for %[[JL:.+]] = %[[START]] to %[[MATSIZE]] step %[[TILESIZE]]
+// CHECK: %[[SUBVVIEWC:.+]] = subview %[[ARG2]][%[[IL]], %[[JL]]] [32, 32] [1, 1] : memref<128x128xf32> to memref<32x32xf32
+// CHECK: scf.for %[[KL:.+]] = %[[START]] to %[[MATSIZE]] step %[[TILESIZE]]
+// CHECK: %[[SUBVVIEWA:.+]] = subview %[[ARG0]][%[[IL]], %[[KL]]] [32, 32] [1, 1] : memref<128x128xf32> to memref<32x32xf32
+// CHECK: %[[SUBVVIEWB:.+]] = subview %[[ARG1]][%[[KL]], %[[JL]]] [32, 32] [1, 1] : memref<128x128xf32> to memref<32x32xf32
+
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
index a24c77b..0f97d1e 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
@@ -1,47 +1,55 @@
-// RUN: iree-opt -split-input-file -iree-codegen-linalg-tile-and-fuse=use-workgroup-memory %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -iree-codegen-linalg-tile-and-fuse=use-workgroup-memory -canonicalize -cse %s | IreeFileCheck %s
 
 module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
-  func @matmul_tile() {
-    %arg0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<96x96xf32>
-    %arg1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<96x96xf32>
-    %arg2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<96x96xf32>
+  func @matmul_tile(%arg0 : memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
     linalg.matmul %arg0, %arg1, %arg2 :
-      (memref<96x96xf32>, memref<96x96xf32>, memref<96x96xf32>)
+      (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
     return
   }
+}
+// CHECK-LABEL: func @matmul_tile
+//  CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//  CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//  CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//       CHECK: scf.parallel (%{{.*}}, %{{.*}})
+//       CHECK:   scf.for %{{.*}}
+//       CHECK:     %[[ARG0SV:.+]] = subview %[[ARG0]]
+//       CHECK:     %[[ARG1SV:.+]] = subview %[[ARG1]]
+//       CHECK:     %[[ARG2SV:.+]] = subview %[[ARG2]]
+//       CHECK:     %[[ALLOC1:.+]] = alloc() : memref<8x4xf32, 3>
+//       CHECK:     %[[SUBVIEW1:.+]] = subview %[[ALLOC1]]
+//       CHECK:     %[[ALLOC2:.+]] = alloc() : memref<4x8xf32, 3>
+//       CHECK:     %[[SUBVIEW2:.+]] = subview %[[ALLOC2]]
+//       CHECK:     linalg.copy(%[[ARG0SV]], %[[SUBVIEW1]])
+//  CHECK-SAME:       "copy_to_workgroup_memory"
+//       CHECK:     linalg.copy(%[[ARG1SV]], %[[SUBVIEW2]])
+//  CHECK-SAME:       "copy_to_workgroup_memory"
+//       CHECK:     linalg.matmul
+//  CHECK-SAME:       "workgroup_memory_numprocs_ge_numiters"
+//  CHECK-SAME:       %[[SUBVIEW1]], %[[SUBVIEW2]], %[[ARG2SV]]
+//   CHECK-DAG:     dealloc %[[ALLOC1]] : memref<8x4xf32, 3>
+//   CHECK-DAG:     dealloc %[[ALLOC2]] : memref<4x8xf32, 3>
 
-  hal.interface @legacy_io attributes {push_constants = 5 : i32, sym_visibility = "private"} {
-    hal.interface.binding @arg0, set=0, binding=1, type="StorageBuffer", access="Read"
-    hal.interface.binding @arg1, set=0, binding=2, type="StorageBuffer", access="Read"
-    hal.interface.binding @ret0, set=0, binding=3, type="StorageBuffer", access="Write"
+// -----
+
+
+module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+  func @conv_no_padding_tile(%arg0: memref<3x4x3x2xf32>, %arg1: memref<?x?x?x3xf32>, %arg2: memref<?x?x?x2xf32>) {
+    linalg.conv(%arg0, %arg1, %arg2) {dilations = [1, 1], strides = [1, 1]} : memref<3x4x3x2xf32>, memref<?x?x?x3xf32>, memref<?x?x?x2xf32>
+    return
   }
 }
-//  CHECK-DAG: %[[C4:.+]] = constant 4 : index
-//  CHECK-DAG: %[[C8:.+]] = constant 8 : index
-//  CHECK-DAG: %[[C96:.+]] = constant 96 : index
-//  CHECK-DAG: %[[C0:.+]] = constant 0 : index
-//      CHECK: %[[ARG0:.+]] = iree.placeholder
-// CHECK-SAME:               binding = @legacy_io::@arg0
-//      CHECK: %[[ARG1:.+]] = iree.placeholder
-// CHECK-SAME:               binding = @legacy_io::@arg1
-//      CHECK: %[[RET0:.+]] = iree.placeholder
-// CHECK-SAME:               binding = @legacy_io::@ret0
-//      CHECK: scf.parallel (%{{.*}}, %{{.*}})
-//      CHECK:   scf.for %{{.*}} = %[[C0]] to %{{.*}} step %[[C4]]
-//      CHECK:     %[[ARG0SV:.+]] = subview %[[ARG0]]
-//      CHECK:     %[[ARG1SV:.+]] = subview %[[ARG1]]
-//      CHECK:     %[[RET0SV:.+]] = subview %[[RET0]]
-//      CHECK:     %[[ALLOC1:.+]] = alloc(%[[C8]], %[[C4]]) : memref<?x?xf32, 3>
-//      CHECK:     %[[SUBVIEW1:.+]] = subview %[[ALLOC1]]
-//      CHECK:     %[[ALLOC2:.+]] = alloc(%[[C4]], %[[C8]]) : memref<?x?xf32, 3>
-//      CHECK:     %[[SUBVIEW2:.+]] = subview %[[ALLOC2]]
-//      CHECK:     linalg.copy(%[[ARG0SV]], %[[SUBVIEW1]])
-// CHECK-SAME:       "workgroup"
-//      CHECK:     spv.ControlBarrier "Workgroup", "Workgroup", "AcquireRelease"
-//      CHECK:     linalg.copy(%[[ARG1SV]], %[[SUBVIEW2]])
-// CHECK-SAME:       "workgroup"
-//      CHECK:     spv.ControlBarrier "Workgroup", "Workgroup", "AcquireRelease"
-//      CHECK:     linalg.matmul {{.*}}"workgroup"{{.*}} %[[SUBVIEW1]], %[[SUBVIEW2]], %[[RET0SV]]
-//      CHECK:     spv.ControlBarrier "Workgroup", "Workgroup", "AcquireRelease"
-//  CHECK-DAG:     dealloc %[[ALLOC1]] : memref<?x?xf32, 3>
-//  CHECK-DAG:     dealloc %[[ALLOC2]] : memref<?x?xf32, 3>
+// CHECK-LABEL: func @conv_no_padding_tile
+//  CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<3x4x3x2xf32>
+//  CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?x?x3xf32>
+//  CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?x?x2xf32>
+//       CHECK: scf.parallel (%{{.*}}, %{{.*}}, %{{.*}})
+//       CHECK:    %[[ARG1SV:.+]] = subview %[[ARG1]]
+//       CHECK:    %[[ARG2SV:.+]] = subview %[[ARG2]]
+//       CHECK:    %[[ALLOC1:.+]] = alloc() : memref<1x7x36x3xf32, 3>
+//       CHECK:    %[[SUBVIEW1:.+]] = subview %[[ALLOC1]]
+//       CHECK:    linalg.copy(%[[ARG1SV]], %[[SUBVIEW1]])
+//  CHECK-SAME:       "copy_to_workgroup_memory"
+//       CHECK:    linalg.conv(%[[ARG0]], %[[SUBVIEW1]], %[[ARG2SV]])
+//  CHECK-SAME:       "workgroup_memory"
+//       CHECK:    dealloc %[[ALLOC1]] : memref<1x7x36x3xf32, 3>
diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h
index bc218e8..74eaf71 100644
--- a/iree/compiler/Conversion/init_conversions.h
+++ b/iree/compiler/Conversion/init_conversions.h
@@ -40,6 +40,7 @@
     createLinalgTileAndFusePass();
     createSplitDispatchFunctionPass();
     createVectorToGPUPass();
+    createMatMulTileAndVectorizeGPUPass();
     return true;
   }();
   (void)init_once;
diff --git a/iree/compiler/Dialect/HAL/IR/HALTypes.h b/iree/compiler/Dialect/HAL/IR/HALTypes.h
index 3df35e1..790e8e6 100644
--- a/iree/compiler/Dialect/HAL/IR/HALTypes.h
+++ b/iree/compiler/Dialect/HAL/IR/HALTypes.h
@@ -63,7 +63,6 @@
   static AllocatorType get(MLIRContext *context) {
     return Base::get(context, TypeKind::Allocator);
   }
-  static bool kindof(unsigned kind) { return kind == TypeKind::Allocator; }
 };
 
 class BufferType : public Type::TypeBase<BufferType, Type, TypeStorage> {
@@ -72,7 +71,6 @@
   static BufferType get(MLIRContext *context) {
     return Base::get(context, TypeKind::Buffer);
   }
-  static bool kindof(unsigned kind) { return kind == TypeKind::Buffer; }
 };
 
 class BufferViewType
@@ -82,7 +80,6 @@
   static BufferViewType get(MLIRContext *context) {
     return Base::get(context, TypeKind::BufferView);
   }
-  static bool kindof(unsigned kind) { return kind == TypeKind::BufferView; }
 };
 
 class CommandBufferType
@@ -92,7 +89,6 @@
   static CommandBufferType get(MLIRContext *context) {
     return Base::get(context, TypeKind::CommandBuffer);
   }
-  static bool kindof(unsigned kind) { return kind == TypeKind::CommandBuffer; }
 };
 
 class DescriptorSetType
@@ -102,7 +98,6 @@
   static DescriptorSetType get(MLIRContext *context) {
     return Base::get(context, TypeKind::DescriptorSet);
   }
-  static bool kindof(unsigned kind) { return kind == TypeKind::DescriptorSet; }
 };
 
 class DescriptorSetLayoutType
@@ -112,9 +107,6 @@
   static DescriptorSetLayoutType get(MLIRContext *context) {
     return Base::get(context, TypeKind::DescriptorSetLayout);
   }
-  static bool kindof(unsigned kind) {
-    return kind == TypeKind::DescriptorSetLayout;
-  }
 };
 
 class DeviceType : public Type::TypeBase<DeviceType, Type, TypeStorage> {
@@ -123,7 +115,6 @@
   static DeviceType get(MLIRContext *context) {
     return Base::get(context, TypeKind::Device);
   }
-  static bool kindof(unsigned kind) { return kind == TypeKind::Device; }
 };
 
 class EventType : public Type::TypeBase<EventType, Type, TypeStorage> {
@@ -132,7 +123,6 @@
   static EventType get(MLIRContext *context) {
     return Base::get(context, TypeKind::Event);
   }
-  static bool kindof(unsigned kind) { return kind == TypeKind::Event; }
 };
 
 class ExecutableType
@@ -142,7 +132,6 @@
   static ExecutableType get(MLIRContext *context) {
     return Base::get(context, TypeKind::Executable);
   }
-  static bool kindof(unsigned kind) { return kind == TypeKind::Executable; }
 };
 
 class ExecutableCacheType
@@ -152,9 +141,6 @@
   static ExecutableCacheType get(MLIRContext *context) {
     return Base::get(context, TypeKind::ExecutableCache);
   }
-  static bool kindof(unsigned kind) {
-    return kind == TypeKind::ExecutableCache;
-  }
 };
 
 class ExecutableLayoutType
@@ -164,9 +150,6 @@
   static ExecutableLayoutType get(MLIRContext *context) {
     return Base::get(context, TypeKind::ExecutableLayout);
   }
-  static bool kindof(unsigned kind) {
-    return kind == TypeKind::ExecutableLayout;
-  }
 };
 
 class RingBufferType
@@ -176,7 +159,6 @@
   static RingBufferType get(MLIRContext *context) {
     return Base::get(context, TypeKind::RingBuffer);
   }
-  static bool kindof(unsigned kind) { return kind == TypeKind::RingBuffer; }
 };
 
 class SemaphoreType : public Type::TypeBase<SemaphoreType, Type, TypeStorage> {
@@ -185,7 +167,6 @@
   static SemaphoreType get(MLIRContext *context) {
     return Base::get(context, TypeKind::Semaphore);
   }
-  static bool kindof(unsigned kind) { return kind == TypeKind::Semaphore; }
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
index a0cafcc..957bb53 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -63,6 +63,11 @@
           "Workgroup size to use for XLA-HLO to Linalg to SPIR-V path"),
       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
 
+  static llvm::cl::list<unsigned> clTileSizes(
+      "iree-spirv-tile-size",
+      llvm::cl::desc("Tile size to use for tiling Linalg operations"),
+      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
+
   static llvm::cl::opt<std::string> clVulkanTargetEnv(
       "iree-vulkan-target-env",
       llvm::cl::desc(
@@ -70,9 +75,10 @@
       llvm::cl::init(Vulkan::swiftShaderTargetEnvAssembly));
 
   VulkanSPIRVTargetOptions targetOptions;
-  for (unsigned dim : clWorkgroupSize) {
-    targetOptions.codegenOptions.workgroupSize.push_back(dim);
-  }
+  targetOptions.codegenOptions.workgroupSize.assign(clWorkgroupSize.begin(),
+                                                    clWorkgroupSize.end());
+  targetOptions.codegenOptions.tileSizes.assign(clTileSizes.begin(),
+                                                clTileSizes.end());
   targetOptions.codegenOptions.useWorkgroupMemory = clUseWorkgroupMemory;
   targetOptions.vulkanTargetEnv = clVulkanTargetEnv;
   return targetOptions;
diff --git a/iree/compiler/Dialect/IREE/IR/IREEDialect.cpp b/iree/compiler/Dialect/IREE/IR/IREEDialect.cpp
index cf95bff..216a258 100644
--- a/iree/compiler/Dialect/IREE/IR/IREEDialect.cpp
+++ b/iree/compiler/Dialect/IREE/IR/IREEDialect.cpp
@@ -65,21 +65,14 @@
 }
 
 void IREEDialect::printType(Type type, DialectAsmPrinter& os) const {
-  switch (type.getKind()) {
-    case IREE::TypeKind::Ptr: {
-      auto targetType = type.cast<IREE::PtrType>().getTargetType();
-      os << "ptr<" << targetType << ">";
-      break;
-    }
-    case IREE::TypeKind::ByteBuffer:
-      os << "byte_buffer";
-      break;
-    case IREE::TypeKind::MutableByteBuffer:
-      os << "mutable_byte_buffer";
-      break;
-    default:
-      llvm_unreachable("unhandled IREE type");
-  }
+  if (auto ptrType = type.dyn_cast<IREE::PtrType>())
+    os << "ptr<" << ptrType.getTargetType() << ">";
+  else if (type.isa<IREE::ByteBufferType>())
+    os << "byte_buffer";
+  else if (type.isa<IREE::MutableByteBufferType>())
+    os << "mutable_byte_buffer";
+  else
+    llvm_unreachable("unhandled IREE type");
 }
 
 }  // namespace iree_compiler
diff --git a/iree/compiler/Dialect/IREE/IR/IREETypes.h b/iree/compiler/Dialect/IREE/IR/IREETypes.h
index 7170a50..db40e1a 100644
--- a/iree/compiler/Dialect/IREE/IR/IREETypes.h
+++ b/iree/compiler/Dialect/IREE/IR/IREETypes.h
@@ -143,7 +143,6 @@
  public:
   static PtrType get(Type targetType);
   static PtrType getChecked(Type targetType, Location location);
-  static bool kindof(unsigned kind) { return kind == TypeKind::Ptr; }
 
   using Base::Base;
 
@@ -156,8 +155,6 @@
  public:
   using Base::Base;
 
-  static bool kindof(unsigned kind) { return kind == TypeKind::ByteBuffer; }
-
   static ByteBufferType get(MLIRContext *context) {
     return Base::get(context, TypeKind::ByteBuffer);
   }
@@ -169,10 +166,6 @@
  public:
   using Base::Base;
 
-  static bool kindof(unsigned kind) {
-    return kind == TypeKind::MutableByteBuffer;
-  }
-
   static MutableByteBufferType get(MLIRContext *context) {
     return Base::get(context, TypeKind::MutableByteBuffer);
   }
diff --git a/iree/compiler/Dialect/IREE/Tools/StructAttrGen.cpp b/iree/compiler/Dialect/IREE/Tools/StructAttrGen.cpp
index b56be53..1e19cc7 100644
--- a/iree/compiler/Dialect/IREE/Tools/StructAttrGen.cpp
+++ b/iree/compiler/Dialect/IREE/Tools/StructAttrGen.cpp
@@ -106,7 +106,6 @@
   using Base::Base;
 
   static StringRef getKindName() { return "{2}"; }
-  static bool kindof(unsigned kind) { return kind == AttrKind::{1}; }
 
 )",
                 structAttr.getDescription(), structAttr.getStructClassName(),
diff --git a/iree/compiler/Dialect/Modules/Strings/IR/Types.h b/iree/compiler/Dialect/Modules/Strings/IR/Types.h
index d67078e..86854a2 100644
--- a/iree/compiler/Dialect/Modules/Strings/IR/Types.h
+++ b/iree/compiler/Dialect/Modules/Strings/IR/Types.h
@@ -30,7 +30,6 @@
   static StringType get(MLIRContext *context) {
     return Base::get(context, TypeKind::String);
   }
-  static bool kindof(unsigned kind) { return kind == TypeKind::String; }
 };
 
 class StringTensorType
@@ -40,7 +39,6 @@
   static StringTensorType get(MLIRContext *context) {
     return Base::get(context, TypeKind::StringTensor);
   }
-  static bool kindof(unsigned kind) { return kind == TypeKind::StringTensor; }
 };
 
 }  // namespace Strings
diff --git a/iree/compiler/Dialect/Modules/TensorList/IR/TensorListTypes.h b/iree/compiler/Dialect/Modules/TensorList/IR/TensorListTypes.h
index bc76f9c..1a23958 100644
--- a/iree/compiler/Dialect/Modules/TensorList/IR/TensorListTypes.h
+++ b/iree/compiler/Dialect/Modules/TensorList/IR/TensorListTypes.h
@@ -37,7 +37,6 @@
   static TensorListType get(MLIRContext *context) {
     return Base::get(context, TypeKind::kTensorList);
   }
-  static bool kindof(unsigned kind) { return kind == TypeKind::kTensorList; }
 };
 
 }  // namespace TensorList
diff --git a/iree/compiler/Dialect/Sequence/IR/SequenceDialect.cpp b/iree/compiler/Dialect/Sequence/IR/SequenceDialect.cpp
index 20a640b..4e0a33d 100644
--- a/iree/compiler/Dialect/Sequence/IR/SequenceDialect.cpp
+++ b/iree/compiler/Dialect/Sequence/IR/SequenceDialect.cpp
@@ -53,17 +53,8 @@
 }
 
 void SequenceDialect::printType(Type type, DialectAsmPrinter& os) const {
-  switch (type.getKind()) {
-    case TypeKind::Sequence: {
-      auto targetType = type.cast<SequenceType>().getTargetType();
-      os << "of<";
-      os.printType(targetType);
-      os << ">";
-      break;
-    }
-    default:
-      llvm_unreachable("unhandled sequence type");
-  }
+  if (auto sequenceType = type.dyn_cast<SequenceType>())
+    os << "of<" << sequenceType.getTargetType() << ">";
 }
 
 }  // namespace Sequence
diff --git a/iree/compiler/Dialect/Sequence/IR/SequenceTypes.h b/iree/compiler/Dialect/Sequence/IR/SequenceTypes.h
index 3e98f08..21a0099 100644
--- a/iree/compiler/Dialect/Sequence/IR/SequenceTypes.h
+++ b/iree/compiler/Dialect/Sequence/IR/SequenceTypes.h
@@ -32,7 +32,6 @@
  public:
   using Base::Base;
 
-  static bool kindof(unsigned kind) { return kind == TypeKind::Sequence; }
   static SequenceType get(Type targetType);
   static SequenceType getChecked(Type targetType, Location location);
 
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeDialect.cpp b/iree/compiler/Dialect/Shape/IR/ShapeDialect.cpp
index 451bc1d..90d592a 100644
--- a/iree/compiler/Dialect/Shape/IR/ShapeDialect.cpp
+++ b/iree/compiler/Dialect/Shape/IR/ShapeDialect.cpp
@@ -164,13 +164,9 @@
 }
 
 void ShapeDialect::printType(Type type, DialectAsmPrinter& os) const {
-  switch (type.getKind()) {
-    case Shape::TypeKind::RankedShape:
-      printRankedShape(type.cast<Shape::RankedShapeType>(), os);
-      break;
-    default:
-      llvm_unreachable("unhandled Shape type");
-  }
+  if (auto rankedShapeTy = type.dyn_cast<Shape::RankedShapeType>())
+    return printRankedShape(type.cast<Shape::RankedShapeType>(), os);
+  llvm_unreachable("unhandled Shape type");
 }
 
 }  // namespace iree_compiler
diff --git a/iree/compiler/Dialect/Shape/IR/ShapeTypes.h b/iree/compiler/Dialect/Shape/IR/ShapeTypes.h
index 2946627..45ea8cf 100644
--- a/iree/compiler/Dialect/Shape/IR/ShapeTypes.h
+++ b/iree/compiler/Dialect/Shape/IR/ShapeTypes.h
@@ -39,9 +39,6 @@
  public:
   using Base::Base;
 
-  /// Support method to enable LLVM-style type casting.
-  static bool kindof(unsigned kind) { return kind == TypeKind::RankedShape; }
-
   // Gets an instance of a RankedShapeType given an array of dimensions.
   // Any dynamic dim should be -1.
   static RankedShapeType get(ArrayRef<int64_t> dims, MLIRContext *context);
diff --git a/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.cpp b/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.cpp
index 36ba5cc..ce0962d 100644
--- a/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.cpp
+++ b/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.cpp
@@ -340,19 +340,18 @@
       int indegree = 0;
       int outdegree = 0;
     };
-    SmallVector<FASNode, 8> nodeStorage;
-    llvm::SmallDenseMap<NodeID, FASNode *> nodes;
+    // This should not be modified after creation in this loop. We take pointers
+    // to its entries so do not want to invalidate them with reallocation.
+    llvm::SmallDenseMap<NodeID, FASNode> nodes;
     for (auto &edge : inputEdges) {
       NodeID sourceID = edge.first.asBaseRegister();
       NodeID sinkID = edge.second.asBaseRegister();
       assert(sourceID != sinkID && "self-cycles not supported");
       if (nodes.count(sourceID) == 0) {
-        nodeStorage.push_back({sourceID, 0, 0});
-        nodes.insert({sourceID, &nodeStorage.back()});
+        nodes.insert({sourceID, {sourceID, 0, 0}});
       }
       if (nodes.count(sinkID) == 0) {
-        nodeStorage.push_back({sinkID, 0, 0});
-        nodes.insert({sinkID, &nodeStorage.back()});
+        nodes.insert({sinkID, {sinkID, 0, 0}});
       }
     }
 
@@ -366,13 +365,13 @@
     for (auto &edge : inputEdges) {
       NodeID sourceID = edge.first.asBaseRegister();
       NodeID sinkID = edge.second.asBaseRegister();
-      auto *sourceNode = nodes[sourceID];
-      ++sourceNode->outdegree;
-      maxOutdegree = std::max(maxOutdegree, sourceNode->outdegree);
-      auto *sinkNode = nodes[sinkID];
-      ++sinkNode->indegree;
-      maxIndegree = std::max(maxIndegree, sinkNode->indegree);
-      edges.push_back({sourceNode, sinkNode});
+      auto &sourceNode = nodes[sourceID];
+      ++sourceNode.outdegree;
+      maxOutdegree = std::max(maxOutdegree, sourceNode.outdegree);
+      auto &sinkNode = nodes[sinkID];
+      ++sinkNode.indegree;
+      maxIndegree = std::max(maxIndegree, sinkNode.indegree);
+      edges.push_back({&sourceNode, &sinkNode});
     }
 
     std::vector<SmallVector<FASNode *, 2>> buckets;
@@ -392,8 +391,10 @@
         buckets[index].erase(it);
       }
     };
+    llvm::SmallPtrSet<FASNode *, 8> remainingNodes;
     for (auto &nodeEntry : nodes) {
-      assignBucket(nodeEntry.second);
+      assignBucket(&nodeEntry.getSecond());
+      remainingNodes.insert(&nodeEntry.getSecond());
     }
 
     auto removeNode = [&](FASNode *node) {
@@ -416,6 +417,7 @@
         }
         removeBucket(edge.source);
         --edge.source->outdegree;
+        assert(edge.source->outdegree >= 0 && "outdegree has become negative");
         assignBucket(edge.source);
       }
       for (auto &edge : outEdges) {
@@ -425,10 +427,11 @@
         }
         removeBucket(edge.sink);
         --edge.sink->indegree;
+        assert(edge.sink->indegree >= 0 && "indegree has become negative");
         assignBucket(edge.sink);
       }
 
-      nodes.erase(node->id);
+      remainingNodes.erase(node);
       edges.erase(std::remove_if(edges.begin(), edges.end(),
                                  [&](const FASEdge &edge) {
                                    return edge.source == node ||
@@ -438,13 +441,13 @@
       return results;
     };
     auto ends = buckets.back();
-    while (!nodes.empty()) {
+    while (!remainingNodes.empty()) {
       while (!ends.empty()) {
         auto *node = ends.front();
         ends.erase(ends.begin());
         removeNode(node);
       }
-      if (nodes.empty()) break;
+      if (remainingNodes.empty()) break;
       for (ssize_t i = buckets.size() - 1; i >= 0; --i) {
         if (buckets[i].empty()) continue;
         auto *bucket = buckets[i].front();
diff --git a/iree/compiler/Dialect/VM/IR/VMDialect.cpp b/iree/compiler/Dialect/VM/IR/VMDialect.cpp
index 8c2c1db..ee0f5d6 100644
--- a/iree/compiler/Dialect/VM/IR/VMDialect.cpp
+++ b/iree/compiler/Dialect/VM/IR/VMDialect.cpp
@@ -246,26 +246,21 @@
 }
 
 void VMDialect::printType(Type type, DialectAsmPrinter &os) const {
-  switch (type.getKind()) {
-    case IREE::VM::TypeKind::Ref: {
-      auto objectType = type.cast<IREE::VM::RefType>().getObjectType();
-      if (auto listType = objectType.dyn_cast<IREE::VM::ListType>()) {
-        printType(listType, os);
-      } else if (objectType.isa<IREE::VM::OpaqueType>()) {
-        os << "ref<?>";
-      } else {
-        os << "ref<" << objectType << ">";
-      }
-      break;
+  if (auto refType = type.dyn_cast<IREE::VM::RefType>()) {
+    auto objectType = refType.getObjectType();
+    if (auto listType = objectType.dyn_cast<IREE::VM::ListType>()) {
+      printType(listType, os);
+    } else if (objectType.isa<IREE::VM::OpaqueType>()) {
+      os << "ref<?>";
+    } else {
+      os << "ref<" << objectType << ">";
     }
-    case IREE::VM::TypeKind::Opaque:
-      os << "opaque";
-      break;
-    case IREE::VM::TypeKind::List:
-      os << "list<" << type.cast<IREE::VM::ListType>().getElementType() << ">";
-      break;
-    default:
-      llvm_unreachable("unhandled VM type");
+  } else if (type.isa<IREE::VM::OpaqueType>()) {
+    os << "opaque";
+  } else if (auto listType = type.dyn_cast<IREE::VM::ListType>()) {
+    os << "list<" << listType.getElementType() << ">";
+  } else {
+    llvm_unreachable("unhandled VM type");
   }
 }
 
diff --git a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
index 3bd554c..819373f 100644
--- a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
+++ b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
@@ -46,26 +46,14 @@
 /// Creates a constant one attribute matching the given type.
 Attribute oneOfType(Type type) {
   Builder builder(type.getContext());
-  switch (type.getKind()) {
-    case StandardTypes::BF16:
-    case StandardTypes::F16:
-    case StandardTypes::F32:
-    case StandardTypes::F64:
-      return builder.getFloatAttr(type, 1.0);
-    case StandardTypes::Integer: {
-      auto width = type.cast<IntegerType>().getWidth();
-      if (width == 1) return builder.getBoolAttr(true);
-      return builder.getIntegerAttr(type, APInt(width, 1));
-    }
-    case StandardTypes::Vector:
-    case StandardTypes::RankedTensor: {
-      auto vtType = type.cast<ShapedType>();
-      auto element = oneOfType(vtType.getElementType());
-      if (!element) return {};
-      return DenseElementsAttr::get(vtType, element);
-    }
-    default:
-      break;
+  if (type.isa<FloatType>()) return builder.getFloatAttr(type, 1.0);
+  if (auto integerTy = type.dyn_cast<IntegerType>())
+    return builder.getIntegerAttr(integerTy, APInt(integerTy.getWidth(), 1));
+  if (type.isa<RankedTensorType, VectorType>()) {
+    auto vtType = type.cast<ShapedType>();
+    auto element = oneOfType(vtType.getElementType());
+    if (!element) return {};
+    return DenseElementsAttr::get(vtType, element);
   }
   return {};
 }
diff --git a/iree/compiler/Dialect/VM/IR/VMTypes.h b/iree/compiler/Dialect/VM/IR/VMTypes.h
index 67bc785..2a85ae6 100644
--- a/iree/compiler/Dialect/VM/IR/VMTypes.h
+++ b/iree/compiler/Dialect/VM/IR/VMTypes.h
@@ -64,8 +64,6 @@
   }
 
   Type getElementType();
-
-  static bool kindof(unsigned kind) { return kind == TypeKind::List; }
 };
 
 /// An opaque ref object that comes from an external source.
@@ -73,8 +71,6 @@
  public:
   using Base::Base;
 
-  static bool kindof(unsigned kind) { return kind == TypeKind::Opaque; }
-
   static OpaqueType get(MLIRContext *context) {
     return Base::get(context, TypeKind::Opaque);
   }
@@ -106,8 +102,6 @@
   }
 
   Type getObjectType();
-
-  static bool kindof(unsigned kind) { return kind == TypeKind::Ref; }
 };
 
 }  // namespace VM
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLATypes.h b/iree/compiler/Dialect/VMLA/IR/VMLATypes.h
index 634b903..0bebaf7 100644
--- a/iree/compiler/Dialect/VMLA/IR/VMLATypes.h
+++ b/iree/compiler/Dialect/VMLA/IR/VMLATypes.h
@@ -49,7 +49,6 @@
   static BufferType get(MLIRContext *context) {
     return Base::get(context, TypeKind::Buffer);
   }
-  static bool kindof(unsigned kind) { return kind == TypeKind::Buffer; }
 };
 
 class InterfaceType : public Type::TypeBase<InterfaceType, Type, TypeStorage> {
@@ -58,7 +57,6 @@
   static InterfaceType get(MLIRContext *context) {
     return Base::get(context, TypeKind::Interface);
   }
-  static bool kindof(unsigned kind) { return kind == TypeKind::Interface; }
 };
 
 }  // namespace VMLA
diff --git a/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.h b/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.h
index 0ae3957..f9dbc38 100644
--- a/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.h
+++ b/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.h
@@ -70,8 +70,6 @@
   /// bits.
   CapabilitiesAttr getCapabilitiesAttr();
 
-  static bool kindof(unsigned kind) { return kind == AttrKind::TargetEnv; }
-
   static LogicalResult verifyConstructionInvariants(
       Location loc, IntegerAttr version, IntegerAttr revision,
       ArrayAttr extensions, DictionaryAttr capabilities);
diff --git a/iree/samples/custom_modules/dialect/custom_dialect.h b/iree/samples/custom_modules/dialect/custom_dialect.h
index 7a5c9b1..fe35ed4 100644
--- a/iree/samples/custom_modules/dialect/custom_dialect.h
+++ b/iree/samples/custom_modules/dialect/custom_dialect.h
@@ -47,7 +47,6 @@
   static MessageType get(MLIRContext *context) {
     return Base::get(context, TypeKind::Message);
   }
-  static bool kindof(unsigned kind) { return kind == TypeKind::Message; }
 };
 
 #define GET_OP_CLASSES
diff --git a/iree/test/e2e/vulkan_specific/BUILD b/iree/test/e2e/vulkan_specific/BUILD
index 481366f..9686098 100644
--- a/iree/test/e2e/vulkan_specific/BUILD
+++ b/iree/test/e2e/vulkan_specific/BUILD
@@ -27,7 +27,9 @@
     name = "check_vulkan-spirv_vulkan",
     srcs = glob(
         ["*.mlir"],
-        exclude = ["gemm.mlir"],
+        exclude = [
+            "gemm.mlir",
+        ],
     ),
     driver = "vulkan",
     target_backend = "vulkan-spirv",
@@ -35,7 +37,10 @@
 
 iree_check_single_backend_test_suite(
     name = "check_vulkan-spirv_vulkan_wgmem",
-    srcs = ["gemm.mlir"],
+    srcs = [
+        "conv.mlir",
+        "gemm.mlir",
+    ],
     compiler_flags = ["-iree-spirv-use-workgroup-memory"],
     driver = "vulkan",
     target_backend = "vulkan-spirv",
diff --git a/iree/test/e2e/vulkan_specific/CMakeLists.txt b/iree/test/e2e/vulkan_specific/CMakeLists.txt
index edf9415..4260b07 100644
--- a/iree/test/e2e/vulkan_specific/CMakeLists.txt
+++ b/iree/test/e2e/vulkan_specific/CMakeLists.txt
@@ -32,6 +32,7 @@
   NAME
     check_vulkan-spirv_vulkan_wgmem
   SRCS
+    "conv.mlir"
     "gemm.mlir"
   TARGET_BACKEND
     vulkan-spirv
diff --git a/iree/test/e2e/vulkan_specific/conv.mlir b/iree/test/e2e/vulkan_specific/conv.mlir
new file mode 100644
index 0000000..e2dc59d
--- /dev/null
+++ b/iree/test/e2e/vulkan_specific/conv.mlir
@@ -0,0 +1,83 @@
+func @conv() attributes { iree.module.export } {
+  %0 = iree.unfoldable_constant dense<
+      [[[[0.5       , 0.5212766 ],
+         [0.54255319, 0.56382979],
+         [0.58510638, 0.60638298],
+         [0.62765957, 0.64893617],
+         [0.67021277, 0.69148936],
+         [0.71276596, 0.73404255]],
+
+        [[0.75531915, 0.77659574],
+         [0.79787234, 0.81914894],
+         [0.84042553, 0.86170213],
+         [0.88297872, 0.90425532],
+         [0.92553191, 0.94680851],
+         [0.96808511, 0.9893617 ]],
+
+        [[1.0106383 , 1.03191489],
+         [1.05319149, 1.07446809],
+         [1.09574468, 1.11702128],
+         [1.13829787, 1.15957447],
+         [1.18085106, 1.20212766],
+         [1.22340426, 1.24468085]],
+
+        [[1.26595745, 1.28723404],
+         [1.30851064, 1.32978723],
+         [1.35106383, 1.37234043],
+         [1.39361702, 1.41489362],
+         [1.43617021, 1.45744681],
+         [1.4787234 , 1.5       ]]]]> : tensor<1x4x6x2xf32>
+  %1 = iree.unfoldable_constant dense<
+      [[[[0.5       , 0.52857143, 0.55714286],
+         [0.58571429, 0.61428571, 0.64285714]],
+
+        [[0.67142857, 0.7       , 0.72857143],
+         [0.75714286, 0.78571429, 0.81428571]],
+
+        [[0.84285714, 0.87142857, 0.9       ],
+         [0.92857143, 0.95714286, 0.98571429]]],
+
+
+       [[[1.01428571, 1.04285714, 1.07142857],
+         [1.1       , 1.12857143, 1.15714286]],
+
+        [[1.18571429, 1.21428571, 1.24285714],
+         [1.27142857, 1.3       , 1.32857143]],
+
+        [[1.35714286, 1.38571429, 1.41428571],
+         [1.44285714, 1.47142857, 1.5       ]]]]>
+   : tensor<2x3x2x3xf32>
+  %2 = "mhlo.convolution"(%0, %1) {
+       batch_group_count = 1 : i64,
+       dimension_numbers = {
+         input_batch_dimension = 0 : i64,
+	 input_feature_dimension = 3 : i64,
+	 input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+	 kernel_input_feature_dimension = 2 : i64,
+	 kernel_output_feature_dimension = 3 : i64,
+	 kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+	 output_batch_dimension = 0 : i64,
+	 output_feature_dimension = 3 : i64,
+	 output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
+       feature_group_count = 1 : i64,
+       rhs_dilation = dense<1> : tensor<2xi64>,
+       window_strides = dense<1> : tensor<2xi64>}
+       : (tensor<1x4x6x2xf32>, tensor<2x3x2x3xf32>) -> (tensor<1x3x4x3xf32>)
+   check.expect_almost_eq_const(%2, dense<
+         [[[[ 8.39452888,  8.62796353,  8.86139818],
+         [ 8.89057751,  9.13860182,  9.38662614],
+         [ 9.38662614,  9.64924012,  9.9118541 ],
+         [ 9.88267477, 10.15987842, 10.43708207]],
+
+        [[11.37082067, 11.69179331, 12.01276596],
+         [11.8668693 , 12.20243161, 12.53799392],
+         [12.36291793, 12.71306991, 13.06322188],
+         [12.85896657, 13.22370821, 13.58844985]],
+
+        [[14.34711246, 14.7556231 , 15.16413374],
+         [14.84316109, 15.2662614 , 15.6893617 ],
+         [15.33920973, 15.7768997 , 16.21458967],
+         [15.83525836, 16.28753799, 16.73981763]]]]>
+        : tensor<1x3x4x3xf32>) : tensor<1x3x4x3xf32>
+   return
+}
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 950f1bf..30c1633 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 950f1bf976b332eca60267b25bf759e2ad564e0c
+Subproject commit 30c1633386e7cfb01c0a54b31ccf4c3a3873e71b
diff --git a/third_party/mlir-emitc b/third_party/mlir-emitc
index 80885f8..a3479bb 160000
--- a/third_party/mlir-emitc
+++ b/third_party/mlir-emitc
@@ -1 +1 @@
-Subproject commit 80885f899e12d55a45561ef758eea47bb340dbf1
+Subproject commit a3479bbf9161df8c8cac55a08205864e6f371491
diff --git a/third_party/tensorflow b/third_party/tensorflow
index 8a4ffe2..86efb18 160000
--- a/third_party/tensorflow
+++ b/third_party/tensorflow
@@ -1 +1 @@
-Subproject commit 8a4ffe2e1ae722cff5306778df0cfca8b7f503fe
+Subproject commit 86efb18ca5812c76dd52c8536f336e6962b7f8ca