Merge pull request #6830 from hanhanW:main-to-google

PiperOrigin-RevId: 392697643
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
index 065e603..f877e5e 100644
--- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
+++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
@@ -12,6 +12,11 @@
     "//build_tools:default_linkopts": [],
     "//build_tools:dl": ["${CMAKE_DL_LIBS}"],
 
+    # IREE llvm-external-projects
+    "//llvm-external-projects/iree-dialects:IREEDialect": [
+        "IREEDialectsIREEDialect"
+    ],
+
     # LLVM
     "@llvm-project//llvm:IPO": ["LLVMipo"],
     # MLIR
diff --git a/integrations/tensorflow/WORKSPACE b/integrations/tensorflow/WORKSPACE
index b7b0b6c..b7e8507 100644
--- a/integrations/tensorflow/WORKSPACE
+++ b/integrations/tensorflow/WORKSPACE
@@ -27,35 +27,6 @@
 load("@bazel_skylib//lib:paths.bzl", "paths")
 ################################################################################
 
-################################## TensorFlow ##################################
-maybe(
-    local_repository,
-    name = "org_tensorflow",
-    path = paths.join(IREE_PATH, "third_party/tensorflow"),
-)
-
-# Import all of the tensorflow dependencies. Note that we are deliberately
-# letting TensorFlow take control of all the dependencies it sets up, whereas
-# ours are initialized with `maybe`. Actually tracking this with Bazel is PITA
-# and for now this gets TF stuff building. This includes, for instance,
-# @llvm-project and @com_google_absl.
-load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3")
-
-tf_workspace3()
-
-load("@org_tensorflow//tensorflow:workspace2.bzl", "tf_workspace2")
-
-tf_workspace2()
-
-load("@org_tensorflow//tensorflow:workspace1.bzl", "tf_workspace1")
-
-tf_workspace1()
-
-load("@org_tensorflow//tensorflow:workspace0.bzl", "tf_workspace0")
-
-tf_workspace0()
-################################################################################
-
 ##################################### IREE #####################################
 # We need a shim here to make this use the version of mlir-hlo present in
 # TensorFlow. This shim is just alias rules that forward to the TF rule. Note
@@ -79,4 +50,43 @@
     iree_path = IREE_PATH,
     iree_repo_alias = "@iree",
 )
+
+new_local_repository(
+    name = "llvm-raw",
+    build_file_content = "# empty",
+    path = paths.join(IREE_PATH, "third_party/llvm-project"),
+)
+
+load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure", "llvm_disable_optional_support_deps")
+
+llvm_configure(name = "llvm-project")
+
+llvm_disable_optional_support_deps()
+
+################################################################################
+
+################################## TensorFlow ##################################
+maybe(
+    local_repository,
+    name = "org_tensorflow",
+    path = paths.join(IREE_PATH, "third_party/tensorflow"),
+)
+
+# Import all of the tensorflow dependencies. Note that any repository previously
+# defined (e.g. from IREE submodules) will be skipped.
+load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3")
+
+tf_workspace3()
+
+load("@org_tensorflow//tensorflow:workspace2.bzl", "tf_workspace2")
+
+tf_workspace2()
+
+load("@org_tensorflow//tensorflow:workspace1.bzl", "tf_workspace1")
+
+tf_workspace1()
+
+load("@org_tensorflow//tensorflow:workspace0.bzl", "tf_workspace0")
+
+tf_workspace0()
 ################################################################################
diff --git a/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp b/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
index f00eca4..586b289 100644
--- a/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
+++ b/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
@@ -41,7 +41,7 @@
   }
 
   StringRef getDescription() const override {
-    return " Wraps all entry points in a function that is compatible with the "
+    return "Wraps all entry points in a function that is compatible with the "
            "expected invocation semantics of bindings following the native "
            "IREE ABI.";
   }
@@ -51,7 +51,9 @@
 
     SmallVector<FuncOp, 4> entryFuncOps;
     for (auto funcOp : moduleOp.getOps<FuncOp>()) {
-      if (funcOp.isPublic()) entryFuncOps.push_back(funcOp);
+      if (funcOp.isPublic() && !funcOp->hasAttr("iree.abi.stub")) {
+        entryFuncOps.push_back(funcOp);
+      }
     }
 
     // Create a wrapper function for each entry point.
diff --git a/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir b/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir
index 9ab791a..78e47bb 100644
--- a/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir
+++ b/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir
@@ -26,3 +26,13 @@
   %1 = "mhlo.add"(%0, %arg0) : (tensor<?x8x8x3xf32>, tensor<?x8x8x3xf32>) -> tensor<?x8x8x3xf32>
   return %0, %1 : tensor<?x8x8x3xf32>, tensor<?x8x8x3xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @wrappedAlready
+//  CHECK-SAME: (%arg0: !hal.buffer_view) -> !hal.buffer_view
+//  CHECK-SAME: attributes {iree.abi.stub}
+func @wrappedAlready(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
+  return %arg0 : !hal.buffer_view
+}
+// CHECK-NOT: func @_wrappedAlready
diff --git a/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp b/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp
index 1c7cdbb..981cf82 100644
--- a/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp
+++ b/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp
@@ -44,7 +44,9 @@
 
     SmallVector<FuncOp, 4> entryFuncOps;
     for (auto funcOp : moduleOp.getOps<FuncOp>()) {
-      if (funcOp.isPublic()) entryFuncOps.push_back(funcOp);
+      if (funcOp.isPublic() && !funcOp->hasAttr("iree.abi.stub")) {
+        entryFuncOps.push_back(funcOp);
+      }
     }
     if (entryFuncOps.size() == 0) {
       moduleOp.emitError()
diff --git a/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 8e9d20c..a069bed 100644
--- a/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -103,7 +103,7 @@
       {tileX / workgroupSize[1], tileY / workgroupSize[0]});
   tileSizes.push_back(invocationLevelTs);  // Thread level.
   return setOpConfigAndEntryPointFnTranslation(
-      entryPoint, op, tileSizes, /*nativeVectorSize=*/ArrayRef<int64_t>{},
+      entryPoint, op, tileSizes, /*nativeVectorSizes=*/ArrayRef<int64_t>{},
       IREE::HAL::DispatchLoweringPassPipeline::LLVMGPUMatmulSimt,
       workgroupSize);
 }
@@ -184,7 +184,7 @@
   tileSizes.push_back({});                                // Subgroup level.
   tileSizes.emplace_back(std::move(threadTileSizes));     // Thread level
   return setOpConfigAndEntryPointFnTranslation(
-      entryPoint, op, tileSizes, /*nativeVectorSize=*/ArrayRef<int64_t>{},
+      entryPoint, op, tileSizes, /*nativeVectorSizes=*/ArrayRef<int64_t>{},
       IREE::HAL::DispatchLoweringPassPipeline::LLVMGPUVectorize, workgroupSize);
 }
 
diff --git a/iree/compiler/Codegen/Passes.h b/iree/compiler/Codegen/Passes.h
index 84cfd52..01fd686 100644
--- a/iree/compiler/Codegen/Passes.h
+++ b/iree/compiler/Codegen/Passes.h
@@ -254,53 +254,51 @@
 std::unique_ptr<OperationPass<FuncOp>> createLLVMGPUPipeliningPass();
 
 //------------------------------------------------------------------------------
-// SPIRV Passes
+// SPIR-V Passes
 //------------------------------------------------------------------------------
 
-/// Pass pipeline to lower executable obtained from Linalg tile + distribute to
-/// scalar + vector code. Does distribution to threads (no vectorization).
-void addSPIRVDistributePassPipeline(OpPassManager &pm);
+/// Pass pipeline to lower IREE HAL executables with workgroup tiled and
+/// distributed Linalg ops to SPIR-V scalar code. Additionally performs
+/// distribution to threads without vectorization.
+void addSPIRVTileAndDistributePassPipeline(OpPassManager &pm);
 
-/// Pass pipeline to lower executables that contain operations that are not
-/// tiled + distributed.
-void addSPIRVDistributeToGlobalIDPipeline(OpPassManager &pm);
+/// Pass pipeline to lower IREE HAL executables that contain Linalg ops that are
+/// not tiled/distributed. Performs distribution to global invocations.
+void addSPIRVDistributeToGlobalIDPassPipeline(OpPassManager &pm);
 
-/// pipeline to lower executable obtained from Linalg tile + distribute to
-/// scalar + vector code. Does distribution to threads and vectorization.
-void addSPIRVVectorizationPassPipeline(OpPassManager &pm);
+/// Pass pipeline to lower IREE HAL executables with workgroup tiled and
+/// distributed Linalg ops to SPIR-V scalar and vector code. Additionally
+/// performs distribution to threads with vectorization.
+void addSPIRVTileAndVectorizePassPipeline(OpPassManager &pm);
 
 /// Pass to perform the final conversion to SPIR-V dialect.
+///
 /// This pass converts remaining interface ops into SPIR-V global variables,
 /// GPU processor ID ops into SPIR-V global variables, loop/standard ops into
 /// corresponding SPIR-V ops.
 std::unique_ptr<OperationPass<ModuleOp>> createConvertToSPIRVPass();
 
-/// Pass to add the synchronizations and attributes needed to lower from PLoops
-/// to GPU dialect.
-std::unique_ptr<OperationPass<FuncOp>> createSPIRVConvertToGPUPass();
+/// Pass to distribute Linalg ops with buffer semantics to global invocations.
+std::unique_ptr<OperationPass<FuncOp>> createSPIRVDistributeToGlobalIDPass();
 
 /// Creates a pass to fold processor ID uses where possible.
 std::unique_ptr<OperationPass<FuncOp>> createSPIRVFoldProcessorIDUsesPass();
 
-/// Main pass to lower executables to scalar + vector code on SPIR-V
-/// path. Invokes one of the pass pipelines that translate the executable to
+/// Main pass to lower executables to scalar + vector code on SPIR-V path.
+/// Invokes one of the pass pipelines that translate the executable to
 /// scalar + vector code.
 std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
 createSPIRVLowerExecutableTargetPass();
 
-/// Pass to remove loop generated at Flow for tile + distribute when the loop is
-/// known to have a single trip count. NOTE: DO NOT USE. This is a legacy pass
-/// that is to be deprecated.
+/// Pass to remove loop generated at flow for tiled and distributed Linalg ops
+/// when the loop is known to have a single trip count.
+/// WARNING: DO NOT USE. This is a legacy pass that is to be deprecated.
 std::unique_ptr<OperationPass<FuncOp>> createSPIRVRemoveOneTripTiledLoopPass();
 
-/// Pass to tile and distribute Linalg operations on buffers in a single
-/// workgroup.
+/// Pass to tile and distribute Linalg ops with buffer semantics to subgroups
+/// and invocations.
 std::unique_ptr<OperationPass<FuncOp>> createSPIRVTileAndDistributePass();
 
-/// Pass to tile and vectorize Linalg operations on buffers in a single
-/// workgroup.
-std::unique_ptr<OperationPass<FuncOp>> createSPIRVTileAndVectorizePass();
-
 /// Pass to convert vector read/write/arithmetic operations to the corresponding
 /// cooperative matrix ops when possible.
 std::unique_ptr<OperationPass<FuncOp>>
@@ -309,6 +307,9 @@
 /// Pass to lower linalg.copy for copying data to workgroup memory.
 std::unique_ptr<OperationPass<FuncOp>> createSPIRVCopyToWorkgroupMemoryPass();
 
+/// Pass to vectorize Linalg ops with buffer semantics.
+std::unique_ptr<OperationPass<FuncOp>> createSPIRVVectorizePass();
+
 /// Converts memref of scalar to memref of vector of efficent size. This will
 /// allow to convert memory accesses to vector load/store in SPIR-V without
 /// having pointer bitcast.
diff --git a/iree/compiler/Codegen/Passes.td b/iree/compiler/Codegen/Passes.td
index c322788..ce1a556 100644
--- a/iree/compiler/Codegen/Passes.td
+++ b/iree/compiler/Codegen/Passes.td
@@ -210,56 +210,60 @@
   let summary = "Pass to do software pipelining.";
   let constructor = "mlir::iree_compiler::createLLVMGPUPipeliningPass()";
 }
+
 //------------------------------------------------------------------------------
-// SPIRV
+// SPIR-V
 //------------------------------------------------------------------------------
 
 // TODO: Rename argument to be fully qualified.
-def ConvertToSPIRV :
-    Pass<"iree-convert-to-spirv", "ModuleOp"> {
-  let summary = "Perform final conversion to SPIR-V dialect";
+def ConvertToSPIRV : Pass<"iree-convert-to-spirv", "ModuleOp"> {
+  let summary = "Perform the final conversion to SPIR-V dialect";
   let constructor = "mlir::iree_compiler::createConvertToSPIRVPass()";
 }
 
 // TODO: Rename argument to be fully qualified.
-def SPIRVConvertToGPU : Pass<"iree-spirv-convert-to-gpu", "FuncOp"> {
-  let summary = "Map tiled linalg and loop ops to GPU";
-  let constructor = "mlir::iree_compiler::createSPIRVConvertToGPUPass()";
+def SPIRVDistributeToGlobalID :
+    Pass<"iree-spirv-distribute-to-global-id", "FuncOp"> {
+  let summary = "Distribute Linalg ops with buffer semantics to global "
+                "invocations";
+  let constructor =
+      "mlir::iree_compiler::createSPIRVDistributeToGlobalIDPass()";
 }
 
 // TODO: Rename argument to be fully qualified.
-// TODO: Does not appear used?
-def SPIRVFoldProcessorIDUses : Pass<"iree-spirv-fold-gpu-procid-uses", "FuncOp"> {
+def SPIRVFoldProcessorIDUses :
+    Pass<"iree-spirv-fold-gpu-procid-uses", "FuncOp"> {
   let summary = "Fold GPU processor ID uses where possible";
   let constructor = "mlir::iree_compiler::createSPIRVFoldProcessorIDUsesPass()";
 }
 
 def SPIRVLowerExecutableTarget :
-    Pass<"iree-spirv-lower-executable-target-pass", "mlir::iree_compiler::IREE::HAL::ExecutableVariantOp"> {
-  let summary = "Perform lowering of executable target using one of the IREE::HAL::DispatchLoweringPassPipeline";
-  let constructor = "mlir::iree_compiler::createSPIRVLowerExecutableTargetPass()";
+    Pass<"iree-spirv-lower-executable-target-pass",
+         "mlir::iree_compiler::IREE::HAL::ExecutableVariantOp"> {
+  let summary = "Lower the executable target to SPIR-V using one of the "
+                "IREE::HAL::DispatchLoweringPassPipeline";
+  let constructor =
+      "mlir::iree_compiler::createSPIRVLowerExecutableTargetPass()";
 }
 
 def SPIRVRemoveOneTripTiledLoop :
     Pass<"iree-spirv-remove-one-trip-tiled-loop", "FuncOp"> {
-  let summary = "Remove one trip tiled loop. ---- Legacy Pass! Do not use ---";
-  let constructor = "mlir::iree_compiler::createSPIRVRemoveOneTripTiledLoopPass()";
-}
-
-// TODO: Rename argument to be fully qualified.
-def SPIRVTileAndVectorize : Pass<"iree-spirv-tile-and-vectorize", "FuncOp"> {
-  let summary =
-      "Tile and vectorize Linalg operations on buffers in one workgroup";
+  let summary = "Remove one-trip tiled loop for convolution "
+                "(LEGACY PASS; DO NOT USE!)";
   let constructor =
-      "mlir::iree_compiler::createSPIRVTileAndVectorizePass()";
+      "mlir::iree_compiler::createSPIRVRemoveOneTripTiledLoopPass()";
 }
 
 // TODO: Rename argument to be fully qualified.
 def SPIRVTileAndDistribute : Pass<"iree-spirv-tile-and-distribute", "FuncOp"> {
-  let summary =
-      "Tile and distribute Linalg operations on buffers in one workgroup";
-  let constructor =
-      "mlir::iree_compiler::createSPIRVTileAndDistributePass()";
+  let summary = "Tile and distribute Linalg ops with buffer semantics to "
+                "subgroups and invocations";
+  let constructor = "mlir::iree_compiler::createSPIRVTileAndDistributePass()";
+}
+
+def SPIRVVectorize : Pass<"iree-spirv-vectorize", "FuncOp"> {
+  let summary = "Vectorize Linalg ops with buffer semantics";
+  let constructor = "mlir::iree_compiler::createSPIRVVectorizePass()";
 }
 
 // TODO: Rename argument to be fully qualified.
@@ -272,7 +276,8 @@
 // TODO: Rename argument to be fully qualified.
 def SPIRVVectorToCooperativeMatrix :
     Pass<"iree-spirv-vector-to-cooperative-matrix", "FuncOp"> {
-  let summary = "Generate cooperative matrix ops when possible";
+  let summary = "Convert vector ops to SPIR-V cooperative matrix ops "
+                "when possible";
   let constructor =
       "mlir::iree_compiler::createSPIRVVectorToCooperativeMatrixPass()";
 }
@@ -280,8 +285,9 @@
 // TODO: Rename argument to be fully qualified.
 def SPIRVCopyToWorkgroupMemory :
     Pass<"iree-spirv-copy-to-workgroup-memory", "FuncOp"> {
-  let summary = "Convert vector dialect to gpu subgroup level GPU instructions";
-  let constructor = "mlir::iree_compiler::createSPIRVCopyToWorkgroupMemoryPass()";
+  let summary = "Lower linalg.copy for copying data to workgroup memory";
+  let constructor =
+      "mlir::iree_compiler::createSPIRVCopyToWorkgroupMemoryPass()";
 }
 
 //------------------------------------------------------------------------------
diff --git a/iree/compiler/Codegen/SPIRV/BUILD b/iree/compiler/Codegen/SPIRV/BUILD
index 63662d0..af8ec54 100644
--- a/iree/compiler/Codegen/SPIRV/BUILD
+++ b/iree/compiler/Codegen/SPIRV/BUILD
@@ -14,21 +14,21 @@
     name = "SPIRV",
     srcs = [
         "ConvertToSPIRVPass.cpp",
-        "KernelDispatchUtils.cpp",
+        "KernelConfig.cpp",
         "Passes.cpp",
-        "SPIRVConvertToGPU.cpp",
         "SPIRVCopyToWorkgroupMemory.cpp",
+        "SPIRVDistributeToGlobalID.cpp",
         "SPIRVFoldGPUProcessorIDUses.cpp",
         "SPIRVLowerExecutableTargetPass.cpp",
         "SPIRVRemoveOneTripTiledLoops.cpp",
         "SPIRVTileAndDistribute.cpp",
-        "SPIRVTileAndVectorize.cpp",
         "SPIRVVectorToCooperativeMatrix.cpp",
+        "SPIRVVectorize.cpp",
         "SPIRVVectorizeLoadStore.cpp",
         "Utils.cpp",
     ],
     hdrs = [
-        "KernelDispatchUtils.h",
+        "KernelConfig.h",
         "MemorySpace.h",
         "Utils.h",
     ],
@@ -74,6 +74,7 @@
         "@llvm-project//mlir:TosaDialect",
         "@llvm-project//mlir:TosaToStandard",
         "@llvm-project//mlir:Transforms",
+        "@llvm-project//mlir:VectorInterfaces",
         "@llvm-project//mlir:VectorOps",
         "@llvm-project//mlir:VectorToSPIRV",
         "@mlir-hlo//:hlo",
diff --git a/iree/compiler/Codegen/SPIRV/CMakeLists.txt b/iree/compiler/Codegen/SPIRV/CMakeLists.txt
index f55dd2a..65f615d 100644
--- a/iree/compiler/Codegen/SPIRV/CMakeLists.txt
+++ b/iree/compiler/Codegen/SPIRV/CMakeLists.txt
@@ -14,21 +14,21 @@
   NAME
     SPIRV
   HDRS
-    "KernelDispatchUtils.h"
+    "KernelConfig.h"
     "MemorySpace.h"
     "Utils.h"
   SRCS
     "ConvertToSPIRVPass.cpp"
-    "KernelDispatchUtils.cpp"
+    "KernelConfig.cpp"
     "Passes.cpp"
-    "SPIRVConvertToGPU.cpp"
     "SPIRVCopyToWorkgroupMemory.cpp"
+    "SPIRVDistributeToGlobalID.cpp"
     "SPIRVFoldGPUProcessorIDUses.cpp"
     "SPIRVLowerExecutableTargetPass.cpp"
     "SPIRVRemoveOneTripTiledLoops.cpp"
     "SPIRVTileAndDistribute.cpp"
-    "SPIRVTileAndVectorize.cpp"
     "SPIRVVectorToCooperativeMatrix.cpp"
+    "SPIRVVectorize.cpp"
     "SPIRVVectorizeLoadStore.cpp"
     "Utils.cpp"
   DEPS
@@ -61,6 +61,7 @@
     MLIRTosaToStandard
     MLIRTransforms
     MLIRVector
+    MLIRVectorInterfaces
     MLIRVectorToSPIRV
     iree::compiler::Codegen::Common
     iree::compiler::Codegen::PassHeaders
diff --git a/iree/compiler/Codegen/SPIRV/KernelDispatchUtils.cpp b/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
similarity index 85%
rename from iree/compiler/Codegen/SPIRV/KernelDispatchUtils.cpp
rename to iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index 7963338..888993a 100644
--- a/iree/compiler/Codegen/SPIRV/KernelDispatchUtils.cpp
+++ b/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -13,7 +13,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "iree/compiler/Codegen/SPIRV/KernelDispatchUtils.h"
+#include "iree/compiler/Codegen/SPIRV/KernelConfig.h"
 
 #include "iree/compiler/Codegen/Passes.h"
 #include "iree/compiler/Codegen/SPIRV/Utils.h"
@@ -148,7 +148,7 @@
         batchTs[2] / pair.workgroupSize[0], batchTs[3]};
     tileSizes.emplace_back(invocationLevelTs);
     return setOpConfigAndEntryPointFnTranslation(
-        entryPoint, op, tileSizes, /*nativeVectorSize=*/ArrayRef<int64_t>{},
+        entryPoint, op, tileSizes, /*nativeVectorSizes=*/ArrayRef<int64_t>{},
         IREE::HAL::DispatchLoweringPassPipeline::SPIRVVectorize,
         pair.workgroupSize);
   }
@@ -185,7 +185,7 @@
   tileSizes.emplace_back();  // subgroup level
   tileSizes.emplace_back(std::move(invocationLevel));
   return setOpConfigAndEntryPointFnTranslation(
-      entryPoint, op, tileSizes, /*nativeVectorSize=*/ArrayRef<int64_t>{},
+      entryPoint, op, tileSizes, /*nativeVectorSizes=*/ArrayRef<int64_t>{},
       IREE::HAL::DispatchLoweringPassPipeline::SPIRVDistribute, workgroupSize);
 }
 
@@ -269,7 +269,7 @@
       numVecMatmulPerSubgroupX * (*coopMatmulSize)[1]};
   tileSizes.emplace_back(std::move(subgroupTs));
   return setOpConfigAndEntryPointFnTranslation(
-      entryPoint, op, tileSizes, /*nativeVectorSize=*/ArrayRef<int64_t>{},
+      entryPoint, op, tileSizes, /*nativeVectorSizes=*/ArrayRef<int64_t>{},
       IREE::HAL::DispatchLoweringPassPipeline::SPIRVVectorize, workgroupSize);
 }
 
@@ -282,7 +282,7 @@
     // Serialized computation.
     return setOpConfigAndEntryPointFnTranslation(
         entryPoint, op, /*tileSizes =*/TileSizesListType{{}},
-        /*nativeVectorSize=*/ArrayRef<int64_t>{},
+        /*nativeVectorSizes=*/ArrayRef<int64_t>{},
         IREE::HAL::DispatchLoweringPassPipeline::SPIRVVectorize, {1, 1, 1});
   }
 
@@ -357,7 +357,7 @@
 
   return setOpConfigAndEntryPointFnTranslation(
       entryPoint, op, tileSizes,
-      /*nativeVectorSize =*/ArrayRef<int64_t>{}, pipeline, workgroupSize);
+      /*nativeVectorSizes =*/ArrayRef<int64_t>{}, pipeline, workgroupSize);
 }
 
 /// Launch configuration for different known GPU configuration.
@@ -400,7 +400,7 @@
     tileSizes.emplace_back(invocationLevelTs);
     return setOpConfigAndEntryPointFnTranslation(
         entryPoint, op, tileSizes,
-        /*nativeVectorSize =*/ArrayRef<int64_t>{},
+        /*nativeVectorSizes =*/ArrayRef<int64_t>{},
         IREE::HAL::DispatchLoweringPassPipeline::SPIRVVectorize,
         pair.workgroupSize);
   }
@@ -445,7 +445,7 @@
   tileSizes.emplace_back();  // subgroup level
   tileSizes.emplace_back(std::move(invocationLevel));
   return setOpConfigAndEntryPointFnTranslation(
-      entryPoint, op, tileSizes, /*nativeVectorSize =*/ArrayRef<int64_t>{},
+      entryPoint, op, tileSizes, /*nativeVectorSizes =*/ArrayRef<int64_t>{},
       IREE::HAL::DispatchLoweringPassPipeline::SPIRVDistribute, workgroupSize);
 }
 
@@ -505,7 +505,7 @@
     tileSizes.emplace_back(fourthLevel);
 
     if (failed(setOpConfigAndEntryPointFnTranslation(
-            entryFn, op, tileSizes, /*nativeVectorSize=*/ArrayRef<int64_t>{},
+            entryFn, op, tileSizes, /*nativeVectorSizes=*/ArrayRef<int64_t>{},
             IREE::HAL::DispatchLoweringPassPipeline::SPIRVVectorize,
             workgroupSize)))
       return failure();
@@ -593,7 +593,7 @@
     tileSizes.emplace_back(fourthLevel);
 
     if (failed(setOpConfigAndEntryPointFnTranslation(
-            entryFn, op, tileSizes, /*nativeVectorSize=*/ArrayRef<int64_t>{},
+            entryFn, op, tileSizes, /*nativeVectorSizes=*/ArrayRef<int64_t>{},
             IREE::HAL::DispatchLoweringPassPipeline::SPIRVVectorize,
             workgroupSize)))
       return failure();
@@ -753,124 +753,5 @@
   return success();
 }
 
-template <typename OpTy>
-static Optional<SmallVector<int64_t, 4>> getOpNativeVectorSize(OpTy op) {
-  return llvm::None;
-}
-
-template <>
-Optional<SmallVector<int64_t, 4>> getOpNativeVectorSize<vector::ContractionOp>(
-    vector::ContractionOp op) {
-  auto targetEnvAttr = spirv::lookupTargetEnv(op);
-  auto targetEnv = spirv::TargetEnv(targetEnvAttr);
-  if (targetEnv.allows(spirv::Capability::CooperativeMatrixNV) &&
-      targetEnv.allows(spirv::Extension::SPV_NV_cooperative_matrix)) {
-    return getCooperativeMatmulSubgroupSize(
-        targetEnv.getResourceLimits(), op.getLhsType().getElementType(),
-        op.getRhsType().getElementType(),
-        op.getAccType().cast<VectorType>().getElementType(),
-        op.getResultType().cast<VectorType>().getElementType());
-  } else {
-    unsigned lastParalleldim = 0;
-    for (auto it : llvm::enumerate(op.iterator_types())) {
-      if (isParallelIterator(it.value())) lastParalleldim = it.index();
-    }
-    SmallVector<int64_t, 4> nativeSize(op.iterator_types().size(), 1);
-    nativeSize[lastParalleldim] = 4;
-    // Map to vec4 fma operations.
-    return nativeSize;
-  }
-}
-
-template <>
-Optional<SmallVector<int64_t, 4>> getOpNativeVectorSize<vector::FMAOp>(
-    vector::FMAOp op) {
-  SmallVector<int64_t, 4> size(op.getType().getRank(), 1);
-  size.back() = 4;
-  return size;
-}
-
-template <>
-Optional<SmallVector<int64_t, 4>> getOpNativeVectorSize<vector::TransferReadOp>(
-    vector::TransferReadOp op) {
-  auto targetEnv = spirv::TargetEnv(spirv::lookupTargetEnv(op));
-  if (targetEnv.allows(spirv::Capability::CooperativeMatrixNV) &&
-      targetEnv.allows(spirv::Extension::SPV_NV_cooperative_matrix)) {
-    // Unroll cooperative martrix load based on the size of the contract.
-    VectorType dstVec;
-    for (Operation *users : op->getUsers()) {
-      auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
-      if (!extract) return llvm::None;
-      auto vecType = extract.getResult().getType().cast<VectorType>();
-      if (dstVec && dstVec != vecType) return llvm::None;
-      dstVec = vecType;
-    }
-    return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
-                                   dstVec.getShape().end());
-  }
-
-  // Map to load4.
-  auto rank = op.getVectorType().getRank();
-  SmallVector<int64_t, 4> nativeSize(rank, 1);
-  // Load 4 elements on the most inner dimension.
-  for (auto dim : llvm::enumerate(op.permutation_map().getResults())) {
-    if (auto dimExpr = dim.value().dyn_cast<AffineDimExpr>()) {
-      if (dimExpr.getPosition() == op.permutation_map().getNumDims() - 1)
-        nativeSize[dim.index()] = 4;
-    }
-  }
-  return nativeSize;
-}
-
-template <>
-Optional<SmallVector<int64_t, 4>>
-getOpNativeVectorSize<vector::TransferWriteOp>(vector::TransferWriteOp op) {
-  auto targetEnv = spirv::TargetEnv(spirv::lookupTargetEnv(op));
-  if (targetEnv.allows(spirv::Capability::CooperativeMatrixNV) &&
-      targetEnv.allows(spirv::Extension::SPV_NV_cooperative_matrix)) {
-    // Unroll cooperative martrix store based on the size of the contract.
-    auto insert = op.vector().getDefiningOp<vector::InsertStridedSliceOp>();
-    if (!insert) return llvm::None;
-    ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
-    return SmallVector<int64_t, 4>(shape.begin(), shape.end());
-  }
-
-  // Map to store4.
-  auto rank = op.getVectorType().getRank();
-  SmallVector<int64_t, 4> nativeSize(rank, 1);
-  // Store 4 elements on the most inner dimension.
-  for (auto dim : llvm::enumerate(op.permutation_map().getResults())) {
-    if (auto dimExpr = dim.value().dyn_cast<AffineDimExpr>()) {
-      if (dimExpr.getPosition() == op.permutation_map().getNumDims() - 1)
-        nativeSize[dim.index()] = 4;
-    }
-  }
-  return nativeSize;
-}
-
-Optional<SmallVector<int64_t, 4>> getSPIRVNativeVectorSize(Operation *op) {
-#define DISPATCH(opname)                            \
-  if (isa<opname>(op)) {                            \
-    return getOpNativeVectorSize(cast<opname>(op)); \
-  }
-
-  DISPATCH(vector::ContractionOp)
-  DISPATCH(vector::FMAOp)
-  DISPATCH(vector::TransferReadOp)
-  DISPATCH(vector::TransferWriteOp)
-
-#undef DISPATCH
-
-  if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
-    if (auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>()) {
-      // Map elementwise ops to vec4.
-      SmallVector<int64_t, 4> nativeSize(vecType.getRank() - 1, 1);
-      nativeSize.push_back(4);
-      return nativeSize;
-    }
-  }
-  return llvm::None;
-}
-
 }  // namespace iree_compiler
 }  // namespace mlir
diff --git a/iree/compiler/Codegen/SPIRV/KernelConfig.h b/iree/compiler/Codegen/SPIRV/KernelConfig.h
new file mode 100644
index 0000000..22a1a8e
--- /dev/null
+++ b/iree/compiler/Codegen/SPIRV/KernelConfig.h
@@ -0,0 +1,31 @@
+// Copyright 2020 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+//===- KernelConfig.h - Kernel Generation Configurations ------------------===//
+//
+// This file declares utility functions for configuring SPIR-V kernel
+// generation, e.g., tiling schemes and workgroup size for important
+// Linalg named ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef IREE_COMPILER_CODEGEN_SPIRV_KERNELCONFIG_H_
+#define IREE_COMPILER_CODEGEN_SPIRV_KERNELCONFIG_H_
+
+#include "mlir/IR/BuiltinOps.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Attaches the `translation.info` attribute to entry points in `moduleOp` and
+/// `lowering.config` attributes to all root ops in `moduleOp`'s region.
+/// These attributes are used to drive the CodeGen pipeline.
+LogicalResult initSPIRVLaunchConfig(ModuleOp moduleOp);
+
+}  // namespace iree_compiler
+}  // namespace mlir
+
+#endif  // IREE_COMPILER_CODEGEN_SPIRV_KERNELCONFIG_H_
diff --git a/iree/compiler/Codegen/SPIRV/KernelDispatchUtils.h b/iree/compiler/Codegen/SPIRV/KernelDispatchUtils.h
deleted file mode 100644
index 3465e7d..0000000
--- a/iree/compiler/Codegen/SPIRV/KernelDispatchUtils.h
+++ /dev/null
@@ -1,46 +0,0 @@
-// Copyright 2020 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-//===- KernelDispatchUtils.h - Utilities for generating dispatch info -----===//
-//
-// This file declares utility functions that can be used to create information
-// the dispatch on the host side needs to execute an entry point function, like
-// the number of workgroups to use for launch, etc.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef IREE_COMPILER_CODEGEN_SPIRV_KERNELDISPATCHUTILS_H_
-#define IREE_COMPILER_CODEGEN_SPIRV_KERNELDISPATCHUTILS_H_
-
-#include <array>
-
-#include "iree/compiler/Codegen/Passes.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/StringMap.h"
-#include "llvm/Support/FormatVariadic.h"
-#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
-#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/Types.h"
-#include "mlir/IR/Value.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-LogicalResult initSPIRVLaunchConfig(ModuleOp moduleOp);
-
-/// Returns the size of instruction in `vector` dialect that maps directly to
-/// the hardware.
-Optional<SmallVector<int64_t, 4>> getSPIRVNativeVectorSize(Operation *op);
-
-}  // namespace iree_compiler
-}  // namespace mlir
-
-#endif  // IREE_COMPILER_CODEGEN_SPIRV_DISPATCHUTILS_H_
diff --git a/iree/compiler/Codegen/SPIRV/Passes.cpp b/iree/compiler/Codegen/SPIRV/Passes.cpp
index 4fa7ee2..af8b525 100644
--- a/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -4,9 +4,10 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-//===- Passes.cpp - Pipeline from HLO to Linalg to SPIR-V -----------------===//
+//===- Passes.cpp - Pipelines from Linalg ops to SPIR-V -------------------===//
 //
-// Implementation of conversion from XLA-HLO to Linalg to SPIR-V dialect.
+// This file contains various pipelines to lower IREE HAL executables containing
+// Linalg ops to SPIR-V.
 //
 //===----------------------------------------------------------------------===//
 
@@ -57,74 +58,51 @@
   return builder.create<memref::AllocOp>(loc, allocType, dynamicSizes);
 }
 
-void addSPIRVVectorizationPassPipeline(OpPassManager &pm) {
-  //===--------------------------------------------------------------------===//
-  // Initial clean up.
-  //===--------------------------------------------------------------------===//
+void addSPIRVTileAndVectorizePassPipeline(OpPassManager &pm) {
   pm.addPass(createCanonicalizerPass());
   pm.addPass(createCSEPass());
+
   pm.addNestedPass<FuncOp>(createSPIRVRemoveOneTripTiledLoopPass());
   //  Tile and distribute to GPU subgroups/invocations and vectorize.
-  pm.addNestedPass<FuncOp>(createSPIRVTileAndVectorizePass());
+  pm.addNestedPass<FuncOp>(createSPIRVTileAndDistributePass());
+  pm.addNestedPass<FuncOp>(createSPIRVVectorizePass());
   pm.addPass(createCanonicalizerPass());
 
-  // Handle ops that cannot go through the previous tiling, distribution, and
-  // vectorization flow. Only perform one level of distribution to map them to
-  // GPU global invocation IDs for distribution.
-  // TODO(antiagainst): Handle all the cases uniformly and remove this pass.
   pm.addNestedPass<FuncOp>(createSPIRVCopyToWorkgroupMemoryPass());
+  pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
   pm.addPass(createLowerAffinePass());
   pm.addPass(createCanonicalizerPass());
   pm.addPass(createCSEPass());
 
-  //===--------------------------------------------------------------------===//
-  // Optimizations and cleanups
-  //===--------------------------------------------------------------------===//
-
   // Perform various vector-level cross-op optimizations like load-store
   // forwarding, shape casting and casting op cancelling.
   pm.addNestedPass<FuncOp>(createOptimizeVectorTransferPass());
 }
 
-void addSPIRVDistributePassPipeline(OpPassManager &pm) {
-  //===--------------------------------------------------------------------===//
-  // Initial clean up.
-  //===--------------------------------------------------------------------===//
+void addSPIRVTileAndDistributePassPipeline(OpPassManager &pm) {
   pm.addPass(createCanonicalizerPass());
   pm.addPass(createCSEPass());
-  // Tile and distribute to GPU subgroups/invocations and vectorize.
+
+  // Tile and distribute to GPU subgroups/invocations.
   pm.addNestedPass<FuncOp>(createSPIRVTileAndDistributePass());
   pm.addPass(createCanonicalizerPass());
 
-  // Handle ops that cannot go through the previous tiling, distribution, and
-  // vectorization flow. Only perform one level of distribution to map them to
-  // GPU global invocation IDs for distribution.
-  // TODO(antiagainst): Handle all the cases uniformly and remove this pass.
   pm.addNestedPass<FuncOp>(createSPIRVCopyToWorkgroupMemoryPass());
+  pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
   pm.addPass(createLowerAffinePass());
   pm.addPass(createCanonicalizerPass());
   pm.addPass(createCSEPass());
-  //===--------------------------------------------------------------------===//
-  // Optimizations and cleanups
-  //===--------------------------------------------------------------------===//
 
   // Perform various vector-level cross-op optimizations like load-store
   // forwarding, shape casting and casting op cancelling.
   pm.addNestedPass<FuncOp>(createOptimizeVectorTransferPass());
 }
 
-void addSPIRVDistributeToGlobalIDPipeline(OpPassManager &pm) {
-  // Handle ops that cannot go through the previous tiling, distribution, and
-  // vectorization flow. Only perform one level of distribution to map them to
-  // GPU global invocation IDs for distribution.
-  // TODO(antiagainst): Handle all the cases uniformly and remove this pass.
-  pm.addNestedPass<FuncOp>(createSPIRVConvertToGPUPass());
+void addSPIRVDistributeToGlobalIDPassPipeline(OpPassManager &pm) {
+  pm.addNestedPass<FuncOp>(createSPIRVDistributeToGlobalIDPass());
   pm.addPass(createLowerAffinePass());
   pm.addPass(createCanonicalizerPass());
   pm.addPass(createCSEPass());
-  //===--------------------------------------------------------------------===//
-  // Optimizations and cleanups
-  //===--------------------------------------------------------------------===//
 
   // Perform various vector-level cross-op optimizations like load-store
   // forwarding, shape casting and casting op cancelling.
@@ -159,10 +137,6 @@
   pm.addPass(createCanonicalizerPass());
   pm.addPass(createCSEPass());
 
-  //===--------------------------------------------------------------------===//
-  // SPIR-V conversions
-  //===--------------------------------------------------------------------===//
-
   // Finally convert everything to SPIR-V.
   pm.addPass(createConvertToSPIRVPass());
 
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVConvertToGPU.cpp b/iree/compiler/Codegen/SPIRV/SPIRVDistributeToGlobalID.cpp
similarity index 94%
rename from iree/compiler/Codegen/SPIRV/SPIRVConvertToGPU.cpp
rename to iree/compiler/Codegen/SPIRV/SPIRVDistributeToGlobalID.cpp
index 5613f85..b8df3b4 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVConvertToGPU.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVDistributeToGlobalID.cpp
@@ -4,9 +4,9 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-//===- SPIRVConvertToGPUPass.cpp ------------------------------------------===//
+//===- SPIRVDistributeToGlobalIDPass.cpp ----------------------------------===//
 //
-// Partition computation within dispatch function to workgroups/workitems.
+// This pass distributes Linalg ops with buffer semantics to global invocations.
 //
 //===----------------------------------------------------------------------===//
 
@@ -15,7 +15,6 @@
 
 #include "iree/compiler/Codegen/PassDetail.h"
 #include "iree/compiler/Codegen/Passes.h"
-#include "iree/compiler/Codegen/SPIRV/KernelDispatchUtils.h"
 #include "iree/compiler/Codegen/SPIRV/Utils.h"
 #include "iree/compiler/Codegen/Transforms/Transforms.h"
 #include "iree/compiler/Codegen/Utils/MarkerUtils.h"
@@ -121,8 +120,8 @@
 
 namespace {
 /// Pass to convert from tiled and fused linalg ops into gpu.func.
-struct SPIRVConvertToGPUPass
-    : public SPIRVConvertToGPUBase<SPIRVConvertToGPUPass> {
+struct SPIRVDistributeToGlobalIDPass
+    : public SPIRVDistributeToGlobalIDBase<SPIRVDistributeToGlobalIDPass> {
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<AffineDialect, gpu::GPUDialect, memref::MemRefDialect,
                     scf::SCFDialect, ShapeDialect>();
@@ -186,9 +185,10 @@
 
 }  // namespace
 
-void SPIRVConvertToGPUPass::runOnOperation() {
+void SPIRVDistributeToGlobalIDPass::runOnOperation() {
   FuncOp funcOp = getOperation();
   if (!isEntryPoint(funcOp)) return;
+
   MLIRContext *context = &getContext();
   ConversionTarget target(*context);
   // After this pass Linalg and scf.parallel ops should be gone.
@@ -217,8 +217,8 @@
     return signalPassFailure();
 }
 
-std::unique_ptr<OperationPass<FuncOp>> createSPIRVConvertToGPUPass() {
-  return std::make_unique<SPIRVConvertToGPUPass>();
+std::unique_ptr<OperationPass<FuncOp>> createSPIRVDistributeToGlobalIDPass() {
+  return std::make_unique<SPIRVDistributeToGlobalIDPass>();
 }
 
 }  // namespace iree_compiler
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp b/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
index 91e499d..5413c48 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
@@ -6,7 +6,7 @@
 
 #include "iree/compiler/Codegen/PassDetail.h"
 #include "iree/compiler/Codegen/Passes.h"
-#include "iree/compiler/Codegen/SPIRV/KernelDispatchUtils.h"
+#include "iree/compiler/Codegen/SPIRV/KernelConfig.h"
 #include "iree/compiler/Codegen/Utils/Utils.h"
 #include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
 #include "iree/compiler/Dialect/HAL/IR/HALOps.h"
@@ -25,13 +25,16 @@
 namespace iree_compiler {
 
 namespace {
-/// Lowers an hal.executable.variant operation to scalar/native-vector
+/// Lowers a hal.executable.variant inner module to SPIR-V scalar/native-vector
 /// code. Invokes different compilation pipeline to
-/// - first lower to scalar/native-vector code
+/// - first lower to scalar/native-vector code,
 /// - then convert to SPIRV dialect.
 class SPIRVLowerExecutableTargetPass
     : public SPIRVLowerExecutableTargetBase<SPIRVLowerExecutableTargetPass> {
  public:
+  SPIRVLowerExecutableTargetPass() = default;
+  SPIRVLowerExecutableTargetPass(const SPIRVLowerExecutableTargetPass &pass) {}
+
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<AffineDialect, gpu::GPUDialect, IREE::HAL::HALDialect,
                     linalg::LinalgDialect, linalg_ext::LinalgExtDialect,
@@ -39,19 +42,14 @@
                     spirv::SPIRVDialect, vector::VectorDialect>();
   }
 
-  SPIRVLowerExecutableTargetPass() = default;
-  SPIRVLowerExecutableTargetPass(const SPIRVLowerExecutableTargetPass &pass){};
-
   void runOnOperation() override;
 
  private:
   Option<bool> testLoweringConfiguration{
       *this, "test-lowering-configuration",
-      llvm::cl::desc(
-          "Flag used for lit-testing the default configuration set for root "
-          "ops in hal.executable.variants. Defaults to false and is set to "
-          "true "
-          "for lit tests. Not for general usage"),
+      llvm::cl::desc("Flag used for lit-testing the configuration set for root "
+                     "ops in hal.executable.variants. Defaults to false. Set "
+                     "to true for lit tests; not for general usage"),
       llvm::cl::init(false)};
 };
 }  // namespace
@@ -99,13 +97,13 @@
     OpPassManager &nestedModulePM = executableLoweringPipeline.nest<ModuleOp>();
     switch (*passPipeline) {
       case IREE::HAL::DispatchLoweringPassPipeline::SPIRVDistribute:
-        addSPIRVDistributePassPipeline(nestedModulePM);
+        addSPIRVTileAndDistributePassPipeline(nestedModulePM);
         break;
       case IREE::HAL::DispatchLoweringPassPipeline::SPIRVDistributeToGlobalID:
-        addSPIRVDistributeToGlobalIDPipeline(nestedModulePM);
+        addSPIRVDistributeToGlobalIDPassPipeline(nestedModulePM);
         break;
       case IREE::HAL::DispatchLoweringPassPipeline::SPIRVVectorize:
-        addSPIRVVectorizationPassPipeline(nestedModulePM);
+        addSPIRVTileAndVectorizePassPipeline(nestedModulePM);
         break;
       default:
         llvm_unreachable("Unsupported pipeline on GPU target.");
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp b/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp
index 8fcbea0..0befc77 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp
@@ -4,38 +4,31 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-//===- SPIRVTileAndDistribute.cpp
-//------------------------------------------===//
+//===- SPIRVTileAndDistribute.cpp -----------------------------------------===//
 //
-// This pass tiles and vectorizes Linalg ops on buffers within in a single
-// workgroup.
+// This pass tiles and distributes Linalg ops with buffer semantics to subgroups
+// and invocations.
 //
 //===----------------------------------------------------------------------===//
 
 #include "iree/compiler/Codegen/PassDetail.h"
 #include "iree/compiler/Codegen/Passes.h"
-#include "iree/compiler/Codegen/SPIRV/KernelDispatchUtils.h"
-#include "iree/compiler/Codegen/SPIRV/MemorySpace.h"
 #include "iree/compiler/Codegen/SPIRV/Utils.h"
 #include "iree/compiler/Codegen/Transforms/Transforms.h"
 #include "iree/compiler/Codegen/Utils/MarkerUtils.h"
 #include "iree/compiler/Codegen/Utils/Utils.h"
-#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
-#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorTransforms.h"
 #include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Identifier.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
@@ -43,7 +36,7 @@
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/LoopUtils.h"
 
-#define DEBUG_TYPE "iree-spirv-tile-and-vectorize"
+#define DEBUG_TYPE "iree-spirv-tile-and-distribute"
 
 namespace mlir {
 namespace iree_compiler {
@@ -72,29 +65,7 @@
 }
 
 //===----------------------------------------------------------------------===//
-// Main pass
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Function pass that implements tiling and fusion in Linalg on buffers.
-class SPIRVTileAndDistributePass
-    : public SPIRVTileAndDistributeBase<SPIRVTileAndDistributePass> {
- public:
-  SPIRVTileAndDistributePass() = default;
-  SPIRVTileAndDistributePass(const SPIRVTileAndDistributePass &pass) = default;
-
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<AffineDialect, IREE::HAL::HALDialect, gpu::GPUDialect,
-                    linalg::LinalgDialect, memref::MemRefDialect,
-                    scf::SCFDialect, ShapeDialect, vector::VectorDialect>();
-  }
-
-  void runOnOperation() override;
-};
-}  // namespace
-
-//===----------------------------------------------------------------------===//
-// Patterns to tile computation to map to subgroups
+// Subgroup tiling patterns
 //===----------------------------------------------------------------------===//
 
 /// Computes the Value for subgroupID along each dimension given number of
@@ -169,7 +140,7 @@
 }
 
 //===----------------------------------------------------------------------===//
-// Patterns and methods for thread tiling.
+// Invocation tiling patterns
 //===----------------------------------------------------------------------===//
 
 /// Patterns for third level tiling to target invocations.
@@ -184,8 +155,8 @@
             }));
       };
 
-  auto getThreadProcInfoFn = [&](OpBuilder &builder, Location loc,
-                                 ArrayRef<Range> parallelLoopRanges) {
+  auto getThreadProcInfoFn = [](OpBuilder &builder, Location loc,
+                                ArrayRef<Range> parallelLoopRanges) {
     return getGPUProcessorIdsAndCounts<gpu::ThreadIdOp, gpu::BlockDimOp>(
         builder, loc, parallelLoopRanges.size());
   };
@@ -205,7 +176,7 @@
                   linalg::LinalgTilingPattern<linalg::FillOp>,
                   linalg::LinalgTilingPattern<linalg::BatchMatmulOp>,
                   linalg::LinalgTilingPattern<linalg::Conv1DNwcWcfOp>,
-                  linalg::LinalgTilingPattern<linalg::Conv2DNhwcHwcfOp>,
+                  linalg::LinalgTilingPattern<linalg::Conv3DNdhwcDhwcfOp>,
                   linalg::LinalgTilingPattern<linalg::DepthwiseConv2DNhwcOp>,
                   linalg::LinalgTilingPattern<linalg::GenericOp>,
                   linalg::LinalgTilingPattern<linalg::PoolingNhwcMaxOp>,
@@ -246,7 +217,7 @@
 }
 
 //====---------------------------------------------------------------------===//
-// Patterns to tile convolution window dimensions
+// Convolution filter tiling patterns
 //====---------------------------------------------------------------------===//
 
 static void populateTilingConvFilterPatterns(
@@ -274,29 +245,28 @@
       context, tilingOptions, marker);
 }
 
-//====---------------------------------------------------------------------===//
-// Patterns to lower linalg ops to loops
-//====---------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
+// Main pass
+//===----------------------------------------------------------------------===//
 
-template <typename OpTy>
-struct LowerToLoops final : public OpRewritePattern<OpTy> {
-  using OpRewritePattern<OpTy>::OpRewritePattern;
+namespace {
+/// Function pass that implements tiling and distributing Linalg ops with
+/// buffer semantics.
+class SPIRVTileAndDistributePass
+    : public SPIRVTileAndDistributeBase<SPIRVTileAndDistributePass> {
+ public:
+  SPIRVTileAndDistributePass() = default;
+  SPIRVTileAndDistributePass(const SPIRVTileAndDistributePass &pass) = default;
 
-  LogicalResult matchAndRewrite(OpTy op,
-                                PatternRewriter &rewriter) const override {
-    // Only handle the cases where tiling to invocations was done, where tiling
-    // convolution filters or vectorization is expected.
-    if (!hasMarker(op, {getConvFilterTileMarker(), getVectorizeMarker()}))
-      return failure();
-
-    if (linalg::linalgOpToLoops(rewriter, op)) {
-      rewriter.eraseOp(op);
-      return success();
-    }
-
-    return failure();
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect,
+                    memref::MemRefDialect, scf::SCFDialect,
+                    vector::VectorDialect>();
   }
+
+  void runOnOperation() override;
 };
+}  // namespace
 
 //====---------------------------------------------------------------------===//
 // Main pass implementation
@@ -309,16 +279,40 @@
   if (!entryPointOp) return;
 
   {
-    RewritePatternSet thirdLevelTilingPatterns(&getContext());
-    populateTilingToInvocationPatterns(context, thirdLevelTilingPatterns);
+    RewritePatternSet subgroupTilingPatterns(&getContext());
+    populateTilingToSubgroupPatterns(context, subgroupTilingPatterns);
     (void)applyPatternsAndFoldGreedily(funcOp,
-                                       std::move(thirdLevelTilingPatterns));
+                                       std::move(subgroupTilingPatterns));
+
+    RewritePatternSet canonicalizationPatterns =
+        linalg::getLinalgTilingCanonicalizationPatterns(context);
+    populateAffineMinCanonicalizationPattern(canonicalizationPatterns);
+    (void)applyPatternsAndFoldGreedily(funcOp,
+                                       std::move(canonicalizationPatterns));
+    promoteSingleIterationLoops(funcOp);
+
+    LLVM_DEBUG({
+      llvm::dbgs() << "--- After tiling to subgroups ---\n";
+      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+      llvm::dbgs() << "\n\n";
+    });
+  }
+
+  {
+    RewritePatternSet invocationTilingPatterns(&getContext());
+    populateTilingToInvocationPatterns(context, invocationTilingPatterns);
+    (void)applyPatternsAndFoldGreedily(funcOp,
+                                       std::move(invocationTilingPatterns));
 
     // Remove trip-one loops created during cyclic loop distribution if we can
     // prove the tiling was perfect.
     RewritePatternSet canoncalizationPatterns(context);
     populateAffineMinSCFCanonicalizationPattern(canoncalizationPatterns);
-    auto workgroupSize = getWorkgroupSize(entryPointOp);
+    SmallVector<int64_t> workgroupSize = getWorkgroupSize(entryPointOp);
+    if (workgroupSize.empty()) {
+      entryPointOp.emitError("expected to have workgroup_size attribute");
+      return signalPassFailure();
+    }
     auto getThreadRangeFn = [workgroupSize](Value processorValue,
                                             SmallVectorImpl<Value> &dims,
                                             SmallVectorImpl<Value> &symbols) {
@@ -345,20 +339,21 @@
   }
 
   {
-    RewritePatternSet tilingPatterns(&getContext());
+    RewritePatternSet convFilterTilingPatterns(&getContext());
     auto marker = getLinalgMatchAndReplaceMarker(getConvFilterTileMarker(),
                                                  getVectorizeMarker(), context);
-    populateTilingConvFilterPatterns(context, tilingPatterns, marker);
-    populateFoldGPUProcessorIDUsesPatterns(context, tilingPatterns);
-    tilingPatterns.insert<linalg::AffineMinSCFCanonicalizationPattern>(context);
-    (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPatterns));
+    populateTilingConvFilterPatterns(context, convFilterTilingPatterns, marker);
+    populateFoldGPUProcessorIDUsesPatterns(context, convFilterTilingPatterns);
+    convFilterTilingPatterns
+        .insert<linalg::AffineMinSCFCanonicalizationPattern>(context);
+    (void)applyPatternsAndFoldGreedily(funcOp,
+                                       std::move(convFilterTilingPatterns));
 
-    RewritePatternSet convTilingCanonicalizationPatterns =
+    RewritePatternSet canonicalizationPatterns =
         linalg::getLinalgTilingCanonicalizationPatterns(context);
-    populateAffineMinCanonicalizationPattern(
-        convTilingCanonicalizationPatterns);
-    (void)applyPatternsAndFoldGreedily(
-        funcOp, std::move(convTilingCanonicalizationPatterns));
+    populateAffineMinCanonicalizationPattern(canonicalizationPatterns);
+    (void)applyPatternsAndFoldGreedily(funcOp,
+                                       std::move(canonicalizationPatterns));
 
     LLVM_DEBUG({
       llvm::dbgs() << "--- After tiling convolution filter  ---\n";
@@ -366,27 +361,6 @@
       llvm::dbgs() << "\n\n";
     });
   }
-
-  // Lower ops that were tiled to invocations but not vectorized to loops.
-  // TODO(antiagainst): This is here now to simplify the interaction with
-  // ConvertToGPUPass, where we finally lower away all Linalg ops. Once that
-  // pass is cleaned up, we can invoke createConvertLinalgToLoopsPass
-  // directly.
-  {
-    RewritePatternSet patterns(context);
-    patterns.add<LowerToLoops<linalg::BatchMatmulOp>,
-                 LowerToLoops<linalg::Conv1DNwcWcfOp>,
-                 LowerToLoops<linalg::Conv2DNhwcHwcfOp>,
-                 LowerToLoops<linalg::Conv3DNdhwcDhwcfOp>,
-                 LowerToLoops<linalg::DepthwiseConv2DNhwOp>,
-                 LowerToLoops<linalg::DepthwiseConv2DNhwcOp>,
-                 LowerToLoops<linalg::FillOp>, LowerToLoops<linalg::GenericOp>,
-                 LowerToLoops<linalg::MatmulOp>,
-                 LowerToLoops<linalg::PoolingNhwcMaxOp>,
-                 LowerToLoops<linalg::PoolingNhwcMinOp>,
-                 LowerToLoops<linalg::PoolingNhwcSumOp>>(context);
-    (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
-  }
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorize.cpp b/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorize.cpp
deleted file mode 100644
index 2d62d27..0000000
--- a/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorize.cpp
+++ /dev/null
@@ -1,649 +0,0 @@
-// Copyright 2020 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-//===- SPIRVTileAndVectorize.cpp ------------------------------------------===//
-//
-// This pass tiles and vectorizes Linalg ops on buffers within in a single
-// workgroup.
-//
-//===----------------------------------------------------------------------===//
-
-#include "iree/compiler/Codegen/PassDetail.h"
-#include "iree/compiler/Codegen/Passes.h"
-#include "iree/compiler/Codegen/SPIRV/KernelDispatchUtils.h"
-#include "iree/compiler/Codegen/SPIRV/MemorySpace.h"
-#include "iree/compiler/Codegen/SPIRV/Utils.h"
-#include "iree/compiler/Codegen/Transforms/Transforms.h"
-#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
-#include "iree/compiler/Codegen/Utils/Utils.h"
-#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/Support/Debug.h"
-#include "mlir/Dialect/GPU/GPUDialect.h"
-#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
-#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
-#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/Dialect/Linalg/Utils/Utils.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Vector/VectorTransforms.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Identifier.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/FoldUtils.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/LoopUtils.h"
-
-#define DEBUG_TYPE "iree-spirv-tile-and-vectorize"
-
-namespace mlir {
-namespace iree_compiler {
-
-//===----------------------------------------------------------------------===//
-// Utility functions
-//===----------------------------------------------------------------------===//
-
-/// Returns a Linalg marker that matches any of the `matchMarkers` and replaces
-/// it with `replaceMarker`.
-static linalg::LinalgTransformationFilter getLinalgMatchAndReplaceMarker(
-    ArrayRef<StringRef> matchMarkers, StringRef replaceMarker,
-    MLIRContext *context) {
-  SmallVector<Identifier, 2> markers;
-  markers.reserve(matchMarkers.size());
-  for (StringRef marker : matchMarkers) {
-    markers.emplace_back(Identifier::get(marker, context));
-  }
-  return linalg::LinalgTransformationFilter(
-      markers, Identifier::get(replaceMarker, context));
-}
-
-/// Converts a symbolic GPU processor dimension to its numeric one.
-static unsigned dimToIndex(StringRef dim) {
-  return StringSwitch<unsigned>(dim).Case("x", 0).Case("y", 1).Case("z", 2);
-}
-
-//===----------------------------------------------------------------------===//
-// Main pass
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Function pass that implements tiling and fusion in Linalg on buffers.
-class SPIRVTileAndVectorizePass
-    : public SPIRVTileAndVectorizeBase<SPIRVTileAndVectorizePass> {
- public:
-  SPIRVTileAndVectorizePass() = default;
-  SPIRVTileAndVectorizePass(const SPIRVTileAndVectorizePass &pass) = default;
-
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<AffineDialect, IREE::HAL::HALDialect, gpu::GPUDialect,
-                    linalg::LinalgDialect, memref::MemRefDialect,
-                    scf::SCFDialect, ShapeDialect, vector::VectorDialect>();
-  }
-
-  void runOnOperation() override;
-};
-}  // namespace
-
-//===----------------------------------------------------------------------===//
-// Patterns to promote subviews to workgroup memory
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Pattern to promote matmul operands to workgroup memory.
-struct PromoteMatmulSubviewsPattern
-    : public linalg::LinalgPromotionPattern<linalg::MatmulOp> {
-  PromoteMatmulSubviewsPattern(MLIRContext *context,
-                               linalg::LinalgPromotionOptions options,
-                               linalg::LinalgTransformationFilter marker,
-                               PatternBenefit benefit = 1)
-      : linalg::LinalgPromotionPattern<linalg::MatmulOp>(
-            context,
-            options.setOperandsToPromote({0, 1}).setUseFullTileBuffers(
-                {false, false}),
-            marker, benefit) {}
-};
-
-/// Patterns to promote convolution operands to workgroup memory.
-// TODO(ravishankarm): This pattern is only promoting the image subview to
-// workgroup memory. In reality we should also be able to promote the filter
-// subview to workgroup memory as well. Since none of the loops used to access
-// the filter are tiled, this would mean the entire filter is moved to workgroup
-// memory. Two reasons this is not done right now:
-// 1) Linalg when tiling doesnt create a subview for the filter (since none of
-//    its dimensions are tiled. This needs to be relaxed (maybe by using an
-//    option).
-// 2) Maybe there are better alternatives for handling filter like using
-//    different storage classes, since for inference workloads these are model
-//    constants. This is TBD.
-template <typename ConvOpTy>
-struct PromoteConvSubviewsPattern
-    : public linalg::LinalgPromotionPattern<ConvOpTy> {
-  PromoteConvSubviewsPattern(MLIRContext *context,
-                             linalg::LinalgPromotionOptions options,
-                             linalg::LinalgTransformationFilter marker,
-                             PatternBenefit benefit = 1)
-      : linalg::LinalgPromotionPattern<ConvOpTy>(
-            context,
-            options.setOperandsToPromote({0}).setUseFullTileBuffers(
-                {false, false}),
-            marker, benefit) {}
-};
-}  // namespace
-
-static void populatePromotionPatterns(MLIRContext *context,
-                                      RewritePatternSet &patterns) {
-  patterns.insert<PromoteMatmulSubviewsPattern,
-                  PromoteConvSubviewsPattern<linalg::Conv2DNhwcHwcfOp>>(
-      context,
-      linalg::LinalgPromotionOptions()
-          .setAllocationDeallocationFns(allocateWorkgroupMemory,
-                                        deallocateWorkgroupMemory)
-          .setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory),
-      getLinalgMatchAndReplaceMarker(getWorkgroupMarker(),
-                                     getWorkgroupMemoryMarker(), context));
-}
-
-//===----------------------------------------------------------------------===//
-// Patterns to tile computation to map to subgroups
-//===----------------------------------------------------------------------===//
-
-/// Computes the Value for subgroupID along each dimension given number of
-/// subgroups `numSubGroups` along each dimension (x-first, y-second, z-third).
-static SmallVector<linalg::ProcInfo, 2> getSubgroupIdsAndCounts(
-    OpBuilder &builder, Location loc, ArrayRef<int64_t> numSubgroups) {
-  Type indexType = builder.getIndexType();
-  Value subgroupId = builder.create<gpu::SubgroupIdOp>(loc, indexType);
-  SmallVector<linalg::ProcInfo, 2> procInfo(numSubgroups.size());
-
-  // subgroupID
-  //   = id.z * nsubgroups.y * nsubgroups.x + id.y * nsubgroups.x + id.x
-  for (size_t i = 0, e = numSubgroups.size(); i != e; ++i) {
-    Value nprocs = builder.create<ConstantIndexOp>(loc, numSubgroups[i]);
-    AffineExpr d0 = getAffineDimExpr(0, builder.getContext());
-    AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext());
-    Value procId =
-        makeComposedAffineApply(builder, loc, d0 % s0, {subgroupId, nprocs});
-    procInfo[e - i - 1] = linalg::ProcInfo{procId, nprocs};
-    subgroupId = builder.create<SignedDivIOp>(loc, subgroupId, nprocs);
-  }
-  return procInfo;
-}
-
-namespace {
-/// Pattern to tile linalg.matmul for subgroups.
-struct TileMatmulSubgroupPattern
-    : public linalg::LinalgTilingPattern<linalg::MatmulOp> {
-  using Base = linalg::LinalgTilingPattern<linalg::MatmulOp>;
-  TileMatmulSubgroupPattern(MLIRContext *context,
-                            linalg::LinalgTilingOptions options,
-                            linalg::LinalgTransformationFilter marker,
-                            PatternBenefit benefit = 1)
-      : Base(context, options, marker, benefit) {}
-};
-}  // namespace
-
-/// Patterns for second level tiling to target subgroups.
-static void populateTilingToSubgroupPatterns(MLIRContext *context,
-                                             RewritePatternSet &patterns) {
-  auto getInnerTileSizeFn = [&](OpBuilder &builder,
-                                Operation *operation) -> SmallVector<Value, 4> {
-    SmallVector<int64_t> tileSizes = getTileSizes(operation, 1);
-    return llvm::to_vector<4>(
-        llvm::map_range(tileSizes, [&](int64_t v) -> Value {
-          return builder.create<ConstantIndexOp>(operation->getLoc(), v);
-        }));
-  };
-
-  auto getSubgroupProcInfoFn = [&](OpBuilder &builder, Location loc,
-                                   ArrayRef<Range> parallelLoopRanges) {
-    // TODO(ravishankarm): For now assume that there is always a single subgroup
-    std::array<int64_t, 3> numSubgroups = {1, 1, 1};
-    return getSubgroupIdsAndCounts(builder, loc, numSubgroups);
-  };
-
-  linalg::LinalgLoopDistributionOptions subgroupDistributionOptions;
-  subgroupDistributionOptions.procInfo = getSubgroupProcInfoFn;
-  subgroupDistributionOptions.distributionMethod = {
-      {linalg::DistributionMethod::CyclicNumProcsEqNumIters,
-       linalg::DistributionMethod::CyclicNumProcsEqNumIters}};
-
-  patterns.insert<TileMatmulSubgroupPattern>(
-      context,
-      linalg::LinalgTilingOptions()
-          .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops)
-          .setTileSizeComputationFunction(getInnerTileSizeFn)
-          .setDistributionOptions(subgroupDistributionOptions),
-      getLinalgMatchAndReplaceMarker(
-          {getWorkgroupMemoryMarker(), getWorkgroupMarker()},
-          getVectorizeMarker(), context));
-}
-
-//===----------------------------------------------------------------------===//
-// Patterns and methods for thread tiling.
-//===----------------------------------------------------------------------===//
-
-/// Patterns for third level tiling to target invocations.
-static void populateTilingToInvocationPatterns(MLIRContext *context,
-                                               RewritePatternSet &patterns) {
-  linalg::TileSizeComputationFunction getInnerTileSizeFn =
-      [&](OpBuilder &builder, Operation *operation) {
-        SmallVector<int64_t> tileSizes = getTileSizes(operation, 2);
-        return llvm::to_vector<4>(
-            llvm::map_range(tileSizes, [&](int64_t v) -> Value {
-              return builder.create<ConstantIndexOp>(operation->getLoc(), v);
-            }));
-      };
-
-  auto getThreadProcInfoFn = [](OpBuilder &builder, Location loc,
-                                ArrayRef<Range> parallelLoopRanges) {
-    return getGPUProcessorIdsAndCounts<gpu::ThreadIdOp, gpu::BlockDimOp>(
-        builder, loc, parallelLoopRanges.size());
-  };
-  linalg::LinalgLoopDistributionOptions invocationDistributionOptions;
-  invocationDistributionOptions.procInfo = getThreadProcInfoFn;
-  invocationDistributionOptions.distributionMethod = {
-      {linalg::DistributionMethod::Cyclic, linalg::DistributionMethod::Cyclic,
-       linalg::DistributionMethod::Cyclic}};
-
-  auto tilingOptions =
-      linalg::LinalgTilingOptions()
-          .setLoopType(linalg::LinalgTilingLoopType::Loops)
-          .setTileSizeComputationFunction(getInnerTileSizeFn)
-          .setDistributionOptions(invocationDistributionOptions);
-
-  patterns.insert<linalg::LinalgTilingPattern<linalg::MatmulOp>,
-                  linalg::LinalgTilingPattern<linalg::FillOp>,
-                  linalg::LinalgTilingPattern<linalg::BatchMatmulOp>,
-                  linalg::LinalgTilingPattern<linalg::Conv1DNwcWcfOp>,
-                  linalg::LinalgTilingPattern<linalg::Conv3DNdhwcDhwcfOp>,
-                  linalg::LinalgTilingPattern<linalg::DepthwiseConv2DNhwcOp>,
-                  linalg::LinalgTilingPattern<linalg::GenericOp>,
-                  linalg::LinalgTilingPattern<linalg::PoolingNhwcMaxOp>,
-                  linalg::LinalgTilingPattern<linalg::PoolingNhwcMinOp>,
-                  linalg::LinalgTilingPattern<linalg::PoolingNhwcSumOp>>(
-      context, tilingOptions,
-      getLinalgMatchAndReplaceMarker(
-          {getWorkgroupMemoryMarker(), getWorkgroupMarker()},
-          getVectorizeMarker(), context));
-
-  patterns.insert<linalg::LinalgTilingPattern<linalg::Conv2DNhwcHwcfOp>,
-                  linalg::LinalgTilingPattern<linalg::DepthwiseConv2DNhwOp>>(
-      context, tilingOptions,
-      getLinalgMatchAndReplaceMarker(
-          {getWorkgroupMemoryMarker(), getWorkgroupMarker()},
-          getConvFilterTileMarker(), context));
-}
-
-/// Returns the corresponding range for the given `processorValue` is a GPU
-/// thread id or block dim.
-static Optional<std::pair<AffineExpr, AffineExpr>> getThreadRange(
-    Value processorValue, SmallVectorImpl<Value> & /*dims*/,
-    SmallVectorImpl<Value> & /*symbols*/, ArrayRef<int64_t> workgroupSize) {
-  if (auto idOp = processorValue.getDefiningOp<gpu::ThreadIdOp>()) {
-    OpBuilder builder(processorValue.getContext());
-    unsigned index = dimToIndex(idOp.dimension());
-    AffineExpr zero = builder.getAffineConstantExpr(0);
-    AffineExpr ubExpr = builder.getAffineConstantExpr(workgroupSize[index]);
-    return std::make_pair(zero, ubExpr - 1);
-  }
-  if (auto dimOp = processorValue.getDefiningOp<gpu::BlockDimOp>()) {
-    OpBuilder builder(processorValue.getContext());
-    unsigned index = dimToIndex(dimOp.dimension());
-    AffineExpr bound = builder.getAffineConstantExpr(workgroupSize[index]);
-    return std::make_pair(bound, bound);
-  }
-  return llvm::None;
-}
-
-//====---------------------------------------------------------------------===//
-// Patterns for vectorization
-//====---------------------------------------------------------------------===//
-
-static void populateVectorizationPatterns(MLIRContext *context,
-                                          RewritePatternSet &patterns) {
-  linalg::insertVectorizationPatterns<linalg::FillOp, linalg::GenericOp,
-                                      linalg::ContractionOpInterface>(
-      patterns, linalg::LinalgVectorizationOptions(),
-      linalg::LinalgTransformationFilter(
-          Identifier::get(getVectorizeMarker(), context)));
-}
-
-//====---------------------------------------------------------------------===//
-// Patterns for unrolling vectors
-//====---------------------------------------------------------------------===//
-
-static void populateVectorUnrollPatterns(MLIRContext *context,
-                                         RewritePatternSet &patterns) {
-  vector::populateVectorUnrollPatterns(
-      patterns,
-      vector::UnrollVectorOptions().setNativeShapeFn(getSPIRVNativeVectorSize));
-}
-
-namespace {
-
-/// Workaround SPIR-V backend limitations. SPIR-V vetorization pass relies on
-/// unrolling to reduce instructions to a vector size we can convert to SPIR-V.
-/// When vectorization creates transpose those block unrolling and result in
-/// large vector we currently cannot lower. For now we always merge the
-/// transpose into the contract op so that it can be unrolled.
-// TODO(thomasraoux): Make transpose work with the current unrolling mechanism
-// or replace unrolling.
-class CombineContractTranspose final
-    : public OpRewritePattern<vector::ContractionOp> {
- public:
-  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override {
-    // Perform lhs + rhs transpositions to conform to matmul row-major
-    // semantics. Bail out if the contraction cannot be put in this form.
-    MLIRContext *ctx = op.getContext();
-    Location loc = op.getLoc();
-    bool foundTranspose = false;
-    std::array<Value, 3> sources = {op.lhs(), op.rhs(), op.acc()};
-    SmallVector<AffineMap> newMaps;
-    SmallVector<Value> newSources;
-    for (auto source : llvm::enumerate(sources)) {
-      auto map =
-          op.indexing_maps()[source.index()].cast<AffineMapAttr>().getValue();
-      auto tranposeOp = source.value().getDefiningOp<vector::TransposeOp>();
-      if (!tranposeOp) {
-        newSources.push_back(source.value());
-        newMaps.push_back(map);
-        continue;
-      }
-      SmallVector<int64_t, 3> perm;
-      tranposeOp.getTransp(perm);
-      SmallVector<AffineExpr> exprs(perm.size());
-      for (auto remap : llvm::enumerate(perm)) {
-        exprs[remap.value()] = map.getResult(remap.index());
-      }
-      newMaps.push_back(
-          AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs, ctx));
-      newSources.push_back(tranposeOp.vector());
-      foundTranspose = true;
-    }
-    if (!foundTranspose) return failure();
-
-    Value res = rewriter.create<vector::ContractionOp>(
-        loc, newSources[0], newSources[1], newSources[2],
-        rewriter.getAffineMapArrayAttr(newMaps), op.iterator_types());
-    rewriter.replaceOp(op, res);
-    return success();
-  }
-};
-
-}  // namespace
-
-//====---------------------------------------------------------------------===//
-// Vector patterns
-//====---------------------------------------------------------------------===//
-
-static void applyVectorTransformation(FuncOp funcOp) {
-  auto targetEnv = spirv::TargetEnv(spirv::lookupTargetEnv(funcOp));
-  bool useCooperativeMatrix =
-      targetEnv.allows(spirv::Capability::CooperativeMatrixNV) &&
-      targetEnv.allows(spirv::Extension::SPV_NV_cooperative_matrix);
-  {
-    {
-      RewritePatternSet vectorUnrollPatterns(funcOp.getContext());
-      populateVectorUnrollPatterns(funcOp.getContext(), vectorUnrollPatterns);
-      (void)applyPatternsAndFoldGreedily(funcOp,
-                                         std::move(vectorUnrollPatterns));
-    }
-    {
-      linalg::hoistRedundantVectorTransfers(funcOp);
-
-      LLVM_DEBUG({
-        llvm::dbgs() << "--- After hoisting vector transfers ---\n";
-        funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-        llvm::dbgs() << "\n\n";
-      });
-    }
-    {
-      RewritePatternSet canonicalizationPatterns2(funcOp.getContext());
-      vector::populateVectorTransferPermutationMapLoweringPatterns(
-          canonicalizationPatterns2);
-      (void)applyPatternsAndFoldGreedily(funcOp,
-                                         std::move(canonicalizationPatterns2));
-
-      if (useCooperativeMatrix) {
-        // When using cooperative matrix we don't want to lower the contract,
-        // instead we want to merge contract and transpose so that they can be
-        // converted to cooperative matrix matmul op.
-        // TODO(thomasraoux): remove that once we support cooperative matrix
-        // lowering in MLIR core.
-        RewritePatternSet combineTransposePatterns(funcOp.getContext());
-        combineTransposePatterns.add<CombineContractTranspose>(
-            funcOp.getContext());
-        (void)applyPatternsAndFoldGreedily(funcOp,
-                                           std::move(combineTransposePatterns));
-      } else {
-        RewritePatternSet contractLoweringPatterns(funcOp.getContext());
-        vector::populateVectorContractLoweringPatterns(
-            contractLoweringPatterns,
-            vector::VectorTransformsOptions().setVectorTransformsOptions(
-                vector::VectorContractLowering::OuterProduct));
-        (void)applyPatternsAndFoldGreedily(funcOp,
-                                           std::move(contractLoweringPatterns));
-      }
-    }
-    LLVM_DEBUG({
-      llvm::dbgs() << "--- After unrolling vector ---\n";
-      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-      llvm::dbgs() << "\n\n";
-    });
-  }
-}
-
-//====---------------------------------------------------------------------===//
-// Patterns to tile convolution window dimensions
-//====---------------------------------------------------------------------===//
-
-static void populateTilingConvFilterPatterns(
-    MLIRContext *context, RewritePatternSet &patterns,
-    linalg::LinalgTransformationFilter marker) {
-  auto getTileSizeFn = [&](OpBuilder &builder, Operation *op) {
-    SmallVector<Value, 4> tileSizes;
-    SmallVector<int64_t, 4> fourthLevel = getTileSizes(op, 3);
-    tileSizes.reserve(fourthLevel.size());
-
-    Location loc = op->getLoc();
-    for (int64_t size : fourthLevel) {
-      tileSizes.push_back(builder.create<ConstantIndexOp>(loc, size));
-    }
-    return tileSizes;
-  };
-
-  auto tilingOptions = linalg::LinalgTilingOptions()
-                           .setLoopType(linalg::LinalgTilingLoopType::Loops)
-                           .setTileSizeComputationFunction(getTileSizeFn);
-
-  patterns.insert<linalg::LinalgTilingPattern<linalg::Conv2DNhwcHwcfOp>,
-                  linalg::LinalgTilingPattern<linalg::DepthwiseConv2DNhwcOp>,
-                  linalg::LinalgTilingPattern<linalg::DepthwiseConv2DNhwOp>>(
-      context, tilingOptions, marker);
-}
-
-//====---------------------------------------------------------------------===//
-// Patterns to lower linalg ops to loops
-//====---------------------------------------------------------------------===//
-
-template <typename OpTy>
-struct LowerToLoops final : public OpRewritePattern<OpTy> {
-  using OpRewritePattern<OpTy>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(OpTy op,
-                                PatternRewriter &rewriter) const override {
-    // Only handle the cases where tiling to invocations was done, where tiling
-    // convolution filters or vectorization is expected.
-    if (!hasMarker(op, {getConvFilterTileMarker(), getVectorizeMarker()}))
-      return failure();
-
-    if (linalg::linalgOpToLoops(rewriter, op)) {
-      rewriter.eraseOp(op);
-      return success();
-    }
-
-    return failure();
-  }
-};
-
-//====---------------------------------------------------------------------===//
-// Main pass implementation
-//====---------------------------------------------------------------------===//
-
-void SPIRVTileAndVectorizePass::runOnOperation() {
-  MLIRContext *context = &getContext();
-  FuncOp funcOp = getOperation();
-  auto entryPointOp = getEntryPoint(funcOp);
-  if (!entryPointOp) return;
-
-  // TODO(thomasraoux, antiagainst): Tiling to subgroups shouldn't be
-  // controlled by vectorization. This is needed due to historical reasons.
-  // Change the second level tiling to cyclic to loops and remove this.
-  RewritePatternSet secondLevelTilingPatterns(&getContext());
-  populateTilingToSubgroupPatterns(context, secondLevelTilingPatterns);
-  (void)applyPatternsAndFoldGreedily(funcOp,
-                                     std::move(secondLevelTilingPatterns));
-
-  RewritePatternSet secondLevelTilingCanonicalizationPatterns =
-      linalg::getLinalgTilingCanonicalizationPatterns(context);
-  populateAffineMinCanonicalizationPattern(
-      secondLevelTilingCanonicalizationPatterns);
-  (void)applyPatternsAndFoldGreedily(
-      funcOp, std::move(secondLevelTilingCanonicalizationPatterns));
-  promoteSingleIterationLoops(funcOp);
-
-  LLVM_DEBUG({
-    llvm::dbgs() << "--- After tiling to subgroups ---\n";
-    funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-    llvm::dbgs() << "\n\n";
-  });
-
-  {
-    RewritePatternSet thirdLevelTilingPatterns(&getContext());
-    populateTilingToInvocationPatterns(context, thirdLevelTilingPatterns);
-    (void)applyPatternsAndFoldGreedily(funcOp,
-                                       std::move(thirdLevelTilingPatterns));
-
-    // Remove trip-one loops created during cyclic loop distribution if we can
-    // prove the tiling was perfect.
-    RewritePatternSet canoncalizationPatterns(context);
-    populateAffineMinSCFCanonicalizationPattern(canoncalizationPatterns);
-    SmallVector<int64_t> workgroupSize = getWorkgroupSize(entryPointOp);
-    if (workgroupSize.empty()) {
-      entryPointOp.emitError("expected to have workgroup_size attribute");
-      return signalPassFailure();
-    }
-    auto getThreadRangeFn = [workgroupSize](Value processorValue,
-                                            SmallVectorImpl<Value> &dims,
-                                            SmallVectorImpl<Value> &symbols) {
-      return getThreadRange(processorValue, dims, symbols, workgroupSize);
-    };
-    populateRemoveSingleIterationLoopPattern(canoncalizationPatterns,
-                                             getThreadRangeFn);
-    (void)applyPatternsAndFoldGreedily(funcOp,
-                                       std::move(canoncalizationPatterns));
-
-    // Perform generic canonicalization.
-    RewritePatternSet threadLevelTilingCanonicalizationPatterns =
-        linalg::getLinalgTilingCanonicalizationPatterns(context);
-    populateAffineMinCanonicalizationPattern(
-        threadLevelTilingCanonicalizationPatterns);
-    (void)applyPatternsAndFoldGreedily(
-        funcOp, std::move(threadLevelTilingCanonicalizationPatterns));
-
-    LLVM_DEBUG({
-      llvm::dbgs() << "--- After tiling to invocations ---\n";
-      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-      llvm::dbgs() << "\n\n";
-    });
-  }
-
-  {
-    RewritePatternSet tilingPatterns(&getContext());
-    auto marker = getLinalgMatchAndReplaceMarker(getConvFilterTileMarker(),
-                                                 getVectorizeMarker(), context);
-    populateTilingConvFilterPatterns(context, tilingPatterns, marker);
-    populateFoldGPUProcessorIDUsesPatterns(context, tilingPatterns);
-    tilingPatterns.insert<linalg::AffineMinSCFCanonicalizationPattern>(context);
-    (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPatterns));
-
-    RewritePatternSet convTilingCanonicalizationPatterns =
-        linalg::getLinalgTilingCanonicalizationPatterns(context);
-    populateAffineMinCanonicalizationPattern(
-        convTilingCanonicalizationPatterns);
-    (void)applyPatternsAndFoldGreedily(
-        funcOp, std::move(convTilingCanonicalizationPatterns));
-
-    LLVM_DEBUG({
-      llvm::dbgs() << "--- After tiling convolution filter  ---\n";
-      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-      llvm::dbgs() << "\n\n";
-    });
-  }
-
-  {
-    RewritePatternSet vectorizationPatterns(&getContext());
-    populateVectorizationPatterns(context, vectorizationPatterns);
-    populateLinalgToVectorVectorizeConvPatterns(context, vectorizationPatterns);
-    (void)applyPatternsAndFoldGreedily(funcOp,
-                                       std::move(vectorizationPatterns));
-    LLVM_DEBUG({
-      llvm::dbgs() << "--- After vectorization ---\n";
-      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-      llvm::dbgs() << "\n\n";
-    });
-  }
-
-  // TODO: This should be a folding of Add into Contract in core but while
-  // they live in different dialects, it is not possible without unnatural
-  // dependencies.
-  funcOp.walk([&](Operation *op) {
-    if (auto contract = canonicalizeContractionAdd(op))
-      op->replaceAllUsesWith(contract);
-  });
-
-  applyVectorTransformation(funcOp);
-
-  // Lower ops that were tiled to invocations but not vectorized to loops.
-  // TODO(antiagainst): This is here now to simplify the interaction with
-  // ConvertToGPUPass, where we finally lower away all Linalg ops. Once that
-  // pass is cleaned up, we can invoke createConvertLinalgToLoopsPass
-  // directly.
-  {
-    RewritePatternSet patterns(context);
-    patterns.add<LowerToLoops<linalg::BatchMatmulOp>,
-                 LowerToLoops<linalg::Conv1DNwcWcfOp>,
-                 LowerToLoops<linalg::Conv2DNhwcHwcfOp>,
-                 LowerToLoops<linalg::Conv3DNdhwcDhwcfOp>,
-                 LowerToLoops<linalg::DepthwiseConv2DNhwcOp>,
-                 LowerToLoops<linalg::DepthwiseConv2DNhwOp>,
-                 LowerToLoops<linalg::FillOp>, LowerToLoops<linalg::GenericOp>,
-                 LowerToLoops<linalg::MatmulOp>,
-                 LowerToLoops<linalg::PoolingNhwcMaxOp>,
-                 LowerToLoops<linalg::PoolingNhwcMinOp>,
-                 LowerToLoops<linalg::PoolingNhwcSumOp>>(context);
-    (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
-  }
-}
-
-//===----------------------------------------------------------------------===//
-// Pass entry point and registration
-//===----------------------------------------------------------------------===//
-
-std::unique_ptr<OperationPass<FuncOp>> createSPIRVTileAndVectorizePass() {
-  return std::make_unique<SPIRVTileAndVectorizePass>();
-}
-
-}  // namespace iree_compiler
-}  // namespace mlir
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp b/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
index 108bf96..3ea5b03 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
@@ -6,8 +6,297 @@
 
 //===- SPIRVVectorize.cpp -------------------------------------------------===//
 //
-// This pass vectorizes Linalg ops on buffers within in a single workgroup.
+// This pass vectorizes Linalg ops with buffer semantics.
 //
 //===----------------------------------------------------------------------===//
 
+#include "iree/compiler/Codegen/PassDetail.h"
+#include "iree/compiler/Codegen/Passes.h"
+#include "iree/compiler/Codegen/Transforms/Transforms.h"
+#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
 #define DEBUG_TYPE "iree-spirv-vectorize"
+
+namespace mlir {
+namespace iree_compiler {
+namespace {
+
+/// Returns the cooperative matrix (M, N, K) sizes that are supported by the
+/// target environment and match the given parameters.
+static Optional<SmallVector<int64_t, 4>> getCooperativeMatmulSubgroupSize(
+    spirv::ResourceLimitsAttr resourceLimits, Type lhsType, Type rhsType,
+    Type initType, Type resultType) {
+  auto range = resourceLimits.cooperative_matrix_properties_nv()
+                   .getAsRange<spirv::CooperativeMatrixPropertiesNVAttr>();
+  for (auto coopMatmulProperties : range) {
+    if (coopMatmulProperties.a_type().getValue() == lhsType &&
+        coopMatmulProperties.b_type().getValue() == rhsType &&
+        coopMatmulProperties.c_type().getValue() == initType &&
+        coopMatmulProperties.result_type().getValue() == resultType &&
+        coopMatmulProperties.scope().getValue() == spirv::Scope::Subgroup) {
+      return SmallVector<int64_t, 4>{
+          coopMatmulProperties.m_size().getValue().getSExtValue(),
+          coopMatmulProperties.n_size().getValue().getSExtValue(),
+          coopMatmulProperties.k_size().getValue().getSExtValue()};
+    }
+  }
+  return llvm::None;
+}
+
+/// Returns true if the target environment attached to `op`'s ancestor op
+/// supports cooperative matrix.
+bool useCooperativeMatrix(Operation *op) {
+  auto targetEnv = spirv::TargetEnv(spirv::lookupTargetEnv(op));
+  return targetEnv.allows(spirv::Capability::CooperativeMatrixNV) &&
+         targetEnv.allows(spirv::Extension::SPV_NV_cooperative_matrix);
+}
+
+Optional<SmallVector<int64_t, 4>> getSPIRVNativeVectorSize(Operation *op) {
+  auto targetEnv = spirv::TargetEnv(spirv::lookupTargetEnv(op));
+  bool useCoopMatrix = useCooperativeMatrix(op);
+
+  if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
+    if (auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>()) {
+      // Use 4-element vectors for elementwise ops.
+      SmallVector<int64_t, 4> nativeSize(vecType.getRank(), 1);
+      nativeSize.back() = 4;
+      return nativeSize;
+    }
+  } else if (auto vtOp = dyn_cast<VectorTransferOpInterface>(op)) {
+    if (useCoopMatrix) {
+      if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
+        // Unroll cooperative martrix store based on the size of the contract.
+        auto insert =
+            writeOp.vector().getDefiningOp<vector::InsertStridedSliceOp>();
+        if (!insert) return llvm::None;
+        ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
+        return SmallVector<int64_t, 4>(shape.begin(), shape.end());
+      } else if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
+        // Unroll cooperative martrix load based on the size of the contract.
+        VectorType dstVec;
+        for (Operation *users : op->getUsers()) {
+          auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
+          if (!extract) return llvm::None;
+          auto vecType = extract.getResult().getType().cast<VectorType>();
+          if (dstVec && dstVec != vecType) return llvm::None;
+          dstVec = vecType;
+        }
+        return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
+                                       dstVec.getShape().end());
+      }
+    } else {
+      auto rank = vtOp.getVectorType().getRank();
+      SmallVector<int64_t, 4> nativeSize(rank, 1);
+      for (auto dim : llvm::enumerate(vtOp.permutation_map().getResults())) {
+        if (auto dimExpr = dim.value().dyn_cast<AffineDimExpr>()) {
+          if (dimExpr.getPosition() == vtOp.permutation_map().getNumDims() - 1)
+            nativeSize[dim.index()] = 4;
+        }
+      }
+      return nativeSize;
+    }
+  } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
+    if (useCoopMatrix) {
+      return getCooperativeMatmulSubgroupSize(
+          targetEnv.getResourceLimits(),
+          contractOp.getLhsType().getElementType(),
+          contractOp.getRhsType().getElementType(),
+          contractOp.getAccType().cast<VectorType>().getElementType(),
+          contractOp.getResultType().cast<VectorType>().getElementType());
+    } else {
+      unsigned lastParalleldim = 0;
+      for (auto it : llvm::enumerate(contractOp.iterator_types())) {
+        if (isParallelIterator(it.value())) lastParalleldim = it.index();
+      }
+      SmallVector<int64_t, 4> nativeSize(contractOp.iterator_types().size(), 1);
+      nativeSize[lastParalleldim] = 4;
+      // Map to vec4 fma operations.
+      return nativeSize;
+    }
+  }
+  return llvm::None;
+}
+
+/// Add patterns to vectorize Linalg ops with vectorization marker.
+void populateVectorizationPatterns(MLIRContext *context,
+                                   RewritePatternSet &patterns) {
+  linalg::insertVectorizationPatterns<linalg::FillOp, linalg::GenericOp,
+                                      linalg::ContractionOpInterface>(
+      patterns, linalg::LinalgVectorizationOptions(),
+      linalg::LinalgTransformationFilter(
+          Identifier::get(getVectorizeMarker(), context)));
+}
+
+/// Adds patterns to unroll vector ops to SPIR-V native vector size.
+void populateVectorUnrollPatterns(MLIRContext *context,
+                                  RewritePatternSet &patterns) {
+  vector::populateVectorUnrollPatterns(
+      patterns,
+      vector::UnrollVectorOptions().setNativeShapeFn(getSPIRVNativeVectorSize));
+}
+
+/// Workaround SPIR-V backend limitations. SPIR-V vetorization pass relies on
+/// unrolling to reduce instructions to a vector size we can convert to SPIR-V.
+/// When vectorization creates transpose those block unrolling and result in
+/// large vector we currently cannot lower. For now we always merge the
+/// transpose into the contract op so that it can be unrolled.
+// TODO(thomasraoux): Make transpose work with the current unrolling mechanism
+// or replace unrolling.
+class CombineContractTranspose final
+    : public OpRewritePattern<vector::ContractionOp> {
+ public:
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override {
+    // Perform lhs + rhs transpositions to conform to matmul row-major
+    // semantics. Bail out if the contraction cannot be put in this form.
+    MLIRContext *ctx = op.getContext();
+    Location loc = op.getLoc();
+    bool foundTranspose = false;
+    std::array<Value, 3> sources = {op.lhs(), op.rhs(), op.acc()};
+    SmallVector<AffineMap> newMaps;
+    SmallVector<Value> newSources;
+    for (auto source : llvm::enumerate(sources)) {
+      auto map =
+          op.indexing_maps()[source.index()].cast<AffineMapAttr>().getValue();
+      auto tranposeOp = source.value().getDefiningOp<vector::TransposeOp>();
+      if (!tranposeOp) {
+        newSources.push_back(source.value());
+        newMaps.push_back(map);
+        continue;
+      }
+      SmallVector<int64_t, 3> perm;
+      tranposeOp.getTransp(perm);
+      SmallVector<AffineExpr> exprs(perm.size());
+      for (auto remap : llvm::enumerate(perm)) {
+        exprs[remap.value()] = map.getResult(remap.index());
+      }
+      newMaps.push_back(
+          AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs, ctx));
+      newSources.push_back(tranposeOp.vector());
+      foundTranspose = true;
+    }
+    if (!foundTranspose) return failure();
+
+    Value res = rewriter.create<vector::ContractionOp>(
+        loc, newSources[0], newSources[1], newSources[2],
+        rewriter.getAffineMapArrayAttr(newMaps), op.iterator_types());
+    rewriter.replaceOp(op, res);
+    return success();
+  }
+};
+
+/// Vectorizes Linalg ops on buffer semantics.
+class SPIRVVectorizePass : public SPIRVVectorizeBase<SPIRVVectorizePass> {
+ public:
+  SPIRVVectorizePass() = default;
+  SPIRVVectorizePass(const SPIRVVectorizePass &pass) = default;
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<linalg::LinalgDialect, vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    FuncOp funcOp = getOperation();
+
+    auto entryPointOp = getEntryPoint(funcOp);
+    if (!entryPointOp) return;
+
+    {
+      RewritePatternSet vectorizationPatterns(&getContext());
+      populateVectorizationPatterns(context, vectorizationPatterns);
+      populateLinalgToVectorVectorizeConvPatterns(context,
+                                                  vectorizationPatterns);
+      (void)applyPatternsAndFoldGreedily(funcOp,
+                                         std::move(vectorizationPatterns));
+
+      LLVM_DEBUG({
+        llvm::dbgs() << "--- After vectorization ---\n";
+        funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+        llvm::dbgs() << "\n\n";
+      });
+    }
+
+    // TODO: This should be a folding of Add into Contract in core but while
+    // they live in different dialects, it is not possible without unnatural
+    // dependencies.
+    funcOp.walk([&](Operation *op) {
+      if (auto contract = canonicalizeContractionAdd(op))
+        op->replaceAllUsesWith(contract);
+    });
+
+    {
+      RewritePatternSet vectorUnrollPatterns(funcOp.getContext());
+      populateVectorUnrollPatterns(funcOp.getContext(), vectorUnrollPatterns);
+      (void)applyPatternsAndFoldGreedily(funcOp,
+                                         std::move(vectorUnrollPatterns));
+    }
+
+    {
+      linalg::hoistRedundantVectorTransfers(funcOp);
+
+      LLVM_DEBUG({
+        llvm::dbgs() << "--- After hoisting vector transfers ---\n";
+        funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+        llvm::dbgs() << "\n\n";
+      });
+    }
+
+    {
+      RewritePatternSet canonicalizationPatterns(funcOp.getContext());
+      vector::populateVectorTransferPermutationMapLoweringPatterns(
+          canonicalizationPatterns);
+      (void)applyPatternsAndFoldGreedily(funcOp,
+                                         std::move(canonicalizationPatterns));
+
+      if (useCooperativeMatrix(funcOp)) {
+        // When using cooperative matrix we don't want to lower the contract,
+        // instead we want to merge contract and transpose so that they can be
+        // converted to cooperative matrix matmul op.
+        // TODO(thomasraoux): remove that once we support cooperative matrix
+        // lowering in MLIR core.
+        RewritePatternSet combineTransposePatterns(funcOp.getContext());
+        combineTransposePatterns.add<CombineContractTranspose>(
+            funcOp.getContext());
+        (void)applyPatternsAndFoldGreedily(funcOp,
+                                           std::move(combineTransposePatterns));
+      } else {
+        RewritePatternSet contractLoweringPatterns(funcOp.getContext());
+        vector::populateVectorContractLoweringPatterns(
+            contractLoweringPatterns,
+            vector::VectorTransformsOptions().setVectorTransformsOptions(
+                vector::VectorContractLowering::OuterProduct));
+        (void)applyPatternsAndFoldGreedily(funcOp,
+                                           std::move(contractLoweringPatterns));
+      }
+    }
+
+    LLVM_DEBUG({
+      llvm::dbgs() << "--- After unrolling vector ---\n";
+      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+      llvm::dbgs() << "\n\n";
+    });
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<OperationPass<FuncOp>> createSPIRVVectorizePass() {
+  return std::make_unique<SPIRVVectorizePass>();
+}
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Codegen/SPIRV/Utils.h b/iree/compiler/Codegen/SPIRV/Utils.h
index b286642..a7f899b 100644
--- a/iree/compiler/Codegen/SPIRV/Utils.h
+++ b/iree/compiler/Codegen/SPIRV/Utils.h
@@ -4,11 +4,12 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-//===- Utils.h - Utility functions used in Linalg to SPIR-V lowering ------===//
+//===- Utils.h - Utility functions for lowering Linalg to SPIR-V ----------===//
 //
-// Utility functions used while lowering from Linalg to SPIRV.
+// Utility functions used while lowering from Linalg to SPIR-V.
 //
 //===----------------------------------------------------------------------===//
+
 #ifndef IREE_COMPILER_CODEGEN_SPIRV_UTILS_H_
 #define IREE_COMPILER_CODEGEN_SPIRV_UTILS_H_
 
diff --git a/iree/compiler/Codegen/SPIRV/test/BUILD b/iree/compiler/Codegen/SPIRV/test/BUILD
index 0f3bc85..2624a6a 100644
--- a/iree/compiler/Codegen/SPIRV/test/BUILD
+++ b/iree/compiler/Codegen/SPIRV/test/BUILD
@@ -19,8 +19,8 @@
     name = "lit",
     srcs = enforce_glob(
         [
-            "convert_to_gpu.mlir",
             "convert_to_spirv.mlir",
+            "distribute_to_global_id.mlir",
             "fold_gpu_procid_uses.mlir",
             "pipeline_matmul_cooperative_matrix.mlir",
             "pipeline_matmul_vectorization.mlir",
diff --git a/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt b/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
index cc4083a..9f99d11 100644
--- a/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
+++ b/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
@@ -14,8 +14,8 @@
   NAME
     lit
   SRCS
-    "convert_to_gpu.mlir"
     "convert_to_spirv.mlir"
+    "distribute_to_global_id.mlir"
     "fold_gpu_procid_uses.mlir"
     "pipeline_matmul_cooperative_matrix.mlir"
     "pipeline_matmul_vectorization.mlir"
diff --git a/iree/compiler/Codegen/SPIRV/test/convert_to_gpu.mlir b/iree/compiler/Codegen/SPIRV/test/distribute_to_global_id.mlir
similarity index 99%
rename from iree/compiler/Codegen/SPIRV/test/convert_to_gpu.mlir
rename to iree/compiler/Codegen/SPIRV/test/distribute_to_global_id.mlir
index 729557e..8ba2fd0 100644
--- a/iree/compiler/Codegen/SPIRV/test/convert_to_gpu.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/distribute_to_global_id.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-convert-to-gpu))))' -canonicalize -cse %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-distribute-to-global-id))))' -canonicalize -cse %s | IreeFileCheck %s
 
 #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 hal.executable @parallel_4D attributes {sym_visibility = "private"} {
diff --git a/iree/compiler/Codegen/SPIRV/test/promote_workgroup_memory.mlir b/iree/compiler/Codegen/SPIRV/test/promote_workgroup_memory.mlir
index b257355..adda0d2 100644
--- a/iree/compiler/Codegen/SPIRV/test/promote_workgroup_memory.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/promote_workgroup_memory.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-tile-and-vectorize,canonicalize,cse))))'
+// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-tile-and-distribute,iree-spirv-vectorize,canonicalize,cse))))'
 // TODO(antiagainst): Fix promotion to workgroup and enable the test.
 // | IreeFileCheck %s
 
diff --git a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize.mlir b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize.mlir
index 78d41ab..0ab91cb 100644
--- a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-tile-and-vectorize,canonicalize,cse))))' %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-tile-and-distribute,iree-spirv-vectorize,canonicalize,cse))))' %s | IreeFileCheck %s
 
 #map0 = affine_map<()[s0] -> (s0 * 8)>
 #map1 = affine_map<()[s0, s1] -> (8, s1 - s0 * 8)>
@@ -56,7 +56,7 @@
           %18 = memref.subview %arg2[%3, %10] [%15, %17] [1, 1]  : memref<?x?xf32> to memref<?x?xf32, #map3>
           linalg.matmul {__internal_linalg_transform__ = "workgroup", lowering.config = #config}
             ins(%7, %13 : memref<?x?xf32, #map3>, memref<?x?xf32, #map3>)
-           outs(%18 : memref<?x?xf32, #map3>)
+            outs(%18 : memref<?x?xf32, #map3>)
         }
         return
       }
@@ -98,7 +98,7 @@
       translation.info = {passPipeline = 6 : i32, workloadPerWorkgroup = [32, 4, 1]}
     }
     module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative], [SPV_KHR_storage_buffer_storage_class]>, SwiftShader:CPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 16384 : i32, max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, subgroup_size = 4 : i32}>}  {
-      func @conv_1d() attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}} {
+      func @conv_1d() {
         %cst = constant 0.000000e+00 : f32
         %c0 = constant 0 : index
         %0 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<3x6x1xf32>
@@ -132,9 +132,7 @@
 }
 
 // CHECK-LABEL: func @conv_1d
-//   CHECK-DAG: %[[C0:.+]] = constant 0 : index
-//   CHECK-DAG: %[[C1:.+]] = constant 1 : index
-//   CHECK-DAG: %[[C3:.+]] = constant 3 : index
+//       CHECK: %[[C0:.+]] = constant 0 : index
 //       CHECK: %[[RET:.+]] = hal.interface.binding.subspan @io::@ret0
 //       CHECK: %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0
 //       CHECK: %[[ARG1:.+]] = hal.interface.binding.subspan @io::@arg1
@@ -151,11 +149,10 @@
 //       CHECK:     %[[ARG0SV2:.+]] = memref.subview %[[ARG0SV1]][%[[TIDZ]], %[[IV0]], 0] [1, %{{.+}}, 1]
 //       CHECK:     %[[ARG1SV2:.+]] = memref.subview %[[ARG1SV1]][0, 0, %[[IV1]]] [3, 1, 1]
 //       CHECK:     %[[RETSV2:.+]] = memref.subview %[[RETSV1]][%[[TIDZ]], %[[IV0]], %[[IV1]]] [1, 1, 1]
-//       CHECK:     scf.for %[[IV2:.+]] = %[[C0]] to %[[C3]] step %[[C1]]
-//       CHECK:       memref.load %[[ARG0SV2]][%[[C0]], %[[IV2]], %[[C0]]]
-//       CHECK:       memref.load %[[ARG1SV2]][%[[IV2]], %[[C0]], %[[C0]]]
-//       CHECK:       memref.load %[[RETSV2]][%[[C0]], %[[C0]], %[[C0]]]
-//       CHECK:       memref.store %{{.+}}, %[[RETSV2]][%[[C0]], %[[C0]], %[[C0]]]
+//       CHECK:     linalg.conv_1d_nwc_wcf
+//  CHECK-SAME:       __internal_linalg_transform__ = "vectorize"
+//  CHECK-SAME:       ins(%[[ARG0SV2]], %[[ARG1SV2]]
+//  CHECK-SAME:       outs(%[[RETSV2]]
 
 
 // -----
@@ -270,22 +267,24 @@
 //         CHECK:   %[[BSTEPY:.+]] = affine.apply #[[MAP0]]()[%[[NBLOCKSY]]]
 //         CHECK:   %[[BOFFSETX:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]]
 //         CHECK:   %[[BSTEPX:.+]] = affine.apply #[[MAP1]]()[%[[NBLOCKSX]]]
-//         CHECK:   scf.for %[[IV3:.+]] = %[[BIDZ]] to %[[N]] step %[[NBLOCKSZ]]
-//         CHECK:     scf.for %[[IV4:.+]] = %[[BOFFSETY]] to %[[P]] step %[[BSTEPY]]
-//         CHECK:       scf.for %[[IV5:.+]] = %[[BOFFSETX]] to %[[Q]] step %[[BSTEPX]]
-//         CHECK:         %[[SV1:.+]] = memref.subview %[[ARG1]][%[[IV3]], %[[IV4]], %[[IV5]], 0]
-//         CHECK:         %[[SV2:.+]] = memref.subview %[[RET0]][%[[IV3]], %[[IV4]], %[[IV5]], 0]
+//         CHECK:   scf.for %[[IV0:.+]] = %[[BIDZ]] to %[[N]] step %[[NBLOCKSZ]]
+//         CHECK:     scf.for %[[IV1:.+]] = %[[BOFFSETY]] to %[[P]] step %[[BSTEPY]]
+//         CHECK:       scf.for %[[IV2:.+]] = %[[BOFFSETX]] to %[[Q]] step %[[BSTEPX]]
+//         CHECK:         %[[SV1:.+]] = memref.subview %[[ARG1]][%[[IV0]], %[[IV1]], %[[IV2]], 0]
+//         CHECK:         %[[SV2:.+]] = memref.subview %[[RET0]][%[[IV0]], %[[IV1]], %[[IV2]], 0]
 //     CHECK-DAG:         %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
 //     CHECK-DAG:         %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
 //     CHECK-DAG:         %[[TIDZ:.+]] = "gpu.thread_id"() {dimension = "z"}
 //     CHECK-DAG:         %[[BDIMX:.+]] = "gpu.block_dim"() {dimension = "x"}
 //     CHECK-DAG:         %[[BDIMY:.+]] = "gpu.block_dim"() {dimension = "y"}
 //     CHECK-DAG:         %[[BDIMZ:.+]] = "gpu.block_dim"() {dimension = "z"}
-//         CHECK:         scf.for %{{.+}} = %[[TIDZ]] to %{{.*}} step %[[BDIMZ]]
-//         CHECK:           scf.for %{{.+}} = %[[TIDY]] to %{{.*}} step %[[BDIMY]]
-//         CHECK:             scf.for %{{.+}} = %[[TIDX]] to %{{.*}} step %[[BDIMX]]
-// CHECK-COUNT-3:               scf.for
-//     CHECK-NOT:               linalg.conv_2d_nhwc_hwcf
+//         CHECK:         scf.for %[[IV3:.+]] = %[[TIDZ]] to %{{.*}} step %[[BDIMZ]]
+//         CHECK:           scf.for %[[IV4:.+]] = %[[TIDY]] to %{{.*}} step %[[BDIMY]]
+//         CHECK:             scf.for %[[IV5:.+]] = %[[TIDX]] to %{{.*}} step %[[BDIMX]]
+//         CHECK:               %[[OUT:.+]] = memref.subview %[[SV2]][0, %[[IV3]], %[[IV4]], %[[IV5]]]
+//         CHECK:               linalg.conv_2d_nhwc_hwcf
+//    CHECK-SAME:                 __internal_linalg_transform__ = "tile_conv_filter"
+//    CHECK-SAME:                 outs(%[[OUT]]
 
 // -----
 
@@ -304,7 +303,7 @@
       translation.info = {passPipeline = 6 : i32, workloadPerWorkgroup = [32, 4, 1]}
     }
     module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative], [SPV_KHR_storage_buffer_storage_class]>, SwiftShader:CPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 16384 : i32, max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, subgroup_size = 4 : i32}>}  {
-      func @conv_3d() attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}} {
+      func @conv_3d() {
         %cst = constant 0.000000e+00 : f32
         %c0 = constant 0 : index
         %0 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<2x7x7x7x2xf32>
@@ -343,11 +342,13 @@
 //     CHECK-DAG:         %[[BDIMX:.+]] = "gpu.block_dim"() {dimension = "x"}
 //     CHECK-DAG:         %[[BDIMY:.+]] = "gpu.block_dim"() {dimension = "y"}
 //     CHECK-DAG:         %[[BDIMZ:.+]] = "gpu.block_dim"() {dimension = "z"}
-//         CHECK:         scf.for %{{.+}} = %[[TIDZ]] to %{{.*}} step %[[BDIMZ]]
-//         CHECK:           scf.for %{{.+}} = %[[TIDY]] to %{{.*}} step %[[BDIMY]]
-//         CHECK:             scf.for %{{.+}} = %[[TIDX]] to %{{.*}} step %[[BDIMX]]
-// CHECK-COUNT-5:               scf.for
-//     CHECK-NOT:               linalg.conv_3d_ndhwc_dhwcf
+//         CHECK:         scf.for %[[IV0:.+]] = %[[TIDZ]] to %{{.*}} step %[[BDIMZ]]
+//         CHECK:           scf.for %[[IV1:.+]] = %[[TIDY]] to %{{.*}} step %[[BDIMY]]
+//         CHECK:             scf.for %[[IV2:.+]] = %[[TIDX]] to %{{.*}} step %[[BDIMX]]
+//         CHECK:               %[[OUT:.+]] = memref.subview %{{.+}}[0, 0, %[[IV0]], %[[IV1]], %[[IV2]]]
+//         CHECK:               linalg.conv_3d_ndhwc_dhwcf
+//    CHECK-SAME:                 __internal_linalg_transform__ = "vectorize"
+//    CHECK-SAME:                 outs(%[[OUT]]
 
 // -----
 
@@ -376,7 +377,7 @@
         translation.info = {passPipeline = 6 : i32, workloadPerWorkgroup = [32, 4, 1]}
       }
       module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>}  {
-        func @pooling_nhwc_max() attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}} {
+        func @pooling_nhwc_max() {
           %c0 = constant 0 : index
           %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<2x16x16x6xf32>
           %1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<3x4xf32>
@@ -422,8 +423,12 @@
 //     CHECK-DAG:   %[[BDIMX:.+]] = "gpu.block_dim"() {dimension = "x"}
 //     CHECK-DAG:   %[[BDIMY:.+]] = "gpu.block_dim"() {dimension = "y"}
 //     CHECK-DAG:   %[[BDIMZ:.+]] = "gpu.block_dim"() {dimension = "z"}
-//         CHECK:   scf.for %{{.+}} = %[[TIDZ]] to %{{.*}} step %[[BDIMZ]]
-//         CHECK:     scf.for %{{.+}} = %[[TIDY]] to %{{.*}} step %[[BDIMY]]
-//         CHECK:       scf.for %{{.+}} = %[[TIDX]] to %{{.*}} step %[[BDIMX]]
-// CHECK-COUNT-3:         scf.for
-//     CHECK-NOT:           linalg.pooling_nhwc_max
+//         CHECK:   scf.for %[[IV0:.+]] = %[[TIDZ]] to %{{.*}} step %[[BDIMZ]]
+//         CHECK:     scf.for %[[IV1:.+]] = %[[TIDY]] to %{{.*}} step %[[BDIMY]]
+//         CHECK:       scf.for %[[IV2:.+]] = %[[TIDX]] to %{{.*}} step %[[BDIMX]]
+//         CHECK:         %[[IN:.+]] = memref.subview %[[SV1]][%[[IV0]], %[[IV1]], %[[IV2]], 0] [1, %{{.+}}, %{{.+}}, 6]
+//         CHECK:         %[[OUT:.+]] = memref.subview %[[SV2]][%[[IV0]], %[[IV1]], %[[IV2]], 0] [1, 1, 1, 6]
+//         CHECK:         linalg.pooling_nhwc_max
+//    CHECK-SAME:           __internal_linalg_transform__ = "vectorize"
+//    CHECK-SAME:           ins(%[[IN]], %[[ARG1]]
+//    CHECK-SAME:           outs(%[[OUT]]
diff --git a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir
index 1cbb881..5e10ef7 100644
--- a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-set-num-workgroups,builtin.module(builtin.func(iree-spirv-tile-and-vectorize))))' -canonicalize -cse %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-set-num-workgroups,builtin.module(builtin.func(iree-spirv-tile-and-distribute,iree-spirv-vectorize))))' -canonicalize -cse %s | IreeFileCheck %s
 
 #config = {tileSizes = [[1, 8, 64, 4], [], [1, 8, 4, 4]]}
 
diff --git a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
index cf389cc..dbb19ad 100644
--- a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-set-num-workgroups,builtin.module(builtin.func(canonicalize,iree-spirv-remove-one-trip-tiled-loop,iree-spirv-tile-and-vectorize))))' -canonicalize -cse %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-set-num-workgroups,builtin.module(builtin.func(canonicalize,iree-spirv-remove-one-trip-tiled-loop,iree-spirv-tile-and-distribute,iree-spirv-vectorize))))' -canonicalize -cse %s | IreeFileCheck %s
 
 #config = {tileSizes = [[0, 4, 4, 16], [], [0, 4, 1, 4], [0, 0, 0, 0, 1, 1, 4]]}
 
diff --git a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir
index e3c3076..06227f3 100644
--- a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-set-num-workgroups,builtin.module(builtin.func(iree-spirv-tile-and-vectorize))))' -canonicalize -cse %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-set-num-workgroups,builtin.module(builtin.func(iree-spirv-tile-and-distribute,iree-spirv-vectorize))))' -canonicalize -cse %s | IreeFileCheck %s
 
 #config = {tileSizes = [[8, 64, 4], [], [8, 4, 4]]}
 
diff --git a/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir b/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir
index d84bb33..bb5d1f1 100644
--- a/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-tile-and-vectorize))))' %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-tile-and-distribute,iree-spirv-vectorize))))' %s | IreeFileCheck %s
 
 // CHECK-LABEL: func @elementwise_static_shape
 //       CHECK:   vector.transfer_read %{{.+}}[%c0], {{.+}} memref<4xf32, #{{.+}}>, vector<4xf32>
diff --git a/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir b/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
index f343f6b..ffa8a7c 100644
--- a/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-tile-and-vectorize,canonicalize,cse))))' %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(builtin.func(iree-spirv-tile-and-distribute,iree-spirv-vectorize,canonicalize,cse))))' %s | IreeFileCheck %s
 // TODO(antiagainst): Fix promotion to workgroup and enable the test.
 // | IreeFileCheck %s -check-prefix=PROMOTE
 
diff --git a/iree/compiler/Codegen/Utils/Utils.cpp b/iree/compiler/Codegen/Utils/Utils.cpp
index 7fec29c..28a28b2 100644
--- a/iree/compiler/Codegen/Utils/Utils.cpp
+++ b/iree/compiler/Codegen/Utils/Utils.cpp
@@ -82,11 +82,11 @@
 
 LogicalResult setOpConfigAndEntryPointFnTranslation(
     FuncOp entryPointFn, Operation *op, TileSizesListTypeRef tileSizes,
-    ArrayRef<int64_t> nativeVectorSize,
+    ArrayRef<int64_t> nativeVectorSizes,
     IREE::HAL::DispatchLoweringPassPipeline passPipeline,
     ArrayRef<int64_t> workgroupSize) {
   IREE::HAL::LoweringConfig config =
-      buildConfigAttr(tileSizes, nativeVectorSize, op->getContext());
+      buildConfigAttr(tileSizes, nativeVectorSizes, op->getContext());
   setLoweringConfig(op, config);
   auto partitionedLoops = getPartitionedLoops(op);
   SmallVector<int64_t, 3> workloadPerWorkgroup;
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/constant_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/constant_ops.mlir
index dbd1893..4cda45e 100644
--- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/constant_ops.mlir
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/constant_ops.mlir
@@ -11,8 +11,13 @@
   hal.constant_storage @_storage1 = dense<[6, 7, 8, 0]> : vector<4xi8>
 }
 
-// CHECK: vm.global.ref private @pool_storage0_buffer initializer(@pool_storage0_buffer_initializer) : !vm.ref<!hal.buffer>
+// CHECK: vm.global.ref private @pool_storage0_buffer : !vm.ref<!hal.buffer>
 util.global private @pool_storage0_buffer initializer(@pool_storage0_buffer_initializer) : !hal.buffer
+// CHECK-NEXT: vm.initializer {
+// CHECK-NEXT:   %[[REF:.+]] = vm.call @pool_storage0_buffer_initializer() : () -> !vm.ref<!hal.buffer>
+// CHECK-NEXT:   vm.global.store.ref %[[REF]], @pool_storage0_buffer : !vm.ref<!hal.buffer>
+// CHECK-NEXT:   vm.return
+// CHECK-NEXT: }
 // CHECK: vm.func private @pool_storage0_buffer_initializer() -> !vm.ref<!hal.buffer>
 func private @pool_storage0_buffer_initializer() -> !hal.buffer {
   %c0 = constant 0 : index
@@ -29,11 +34,11 @@
   return %mapped : !hal.buffer
 }
 
-// CHECK: vm.global.ref private @pool_storage1_buffer initializer(@pool_storage1_buffer_initializer) : !vm.ref<!hal.buffer>
+// CHECK: vm.global.ref private @pool_storage1_buffer : !vm.ref<!hal.buffer>
 util.global private @pool_storage1_buffer initializer(@pool_storage1_buffer_initializer) : !hal.buffer
 func private @pool_storage1_buffer_initializer() -> !hal.buffer
 
-// CHECK: vm.global.ref private @pool_splats initializer(@pool_splats_initializer) : !vm.ref<!hal.buffer>
+// CHECK: vm.global.ref private @pool_splats : !vm.ref<!hal.buffer>
 util.global private @pool_splats initializer(@pool_splats_initializer) : !hal.buffer
 // CHECK: vm.func private @pool_splats_initializer() -> !vm.ref<!hal.buffer>
 func private @pool_splats_initializer() -> !hal.buffer {
diff --git a/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/iree/compiler/Dialect/Util/IR/UtilOps.cpp
index 11a6c11..f847694 100644
--- a/iree/compiler/Dialect/Util/IR/UtilOps.cpp
+++ b/iree/compiler/Dialect/Util/IR/UtilOps.cpp
@@ -90,12 +90,15 @@
 
 void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
                      Attribute attr) {
+  bool needsSpace = false;
   if (!attr || attr.getType() != type.getValue()) {
-    p << " : ";
+    p << ": ";
     p.printAttribute(type);
+    needsSpace = true;  // subsequent attr value needs a space separator
   }
   if (attr) {
-    p << " = ";
+    if (needsSpace) p << ' ';
+    p << "= ";
     p.printAttribute(attr);
   }
 }
diff --git a/iree/compiler/Dialect/Util/IR/UtilOps.td b/iree/compiler/Dialect/Util/IR/UtilOps.td
index 1b8eb82..5a733fe 100644
--- a/iree/compiler/Dialect/Util/IR/UtilOps.td
+++ b/iree/compiler/Dialect/Util/IR/UtilOps.td
@@ -178,7 +178,7 @@
     (`mutable` $is_mutable^)?
     $sym_name
     attr-dict
-    (`initializer` `(` $initializer^ `)`):(``)?
+    (`initializer` `(` $initializer^ `)`)?
     custom<TypeOrAttr>($type, $initial_value)
   }];
 
diff --git a/iree/compiler/Dialect/VM/Analysis/ValueLiveness.cpp b/iree/compiler/Dialect/VM/Analysis/ValueLiveness.cpp
index c716d57..9ffe386 100644
--- a/iree/compiler/Dialect/VM/Analysis/ValueLiveness.cpp
+++ b/iree/compiler/Dialect/VM/Analysis/ValueLiveness.cpp
@@ -47,7 +47,8 @@
   // Block names are their order in the function.
   DenseMap<Block *, int> blockOrdinals;
   for (auto &block : funcOp.getBlocks()) {
-    blockOrdinals[&block] = blockOrdinals.size();
+    int ordinal = blockOrdinals.size();
+    blockOrdinals[std::addressof(block)] = ordinal;
   }
 
   // Keep asm state to make getting the SSA value names fast.
diff --git a/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertGlobalOps.cpp b/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertGlobalOps.cpp
index d967703..f1411db 100644
--- a/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertGlobalOps.cpp
+++ b/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertGlobalOps.cpp
@@ -22,67 +22,103 @@
   LogicalResult matchAndRewrite(
       IREE::Util::GlobalOp op, llvm::ArrayRef<Value> operands,
       ConversionPatternRewriter &rewriter) const override {
+    Operation *newOp = nullptr;
     auto convertedType = typeConverter.convertType(op.type());
     if (convertedType.isa<IREE::VM::RefType>() ||
         IREE::VM::RefType::isCompatible(convertedType)) {
-      auto newOp = rewriter.replaceOpWithNewOp<IREE::VM::GlobalRefOp>(
-          op, op.sym_name(), op.is_mutable(), convertedType, op.initializer(),
-          op.initial_value(), llvm::to_vector<4>(op->getDialectAttrs()));
-      newOp.setVisibility(op.getVisibility());
-      return success();
+      newOp = rewriter.replaceOpWithNewOp<IREE::VM::GlobalRefOp>(
+          op, op.sym_name(), op.is_mutable(), convertedType, op.initial_value(),
+          llvm::to_vector<4>(op->getDialectAttrs()));
     } else if (convertedType.isInteger(32)) {
-      auto convertedValue =
-          op.initial_value().hasValue()
-              ? rewriter.getI32IntegerAttr(static_cast<int32_t>(
-                    op.initial_value().getValue().cast<IntegerAttr>().getInt()))
-              : Attribute{};
-      auto newOp = rewriter.replaceOpWithNewOp<IREE::VM::GlobalI32Op>(
-          op, op.sym_name(), op.is_mutable(), convertedType, op.initializer(),
-          convertedValue, llvm::to_vector<4>(op->getDialectAttrs()));
-      newOp.setVisibility(op.getVisibility());
-      return success();
+      llvm::Optional<Attribute> convertedValue = llvm::None;
+      if (op.initial_value().hasValue()) {
+        convertedValue = rewriter.getI32IntegerAttr(static_cast<int32_t>(
+            op.initial_value().getValue().cast<IntegerAttr>().getInt()));
+      }
+      newOp = rewriter.replaceOpWithNewOp<IREE::VM::GlobalI32Op>(
+          op, op.sym_name(), op.is_mutable(), convertedType, convertedValue,
+          llvm::to_vector<4>(op->getDialectAttrs()));
     } else if (convertedType.isInteger(64)) {
-      auto convertedValue =
-          op.initial_value().hasValue()
-              ? rewriter.getI64IntegerAttr(
-                    op.initial_value().getValue().cast<IntegerAttr>().getInt())
-              : Attribute{};
-      auto newOp = rewriter.replaceOpWithNewOp<IREE::VM::GlobalI64Op>(
-          op, op.sym_name(), op.is_mutable(), convertedType, op.initializer(),
-          convertedValue, llvm::to_vector<4>(op->getDialectAttrs()));
-      newOp.setVisibility(op.getVisibility());
-      return success();
+      llvm::Optional<Attribute> convertedValue = llvm::None;
+      if (op.initial_value().hasValue()) {
+        convertedValue = rewriter.getI64IntegerAttr(
+            op.initial_value().getValue().cast<IntegerAttr>().getInt());
+      }
+      newOp = rewriter.replaceOpWithNewOp<IREE::VM::GlobalI64Op>(
+          op, op.sym_name(), op.is_mutable(), convertedType, convertedValue,
+          llvm::to_vector<4>(op->getDialectAttrs()));
     } else if (convertedType.isF32()) {
-      auto convertedValue = op.initial_value().hasValue()
-                                ? rewriter.getF32FloatAttr(static_cast<float>(
-                                      op.initial_value()
-                                          .getValue()
-                                          .cast<FloatAttr>()
-                                          .getValueAsDouble()))
-                                : Attribute{};
-      auto newOp = rewriter.replaceOpWithNewOp<IREE::VM::GlobalF32Op>(
-          op, op.sym_name(), op.is_mutable(), convertedType, op.initializer(),
-          convertedValue, llvm::to_vector<4>(op->getDialectAttrs()));
-      newOp.setVisibility(op.getVisibility());
-      return success();
+      llvm::Optional<Attribute> convertedValue = llvm::None;
+      if (op.initial_value().hasValue()) {
+        convertedValue = rewriter.getF32FloatAttr(
+            static_cast<float>(op.initial_value()
+                                   .getValue()
+                                   .cast<FloatAttr>()
+                                   .getValueAsDouble()));
+      }
+      newOp = rewriter.replaceOpWithNewOp<IREE::VM::GlobalF32Op>(
+          op, op.sym_name(), op.is_mutable(), convertedType, convertedValue,
+          llvm::to_vector<4>(op->getDialectAttrs()));
     } else if (convertedType.isF64()) {
-      auto convertedValue =
-          op.initial_value().hasValue()
-              ? rewriter.getF64FloatAttr(op.initial_value()
-                                             .getValue()
-                                             .cast<FloatAttr>()
-                                             .getValueAsDouble())
-              : Attribute{};
-      auto newOp = rewriter.replaceOpWithNewOp<IREE::VM::GlobalF64Op>(
-          op, op.sym_name(), op.is_mutable(), convertedType, op.initializer(),
-          convertedValue, llvm::to_vector<4>(op->getDialectAttrs()));
-      newOp.setVisibility(op.getVisibility());
-      return success();
+      llvm::Optional<Attribute> convertedValue = llvm::None;
+      if (op.initial_value().hasValue()) {
+        convertedValue = rewriter.getF64FloatAttr(
+            op.initial_value().getValue().cast<FloatAttr>().getValueAsDouble());
+      }
+      newOp = rewriter.replaceOpWithNewOp<IREE::VM::GlobalF64Op>(
+          op, op.sym_name(), op.is_mutable(), convertedType, convertedValue,
+          llvm::to_vector<4>(op->getDialectAttrs()));
+    } else {
+      return op.emitOpError("unsupported global type");
     }
-    return op.emitOpError("unsupported global type");
+
+    // New global carries the same visibility as the original.
+    cast<SymbolOpInterface>(newOp).setVisibility(op.getVisibility());
+
+    // If there was an initializer function specified we turn that into a
+    // vm.initializer now.
+    if (op.initializer()) {
+      auto initializerOp =
+          rewriter.create<IREE::VM::InitializerOp>(op.getLoc());
+      auto ip = rewriter.saveInsertionPoint();
+      rewriter.setInsertionPointToStart(initializerOp.addEntryBlock());
+      SmallVector<Type> resultTypes;
+      resultTypes.push_back(convertedType);
+      auto callOp = rewriter.create<IREE::VM::CallOp>(
+          op.getLoc(), op.initializer().getValue(), resultTypes,
+          /*operands=*/ValueRange{});
+      storeToGlobal(callOp.getResult(0), newOp, rewriter);
+      rewriter.create<IREE::VM::ReturnOp>(op.getLoc());
+      rewriter.restoreInsertionPoint(ip);
+    }
+
+    return success();
   }
 
  private:
+  void storeToGlobal(Value value, Operation *globalOp,
+                     ConversionPatternRewriter &rewriter) const {
+    auto globalName = cast<SymbolOpInterface>(globalOp).getName();
+    if (value.getType().isa<IREE::VM::RefType>()) {
+      rewriter.create<IREE::VM::GlobalStoreRefOp>(globalOp->getLoc(), value,
+                                                  globalName);
+    } else if (value.getType().isInteger(32)) {
+      rewriter.create<IREE::VM::GlobalStoreI32Op>(globalOp->getLoc(), value,
+                                                  globalName);
+    } else if (value.getType().isInteger(64)) {
+      rewriter.create<IREE::VM::GlobalStoreI64Op>(globalOp->getLoc(), value,
+                                                  globalName);
+    } else if (value.getType().isF32()) {
+      rewriter.create<IREE::VM::GlobalStoreF32Op>(globalOp->getLoc(), value,
+                                                  globalName);
+    } else if (value.getType().isF64()) {
+      rewriter.create<IREE::VM::GlobalStoreF64Op>(globalOp->getLoc(), value,
+                                                  globalName);
+    } else {
+      llvm_unreachable("unhandled vm type");
+    }
+  }
+
   TypeConverter &typeConverter;
 };
 
diff --git a/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/global_ops.mlir b/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/global_ops.mlir
index fdc41b6..57fcd33 100644
--- a/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/global_ops.mlir
+++ b/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/global_ops.mlir
@@ -8,8 +8,14 @@
 
 // -----
 
-// CHECK: vm.global.ref public @v_initialized initializer(@initializer) : !vm.ref<!hal.buffer>
+// CHECK: vm.global.ref public @v_initialized : !vm.ref<!hal.buffer>
 util.global public @v_initialized initializer(@initializer) : !hal.buffer
+// CHECK-NEXT: vm.initializer {
+// CHECK-NEXT:   %[[REF:.+]] = vm.call @initializer() : () -> !vm.ref<!hal.buffer>
+// CHECK-NEXT:   vm.global.store.ref %[[REF]], @v_initialized : !vm.ref<!hal.buffer>
+// CHECK-NEXT:   vm.return
+// CHECK-NEXT: }
+// CHECK-NEXT: vm.func private @initializer() -> !vm.ref<!hal.buffer>
 func private @initializer() -> !hal.buffer
 
 // -----
diff --git a/iree/compiler/Dialect/VM/IR/VMBase.td b/iree/compiler/Dialect/VM/IR/VMBase.td
index 2aba56d..b58482a 100644
--- a/iree/compiler/Dialect/VM/IR/VMBase.td
+++ b/iree/compiler/Dialect/VM/IR/VMBase.td
@@ -218,14 +218,13 @@
     }]>,
     InterfaceMethod<[{}], "StringRef", "getSymbolName", (ins)>,
     InterfaceMethod<[{}], "bool", "isMutable", (ins)>,
-    InterfaceMethod<[{}], "Optional<StringRef>", "getInitializerAttr", (ins)>,
     InterfaceMethod<[{}], "Optional<Attribute>", "getInitialValueAttr", (ins)>,
+    InterfaceMethod<[{}], "void", "setInitialValue", (ins "Attribute":$value)>,
     InterfaceMethod<[{}], "Optional<IntegerAttr>", "getOrdinalAttr", (ins)>,
     InterfaceMethod<[{}], "int", "getOrdinal", (ins), [{
       return $_self.getOrdinalAttr().getValue().template cast<IntegerAttr>().getInt();
     }]>,
     InterfaceMethod<[{}], "void", "makeMutable", (ins)>,
-    InterfaceMethod<[{}], "void", "clearInitializer", (ins)>,
     InterfaceMethod<[{}], "void", "clearInitialValue", (ins)>,
   ];
 }
diff --git a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
index df29a52..7cd8a6f 100644
--- a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
+++ b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
@@ -58,39 +58,78 @@
 // Structural ops
 //===----------------------------------------------------------------------===//
 
+namespace {
+
+// Deletes empty vm.initializer ops.
+struct DropEmptyInitializerOp : public OpRewritePattern<InitializerOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(InitializerOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.body().getBlocks().size() != 1) return failure();
+    auto &block = op.body().front();
+    if (block.empty() || isa<ReturnOp>(block.front())) {
+      rewriter.eraseOp(op);
+      return success();
+    }
+    return failure();
+  }
+};
+
+// Inlines constant stores from initializers into the global initializer.
+// This is not strictly required but can help our initialization code perform
+// more efficient initialization of large numbers of primitive values.
+struct InlineConstGlobalInitializer : public OpRewritePattern<InitializerOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(InitializerOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<Operation *> deadOps;
+    op.walk([&](Operation *op) {
+      if (!isGlobalStoreOp(op)) return;
+      auto value = op->getOperand(0);
+      Attribute valueAttr;
+      if (!matchPattern(value, m_Constant(&valueAttr))) return;
+      auto globalRefAttr = op->getAttrOfType<SymbolRefAttr>("global");
+      assert(globalRefAttr);
+      auto globalOp =
+          SymbolTable::lookupNearestSymbolFrom<IREE::VM::VMGlobalOp>(
+              op, globalRefAttr);
+      if (valueAttr && !valueAttr.isa<UnitAttr>()) {
+        globalOp.setInitialValue(valueAttr);
+      } else {
+        globalOp.clearInitialValue();
+      }
+      deadOps.push_back(op);
+    });
+    if (deadOps.empty()) return failure();
+    for (auto deadOp : deadOps) rewriter.eraseOp(deadOp);
+    return success();
+  }
+
+  bool isGlobalStoreOp(Operation *op) const {
+    // TODO(benvanik): trait/interface to make this more generic?
+    return isa<IREE::VM::GlobalStoreI32Op>(op) ||
+           isa<IREE::VM::GlobalStoreI64Op>(op) ||
+           isa<IREE::VM::GlobalStoreF32Op>(op) ||
+           isa<IREE::VM::GlobalStoreF64Op>(op) ||
+           isa<IREE::VM::GlobalStoreRefOp>(op);
+  }
+};
+
+}  // namespace
+
+void InitializerOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<DropEmptyInitializerOp, InlineConstGlobalInitializer>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // Globals
 //===----------------------------------------------------------------------===//
 
 namespace {
 
-/// Converts global initializer functions that evaluate to a constant to a
-/// specified initial value.
-template <typename T>
-struct InlineConstGlobalOpInitializer : public OpRewritePattern<T> {
-  using OpRewritePattern<T>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(T op,
-                                PatternRewriter &rewriter) const override {
-    if (!op.initializer()) return failure();
-    auto initializer = dyn_cast_or_null<FuncOp>(
-        SymbolTable::lookupNearestSymbolFrom(op, op.initializer().getValue()));
-    if (!initializer) return failure();
-    if (initializer.getBlocks().size() == 1 &&
-        initializer.getBlocks().front().getOperations().size() == 2 &&
-        isa<ReturnOp>(initializer.getBlocks().front().getOperations().back())) {
-      auto &primaryOp = initializer.getBlocks().front().getOperations().front();
-      Attribute constResult;
-      if (matchPattern(primaryOp.getResult(0), m_Constant(&constResult))) {
-        rewriter.replaceOpWithNewOp<T>(op, op.sym_name(), op.is_mutable(),
-                                       op.type(), constResult);
-        return success();
-      }
-    }
-    return failure();
-  }
-};
-
 /// Drops initial_values from globals where the value is 0, as by default all
 /// globals are zero-initialized upon module load.
 template <typename T>
@@ -116,32 +155,26 @@
 
 void GlobalI32Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                               MLIRContext *context) {
-  results.insert<InlineConstGlobalOpInitializer<GlobalI32Op>,
-                 DropDefaultConstGlobalOpInitializer<GlobalI32Op>>(context);
+  results.insert<DropDefaultConstGlobalOpInitializer<GlobalI32Op>>(context);
 }
 
 void GlobalI64Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                               MLIRContext *context) {
-  results.insert<InlineConstGlobalOpInitializer<GlobalI64Op>,
-                 DropDefaultConstGlobalOpInitializer<GlobalI64Op>>(context);
+  results.insert<DropDefaultConstGlobalOpInitializer<GlobalI64Op>>(context);
 }
 
 void GlobalF32Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                               MLIRContext *context) {
-  results.insert<InlineConstGlobalOpInitializer<GlobalF32Op>,
-                 DropDefaultConstGlobalOpInitializer<GlobalF32Op>>(context);
+  results.insert<DropDefaultConstGlobalOpInitializer<GlobalF32Op>>(context);
 }
 
 void GlobalF64Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                               MLIRContext *context) {
-  results.insert<InlineConstGlobalOpInitializer<GlobalF64Op>,
-                 DropDefaultConstGlobalOpInitializer<GlobalF64Op>>(context);
+  results.insert<DropDefaultConstGlobalOpInitializer<GlobalF64Op>>(context);
 }
 
 void GlobalRefOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
-                                              MLIRContext *context) {
-  results.insert<InlineConstGlobalOpInitializer<GlobalRefOp>>(context);
-}
+                                              MLIRContext *context) {}
 
 namespace {
 
diff --git a/iree/compiler/Dialect/VM/IR/VMOps.cpp b/iree/compiler/Dialect/VM/IR/VMOps.cpp
index 92c5ae6..9dd7c86 100644
--- a/iree/compiler/Dialect/VM/IR/VMOps.cpp
+++ b/iree/compiler/Dialect/VM/IR/VMOps.cpp
@@ -281,6 +281,47 @@
   return success();
 }
 
+void InitializerOp::build(OpBuilder &builder, OperationState &result,
+                          ArrayRef<NamedAttribute> attrs) {
+  result.addAttribute(
+      "type", TypeAttr::get(FunctionType::get(builder.getContext(), {}, {})));
+  result.addRegion();
+  result.attributes.append(attrs.begin(), attrs.end());
+}
+
+static ParseResult parseInitializerOp(OpAsmParser &parser,
+                                      OperationState *result) {
+  result->addAttribute(
+      "type", TypeAttr::get(FunctionType::get(result->getContext(), {}, {})));
+  if (parser.parseOptionalAttrDictWithKeyword(result->attributes)) {
+    return failure();
+  }
+  auto &body = *result->addRegion();
+  if (failed(parser.parseRegion(body))) {
+    return failure();
+  }
+  return success();
+}
+
+static void printInitializerOp(OpAsmPrinter &p, InitializerOp &op) {
+  p << "vm.initializer";
+  p.printOptionalAttrDictWithKeyword(op->getAttrs(), /*elidedAttrs=*/{"type"});
+  p.printRegion(op.body());
+}
+
+Block *InitializerOp::addEntryBlock() {
+  assert(empty() && "function already has an entry block");
+  auto *entry = new Block();
+  push_back(entry);
+  return entry;
+}
+
+Block *InitializerOp::addBlock() {
+  assert(!empty() && "function should at least have an entry block");
+  push_back(new Block());
+  return &back();
+}
+
 //===----------------------------------------------------------------------===//
 // Globals
 //===----------------------------------------------------------------------===//
@@ -375,13 +416,13 @@
   auto *globalOp =
       op->getParentOfType<VM::ModuleOp>().lookupSymbol(globalAttr.getValue());
   if (!globalOp) {
-    return op->emitOpError() << "Undefined global: " << globalAttr;
+    return op->emitOpError() << "undefined global: " << globalAttr;
   }
   auto globalType = globalOp->getAttrOfType<TypeAttr>("type");
   auto loadType = op->getResult(0).getType();
   if (globalType.getValue() != loadType) {
     return op->emitOpError()
-           << "Global type mismatch; global " << globalAttr << " is "
+           << "global type mismatch; global " << globalAttr << " is "
            << globalType << " but load is " << loadType;
   }
   return success();
@@ -392,18 +433,21 @@
   auto *globalOp =
       op->getParentOfType<VM::ModuleOp>().lookupSymbol(globalAttr.getValue());
   if (!globalOp) {
-    return op->emitOpError() << "Undefined global: " << globalAttr;
+    return op->emitOpError() << "undefined global: " << globalAttr;
   }
   auto globalType = globalOp->getAttrOfType<TypeAttr>("type");
   auto storeType = op->getOperand(0).getType();
   if (globalType.getValue() != storeType) {
     return op->emitOpError()
-           << "Global type mismatch; global " << globalAttr << " is "
+           << "global type mismatch; global " << globalAttr << " is "
            << globalType << " but store is " << storeType;
   }
   if (!globalOp->getAttrOfType<UnitAttr>("is_mutable")) {
-    return op->emitOpError() << "Global " << globalAttr
-                             << " is not mutable and cannot be stored to";
+    // Allow stores to immutable globals in initializers.
+    if (!op->getParentOfType<IREE::VM::InitializerOp>()) {
+      return op->emitOpError() << "global " << globalAttr
+                               << " is not mutable and cannot be stored to";
+    }
   }
   return success();
 }
diff --git a/iree/compiler/Dialect/VM/IR/VMOps.td b/iree/compiler/Dialect/VM/IR/VMOps.td
index 6abda0c..d445dbb 100644
--- a/iree/compiler/Dialect/VM/IR/VMOps.td
+++ b/iree/compiler/Dialect/VM/IR/VMOps.td
@@ -44,6 +44,7 @@
     custom<SymbolVisibility>($sym_visibility)
     $sym_name
     attr-dict-with-keyword
+    ``
     regions
   }];
 
@@ -241,6 +242,50 @@
   }];
 }
 
+def VM_InitializerOp : VM_Op<"initializer", [
+    IsolatedFromAbove,
+    HasParent<"IREE::VM::ModuleOp">,
+    NativeOpTrait<"FunctionLike">,
+    CallableOpInterface,
+  ]> {
+  let summary = [{global initialization function}];
+  let description = [{
+    A function that is called in definition order upon module initialization.
+    Must not load any globals that are defined or initialized after it in the
+    module.
+  }];
+
+  let arguments = (ins
+    TypeAttr:$type
+  );
+
+  let regions = (region AnyRegion:$body);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs
+    )>,
+  ];
+
+  let extraClassDeclaration = [{
+    /// Add an entry block to an empty function and set up the block arguments
+    /// to match the signature of the function.
+    Block *addEntryBlock();
+    Block *addBlock();
+
+    unsigned getNumFuncArguments() { return 0; }
+    unsigned getNumFuncResults() { return 0; }
+
+    LogicalResult verifyType() { return success(); }
+
+    Region *getCallableRegion() { return &body(); }
+    ArrayRef<Type> getCallableResults() { return {}; }
+  }];
+
+  let hasCanonicalizer = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Globals
 //===----------------------------------------------------------------------===//
@@ -259,7 +304,6 @@
     SymbolNameAttr:$sym_name,
     TypeAttr:$type,
     UnitAttr:$is_mutable,
-    OptionalAttr<FlatSymbolRefAttr>:$initializer,
     OptionalAttr<attr_type>:$initial_value,
     OptionalAttr<VM_Ordinal>:$ordinal
   );
@@ -269,25 +313,23 @@
     (`mutable` $is_mutable^)?
     $sym_name
     attr-dict
-    (`initializer` `(` $initializer^ `)`):(``)?
     custom<TypeOrAttr>($type, $initial_value)
   }];
 
   let skipDefaultBuilders = 1;
   let builders = [
-    OpBuilder<(ins "StringRef":$name, "bool":$isMutable, "Type":$type,
-      "Optional<StringRef>":$initializer, "Optional<Attribute>":$initialValue,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+    OpBuilder<(ins
+      "StringRef":$name, "bool":$isMutable, "Type":$type,
+      "Optional<Attribute>":$initialValue,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs
+    ),
     [{
       $_state.addAttribute(SymbolTable::getSymbolAttrName(),
                           $_builder.getStringAttr(name));
       if (isMutable) {
         $_state.addAttribute("is_mutable", $_builder.getUnitAttr());
       }
-      if (initializer.hasValue()) {
-        $_state.addAttribute("initializer",
-                            $_builder.getSymbolRefAttr(initializer.getValue()));
-      } else if (initialValue.hasValue() &&
+      if (initialValue.hasValue() &&
                  (initialValue.getValue().isa<IntegerAttr>() ||
                   initialValue.getValue().isa<FloatAttr>())) {
         $_state.addAttribute("initial_value", initialValue.getValue());
@@ -295,25 +337,12 @@
       $_state.addAttribute("type", TypeAttr::get(type));
       $_state.attributes.append(attrs.begin(), attrs.end());
     }]>,
-    OpBuilder<(ins "StringRef":$name, "bool":$isMutable,
-      "IREE::VM::FuncOp":$initializer,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+    OpBuilder<(ins
+      "StringRef":$name, "bool":$isMutable, "Type":$type,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs
+    ),
     [{
-      build($_builder, $_state, name, isMutable,
-            initializer.getType().getResult(0), initializer.getName(),
-            llvm::None, attrs);
-    }]>,
-    OpBuilder<(ins "StringRef":$name, "bool":$isMutable, "Type":$type,
-      "Attribute":$initialValue, CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
-    [{
-      build($_builder, $_state, name, isMutable, type, llvm::None, initialValue,
-            attrs);
-    }]>,
-    OpBuilder<(ins "StringRef":$name, "bool":$isMutable, "Type":$type,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
-    [{
-      build($_builder, $_state, name, isMutable, type, llvm::None, llvm::None,
-            attrs);
+      build($_builder, $_state, name, isMutable, type, llvm::None, attrs);
     }]>,
   ];
 
@@ -322,9 +351,8 @@
     Type getStorageType() { return type(); }
     bool isMutable() { return is_mutable(); }
     void makeMutable() { (*this)->setAttr("is_mutable", UnitAttr::get(getContext())); }
-    Optional<StringRef> getInitializerAttr() { return initializer(); }
-    void clearInitializer() { (*this)->removeAttr("initializer"); }
     Optional<Attribute> getInitialValueAttr() { return initial_valueAttr(); }
+    void setInitialValue(Attribute value) { (*this)->setAttr("initial_value", (value)); }
     void clearInitialValue() { (*this)->removeAttr("initial_value"); }
     Optional<IntegerAttr> getOrdinalAttr() { return ordinalAttr(); }
   }];
@@ -3485,7 +3513,6 @@
 
 def VM_ReturnOp : VM_Op<"return", [
     DeclareOpInterfaceMethods<VM_SerializableOpInterface>,
-    HasParent<"IREE::VM::FuncOp">,
     Terminator,
   ]> {
   let summary = "return operation";
@@ -3493,7 +3520,7 @@
     Represents a return operation within a function.
 
     ```
-    vm.func @foo(%0, %1) : (i32, f8) {
+    vm.func @foo(%0: i32, %1: f8) -> (i32, f8) {
       vm.return %0, %1 : i32, f8
     }
     ```
diff --git a/iree/compiler/Dialect/VM/IR/test/BUILD b/iree/compiler/Dialect/VM/IR/test/BUILD
index 7890721..53303b2 100644
--- a/iree/compiler/Dialect/VM/IR/test/BUILD
+++ b/iree/compiler/Dialect/VM/IR/test/BUILD
@@ -36,6 +36,7 @@
             "list_op_verification.mlir",
             "list_ops.mlir",
             "shift_ops.mlir",
+            "structural_folding.mlir",
             "structural_ops.mlir",
         ],
         include = ["*.mlir"],
diff --git a/iree/compiler/Dialect/VM/IR/test/CMakeLists.txt b/iree/compiler/Dialect/VM/IR/test/CMakeLists.txt
index 2604211..4a277a3 100644
--- a/iree/compiler/Dialect/VM/IR/test/CMakeLists.txt
+++ b/iree/compiler/Dialect/VM/IR/test/CMakeLists.txt
@@ -33,6 +33,7 @@
     "list_op_verification.mlir"
     "list_ops.mlir"
     "shift_ops.mlir"
+    "structural_folding.mlir"
     "structural_ops.mlir"
   DATA
     iree::tools::IreeFileCheck
diff --git a/iree/compiler/Dialect/VM/IR/test/global_folding.mlir b/iree/compiler/Dialect/VM/IR/test/global_folding.mlir
index 3c2dd2e..8ae7a22 100644
--- a/iree/compiler/Dialect/VM/IR/test/global_folding.mlir
+++ b/iree/compiler/Dialect/VM/IR/test/global_folding.mlir
@@ -5,10 +5,11 @@
 // CHECK-LABEL: @global_i32_folds
 vm.module @global_i32_folds {
   // CHECK: vm.global.i32 public mutable @g0 = 123 : i32
-  vm.global.i32 mutable @g0 initializer(@g0init) : i32
-  vm.func @g0init() -> i32 {
+  vm.global.i32 mutable @g0 : i32
+  vm.initializer {
     %c123 = vm.const.i32 123 : i32
-    vm.return %c123 : i32
+    vm.global.store.i32 %c123, @g0 : i32
+    vm.return
   }
 
   // CHECK: vm.global.i32 public mutable @g1 : i32
@@ -17,10 +18,11 @@
   vm.global.i32 @g2 = 0 : i32
 
   // CHECK: vm.global.i32 public mutable @g3 : i32
-  vm.global.i32 mutable @g3 initializer(@g3init) : i32
-  vm.func @g3init() -> i32 {
+  vm.global.i32 mutable @g3 : i32
+  vm.initializer {
     %c0 = vm.const.i32 0 : i32
-    vm.return %c0 : i32
+    vm.global.store.i32 %c0, @g3 : i32
+    vm.return
   }
 }
 
@@ -29,10 +31,11 @@
 // CHECK-LABEL: @global_ref_folds_null
 vm.module @global_ref_folds_null {
   // CHECK: vm.global.ref public mutable @g0 : !vm.ref<?>
-  vm.global.ref mutable @g0 initializer(@g0init) : !vm.ref<?>
-  vm.func @g0init() -> !vm.ref<?> {
+  vm.global.ref mutable @g0 : !vm.ref<?>
+  vm.initializer {
     %null = vm.const.ref.zero : !vm.ref<?>
-    vm.return %null : !vm.ref<?>
+    vm.global.store.ref %null, @g0 : !vm.ref<?>
+    vm.return
   }
 }
 
diff --git a/iree/compiler/Dialect/VM/IR/test/structural_folding.mlir b/iree/compiler/Dialect/VM/IR/test/structural_folding.mlir
new file mode 100644
index 0000000..1ca2d68
--- /dev/null
+++ b/iree/compiler/Dialect/VM/IR/test/structural_folding.mlir
@@ -0,0 +1,9 @@
+// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(canonicalize)' %s | IreeFileCheck %s
+
+// CHECK-LABEL: @empty_initializer
+vm.module @empty_initializer {
+  // CHECK-NOT: vm.initializer
+  vm.initializer {
+    vm.return
+  }
+}
diff --git a/iree/compiler/Dialect/VM/IR/test/structural_ops.mlir b/iree/compiler/Dialect/VM/IR/test/structural_ops.mlir
index 72242f5..234fb1e 100644
--- a/iree/compiler/Dialect/VM/IR/test/structural_ops.mlir
+++ b/iree/compiler/Dialect/VM/IR/test/structural_ops.mlir
@@ -67,3 +67,34 @@
   // CHECK-NEXT: vm.import @my.fn_varargs(%foo : vector<3xi32> ..., %bar : tuple<i32, i32> ...) -> i32
   vm.import @my.fn_varargs(%foo : vector<3xi32> ..., %bar : tuple<i32, i32> ...) -> i32
 }
+
+// -----
+
+// CHECK-LABEL: @initializers
+vm.module @initializers {
+  // CHECK-NEXT: vm.initializer {
+  // CHECK-NEXT:   vm.return
+  // CHECK-NEXT: }
+  vm.initializer {
+    vm.return
+  }
+
+  // CHECK-NEXT: vm.initializer attributes {foo} {
+  // CHECK-NEXT:   vm.return
+  // CHECK-NEXT: }
+  vm.initializer attributes {foo} {
+    vm.return
+  }
+
+  // CHECK-NEXT: vm.initializer {
+  vm.initializer {
+    // CHECK-NEXT: %zero = vm.const.i32 0 : i32
+    %zero = vm.const.i32 0 : i32
+    // CHECK-NEXT:   vm.br ^bb1(%zero : i32)
+    vm.br ^bb1(%zero: i32)
+    // CHECK-NEXT: ^bb1(%0: i32):
+  ^bb1(%0: i32):
+    // CHECK-NEXT:   vm.return
+    vm.return
+  }
+}
diff --git a/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp b/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp
index 0757d9b..c513c8a 100644
--- a/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp
+++ b/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Pass/PassRegistry.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/InliningUtils.h"
 #include "mlir/Transforms/Utils.h"
 
 namespace mlir {
@@ -31,7 +32,6 @@
 // point in the lowering though we cannot know that so we rely on dialects
 // providing their own initialization functions for those cases.
 //
-// TODO(benvanik): add initializer functions to make dialect init possible.
 // TODO(benvanik): combine i32 initializers to store more efficiently.
 class GlobalInitializationPass
     : public PassWrapper<GlobalInitializationPass, OperationPass<ModuleOp>> {
@@ -53,6 +53,7 @@
         moduleBuilder.create<FuncOp>(moduleBuilder.getUnknownLoc(), "__init",
                                      moduleBuilder.getFunctionType({}, {}));
     OpBuilder initBuilder = OpBuilder::atBlockEnd(initFuncOp.addEntryBlock());
+
     auto deinitFuncOp =
         moduleBuilder.create<FuncOp>(moduleBuilder.getUnknownLoc(), "__deinit",
                                      moduleBuilder.getFunctionType({}, {}));
@@ -64,6 +65,8 @@
     // module op order). If we ever want to make this more deterministic we
     // could gather the ops, sort them (by some rule), and then build the
     // initialization function.
+    InlinerInterface inlinerInterface(&getContext());
+    SmallVector<Operation *> deadOps;
     for (auto &op : getOperation().getBlock().getOperations()) {
       if (auto globalOp = dyn_cast<GlobalRefOp>(op)) {
         if (failed(appendRefInitialization(globalOp, initBuilder))) {
@@ -75,8 +78,21 @@
           globalOp.emitOpError() << "unable to be initialized";
           return signalPassFailure();
         }
+      } else if (auto initializerOp = dyn_cast<InitializerOp>(op)) {
+        if (failed(appendInitializer(initializerOp, inlinerInterface,
+                                     initBuilder))) {
+          initializerOp.emitOpError() << "unable to be initialized";
+          return signalPassFailure();
+        }
+        deadOps.push_back(initializerOp);
       }
     }
+    for (auto deadOp : deadOps) {
+      deadOp->erase();
+    }
+
+    // Correct mutability of all globals.
+    fixupGlobalMutability(getOperation());
 
     initBuilder.create<ReturnOp>(initBuilder.getUnknownLoc());
     deinitBuilder.create<ReturnOp>(deinitBuilder.getUnknownLoc());
@@ -112,12 +128,6 @@
                << "unable to create initializer constant for global";
       }
       globalOp.clearInitialValue();
-    } else if (globalOp.getInitializerAttr().hasValue()) {
-      auto callOp = builder.create<CallOp>(
-          globalOp.getLoc(), globalOp.getInitializerAttr().getValue(),
-          ArrayRef<Type>{globalOp.getStorageType()}, ArrayRef<Value>{});
-      value = callOp.getResult(0);
-      globalOp.clearInitializer();
     }
     if (!value) {
       // Globals are zero-initialized by default so we can just strip the
@@ -195,17 +205,90 @@
 
   LogicalResult appendRefInitialization(GlobalRefOp globalOp,
                                         OpBuilder &builder) {
-    if (globalOp.initializer().hasValue()) {
-      auto callOp = builder.create<CallOp>(
-          globalOp.getLoc(), globalOp.initializerAttr(),
-          ArrayRef<Type>{globalOp.type()}, ArrayRef<Value>{});
-      builder.create<GlobalStoreRefOp>(globalOp.getLoc(), callOp.getResult(0),
-                                       globalOp.sym_name());
-      globalOp.clearInitializer();
-      globalOp.makeMutable();
-    }
+    // NOTE: nothing yet, though if we had attribute initialization we'd do it
+    // here (for example, #vm.magic.initial.ref<foo>).
     return success();
   }
+
+  LogicalResult appendInitializer(InitializerOp initializerOp,
+                                  InlinerInterface &inlinerInterface,
+                                  OpBuilder &builder) {
+    // mlir::inlineRegion takes the op to inline _after_, which as we are
+    // building things doesn't exist yet. To work around this we create a dummy
+    // op, inline after it, and then delete it.
+    auto dummyOp =
+        builder.create<IREE::VM::ConstI32ZeroOp>(builder.getUnknownLoc());
+    auto result = mlir::inlineRegion(
+        inlinerInterface, &initializerOp.body(), dummyOp,
+        /*inlinedOperands=*/ValueRange{},
+        /*resultsToReplace=*/ValueRange{}, /*inlineLoc=*/llvm::None,
+        /*shouldCloneInlinedRegion=*/false);
+    builder.setInsertionPointToEnd(dummyOp->getBlock());
+    dummyOp.erase();
+    return result;
+  }
+
+  void fixupGlobalMutability(Operation *moduleOp) {
+    SymbolTable symbolTable(moduleOp);
+    SmallVector<Operation *> deadOps;
+    for (auto &op : moduleOp->getRegion(0).front()) {
+      auto globalOp = dyn_cast<IREE::VM::VMGlobalOp>(op);
+      if (!globalOp) continue;
+      if (!cast<SymbolOpInterface>(op).isPrivate()) {
+        // May be used outside the module; treat as used and mutable.
+        globalOp.makeMutable();
+        continue;
+      }
+      auto uses = symbolTable.getSymbolUses(globalOp, moduleOp);
+      if (!uses.hasValue()) {
+        // No uses - erase the global entirely.
+        deadOps.push_back(globalOp);
+        continue;
+      }
+      bool isIndirect = false;
+      bool isLoaded = false;
+      bool isStored = false;
+      for (auto use : uses.getValue()) {
+        if (isa<IREE::VM::GlobalAddressOp>(use.getUser())) {
+          // Can't analyze indirect variables; assume mutated.
+          isLoaded = true;
+          isStored = true;
+          isIndirect = true;
+          break;
+        } else if (isGlobalLoadOp(use.getUser())) {
+          isLoaded = true;
+        } else if (isGlobalStoreOp(use.getUser())) {
+          isStored = true;
+        }
+      }
+      // NOTE: we could erase globals never loaded if we know that computing
+      // their value has no side effects.
+      if (isStored) {
+        globalOp.makeMutable();
+      }
+    }
+    for (auto *deadOp : deadOps) {
+      deadOp->erase();
+    }
+  }
+
+  bool isGlobalLoadOp(Operation *op) const {
+    // TODO(benvanik): trait/interface to make this more generic?
+    return isa<IREE::VM::GlobalLoadI32Op>(op) ||
+           isa<IREE::VM::GlobalLoadI64Op>(op) ||
+           isa<IREE::VM::GlobalLoadF32Op>(op) ||
+           isa<IREE::VM::GlobalLoadF64Op>(op) ||
+           isa<IREE::VM::GlobalLoadRefOp>(op);
+  }
+
+  bool isGlobalStoreOp(Operation *op) const {
+    // TODO(benvanik): trait/interface to make this more generic?
+    return isa<IREE::VM::GlobalStoreI32Op>(op) ||
+           isa<IREE::VM::GlobalStoreI64Op>(op) ||
+           isa<IREE::VM::GlobalStoreF32Op>(op) ||
+           isa<IREE::VM::GlobalStoreF64Op>(op) ||
+           isa<IREE::VM::GlobalStoreRefOp>(op);
+  }
 };
 
 std::unique_ptr<OperationPass<IREE::VM::ModuleOp>>
diff --git a/iree/compiler/Dialect/VM/Transforms/test/global_initialization.mlir b/iree/compiler/Dialect/VM/Transforms/test/global_initialization.mlir
index b652d78..f691847 100644
--- a/iree/compiler/Dialect/VM/Transforms/test/global_initialization.mlir
+++ b/iree/compiler/Dialect/VM/Transforms/test/global_initialization.mlir
@@ -9,22 +9,16 @@
 
 // CHECK-LABEL: @initI32
 vm.module @initI32 {
-  // CHECK: vm.global.i32 public mutable @g0 : i32
-  vm.global.i32 mutable @g0 initializer(@g0init) : i32
-  vm.func @g0init() -> i32 {
-    %c123 = vm.const.i32 123 : i32
-    vm.return %c123 : i32
-  }
+  // CHECK: vm.global.i32 private @g0
+  vm.global.i32 private @g0 : i32 = 0 : i32
 
-  // CHECK: vm.global.i32 public mutable @g1 : i32
-  vm.global.i32 mutable @g1 = 123 : i32
+  // CHECK: vm.global.i32 private mutable @g1 : i32
+  vm.global.i32 private mutable @g1 = 123 : i32
 
-  // CHECK: vm.global.i32 public mutable @g2 : i32
-  vm.global.i32 @g2 = 123 : i32
+  // CHECK: vm.global.i32 private mutable @g2 : i32
+  vm.global.i32 private @g2 = 123 : i32
 
   // CHECK: vm.func private @__init() {
-  // CHECK-NEXT:   %0 = vm.call @g0init()
-  // CHECK-NEXT:   vm.global.store.i32 %0, @g0
   // CHECK-NEXT:   %c123 = vm.const.i32 123 : i32
   // CHECK-NEXT:   vm.global.store.i32 %c123, @g1
   // CHECK-NEXT:   %c123_0 = vm.const.i32 123 : i32
@@ -37,22 +31,70 @@
 
 // CHECK-LABEL: @initRef
 vm.module @initRef {
-  // CHECK: vm.global.ref public mutable @g0 : !vm.ref<?>
-  vm.global.ref mutable @g0 initializer(@g0init) : !vm.ref<?>
-  vm.func @g0init() -> !vm.ref<?> {
-    %null = vm.const.ref.zero : !vm.ref<?>
-    vm.return %null : !vm.ref<?>
+  // CHECK: vm.global.ref private mutable @g0 : !vm.ref<?>
+  vm.global.ref private mutable @g0 : !vm.ref<?>
+
+  // CHECK: vm.global.ref private mutable @g1 : !vm.ref<?>
+  vm.global.ref private mutable @g1 : !vm.ref<?>
+
+  // CHECK: vm.global.ref private @g2 : !vm.ref<?>
+  vm.global.ref private @g2 : !vm.ref<?>
+
+  // CHECK-NOT: vm.func private @__init()
+}
+
+// -----
+
+// CHECK-LABEL: @initializers
+vm.module @initializers {
+  // CHECK: vm.global.i32 private mutable @g0 : i32
+  vm.global.i32 private @g0 : i32
+  // CHECK-NOT: vm.initializer
+  vm.initializer {
+    %c123 = vm.const.i32 123 : i32
+    vm.global.store.i32 %c123, @g0 : i32
+    vm.return
   }
 
-  // CHECK: vm.global.ref public mutable @g1 : !vm.ref<?>
-  vm.global.ref mutable @g1 : !vm.ref<?>
+  // CHECK: vm.global.ref private mutable @g1 : !vm.ref<?>
+  vm.global.ref private mutable @g1 : !vm.ref<?>
+  // CHECK-NOT: vm.initializer
+  vm.initializer {
+    %null = vm.const.ref.zero : !vm.ref<?>
+    vm.global.store.ref %null, @g1 : !vm.ref<?>
+    vm.return
+  }
 
-  // CHECK: vm.global.ref public @g2 : !vm.ref<?>
-  vm.global.ref @g2 : !vm.ref<?>
+  // CHECK: vm.global.ref private mutable @g2 : !vm.ref<?>
+  vm.global.ref private mutable @g2 : !vm.ref<?>
+  // CHECK-NOT: vm.initializer
+  vm.initializer {
+    %g1 = vm.global.load.ref @g1 : !vm.ref<?>
+    vm.global.store.ref %g1, @g2 : !vm.ref<?>
+    vm.return
+  }
 
-  // CHECK: vm.func private @__init() {
-  // CHECK-NEXT:   %ref = vm.call @g0init()
-  // CHECK-NEXT:   vm.global.store.ref %ref, @g0
+  //      CHECK: vm.func private @__init() {
+  // CHECK-NEXT:   %c123 = vm.const.i32 123 : i32
+  // CHECK-NEXT:   vm.global.store.i32 %c123, @g0 : i32
+  // CHECK-NEXT:   %null = vm.const.ref.zero : !vm.ref<?>
+  // CHECK-NEXT:   vm.global.store.ref %null, @g1 : !vm.ref<?>
+  // CHECK-NEXT:   %g1 = vm.global.load.ref @g1 : !vm.ref<?>
+  // CHECK-NEXT:   vm.global.store.ref %g1, @g2 : !vm.ref<?>
   // CHECK-NEXT:   vm.return
   // CHECK-NEXT: }
 }
+
+// -----
+
+// CHECK-LABEL: @unused_globals
+vm.module @unused_globals {
+  // CHECK: vm.global.i32 private mutable @used
+  vm.global.i32 private @used : i32 = 1 : i32
+  // CHECK-NOT: vm.global.i32 private @unused
+  vm.global.i32 private @unused : i32 = 2 : i32
+  vm.func @foo() {
+    %0 = vm.global.load.i32 @used : i32
+    vm.return
+  }
+}
diff --git a/iree/compiler/InputConversion/Common/BUILD b/iree/compiler/InputConversion/Common/BUILD
index 7e22c00..efce5e6 100644
--- a/iree/compiler/InputConversion/Common/BUILD
+++ b/iree/compiler/InputConversion/Common/BUILD
@@ -44,6 +44,7 @@
 cc_library(
     name = "Common",
     srcs = [
+        "IREEImportPublic.cpp",
         "Passes.cpp",
         "TopLevelSCFToCFG.cpp",
     ],
@@ -53,6 +54,11 @@
     deps = [
         ":PassHeaders",
         ":PassesIncGen",
+        "//iree/compiler/Dialect/Flow/IR",
+        "//iree/compiler/Dialect/HAL/IR",
+        "//iree/compiler/Dialect/Util/IR",
+        "//llvm-external-projects/iree-dialects:IREEDialect",
+        "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LinalgOps",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:SCFDialect",
diff --git a/iree/compiler/InputConversion/Common/CMakeLists.txt b/iree/compiler/InputConversion/Common/CMakeLists.txt
index 7c2e3a8..c05b2c4 100644
--- a/iree/compiler/InputConversion/Common/CMakeLists.txt
+++ b/iree/compiler/InputConversion/Common/CMakeLists.txt
@@ -39,17 +39,23 @@
   HDRS
     "Passes.h"
   SRCS
+    "IREEImportPublic.cpp"
     "Passes.cpp"
     "TopLevelSCFToCFG.cpp"
   DEPS
     ::PassHeaders
     ::PassesIncGen
+    IREEDialectsIREEDialect
+    MLIRIR
     MLIRLinalg
     MLIRPass
     MLIRSCF
     MLIRSCFToStandard
     MLIRStandard
     MLIRTransforms
+    iree::compiler::Dialect::Flow::IR
+    iree::compiler::Dialect::HAL::IR
+    iree::compiler::Dialect::Util::IR
   PUBLIC
 )
 
diff --git a/iree/compiler/InputConversion/Common/IREEImportPublic.cpp b/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
new file mode 100644
index 0000000..7e1df67
--- /dev/null
+++ b/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
@@ -0,0 +1,308 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/IREE/IREEDialect.h"
+#include "iree-dialects/Dialect/IREE/IREEOps.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
+#include "iree/compiler/InputConversion/Common/PassDetail.h"
+#include "iree/compiler/InputConversion/Common/Passes.h"
+#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace IREEPublic = mlir::iree;
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+// Allowlist of function attributes to retain when importing funcs.
+constexpr const char *kRetainedAttributes[] = {
+    "iree.reflection",
+    "sym_visibility",
+    "noinline",
+};
+
+struct IREEImportPublicPass
+    : public IREEImportPublicBase<IREEImportPublicPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<mlir::iree::IREEDialect, IREE::Flow::FlowDialect,
+                    IREE::HAL::HALDialect, IREE::Util::UtilDialect>();
+  }
+  void runOnOperation() override;
+};
+
+class IREETypeConverter : public TypeConverter {
+ public:
+  IREETypeConverter();
+};
+
+// Generic 1:1 conversion pattern which effectively just renames an op.
+// It does not support regions or ops with successors.
+class OneToOneConverionPattern : public ConversionPattern {
+ public:
+  OneToOneConverionPattern(TypeConverter &converter, StringRef srcName,
+                           StringRef targetName, MLIRContext *context,
+                           PatternBenefit benefit)
+      : ConversionPattern(converter, srcName, benefit, context),
+        targetName(targetName) {}
+  LogicalResult matchAndRewrite(
+      Operation *srcOp, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Type> resultTypes;
+    if (failed(typeConverter->convertTypes(srcOp->getResultTypes(),
+                                           resultTypes))) {
+      return srcOp->emitError()
+             << "could not convert result types to IREE internal types";
+    }
+
+    OperationState state(srcOp->getLoc(), targetName, operands, resultTypes,
+                         srcOp->getAttrs());
+    Operation *targetOp = rewriter.createOperation(state);
+    rewriter.replaceOp(srcOp, targetOp->getResults());
+    return success();
+  }
+
+ private:
+  StringRef targetName;
+};
+
+class BufferViewToTensorPattern
+    : public OpConversionPattern<IREEPublic::BufferViewToTensorOp> {
+  using OpConversionPattern<
+      IREEPublic::BufferViewToTensorOp>::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      IREEPublic::BufferViewToTensorOp srcOp, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    IREEPublic::BufferViewToTensorOpAdaptor adaptor(operands);
+    Type resultType = typeConverter->convertType(srcOp.target().getType());
+    if (!resultType) return failure();
+    rewriter.replaceOpWithNewOp<IREE::HAL::TensorCastOp>(
+        srcOp, resultType, adaptor.source(), adaptor.target_dims());
+    return success();
+  }
+};
+
+class TensorToBufferViewPattern
+    : public OpConversionPattern<IREEPublic::TensorToBufferViewOp> {
+  using OpConversionPattern<
+      IREEPublic::TensorToBufferViewOp>::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      IREEPublic::TensorToBufferViewOp srcOp, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    IREEPublic::TensorToBufferViewOpAdaptor adaptor(operands);
+    Type resultType = typeConverter->convertType(srcOp.target().getType());
+    if (!resultType) return failure();
+    rewriter.replaceOpWithNewOp<IREE::HAL::TensorCastOp>(
+        srcOp, resultType, adaptor.source(), adaptor.source_dims());
+    return success();
+  }
+};
+
+class BuiltinFuncOpPattern : public OpConversionPattern<FuncOp> {
+  using OpConversionPattern<FuncOp>::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      FuncOp srcOp, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    FunctionType srcFuncType = srcOp.getType();
+    TypeConverter::SignatureConversion signatureConversion(
+        srcOp.getNumArguments());
+
+    // Convert function arguments.
+    for (unsigned i = 0, e = srcFuncType.getNumInputs(); i < e; ++i) {
+      if (failed(getTypeConverter()->convertSignatureArg(
+              i, srcFuncType.getInput(i), signatureConversion))) {
+        return rewriter.notifyMatchFailure(srcOp, "argument failed to convert");
+      }
+    }
+
+    // Convert function results.
+    SmallVector<Type, 1> convertedResultTypes;
+    if (failed(getTypeConverter()->convertTypes(srcFuncType.getResults(),
+                                                convertedResultTypes))) {
+      return rewriter.notifyMatchFailure(srcOp, "results failed to convert");
+    }
+
+    // Create new function with converted argument and result types.
+    // Note that attributes are dropped. Consider preserving some if needed.
+    auto newFuncType = mlir::FunctionType::get(
+        srcOp.getContext(), signatureConversion.getConvertedTypes(),
+        convertedResultTypes);
+    auto newFuncOp =
+        rewriter.create<FuncOp>(srcOp.getLoc(), srcOp.getName(), newFuncType);
+    rewriter.inlineRegionBefore(srcOp.getBody(), newFuncOp.getBody(),
+                                newFuncOp.end());
+
+    // Retain function attributes in the allowlist.
+    auto retainedAttributes = ArrayRef<const char *>(
+        kRetainedAttributes,
+        sizeof(kRetainedAttributes) / sizeof(kRetainedAttributes[0]));
+    for (auto retainAttrName : retainedAttributes) {
+      StringRef attrName(retainAttrName);
+      Attribute attr = srcOp->getAttr(attrName);
+      if (attr) {
+        newFuncOp->setAttr(attrName, attr);
+      }
+    }
+
+    // Tell the rewriter to convert the region signature.
+    TypeConverter &typeConverter = *getTypeConverter();
+    if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
+                                           &signatureConversion))) {
+      return failure();
+    }
+
+    rewriter.replaceOp(srcOp, llvm::None);
+    return success();
+  }
+};
+
+// Matches any op and generically converts types. Matches with benefit 0.
+class GenericTypeConvert : public ConversionPattern {
+ public:
+  GenericTypeConvert(TypeConverter &converter, MLIRContext *context,
+                     PatternBenefit benefit)
+      : ConversionPattern(converter, MatchAnyOpTypeTag(), benefit, context) {}
+  LogicalResult matchAndRewrite(
+      Operation *op, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    llvm::SmallVector<NamedAttribute, 4> newAttr;
+    llvm::SmallVector<Type, 4> newResults;
+    (void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
+    OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
+                         newResults, newAttr, op->getSuccessors());
+    for (Region &r : op->getRegions()) {
+      Region *newRegion = state.addRegion();
+      rewriter.inlineRegionBefore(r, *newRegion, newRegion->begin());
+      TypeConverter::SignatureConversion result(newRegion->getNumArguments());
+      (void)getTypeConverter()->convertSignatureArgs(
+          newRegion->getArgumentTypes(), result);
+      rewriter.applySignatureConversion(newRegion, result);
+    }
+    Operation *newOp = rewriter.createOperation(state);
+    rewriter.replaceOp(op, newOp->getResults());
+    return success();
+  }
+};
+
+}  // namespace
+
+IREETypeConverter::IREETypeConverter() {
+  addConversion([](Type t) { return t; });
+  addConversion([=](IREEPublic::BufferViewType t) {
+    return IREE::HAL::BufferViewType::get(t.getContext());
+  });
+  addConversion([=](IREEPublic::ListType t) -> IREE::Util::ListType {
+    auto subType = convertType(t.getElementType());
+    if (!subType) return nullptr;
+    return IREE::Util::ListType::get(subType);
+  });
+  addConversion([=](IREEPublic::PtrType t) -> IREE::Util::PtrType {
+    auto subType = convertType(t.getTargetType());
+    if (!subType) return nullptr;
+    return IREE::Util::PtrType::get(subType);
+  });
+  addConversion([](IREEPublic::VariantType t) {
+    return IREE::Util::VariantType::get(t.getContext());
+  });
+}
+
+void IREEImportPublicPass::runOnOperation() {
+  auto &context = getContext();
+  RewritePatternSet patterns(&getContext());
+  ConversionTarget target(getContext());
+  target.addLegalDialect<IREE::Flow::FlowDialect>();
+  target.addLegalDialect<IREE::HAL::HALDialect>();
+  target.addLegalDialect<IREE::Util::UtilDialect>();
+  target.addIllegalDialect<IREEPublic::IREEDialect>();
+
+  auto ireeDialect = context.getOrLoadDialect<IREEPublic::IREEDialect>();
+  auto isIllegalType = [&](Type t) {
+    return t.getDialect().getTypeID() == ireeDialect->getTypeID();
+  };
+
+  target.addDynamicallyLegalOp<FuncOp>([&](FuncOp funcOp) {
+    for (Type type : funcOp.getType().getInputs()) {
+      if (isIllegalType(type)) return false;
+    }
+    for (Type type : funcOp.getType().getResults()) {
+      if (isIllegalType(type)) return false;
+    }
+    return true;
+  });
+
+  target.markUnknownOpDynamicallyLegal([&](Operation *op) {
+    for (Type type : op->getResultTypes()) {
+      if (isIllegalType(type)) return false;
+    }
+    for (Type type : op->getOperandTypes()) {
+      if (isIllegalType(type)) return false;
+    }
+    return true;
+  });
+
+  IREETypeConverter typeConverter;
+  PatternBenefit specific_benefit = 100;
+  patterns.insert<GenericTypeConvert>(typeConverter, &getContext(), 0);
+  patterns.insert<BuiltinFuncOpPattern>(typeConverter, &getContext(),
+                                        specific_benefit);
+  patterns.insert<BufferViewToTensorPattern>(typeConverter, &getContext(),
+                                             specific_benefit);
+  patterns.insert<TensorToBufferViewPattern>(typeConverter, &getContext(),
+                                             specific_benefit);
+
+#define ONETOONE(SrcOpTy, TargetOpTy)             \
+  patterns.insert<OneToOneConverionPattern>(      \
+      typeConverter, SrcOpTy::getOperationName(), \
+      TargetOpTy::getOperationName(), &getContext(), specific_benefit)
+
+  ONETOONE(IREEPublic::BufferViewRankOp, IREE::HAL::BufferViewRankOp);
+  ONETOONE(IREEPublic::BufferViewDimOp, IREE::HAL::BufferViewDimOp);
+  ONETOONE(IREEPublic::ListCreateOp, IREE::Util::ListCreateOp);
+  ONETOONE(IREEPublic::ListSizeOp, IREE::Util::ListSizeOp);
+  ONETOONE(IREEPublic::ListResizeOp, IREE::Util::ListResizeOp);
+  ONETOONE(IREEPublic::ListGetOp, IREE::Util::ListGetOp);
+  ONETOONE(IREEPublic::ListSetOp, IREE::Util::ListSetOp);
+  ONETOONE(IREEPublic::NullOp, IREE::Util::NullOp);
+  ONETOONE(IREEPublic::TensorCloneOp, IREE::Flow::TensorCloneOp);
+  ONETOONE(IREEPublic::TensorLoadOp, IREE::Flow::TensorLoadOp);
+  ONETOONE(IREEPublic::TensorReshapeOp, IREE::Flow::TensorReshapeOp);
+  ONETOONE(IREEPublic::TensorSliceOp, IREE::Flow::TensorSliceOp);
+  ONETOONE(IREEPublic::TensorSplatOp, IREE::Flow::TensorSplatOp);
+  ONETOONE(IREEPublic::TensorStoreOp, IREE::Flow::TensorStoreOp);
+  ONETOONE(IREEPublic::TensorUpdateOp, IREE::Flow::TensorUpdateOp);
+  ONETOONE(IREEPublic::TensorTraceOp, IREE::Flow::TensorTraceOp);
+  ONETOONE(IREEPublic::GlobalOp, IREE::Util::GlobalOp);
+  ONETOONE(IREEPublic::GlobalAddressOp, IREE::Util::GlobalAddressOp);
+  ONETOONE(IREEPublic::GlobalLoadOp, IREE::Util::GlobalLoadOp);
+  ONETOONE(IREEPublic::GlobalLoadIndirectOp, IREE::Util::GlobalLoadIndirectOp);
+  ONETOONE(IREEPublic::GlobalStoreOp, IREE::Util::GlobalStoreOp);
+  ONETOONE(IREEPublic::GlobalStoreIndirectOp,
+           IREE::Util::GlobalStoreIndirectOp);
+
+  if (failed(applyFullConversion(getOperation(), target, std::move(patterns))))
+    signalPassFailure();
+}
+
+std::unique_ptr<OperationPass<ModuleOp>> createIREEImportPublicPass() {
+  return std::make_unique<IREEImportPublicPass>();
+}
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/InputConversion/Common/Passes.cpp b/iree/compiler/InputConversion/Common/Passes.cpp
index 84332fd..7c21a55 100644
--- a/iree/compiler/InputConversion/Common/Passes.cpp
+++ b/iree/compiler/InputConversion/Common/Passes.cpp
@@ -6,6 +6,11 @@
 
 #include "iree/compiler/InputConversion/Common/Passes.h"
 
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassOptions.h"
+#include "mlir/Pass/PassRegistry.h"
+
 namespace mlir {
 namespace iree_compiler {
 
@@ -14,9 +19,20 @@
 #include "iree/compiler/InputConversion/Common/Passes.h.inc"  // IWYU pragma: export
 }  // namespace
 
+void buildCommonInputConversionPassPipeline(OpPassManager &passManager) {
+  passManager.addPass(createIREEImportPublicPass());
+}
+
 void registerCommonInputConversionPasses() {
   // Generated passes.
   registerPasses();
+
+  PassPipelineRegistration<> mhlo(
+      "iree-common-input-transformation-pipeline",
+      "Runs the common input transformation pipeline",
+      [](OpPassManager &passManager) {
+        buildCommonInputConversionPassPipeline(passManager);
+      });
 }
 
 }  // namespace iree_compiler
diff --git a/iree/compiler/InputConversion/Common/Passes.h b/iree/compiler/InputConversion/Common/Passes.h
index f67587a..d490d34 100644
--- a/iree/compiler/InputConversion/Common/Passes.h
+++ b/iree/compiler/InputConversion/Common/Passes.h
@@ -14,10 +14,19 @@
 namespace iree_compiler {
 
 //===----------------------------------------------------------------------===//
+// Pipelines
+//===----------------------------------------------------------------------===//
+
+// Performs common input legalization after specific input dialect conversions
+// have taken place.
+void buildCommonInputConversionPassPipeline(OpPassManager &passManager);
+
+//===----------------------------------------------------------------------===//
 // Passes
 //===----------------------------------------------------------------------===//
 
 std::unique_ptr<OperationPass<FuncOp>> createTopLevelSCFToCFGPass();
+std::unique_ptr<OperationPass<ModuleOp>> createIREEImportPublicPass();
 
 //===----------------------------------------------------------------------===//
 // Register all Passes
diff --git a/iree/compiler/InputConversion/Common/Passes.td b/iree/compiler/InputConversion/Common/Passes.td
index 7c225c6..be99a3f 100644
--- a/iree/compiler/InputConversion/Common/Passes.td
+++ b/iree/compiler/InputConversion/Common/Passes.td
@@ -15,4 +15,10 @@
   let constructor = "mlir::iree_compiler::createTopLevelSCFToCFGPass()";
 }
 
+def IREEImportPublic :
+    Pass<"iree-import-public", "ModuleOp"> {
+  let summary = "Imports IREE public dialect to internal implementation.";
+  let constructor = "mlir::iree_compiler::createIREEImportPublicPass()";
+}
+
 #endif // IREE_COMPILER_INPUTCONVERSION_COMMON_PASSES
diff --git a/iree/compiler/InputConversion/Common/test/BUILD b/iree/compiler/InputConversion/Common/test/BUILD
index d9e87bd..22eef63 100644
--- a/iree/compiler/InputConversion/Common/test/BUILD
+++ b/iree/compiler/InputConversion/Common/test/BUILD
@@ -19,6 +19,7 @@
     name = "lit",
     srcs = enforce_glob(
         [
+            "iree_import_public.mlir",
             "top_level_scf_to_cfg.mlir",
         ],
         include = ["*.mlir"],
diff --git a/iree/compiler/InputConversion/Common/test/CMakeLists.txt b/iree/compiler/InputConversion/Common/test/CMakeLists.txt
index ab43294..da81176 100644
--- a/iree/compiler/InputConversion/Common/test/CMakeLists.txt
+++ b/iree/compiler/InputConversion/Common/test/CMakeLists.txt
@@ -14,6 +14,7 @@
   NAME
     lit
   SRCS
+    "iree_import_public.mlir"
     "top_level_scf_to_cfg.mlir"
   DATA
     iree::tools::IreeFileCheck
diff --git a/iree/compiler/InputConversion/Common/test/iree_import_public.mlir b/iree/compiler/InputConversion/Common/test/iree_import_public.mlir
new file mode 100644
index 0000000..a69d497
--- /dev/null
+++ b/iree/compiler/InputConversion/Common/test/iree_import_public.mlir
@@ -0,0 +1,233 @@
+// RUN: iree-opt -split-input-file -iree-import-public %s | IreeFileCheck %s
+
+// CHECK-LABEL: func @bv_func
+// CHECK-SAME: (%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> (!hal.buffer_view, !hal.buffer_view)
+// CHECK: return %arg0, %arg1 : !hal.buffer_view, !hal.buffer_view
+builtin.func @bv_func(%arg0 : !iree.buffer_view, %arg1 : !iree.buffer_view) -> (!iree.buffer_view, !iree.buffer_view) {
+  return %arg0, %arg1 : !iree.buffer_view, !iree.buffer_view
+}
+
+// -----
+// CHECK-LABEL: func @list_func
+// CHECK-SAME: (%arg0: !util.list<?>) -> !util.list<?>
+builtin.func @list_func(%arg0 : !iree.list<!iree.variant>) -> !iree.list<!iree.variant> {
+  return %arg0 : !iree.list<!iree.variant>
+}
+
+// -----
+// CHECK-LABEL: func @ptr_func
+// CHECK-SAME: (%arg0: !util.ptr<!hal.buffer_view>) -> !util.ptr<!hal.buffer_view>
+builtin.func @ptr_func(%arg0 : !iree.ptr<!iree.buffer_view>) -> !iree.ptr<!iree.buffer_view> {
+  return %arg0 : !iree.ptr<!iree.buffer_view>
+}
+
+// -----
+// CHECK-LABEL: func @null_op
+// CHECK: util.null : !util.variant
+builtin.func @null_op() -> !iree.variant {
+  %0 = iree.null : !iree.variant
+  return %0 : !iree.variant
+}
+
+// -----
+// CHECK-LABEL: func @tensor_to_buffer_view
+// CHECK: hal.tensor.cast %arg0 : tensor<?x?x3xf32>{%arg1, %arg2} -> !hal.buffer_view
+builtin.func @tensor_to_buffer_view(%arg0 : tensor<?x?x3xf32>, %arg1 : index, %arg2 : index) -> !iree.buffer_view {
+  %0 = iree.cast.tensor_to_buffer_view %arg0 : tensor<?x?x3xf32> {%arg1, %arg2} -> !iree.buffer_view
+  return %0 : !iree.buffer_view
+}
+
+// -----
+// CHECK-LABEL: func @buffer_view_to_tensor
+// CHECK: hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<?x?x3xf32>{%arg1, %arg2}
+builtin.func @buffer_view_to_tensor(%arg0 : !iree.buffer_view, %arg1 : index, %arg2 : index) -> tensor<?x?x3xf32> {
+  %0 = iree.cast.buffer_view_to_tensor %arg0 : !iree.buffer_view -> tensor<?x?x3xf32> {%arg1, %arg2}
+  return %0 : tensor<?x?x3xf32>
+}
+
+// -----
+// CHECK-LABEL: func @buffer_view_rank
+// CHECK: hal.buffer_view.rank<%arg0 : !hal.buffer_view> : index
+builtin.func @buffer_view_rank(%arg0 : !iree.buffer_view) -> index {
+  %0 = iree.buffer_view.rank %arg0 : index
+  return %0 : index
+}
+
+// -----
+// CHECK-LABEL: func @buffer_view_dim
+// CHECK: hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
+builtin.func @buffer_view_dim(%arg0 : !iree.buffer_view) -> index {
+  %0 = iree.buffer_view.dim %arg0, 0 : index
+  return %0: index
+}
+
+// -----
+// CHECK-LABEL: func @list_create
+// CHECK: util.list.create %arg0 : !util.list<?>
+builtin.func @list_create(%arg0 : index) -> !iree.list<!iree.variant> {
+  %0 = iree.list.create %arg0 : !iree.list<!iree.variant>
+  return %0 : !iree.list<!iree.variant>
+}
+
+// -----
+// CHECK-LABEL: func @list_size
+// CHECK: util.list.size %arg0 : !util.list<?>
+builtin.func @list_size(%arg0 : !iree.list<!iree.variant>) -> index {
+  %0 = iree.list.size %arg0 : !iree.list<!iree.variant>
+  return %0 : index
+}
+
+// -----
+// CHECK-LABEL: func @list_resize
+// CHECK: util.list.resize %arg0, %arg1 : !util.list<?>
+builtin.func @list_resize(%arg0 : !iree.list<!iree.variant>, %arg1 : index) {
+  iree.list.resize %arg0, %arg1 : !iree.list<!iree.variant>
+  return
+}
+
+// -----
+// CHECK-LABEL: func @list_get
+// CHECK: util.list.get %arg0[%arg1] : !util.list<?>
+builtin.func @list_get(%arg0 : !iree.list<!iree.variant>, %arg1 : index) -> !iree.variant {
+  %0 = iree.list.get %arg0[%arg1] : !iree.list<!iree.variant> -> !iree.variant
+  return %0 : !iree.variant
+}
+
+// -----
+// CHECK-LABEL: func @list_set
+// CHECK: util.list.set %arg0[%arg1], %arg2 : !util.list<?>
+builtin.func @list_set(%arg0 : !iree.list<!iree.variant>, %arg1 : index, %arg2 : !iree.variant) {
+  iree.list.set %arg0[%arg1], %arg2 : !iree.list<!iree.variant>, !iree.variant
+  return
+}
+
+// -----
+// CHECK-LABEL: func @tensor_reshape
+// CHECK: flow.tensor.reshape %arg0 : tensor<?x?xf32>{%arg1, %arg2} -> tensor<?x?xf32>{%arg2, %arg1}
+builtin.func @tensor_reshape(%arg0 : tensor<?x?xf32>, %arg1 : index, %arg2 : index) -> tensor<?x?xf32> {
+  %0 = iree.tensor.reshape %arg0 : tensor<?x?xf32>{%arg1, %arg2} -> tensor<?x?xf32>{%arg2, %arg1}
+  return %0 : tensor<?x?xf32>
+}
+
+// -----
+// CHECK-LABEL: func @tensor_load
+// CHECK: flow.tensor.load %arg0[%arg2, %arg3] : tensor<?x3xf32>{%arg1}
+builtin.func @tensor_load(%arg0 : tensor<?x3xf32>, %arg1 : index, %arg2 : index, %arg3 : index) -> f32 {
+  %0 = iree.tensor.load %arg0[%arg2, %arg3] : tensor<?x3xf32>{%arg1}
+  return %0 : f32
+}
+
+// -----
+// CHECK-LABEL: func @tensor_store
+// CHECK: flow.tensor.store %arg4, %arg0[%arg2, %arg3] : tensor<?x3xf32>{%arg1}
+builtin.func @tensor_store(%arg0 : tensor<?x3xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : f32) {
+  iree.tensor.store %arg4, %arg0[%arg2, %arg3] : tensor<?x3xf32>{%arg1}
+  return
+}
+
+// -----
+// CHECK-LABEL: func @tensor_splat
+// CHECK: flow.tensor.splat %arg0 : tensor<?x?xf32>{%arg1, %arg2}
+builtin.func @tensor_splat(%arg0 : f32, %arg1 : index, %arg2 : index) -> tensor<?x?xf32> {
+  %0 = iree.tensor.splat %arg0 : tensor<?x?xf32>{%arg1, %arg2}
+  return %0 : tensor<?x?xf32>
+}
+
+// -----
+// CHECK-LABEL: func @tensor_clone
+// CHECK: flow.tensor.clone %arg0 : tensor<?x?xf32>{%arg1, %arg2}
+builtin.func @tensor_clone(%arg0 : tensor<?x?xf32>, %arg1 : index, %arg2 : index) -> tensor<?x?xf32> {
+  %0 = iree.tensor.clone %arg0 : tensor<?x?xf32>{%arg1, %arg2}
+  return %0 : tensor<?x?xf32>
+}
+
+// -----
+// CHECK-LABEL: func @tensor_slice
+// CHECK: flow.tensor.slice %arg0[%arg1 for %arg2] : tensor<?xf32>{%arg3} -> tensor<?xf32>{%arg4}
+builtin.func @tensor_slice(%arg0 : tensor<?xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> tensor<?xf32> {
+  %0 = iree.tensor.slice %arg0[%arg1 for %arg2] : tensor<?xf32>{%arg3} -> tensor<?xf32>{%arg4}
+  return %0 : tensor<?xf32>
+}
+
+// -----
+// CHECK-LABEL: func @tensor_update
+// CHECK: flow.tensor.update %arg3, %arg0[%arg1] : tensor<?xf32>{%arg2} -> %arg0 as tensor<?xf32>{%arg4}
+builtin.func @tensor_update(%arg0 : tensor<?xf32>, %arg1 : index, %arg2 : index, %arg3 : tensor<?xf32>, %arg4 : index) -> tensor<?xf32> {
+  %0 = iree.tensor.update %arg3, %arg0[%arg1] : tensor<?xf32>{%arg2} -> tensor<?xf32>{%arg4}
+  return %0 : tensor<?xf32>
+}
+
+// -----
+// CHECK-LABEL: func @tensor_trace
+// CHECK: flow.tensor.trace {key = "FOOBAR"} %arg0, %arg1 : tensor<5xf32>, tensor<3xf32>
+builtin.func @tensor_trace(%arg0 : tensor<5xf32>, %arg1 : tensor<3xf32>) {
+  iree.tensor.trace "FOOBAR" %arg0, %arg1 : tensor<5xf32>, tensor<3xf32>
+  return
+}
+
+// -----
+// CHECK-LABEL: module @globals
+builtin.module @globals {
+  // CHECK: util.global public mutable @global1 = 50 : i32
+  iree.global mutable @global1 = 50 : i32
+  // CHECK: util.global public mutable @global2 = 50 : i32
+  iree.global public mutable @global2 = 50 : i32
+  // CHECK: util.global private mutable @global3 = 50 : i32
+  iree.global private mutable @global3 = 50 : i32
+  // CHECK: util.global private @global4 = 50 : i32
+  iree.global private @global4 = 50 : i32
+
+  // CHECK: util.global public @global5 initializer(@initializer) : tensor<4xi32>
+  iree.global @global5 initializer(@initializer) : tensor<4xi32>
+  builtin.func private @initializer() -> tensor<4xi32>
+}
+
+// -----
+// CHECK-LABEL: module @global_load
+builtin.module @global_load {
+  iree.global private @v_loaded : tensor<4xi32>
+  func @loaded() {
+    // CHECK: util.global.load @v_loaded : tensor<4xi32>
+    %0 = iree.global.load @v_loaded : tensor<4xi32>
+    return
+  }
+}
+
+// -----
+// CHECK-LABEL: module @global_store
+builtin.module @global_store {
+  iree.global private mutable @v_stored : tensor<4xi32>
+  func @stored() {
+    // CHECK: %[[CST:.*]] = constant
+    %cst = constant dense<5> : tensor<4xi32>
+    // CHECK: util.global.store %[[CST]], @v_stored : tensor<4xi32>
+    iree.global.store %cst, @v_stored : tensor<4xi32>
+    return
+  }
+}
+
+// -----
+// CHECK-LABEL: module @global_load_indirect
+builtin.module @global_load_indirect {
+  iree.global private @v_loaded : tensor<4xf32>
+  func @loaded_indirect() {
+    // CHECK: %[[ADDR:.*]] = util.global.address @v_loaded : !util.ptr<tensor<4xf32>>
+    %0 = iree.global.address @v_loaded : !iree.ptr<tensor<4xf32>>
+    // CHECK: util.global.load.indirect %[[ADDR]] : !util.ptr<tensor<4xf32>> -> tensor<4xf32>
+    %1 = iree.global.load.indirect %0 : !iree.ptr<tensor<4xf32>> -> tensor<4xf32>
+    return
+  }
+}
+
+// -----
+// CHECK-LABEL: module @global_store_indirect
+builtin.module @global_store_indirect {
+  iree.global private mutable @v_stored : tensor<4xf32>
+  func @stored_indirect(%arg0: tensor<4xf32>) {
+    // CHECK: %[[ADDR:.*]] = util.global.address @v_stored : !util.ptr<tensor<4xf32>>
+    %0 = iree.global.address @v_stored : !iree.ptr<tensor<4xf32>>
+    // CHECK: util.global.store.indirect %arg0, %ptr_v_stored : tensor<4xf32> -> !util.ptr<tensor<4xf32>>
+    iree.global.store.indirect %arg0, %0 : tensor<4xf32> -> !iree.ptr<tensor<4xf32>>
+    return
+  }
+}
diff --git a/iree/compiler/Translation/BUILD b/iree/compiler/Translation/BUILD
index 4d4bcd1..886a5a2 100644
--- a/iree/compiler/Translation/BUILD
+++ b/iree/compiler/Translation/BUILD
@@ -29,6 +29,7 @@
         "//iree/compiler/Dialect/VM/Conversion/StandardToVM",
         "//iree/compiler/Dialect/VM/Target/Bytecode",
         "//iree/compiler/Dialect/VM/Transforms",
+        "//iree/compiler/InputConversion/Common",
         "//iree/compiler/InputConversion/MHLO",
         "//iree/compiler/InputConversion/TOSA",
         "//iree/compiler/Utils",
diff --git a/iree/compiler/Translation/IREEVM.cpp b/iree/compiler/Translation/IREEVM.cpp
index 4a6ef7b..20bbff2 100644
--- a/iree/compiler/Translation/IREEVM.cpp
+++ b/iree/compiler/Translation/IREEVM.cpp
@@ -13,6 +13,7 @@
 #include "iree/compiler/Dialect/Util/Transforms/Passes.h"
 #include "iree/compiler/Dialect/VM/Target/Bytecode/TranslationFlags.h"
 #include "iree/compiler/Dialect/VM/Transforms/Passes.h"
+#include "iree/compiler/InputConversion/Common/Passes.h"
 #include "iree/compiler/InputConversion/MHLO/Passes.h"
 #include "iree/compiler/InputConversion/TOSA/Passes.h"
 #include "iree/compiler/Utils/TracingUtils.h"
@@ -145,6 +146,7 @@
       break;
   }
 
+  buildCommonInputConversionPassPipeline(passManager);
   IREE::Flow::buildFlowTransformPassPipeline(passManager);
   IREE::HAL::buildHALTransformPassPipeline(passManager, executableOptions);
   IREE::VM::buildVMTransformPassPipeline(passManager, targetOptions);
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEDialect.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEDialect.td
index d9740ae..691359a 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEDialect.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEDialect.td
@@ -95,7 +95,7 @@
   let parameters = (ins IREE_PtrTargetTypeParameter:$targetType);
 
   let printer = [{
-    $_printer << "list<" << getTargetType() << ">";
+    $_printer << "ptr<" << getTargetType() << ">";
   }];
 
   let parser = [{
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEOps.td
index 41774ff..f937ef7 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEOps.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEOps.td
@@ -30,7 +30,7 @@
 // Casts
 //===----------------------------------------------------------------------===//
 
-def IREE_TensorToBufferView : IREE_PureOp<"cast.tensor_to_buffer_view"> {
+def IREE_TensorToBufferViewOp : IREE_PureOp<"cast.tensor_to_buffer_view"> {
   let summary = "Casts a tensor to a BufferView, capturing dynamic dims";
   let arguments = (ins
     IREE_Tensor:$source,
@@ -44,7 +44,7 @@
   }];
 }
 
-def IREE_BufferViewToTensor : IREE_PureOp<"cast.buffer_view_to_tensor"> {
+def IREE_BufferViewToTensorOp : IREE_PureOp<"cast.buffer_view_to_tensor"> {
   let summary = "Casts a BufferView to a tensor, providing dynamic dims";
   let arguments = (ins
     IREE_BufferViewType:$source,
@@ -81,15 +81,14 @@
     OptionalAttr<AnyAttr>:$initial_value
   );
 
-  // TODO(laurenzo): copy SymbolVisibility/TypeOrAttr from UtilOps.cpp.
-  // let assemblyFormat = [{
-  //   custom<SymbolVisibility>($sym_visibility)
-  //   (`mutable` $is_mutable^)?
-  //   $sym_name
-  //   attr-dict
-  //   (`initializer` `(` $initializer^ `)`):(``)?
-  //   custom<TypeOrAttr>($type, $initial_value)
-  // }];
+  let assemblyFormat = [{
+    custom<SymbolVisibility>($sym_visibility)
+    (`mutable` $is_mutable^)?
+    $sym_name
+    attr-dict
+    (`initializer` `(` $initializer^ `)`):(``)?
+    custom<TypeOrAttr>($type, $initial_value)
+  }];
 }
 
 def IREE_GlobalAddressOp : IREE_PureOp<"global.address"> {
@@ -485,8 +484,7 @@
     IREE_ShapeDynamicDims:$target_dims,
     Variadic<IREE_Dim>:$start_indices,
     IREE_Tensor:$update,
-    IREE_ShapeDynamicDims:$update_dims,
-    OptionalAttr<IREE_TiedOpStorageAttr>:$tied_operands
+    IREE_ShapeDynamicDims:$update_dims
   );
   let results = (outs
     IREE_Tensor:$result
@@ -495,7 +493,7 @@
   let assemblyFormat = [{
     $update `,` $target `[` $start_indices `]` `:`
     type($update) (`{` $update_dims^ `}`)? `->`
-    `(` type($result) `,` $target_dims `,` $tied_operands `)`
+    type($result) (`{` $target_dims^ `}`)?
     attr-dict-with-keyword
   }];
 
@@ -521,7 +519,7 @@
     Variadic<IREE_Tensor>:$operands
   );
 
-  let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
+  let assemblyFormat = "$key attr-dict ($operands^ `:` type($operands))?";
 }
 
 #endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_OPS_TD
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREE/IREEOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREE/IREEOps.cpp
index 2a12751..a723b58 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREE/IREEOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREE/IREEOps.cpp
@@ -15,5 +15,80 @@
 using namespace mlir;
 using namespace mlir::iree;
 
+//===----------------------------------------------------------------------===//
+// custom<SymbolVisibility>($sym_visibility)
+//===----------------------------------------------------------------------===//
+// some.op custom<SymbolVisibility>($sym_visibility) $sym_name
+// ->
+// some.op @foo
+// some.op private @foo
+
+static ParseResult parseSymbolVisibility(OpAsmParser &parser,
+                                         StringAttr &symVisibilityAttr) {
+  StringRef symVisibility;
+  parser.parseOptionalKeyword(&symVisibility, {"public", "private", "nested"});
+  if (!symVisibility.empty()) {
+    symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility);
+  }
+  return success();
+}
+
+static void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
+                                  StringAttr symVisibilityAttr) {
+  if (!symVisibilityAttr) {
+    p << "public";
+  } else {
+    p << symVisibilityAttr.getValue();
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// custom<TypeOrAttr>($type, $attr)
+//===----------------------------------------------------------------------===//
+// some.op custom<TypeOrAttr>($type, $attr)
+// ->
+// some.op : i32
+// some.op = 42 : i32
+// some.op : i32 = 42 : index
+
+static ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
+                                   Attribute &attr) {
+  if (succeeded(parser.parseOptionalEqual())) {
+    if (failed(parser.parseAttribute(attr))) {
+      return parser.emitError(parser.getCurrentLocation())
+             << "expected attribute";
+    }
+    typeAttr = TypeAttr::get(attr.getType());
+    return success();
+  }
+
+  Type type;
+  if (failed(parser.parseColonType(type))) {
+    return parser.emitError(parser.getCurrentLocation()) << "expected type";
+  }
+  typeAttr = TypeAttr::get(type);
+
+  if (succeeded(parser.parseOptionalEqual())) {
+    if (failed(parser.parseAttribute(attr))) {
+      return parser.emitError(parser.getCurrentLocation())
+             << "expected attribute";
+    }
+  }
+
+  return success();
+}
+
+static void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
+                            Attribute attr) {
+  if (!attr || attr.getType() != type.getValue()) {
+    p << " : ";
+    p.printAttribute(type);
+  }
+  if (attr) {
+    p << " = ";
+    p.printAttribute(attr);
+  }
+}
+
 #define GET_OP_CLASSES
 #include "iree-dialects/Dialect/IREE/IREEOps.cpp.inc"