Merge pull request #3172 from ScottTodd:main-to-google

PiperOrigin-RevId: 332123717
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
index 19a0fd3..7c0f678 100644
--- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
+++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
@@ -32,6 +32,7 @@
     "@llvm-project//mlir:Affine": ["MLIRAffineOps"],
     "@llvm-project//mlir:AffineToStandardTransforms": ["MLIRAffineToStandard"],
     "@llvm-project//mlir:CFGTransforms": ["MLIRSCFToStandard"],
+    "@llvm-project//mlir:DialectUtils": [""],
     "@llvm-project//mlir:ExecutionEngineUtils": ["MLIRExecutionEngine"],
     "@llvm-project//mlir:GPUDialect": ["MLIRGPU"],
     "@llvm-project//mlir:GPUTransforms": ["MLIRGPU"],
diff --git a/experimental/ModelBuilder/test/BUILD b/experimental/ModelBuilder/test/BUILD
index 64a88a9..9c6c084 100644
--- a/experimental/ModelBuilder/test/BUILD
+++ b/experimental/ModelBuilder/test/BUILD
@@ -27,7 +27,10 @@
     name = "lit",
     srcs = glob(
         ["*.cpp"],
-        exclude = ["Bench*"],
+        exclude = [
+            "Bench*",
+            "TestMatMulVulkan.cpp",  # b/168746423
+        ],
     ),
     data = [
         # runtime libraries
diff --git a/experimental/ModelBuilder/test/TestMatMulVulkan.cpp b/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
index 4252271..61a01c1 100644
--- a/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
+++ b/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
@@ -97,13 +97,14 @@
   ModelRunner runner(modelBuilder.getModuleRef(),
                      ModelRunner::Target::GPUTarget);
   CompilationOptions options;
+  mlir::iree_compiler::SPIRVCodegenOptions codegenOptions;
   SmallVector<Type, 3> args = {typeA, typeB, typeC};
-  SmallVector<int64_t, 4> vWorkgroupSizes(workgroupSize.begin(),
-                                          workgroupSize.end());
-  SmallVector<int64_t, 4> vTileSizes(tileSizes.begin(), tileSizes.end());
+  codegenOptions.workgroupSize.assign(workgroupSize.begin(),
+                                      workgroupSize.end());
+  codegenOptions.tileSizes.assign(tileSizes.begin(), tileSizes.end());
   auto lowering = [&](mlir::PassManager &pm) {
-    pm.addPass(mlir::iree_compiler::createLinalgTileAndFusePass(
-        vWorkgroupSizes, vTileSizes, useWorkgroupMemory));
+    pm.addPass(
+        mlir::iree_compiler::createLinalgTileAndFusePass(codegenOptions));
     pm.addPass(mlir::iree_compiler::createConvertToGPUPass());
     pm.addPass(mlir::createLowerAffinePass());
     pm.addPass(mlir::createLegalizeStdOpsForSPIRVLoweringPass());
@@ -119,14 +120,14 @@
     spirvModulePM.addPass(
         mlir::spirv::createUpdateVersionCapabilityExtensionPass());
 
-    int numWorkgroupX =
-        vWorkgroupSizes.empty()
-            ? 1
-            : (width + vWorkgroupSizes[0] - 1) / vWorkgroupSizes[0];
-    int numWorkgroupY =
-        vWorkgroupSizes.size() < 2
-            ? 1
-            : (height + vWorkgroupSizes[1] - 1) / vWorkgroupSizes[1];
+    int numWorkgroupX = codegenOptions.tileSizes.empty()
+                            ? 1
+                            : (width + codegenOptions.tileSizes[0] - 1) /
+                                  codegenOptions.tileSizes[0];
+    int numWorkgroupY = codegenOptions.tileSizes.size() < 2
+                            ? 1
+                            : (height + codegenOptions.tileSizes[1] - 1) /
+                                  codegenOptions.tileSizes[1];
     pm.addPass(mlir::createAddVulkanLaunchWrapperPass(
         {numWorkgroupX, numWorkgroupY, 1}, args));
     mlir::LowerToLLVMOptions llvmOptions = {
diff --git a/iree/compiler/Conversion/BUILD b/iree/compiler/Conversion/BUILD
index 35f424e..bfa7317 100644
--- a/iree/compiler/Conversion/BUILD
+++ b/iree/compiler/Conversion/BUILD
@@ -27,5 +27,6 @@
         "//iree/compiler/Conversion/HLOToLinalg",
         "//iree/compiler/Conversion/LinalgToLLVM",
         "//iree/compiler/Conversion/LinalgToSPIRV",
+        "//iree/compiler/Conversion/LinalgToVector",
     ],
 )
diff --git a/iree/compiler/Conversion/CMakeLists.txt b/iree/compiler/Conversion/CMakeLists.txt
index 6a022f2..79ead57 100644
--- a/iree/compiler/Conversion/CMakeLists.txt
+++ b/iree/compiler/Conversion/CMakeLists.txt
@@ -23,5 +23,6 @@
     iree::compiler::Conversion::HLOToLinalg
     iree::compiler::Conversion::LinalgToLLVM
     iree::compiler::Conversion::LinalgToSPIRV
+    iree::compiler::Conversion::LinalgToVector
   PUBLIC
 )
diff --git a/iree/compiler/Conversion/HLOToLinalg/BUILD b/iree/compiler/Conversion/HLOToLinalg/BUILD
index da3a73e..a1f9de8 100644
--- a/iree/compiler/Conversion/HLOToLinalg/BUILD
+++ b/iree/compiler/Conversion/HLOToLinalg/BUILD
@@ -38,6 +38,7 @@
         "//iree/compiler/Dialect/IREE/IR",
         "//iree/compiler/Dialect/Shape/IR",
         "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:DialectUtils",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LinalgOps",
         "@llvm-project//mlir:LinalgTransforms",
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
index 87ed137..be68c3a 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
@@ -36,6 +36,7 @@
 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Function.h"
@@ -79,13 +80,11 @@
 /// Returns an ArrayAttr that contains `nLoops` attributes. All the attributes
 /// are "parallel" except the last `nReduction` elements, where are "reduction"
 /// attributes.
-// TODO(hanchung): Use helpers in StructuredOpsUtils.h instead of hardcoded
-// strings once the build system is set up.
 static ArrayAttr getParallelAndReductionIterAttrs(Builder b, unsigned nLoops,
                                                   unsigned nReduction) {
-  SmallVector<Attribute, 3> attrs(nLoops - nReduction,
-                                  b.getStringAttr("parallel"));
-  attrs.append(nReduction, b.getStringAttr("reduction"));
+  SmallVector<Attribute, 3> attrs(
+      nLoops - nReduction, b.getStringAttr(getParallelIteratorTypeName()));
+  attrs.append(nReduction, b.getStringAttr(getReductionIteratorTypeName()));
   return b.getArrayAttr(attrs);
 }
 
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
index dd6717a..cd0d823 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/BUILD
+++ b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
@@ -49,6 +49,7 @@
         "//iree/compiler/Conversion/CodegenUtils",
         "//iree/compiler/Conversion/HLOToHLO",
         "//iree/compiler/Conversion/HLOToLinalg",
+        "//iree/compiler/Conversion/LinalgToVector",
         "//iree/compiler/Dialect/HAL/IR",
         "//iree/compiler/Dialect/IREE/IR",
         "//iree/compiler/Dialect/Shape/IR",
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
index ce9a6cb..42538e3 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
@@ -65,6 +65,7 @@
     iree::compiler::Conversion::CodegenUtils
     iree::compiler::Conversion::HLOToHLO
     iree::compiler::Conversion::HLOToLinalg
+    iree::compiler::Conversion::LinalgToVector
     iree::compiler::Dialect::HAL::IR
     iree::compiler::Dialect::IREE::IR
     iree::compiler::Dialect::Shape::IR
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
index ea85ec2..01b35f3 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
@@ -14,19 +14,24 @@
 
 //===- KernelDispatchUtils.cpp - Utilities for generating dispatch info ---===//
 //
-// This file defines utility functions that can be used to create information
-// the dispatch on the host side needs to execute an entry point function, like
-// the number of workgroups to use for launch, etc.
+// This file defines utility functions that can be used to get the information
+// about tile sizes to use to partition work across workgroups, the workgroup
+// sizes and to create information the dispatch on the host side needs to
+// execute an entry point function (e.g. total number of workgroups).
 //
 //===----------------------------------------------------------------------===//
+#include "iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h"
+
 #include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
 #include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
 #include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
 #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/SPIRV/TargetAndABI.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Function.h"
@@ -40,6 +45,10 @@
 namespace mlir {
 namespace iree_compiler {
 
+//===----------------------------------------------------------------------===//
+// Number of workgroups computation
+//===----------------------------------------------------------------------===//
+
 FuncOp getNumWorkgroupsFn(FuncOp entryPointFn) {
   SymbolRefAttr attr =
       entryPointFn.getAttrOfType<SymbolRefAttr>(getNumWorkgroupsFnAttrName());
@@ -153,5 +162,271 @@
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Launch config calculation.
+//===----------------------------------------------------------------------===//
+
+/// Given `nprocs` try to distribute it evenly across 2 logical x and y.
+static std::tuple<int64_t, int64_t> distributeProcs2D(int64_t nprocs) {
+  int64_t nprocs_x = std::max<int64_t>(
+      1, static_cast<int64_t>(
+             llvm::PowerOf2Ceil(static_cast<uint64_t>(std::sqrt(nprocs)))));
+  return std::make_tuple(nprocs_x, nprocs / nprocs_x);
+}
+
+/// For a given operation `op`, `options` and `resourceLimits` of the hardware
+/// compute the
+/// 1) number of tiling levels and tile sizes to use (updates `tileSizes`),
+/// 2) workgroup size to use (updates `workgroupSize`),
+/// 3) number of subgroups to use if two level tiling is used (updates
+///    `numSubgroups`).
+template <typename T>
+static LogicalResult getOpLaunchConfig(T op, const SPIRVCodegenOptions &options,
+                                       spirv::ResourceLimitsAttr resourceLimits,
+                                       TileSizesListType &tileSizes,
+                                       std::array<int64_t, 3> &workgroupSize,
+                                       std::array<int64_t, 3> &numSubgroups) {
+  return op.emitError("undefined launch config for tiled operation");
+}
+
+/// Launch config for `linalg.batchmatmul`.
+template <>
+LogicalResult getOpLaunchConfig(linalg::BatchMatmulOp op,
+                                const SPIRVCodegenOptions &options,
+                                spirv::ResourceLimitsAttr resourceLimits,
+                                TileSizesListType &tileSizes,
+                                std::array<int64_t, 3> &workgroupSize,
+                                std::array<int64_t, 3> &numSubgroups) {
+  unsigned maxWorkgroupSize =
+      resourceLimits.max_compute_workgroup_invocations().getInt();
+  std::tie(workgroupSize[0], workgroupSize[1]) =
+      distributeProcs2D(maxWorkgroupSize);
+  workgroupSize[2] = 1;
+  // TODO(#3131): This is just being hard-wired for now to be minimal viable,
+  // but this can be decided better when we have better estimates of device
+  // charecteristics.
+  const int64_t nRowsPerWorkitem = 1;
+  const int64_t nColsPerWorkitem = 1;
+  const int64_t nBatchesPerWorkitem = 1;
+  int64_t tileSizeK = 0;
+  if (options.useWorkgroupMemory) {
+    // TODO(#3131): This number should be decided based on the amount of
+    // shared memory available (maybe). For now, just hard-wire it.
+    tileSizeK = 32;
+  }
+  assert(tileSizes.empty());
+  SmallVector<int64_t, 4> ts = {nBatchesPerWorkitem,
+                                nRowsPerWorkitem * workgroupSize[1],
+                                nColsPerWorkitem * workgroupSize[0], tileSizeK};
+  tileSizes.emplace_back(std::move(ts));
+  return success();
+}
+
+/// The size of the co-operative matrix multiply operations on the device.
+// TODO(#3131): This needs to be queried from the device.
+Optional<std::array<int64_t, 3>> getCooperativeMatmulSubgroupSize(
+    Type dataType, Type accumulatorType) {
+  if (dataType.isInteger(8) && accumulatorType.isInteger(32)) {
+    return std::array<int64_t, 3>{8, 8, 32};
+  }
+  if (dataType.isF16() &&
+      (accumulatorType.isF32() || accumulatorType.isF16())) {
+    return std::array<int64_t, 3>{8, 8, 16};
+  }
+  return {};
+}
+
+/// Launch configuration for using spv.CooperativeMatrixMulAddNV
+/// operations. Needs two levels of tiling.
+static LogicalResult getConfigForCooperativeMatmul(
+    linalg::MatmulOp op, spirv::ResourceLimitsAttr resourceLimits,
+    TileSizesListType &tileSizes, std::array<int64_t, 3> &workgroupSize,
+    std::array<int64_t, 3> &numSubgroups) {
+  auto targetEnv = spirv::TargetEnv(spirv::lookupTargetEnv(op));
+  if (!targetEnv.allows(spirv::Capability::CooperativeMatrixNV) ||
+      !targetEnv.allows(spirv::Extension::SPV_NV_cooperative_matrix))
+    return failure();
+
+  ShapedType lhsType = op.getOperand(0).getType().cast<ShapedType>();
+  ArrayRef<int64_t> lhsShape = lhsType.getShape();
+  ShapedType rhsType = op.getOperand(1).getType().cast<ShapedType>();
+  ArrayRef<int64_t> rhsShape = rhsType.getShape();
+  ShapedType outputType = op.getOperand(2).getType().cast<ShapedType>();
+
+  Optional<std::array<int64_t, 3>> coopMatmulSize =
+      getCooperativeMatmulSubgroupSize(lhsType.getElementType(),
+                                       outputType.getElementType());
+  if (!coopMatmulSize) return failure();
+
+  // Check that the matmul sizes are a multiple of the tilesize.
+  auto isMultipleOf = [](int64_t s, int64_t ts) {
+    return !ShapedType::isDynamic(s) && (s % ts) == 0;
+  };
+  if (!isMultipleOf(lhsShape[0], (*coopMatmulSize)[0]) ||
+      !isMultipleOf(rhsShape[1], (*coopMatmulSize)[1]) ||
+      !isMultipleOf(lhsShape[1], (*coopMatmulSize)[2]) ||
+      !isMultipleOf(rhsShape[0], (*coopMatmulSize)[2]))
+    return failure();
+
+  // TODO(ravishankarm, antiagainst): For now hardwire the subgroup size.
+  const int64_t subgroupSize = 32;
+  unsigned maxWorkgroupSize =
+      resourceLimits.max_compute_workgroup_invocations().getInt();
+  std::tie(numSubgroups[0], numSubgroups[1]) =
+      distributeProcs2D(maxWorkgroupSize / subgroupSize);
+  numSubgroups[2] = 1;
+  // TODO(#3131): This is just being hard-wired for now to be minimal viable,
+  // but this can be decided better when we have better estimates of device
+  // charecteristics.
+  const int64_t numVecMatmulPerSubgroupX = 1;
+  const int64_t numVecMatmulPerSubgroupY = 1;
+  SmallVector<int64_t, 4> ts = {
+      numVecMatmulPerSubgroupY * (*coopMatmulSize)[0] * numSubgroups[1],
+      numVecMatmulPerSubgroupX * (*coopMatmulSize)[1] * numSubgroups[0]};
+  tileSizes.emplace_back(std::move(ts));
+
+  workgroupSize[0] = numSubgroups[0] * subgroupSize;
+  workgroupSize[1] = numSubgroups[1];
+  workgroupSize[2] = 1;
+  // Subgroup tile sizes
+  SmallVector<int64_t, 4> subgroupTs = {
+      numVecMatmulPerSubgroupY * (*coopMatmulSize)[0],
+      numVecMatmulPerSubgroupX * (*coopMatmulSize)[1], (*coopMatmulSize)[2]};
+  tileSizes.emplace_back(std::move(subgroupTs));
+  return success();
+}
+
+template <>
+LogicalResult getOpLaunchConfig(linalg::MatmulOp op,
+                                const SPIRVCodegenOptions &options,
+                                spirv::ResourceLimitsAttr resourceLimits,
+                                TileSizesListType &tileSizes,
+                                std::array<int64_t, 3> &workgroupSize,
+                                std::array<int64_t, 3> &numSubgroups) {
+  if (options.useVectorization &&
+      succeeded(getConfigForCooperativeMatmul(op, resourceLimits, tileSizes,
+                                              workgroupSize, numSubgroups))) {
+    return success();
+  }
+  unsigned maxWorkgroupSize =
+      resourceLimits.max_compute_workgroup_invocations().getInt();
+  std::tie(workgroupSize[0], workgroupSize[1]) =
+      distributeProcs2D(maxWorkgroupSize);
+  workgroupSize[2] = 1;
+  const int nRowsPerWorkitem = 1;
+  const int nColsPerWorkitem = 1;
+  int64_t tileSizeK = 0;
+  if (options.useWorkgroupMemory) {
+    // TODO(#3131): This number should be decided based on the amount of shared
+    // memory available (maybe). For now, just hard-wire it.
+    tileSizeK = 32;
+  }
+  assert(tileSizes.empty());
+  SmallVector<int64_t, 4> ts = {nRowsPerWorkitem * workgroupSize[1],
+                                nColsPerWorkitem * workgroupSize[0], tileSizeK};
+  tileSizes.emplace_back(std::move(ts));
+  return success();
+}
+
+template <>
+LogicalResult getOpLaunchConfig(linalg::ConvOp op,
+                                const SPIRVCodegenOptions &options,
+                                spirv::ResourceLimitsAttr resourceLimits,
+                                TileSizesListType &tileSizes,
+                                std::array<int64_t, 3> &workgroupSize,
+                                std::array<int64_t, 3> &numSubgroups) {
+  unsigned maxWorkgroupSize =
+      resourceLimits.max_compute_workgroup_invocations().getInt();
+  const int64_t tileSizeX = 32;
+  int64_t tileSizeY = maxWorkgroupSize / tileSizeX;
+  SmallVector<int64_t, 4> ts = {1, tileSizeY, tileSizeX};
+  tileSizes.emplace_back(std::move(ts));
+  workgroupSize = {tileSizeX, tileSizeY, 1};
+  return success();
+}
+
+static LogicalResult getPoolingOpLaunchConfig(
+    const SPIRVCodegenOptions &options,
+    spirv::ResourceLimitsAttr resourceLimits, TileSizesListType &tileSizes,
+    std::array<int64_t, 3> &workgroupSize,
+    std::array<int64_t, 3> &numSubgroups) {
+  unsigned maxWorkgroupSize =
+      resourceLimits.max_compute_workgroup_invocations().getInt();
+  const int64_t tileSizeX = 32;
+  int64_t tileSizeY = maxWorkgroupSize / tileSizeX;
+  SmallVector<int64_t, 4> ts = {tileSizeY, tileSizeX};
+  tileSizes.emplace_back(std::move(ts));
+  workgroupSize = {tileSizeX, tileSizeY, 1};
+  return success();
+}
+
+#define DEFINE_POOLING_OP_CONFIG(opName)                                      \
+  template <>                                                                 \
+  LogicalResult getOpLaunchConfig(                                            \
+      opName op, const SPIRVCodegenOptions &options,                          \
+      spirv::ResourceLimitsAttr resourceLimits, TileSizesListType &tileSizes, \
+      std::array<int64_t, 3> &workgroupSize,                                  \
+      std::array<int64_t, 3> &numSubgroups) {                                 \
+    return getPoolingOpLaunchConfig(options, resourceLimits, tileSizes,       \
+                                    workgroupSize, numSubgroups);             \
+  }
+
+DEFINE_POOLING_OP_CONFIG(linalg::PoolingMaxOp)
+DEFINE_POOLING_OP_CONFIG(linalg::PoolingMinOp)
+DEFINE_POOLING_OP_CONFIG(linalg::PoolingSumOp)
+
+#undef DEFINE_POOLINGOP_CONFIG
+
+LogicalResult LaunchConfig::init(const SPIRVCodegenOptions &options,
+                                 ArrayRef<linalg::LinalgOp> linalgOps) {
+  if (!options.workgroupSize.empty()) {
+    for (linalg::LinalgOp op : linalgOps)
+      tileSizes[op.getOperation()->getName().getStringRef()] = {};
+    workgroupSize = {1, 1, 1};
+    for (unsigned i = 0,
+                  e = std::min<unsigned>(3, options.workgroupSize.size());
+         i != e; ++i)
+      workgroupSize[i] = options.workgroupSize[i];
+    return success();
+  }
+
+  if (linalgOps.empty()) return success();
+
+  spirv::ResourceLimitsAttr resourceLimits =
+      spirv::lookupTargetEnv(*linalgOps.begin()).getResourceLimits();
+
+  for (linalg::LinalgOp op : linalgOps) {
+    StringRef key = op.getOperation()->getName().getStringRef();
+    if (tileSizes.count(key)) {
+      return op.emitError("unexpected multiple ")
+             << key << " operations within dispatch region";
+    }
+
+    TileSizesListType &tileSizesInfo = tileSizes[key];
+
+#define DISPATCH(opName)                                                      \
+  if (auto lOp = dyn_cast<opName>(op.getOperation())) {                       \
+    if (failed(getOpLaunchConfig(lOp, options, resourceLimits, tileSizesInfo, \
+                                 workgroupSize, numSubgroups))) {             \
+      return failure();                                                       \
+    }                                                                         \
+    continue;                                                                 \
+  }
+
+    DISPATCH(linalg::BatchMatmulOp)
+    DISPATCH(linalg::ConvOp)
+    DISPATCH(linalg::MatmulOp)
+    DISPATCH(linalg::PoolingMaxOp)
+    DISPATCH(linalg::PoolingMinOp)
+    DISPATCH(linalg::PoolingSumOp)
+
+#undef DISPATCH
+  }
+
+  // TODO(ravishankarm): Verify that the set configurations is within the device
+  // limits.
+  return success();
+}
+
 }  // namespace iree_compiler
 }  // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
index 1573c55..d690a0a 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
@@ -22,11 +22,17 @@
 #ifndef IREE_COMPILER_CONVERSION_LINALGTOSPIRV_KERNELDISPATCHUTILS_H_
 #define IREE_COMPILER_CONVERSION_LINALGTOSPIRV_KERNELDISPATCHUTILS_H_
 
+#include <array>
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringMap.h"
+#include "mlir/IR/Operation.h"
 #include "mlir/Support/LLVM.h"
 
 namespace mlir {
 class FuncOp;
 class LogicalResult;
+class Operation;
 class PatternRewriter;
 class ShapedType;
 class Value;
@@ -34,6 +40,9 @@
 namespace linalg {
 class LinalgOp;
 }
+namespace iree_compiler {
+struct SPIRVCodegenOptions;
+}
 
 namespace iree_compiler {
 
@@ -61,6 +70,68 @@
 /// workgroups to use at launch time.
 FuncOp getNumWorkgroupsFn(FuncOp entryPointFn);
 
+/// Store the tile sizes to use at different levels of tiling as a vector of
+/// vectors.
+/// - First level tiling maps to workgroups.
+/// - Second level tiling maps to subgroups.
+using TileSizesListType = SmallVector<SmallVector<int64_t, 4>, 1>;
+
+/// Based on the linalg operations in a dispatch region, the number of levels of
+/// tiling, the tile sizes needed, the workgroup size, etc. need to be
+/// decided. These parameters are called `LaunchConfig`. This class implements
+/// one heuristic to compute these for the different linalg operations on
+/// buffers. This can be adapted later to support multiple configurations that
+/// can be picked based on device information/problem size information. It
+/// exposes the information needed by the codegenerators, and hides the
+/// implementation from the rest of the pipeline.
+class LaunchConfig {
+ public:
+  LaunchConfig() : workgroupSize({1, 1, 1}), numSubgroups({1, 1, 1}) {}
+
+  /// Given the sequence of `linalgOps` (and `options`), decide the launch
+  /// configuration by deciding
+  /// - the number of levels of tiling,
+  /// - tile sizes for each level,
+  /// - the workgroup size, and
+  /// - number of subgroups to use.
+  LogicalResult init(const SPIRVCodegenOptions &options,
+                     ArrayRef<linalg::LinalgOp> linalgOps);
+
+  /// Gets the tile size computed for an operation at all levels.
+  TileSizesListType getTileSizes(Operation *op) const {
+    return tileSizes.lookup(op->getName().getStringRef());
+  }
+
+  /// Gets the tile size computed for an operation for an level.
+  ArrayRef<int64_t> getTileSizes(Operation *op, size_t level) const {
+    auto it = tileSizes.find(op->getName().getStringRef());
+    if (it == tileSizes.end() || level >= it->second.size()) return {};
+    return it->second[level];
+  }
+
+  /// Returns the workgroup size to use based on the tile sizes.
+  ArrayRef<int64_t> getWorkgroupSize() const { return workgroupSize; }
+
+  /// Returns the number of subgroups to use.
+  ArrayRef<int64_t> getNumSubgroups() const { return numSubgroups; }
+
+ protected:
+  /// Current tile size configuration per operation.
+
+  // TODO: For now just use the operation name for the mapping. The tile sizes
+  // will be selected only for operations like matmul, conv, pool, etc. and
+  // assume that there is only one such operation per dispatch
+  // region. Eventually this might need to be relaxed, and some name-marker
+  // based mechanism might be needed.
+  llvm::StringMap<TileSizesListType> tileSizes;
+
+  /// Workgroup size to use.
+  std::array<int64_t, 3> workgroupSize;
+
+  /// Number of subgroups that are logically distributed along x, y & z.
+  std::array<int64_t, 3> numSubgroups;
+};
+
 }  // namespace iree_compiler
 }  // namespace mlir
 #endif  // IREE_COMPILER_CONVERSION_LINALGTOSPIRV_DISPATCHUTILS_H_
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index 4c7417e..135961a 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -18,6 +18,7 @@
 //
 //===----------------------------------------------------------------------===//
 #include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
+#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
 #include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
 #include "iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h"
 #include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
@@ -29,7 +30,6 @@
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
-#include "mlir/Dialect/SPIRV/TargetAndABI.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/Identifier.h"
 #include "mlir/IR/Matchers.h"
@@ -37,7 +37,7 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/FoldUtils.h"
 
-#define DEBUG_TYPE "iree-linalg-tile-and-fuse-buffer"
+#define DEBUG_TYPE "iree-linalg-tile-and-fuse"
 
 namespace mlir {
 namespace iree_compiler {
@@ -46,18 +46,6 @@
 // Utility functions
 //===----------------------------------------------------------------------===//
 
-static ArrayRef<int64_t> dropTrailingOnes(ArrayRef<int64_t> vector) {
-  if (vector.empty()) return vector;
-  auto numTrailingOnes = 0;
-  for (unsigned i = vector.size() - 1; i > 0; --i) {
-    if (vector[i] != 1) {
-      break;
-    }
-    numTrailingOnes++;
-  }
-  return vector.drop_back(numTrailingOnes);
-}
-
 /// Returns true if the linalg op has padding attribute, and that it has
 /// non-zero entries.
 template <typename OpTy>
@@ -68,129 +56,6 @@
                       [](APInt v) -> bool { return !v.isNullValue(); });
 }
 
-namespace {
-
-/// Computes tile sizes (and workgroup size) to use based on operations within
-/// the function, and resource constraints on the module.
-class TileSizeCalculator {
- public:
-  TileSizeCalculator(FuncOp funcOp)
-      : resourceLimits(spirv::lookupTargetEnv(funcOp).getResourceLimits()) {
-    if (DenseIntElementsAttr attr = spirv::lookupLocalWorkGroupSize(funcOp)) {
-      for (auto val : attr.getValues<APInt>())
-        workgroupSize.push_back(val.getSExtValue());
-    }
-    workgroupSize.resize(3, 1);
-  }
-
-  /// Set tile sizes to use.
-  void setTileSizes(ArrayRef<int64_t> sizes) {
-    tileSizes.assign(sizes.begin(), sizes.end());
-  }
-
-  /// Set workgroup size to use.
-  void setWorkgroupSize(ArrayRef<int64_t> sizes) {
-    workgroupSize.assign(sizes.begin(), sizes.end());
-  }
-
-  /// Compute the tile sizes based on the Linalg Ops within the dispatch region.
-  LogicalResult inferTileAndWorkgroupSize(ArrayRef<linalg::LinalgOp> linalgOps);
-
-  /// Get the current tile size computed.
-  ArrayRef<int64_t> getTileSizes() const { return tileSizes; }
-
-  /// Returns the workgroup size to use based on the tile sizes.
-  ArrayRef<int64_t> getWorkgroupSize() const { return workgroupSize; }
-
- private:
-  /// Current tile size configuration.
-  SmallVector<int64_t, 4> tileSizes;
-
-  /// Workgroup size to use.
-  SmallVector<int64_t, 3> workgroupSize;
-
-  /// Attribute for device constraints.
-  spirv::ResourceLimitsAttr resourceLimits;
-};
-}  // namespace
-
-LogicalResult TileSizeCalculator::inferTileAndWorkgroupSize(
-    ArrayRef<linalg::LinalgOp> linalgOps) {
-  tileSizes.clear();
-  if (linalgOps.empty()) {
-    tileSizes = {1, 1, 1};
-    workgroupSize = {1, 1, 1};
-    return success();
-  }
-  // The tile size will be driven by operations like matmul, conv, etc. within
-  // the list. So see what operation exists in the list to decide the tile size.
-  // If there are two such operations in the list, return error.
-  enum OpInfo : uint32_t {
-    None = 0x0,
-    Convolution = 0x1,
-    Matmul = 0x2,
-    Pooling = 0x4,
-    BatchMatmul = 0x8,
-  };
-  uint32_t opInfo = OpInfo::None;
-  for (linalg::LinalgOp linalgOp : linalgOps) {
-    Operation *op = linalgOp.getOperation();
-    if (isa<linalg::ConvOp>(op))
-      opInfo |= OpInfo::Convolution;
-    else if (isa<linalg::MatmulOp>(op))
-      opInfo |= OpInfo::Matmul;
-    else if (isa<linalg::BatchMatmulOp>(op))
-      opInfo |= OpInfo::BatchMatmul;
-    else if (isa<linalg::PoolingMaxOp>(op))
-      opInfo |= OpInfo::Pooling;
-    else if (isa<linalg::PoolingMinOp>(op))
-      opInfo |= OpInfo::Pooling;
-    else if (isa<linalg::PoolingSumOp>(op))
-      opInfo |= OpInfo::Pooling;
-  }
-  // If there are no tilable ops, there is nothing to do here.
-  if (!opInfo) return success();
-
-  Operation *linalgOp = *(linalgOps.begin());
-  if (llvm::countPopulation(opInfo) != 1)
-    return linalgOp->getParentOfType<FuncOp>().emitError(
-        "unhandled fusion of ops in dispatch function");
-
-  // TODO(ravishanarm, antiagainst): Only the maximum workgroup size is used
-  // here for computing tile sizes. In reality we also need the maximum
-  // workgroup memory size available (per workgroup) to compute the tile sizes
-  // effectively.
-  unsigned maxWorkgroupSize =
-      resourceLimits.max_compute_workgroup_invocations().getInt();
-  if (opInfo & OpInfo::Convolution) {
-    int64_t tileSizeX = 32;
-    int64_t tileSizeY = maxWorkgroupSize / 32;
-    tileSizes = {1, tileSizeY, tileSizeX};
-    workgroupSize = {tileSizeX, tileSizeY, 1};
-    return success();
-  }
-  if (opInfo & OpInfo::Matmul) {
-    // TODO: For now just hard wire this, but we can do better.
-    tileSizes = {8, 8, 4};
-    workgroupSize = {8, 8, 1};
-    return success();
-  }
-  if (opInfo & OpInfo::BatchMatmul) {
-    tileSizes = {2, 8, 8, 4};
-    workgroupSize = {8, 8, 2};
-    return success();
-  }
-  if (opInfo & OpInfo::Pooling) {
-    int64_t tileSizeX = 32;
-    int64_t tileSizeY = maxWorkgroupSize / 32;
-    tileSizes = {tileSizeY, tileSizeX};
-    workgroupSize = {tileSizeX, tileSizeY, 1};
-    return success();
-  }
-  return linalgOp->getParentOfType<FuncOp>().emitError(
-      "unable to find tile size for ops in this dispatch function");
-}
-
 //===----------------------------------------------------------------------===//
 // Pass and patterns
 //===----------------------------------------------------------------------===//
@@ -199,37 +64,53 @@
 /// Function pass that implements tiling and fusion in Linalg on buffers.
 struct LinalgTileAndFusePass
     : public PassWrapper<LinalgTileAndFusePass, OperationPass<ModuleOp>> {
-  LinalgTileAndFusePass(ArrayRef<int64_t> workgroupSize = {},
-                        ArrayRef<int64_t> tileSizes = {},
-                        bool useWorkgroupMem = false) {
-    this->workgroupSize = workgroupSize;
-    this->tileSizes = tileSizes;
-    this->useWorkgroupMemory = useWorkgroupMem;
+  LinalgTileAndFusePass() = default;
+  LinalgTileAndFusePass(const SPIRVCodegenOptions &passedOptions) {
+    options = passedOptions;
   }
   LinalgTileAndFusePass(const LinalgTileAndFusePass &pass) {}
 
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect,
-                    scf::SCFDialect, ShapeDialect>();
+                    scf::SCFDialect, ShapeDialect, vector::VectorDialect>();
   }
   void runOnOperation() override;
 
  private:
-  Option<bool> useWorkgroupMemory{
-      *this, "use-workgroup-memory",
-      llvm::cl::desc("Promote subviews to use workgroup memory"),
-      llvm::cl::init(false)};
+  SPIRVCodegenOptions options;
 
-  ListOption<int64_t> workgroupSize{
-      *this, "workgroup-size",
-      llvm::cl::desc("Override the default workgroup size"),
-      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
-
+  // TODO: Find a common place to put these options. They are defined three
+  // times, once here, once for the pass pipeline and once for the binary.
   ListOption<int64_t> tileSizes{
       *this, "tile-sizes", llvm::cl::desc("Set tile sizes to use"),
       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+
+  ListOption<int64_t> workgroupSize{
+      *this, "workgroup-size",
+      llvm::cl::desc(
+          "Number of workgroups to dispatch for the SPIR-V module; at most "
+          "three integers standarding for the x, y, and z dimension; "
+          "additional arguments will be ignored (used only for testing)"),
+      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+
+  Option<bool> useWorkgroupMemory{
+      *this, "use-workgroup-memory",
+      llvm::cl::desc(
+          "Enable use of workgroup memory in SPIR-V code generation pipeline"),
+      llvm::cl::init(false)};
+
+  Option<bool> useVectorization{
+      *this, "use-vectorization",
+      llvm::cl::desc(
+          "Enable use of vectorization in SPIR-V code generation pipeline"),
+      llvm::cl::init(false)};
 };
 
+//===----------------------------------------------------------------------===//
+// Patterns to tile computation to map to workgroups.
+//===----------------------------------------------------------------------===//
+
+/// Distribution options for linalg.matmul when targeting workgroups.
 static linalg::LinalgLoopDistributionOptions matmulDistributionOptions = {
     [](OpBuilder &builder, Location loc,
        ArrayRef<SubViewOp::Range> parallelLoopRanges) {
@@ -237,6 +118,7 @@
           builder, loc, parallelLoopRanges.size());
     },
     {linalg::DistributionMethod::CyclicNumProcsEqNumIters,
+     linalg::DistributionMethod::CyclicNumProcsEqNumIters,
      linalg::DistributionMethod::CyclicNumProcsEqNumIters}};
 
 /// Pattern for tiling operations. Updates the workgroup size in the surrounding
@@ -245,8 +127,8 @@
 struct TileMatmulPattern : public linalg::LinalgBaseTilingPattern {
   using Base = linalg::LinalgBaseTilingPattern;
   TileMatmulPattern(MLIRContext *context, linalg::LinalgTilingOptions options,
-                    ArrayRef<int64_t> tileSizes,
-                    ArrayRef<int64_t> workgroupSize, PatternBenefit benefit = 1)
+                    const LaunchConfig &launchConfig,
+                    PatternBenefit benefit = 1)
       : Base(MatmulOp::getOperationName(), context,
              options.setDistributionOptions(matmulDistributionOptions),
              linalg::LinalgMarker(
@@ -254,8 +136,7 @@
                  Identifier::get(getWorkgroupNumItemsGENumItersMarker(),
                                  context)),
              benefit),
-        tileSizes(tileSizes.begin(), tileSizes.end()),
-        workgroupSize(workgroupSize.begin(), workgroupSize.end()) {}
+        launchConfig(launchConfig) {}
 
   virtual LogicalResult matchAndRewrite(Operation *op,
                                         PatternRewriter &rewriter) const {
@@ -263,20 +144,37 @@
     // erased.
     FuncOp funcOp = op->getParentOfType<FuncOp>();
     if (!funcOp || failed(Base::matchAndRewrite(op, rewriter)) ||
-        failed(updateWorkGroupSize(funcOp, workgroupSize)) ||
+        failed(updateWorkGroupSize(funcOp, launchConfig.getWorkgroupSize())) ||
         (funcOp.getAttr(getNumWorkgroupsFnAttrName()) &&
          failed(createNumWorkgroupsFromResultShape(
-             rewriter, cast<linalg::LinalgOp>(op), funcOp, tileSizes)))) {
+             rewriter, cast<linalg::LinalgOp>(op), funcOp,
+             launchConfig.getTileSizes(op, 0))))) {
       return failure();
     }
     rewriter.eraseOp(op);
     return success();
   }
 
-  SmallVector<int64_t, 3> tileSizes;
-  SmallVector<int64_t, 3> workgroupSize;
+  const LaunchConfig &launchConfig;
 };
 
+/// Pattern to tile linalg.matmul for subgroups.
+struct TileMatmulSubgroupPattern
+    : public linalg::LinalgTilingPattern<linalg::MatmulOp> {
+  using Base = linalg::LinalgTilingPattern<linalg::MatmulOp>;
+  TileMatmulSubgroupPattern(MLIRContext *context,
+                            linalg::LinalgTilingOptions options,
+                            PatternBenefit benefit = 1)
+      : Base(context, options,
+             linalg::LinalgMarker(
+                 Identifier::get(getWorkgroupNumItemsGENumItersMarker(),
+                                 context),
+                 Identifier::get(getVectorizeMarker(), context)),
+             benefit) {}
+};
+
+/// Distribution options for targeting workgroups for convolution/pooling
+/// operations.
 static linalg::LinalgLoopDistributionOptions convPoolDistributionOptions = {
     [](OpBuilder &builder, Location loc,
        ArrayRef<SubViewOp::Range> parallelLoopRanges) {
@@ -286,14 +184,12 @@
     {linalg::DistributionMethod::Cyclic, linalg::DistributionMethod::Cyclic,
      linalg::DistributionMethod::Cyclic}};
 
-/// Pattern for tiling convolution and pooling operations. Currently is just a
-/// way to not tile when the operation has padding.
+/// Pattern for tiling convolution and pooling operations.
 template <typename OpTy>
 struct TileConvPoolPattern : public linalg::LinalgTilingPattern<OpTy> {
   using Base = linalg::LinalgTilingPattern<OpTy>;
   TileConvPoolPattern(MLIRContext *context, linalg::LinalgTilingOptions options,
-                      ArrayRef<int64_t> tileSizes,
-                      ArrayRef<int64_t> workgroupSize,
+                      const LaunchConfig &launchConfig,
                       PatternBenefit benefit = 1)
       : Base(context,
              options.setDistributionOptions(convPoolDistributionOptions),
@@ -301,38 +197,58 @@
                  ArrayRef<Identifier>(),
                  Identifier::get(getWorkgroupMarker(), context)),
              benefit),
-        tileSizes(tileSizes.begin(), tileSizes.end()),
-        workgroupSize(workgroupSize.begin(), workgroupSize.end()) {}
+        launchConfig(launchConfig) {}
 
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
     if (hasPadding(cast<OpTy>(op))) return failure();
     FuncOp funcOp = op->getParentOfType<FuncOp>();
     if (!funcOp || failed(Base::matchAndRewrite(op, rewriter)) ||
-        failed(updateWorkGroupSize(funcOp, this->workgroupSize)))
+        failed(updateWorkGroupSize(funcOp, launchConfig.getWorkgroupSize())))
       return failure();
     funcOp.removeAttr(getNumWorkgroupsFnAttrName());
     return success();
   }
 
-  SmallVector<int64_t, 3> tileSizes;
-  SmallVector<int64_t, 3> workgroupSize;
+  const LaunchConfig &launchConfig;
 };
 
+/// Populate patterns for first-level tiling.
+static void populateTilingToWorkgroupPatterns(
+    MLIRContext *context, const LaunchConfig &launchConfig,
+    OwningRewritePatternList &patterns) {
+  // Function to compute first level tiling values.
+  std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>
+      getOuterTileSizeFn =
+          [&launchConfig](OpBuilder &builder,
+                          Operation *operation) -> SmallVector<Value, 4> {
+    ArrayRef<int64_t> tileSizes = launchConfig.getTileSizes(operation, 0);
+    if (tileSizes.empty()) return {};
+    SmallVector<Value, 4> tileSizesVal;
+    tileSizesVal.reserve(tileSizes.size());
+    for (auto val : tileSizes) {
+      tileSizesVal.push_back(
+          builder.create<ConstantIndexOp>(operation->getLoc(), val));
+    }
+    return tileSizesVal;
+  };
+  patterns.insert<TileConvPoolPattern<linalg::ConvOp>,
+                  TileMatmulPattern<linalg::MatmulOp>,
+                  TileMatmulPattern<linalg::BatchMatmulOp>,
+                  TileConvPoolPattern<linalg::PoolingMaxOp>,
+                  TileConvPoolPattern<linalg::PoolingMinOp>,
+                  TileConvPoolPattern<linalg::PoolingSumOp>>(
+      context,
+      linalg::LinalgTilingOptions()
+          .setTileSizeComputationFunction(getOuterTileSizeFn)
+          .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops),
+      launchConfig);
+}
+
 //===----------------------------------------------------------------------===//
 // Patterns to promote subviews to workgroup memory
 //===----------------------------------------------------------------------===//
 
-/// Function used as callback for copyin/copyout in promotion pattern used to
-/// promote subviews to workgroup memory when the number of threads is known to
-/// be greater than equal to the number of iteration of loops the copy is
-/// lowered to.
-static LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst) {
-  auto copyOp = b.create<linalg::CopyOp>(src.getLoc(), src, dst);
-  setMarker(copyOp, getCopyToWorkgroupMemoryMarker());
-  return success();
-}
-
 /// Pattern to promote matmul operands to workgroup memory.
 struct PromoteMatmulSubviewsPattern
     : public linalg::LinalgPromotionPattern<linalg::MatmulOp> {
@@ -381,10 +297,114 @@
 };
 }  // namespace
 
+static void populatePromotionPatterns(MLIRContext *context,
+                                      OwningRewritePatternList &patterns) {
+  patterns
+      .insert<PromoteMatmulSubviewsPattern, PromoteConvolutionSubviewsPattern>(
+          context,
+          linalg::LinalgPromotionOptions()
+              .setAllocationDeallocationFns(allocateWorkgroupMemory,
+                                            deallocateWorkgroupMemory)
+              .setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory));
+}
+
+//===----------------------------------------------------------------------===//
+// Patterns and methods for subgroup tiling.
+//===----------------------------------------------------------------------===//
+
+/// Computes the Value for subgroupID along each dimension given number of
+/// subgroups `numSubGroups` along each dimension (x-first, y-second, z-third).
+static SmallVector<linalg::ProcInfo, 2> getSubgroupIdsAndCounts(
+    OpBuilder &builder, Location loc, ArrayRef<int64_t> numSubgroups) {
+  Type indexType = builder.getIndexType();
+  Value subgroupId = builder.create<gpu::SubgroupIdOp>(loc, indexType);
+  SmallVector<linalg::ProcInfo, 2> procInfo(numSubgroups.size());
+
+  // subgroupID
+  //   = id.z * nsubgroups.y * nsubgroups.x + id.y * nsubgroups.x + id.x
+  using edsc::op::operator%;
+  for (size_t i = 0, e = numSubgroups.size(); i != e; ++i) {
+    Value nprocs = builder.create<ConstantIndexOp>(loc, numSubgroups[i]);
+    Value procId = subgroupId % nprocs;
+    procInfo[e - i - 1] = linalg::ProcInfo{procId, nprocs};
+    subgroupId = builder.create<SignedDivIOp>(loc, subgroupId, nprocs);
+  }
+  return procInfo;
+}
+
+/// Patterns for second level tiling to target subgroups.
+static void populateTilingToSubgroupPatterns(
+    MLIRContext *context, const LaunchConfig &launchConfig,
+    OwningRewritePatternList &patterns) {
+  std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>
+      getInnerTileSizeFn =
+          [&launchConfig](OpBuilder &builder,
+                          Operation *operation) -> SmallVector<Value, 4> {
+    ArrayRef<int64_t> tileSizes = launchConfig.getTileSizes(operation, 1);
+    if (tileSizes.empty()) return {};
+    SmallVector<Value, 4> tileSizesVal;
+    tileSizesVal.reserve(tileSizes.size());
+    for (auto val : tileSizes) {
+      tileSizesVal.push_back(
+          builder.create<ConstantIndexOp>(operation->getLoc(), val));
+    }
+    return tileSizesVal;
+  };
+
+  auto getSubgroupProcInfoFn =
+      [&launchConfig](OpBuilder &builder, Location loc,
+                      ArrayRef<SubViewOp::Range> parallelLoopRanges) {
+        ArrayRef<int64_t> numSubgroups =
+            launchConfig.getNumSubgroups().take_front(
+                parallelLoopRanges.size());
+        return getSubgroupIdsAndCounts(builder, loc, numSubgroups);
+      };
+  linalg::LinalgLoopDistributionOptions subgroupDistributionOptions = {
+      getSubgroupProcInfoFn,
+      {linalg::DistributionMethod::CyclicNumProcsEqNumIters,
+       linalg::DistributionMethod::CyclicNumProcsEqNumIters}};
+  patterns.insert<TileMatmulSubgroupPattern>(
+      context, linalg::LinalgTilingOptions()
+                   .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops)
+                   .setTileSizeComputationFunction(getInnerTileSizeFn)
+                   .setDistributionOptions(subgroupDistributionOptions));
+}
+
+//====---------------------------------------------------------------------===//
+// Patterns for vectorization
+//====---------------------------------------------------------------------===//
+
+static void populateVectorizationPatterns(MLIRContext *context,
+                                          const LaunchConfig &launchConfig,
+                                          OwningRewritePatternList &patterns) {
+  patterns.insert<linalg::LinalgVectorizationPattern<linalg::MatmulOp>>(
+      context,
+      linalg::LinalgMarker(Identifier::get(getVectorizeMarker(), context)));
+}
+
+/// Apply canonicalizations related to tiling to make promotion/vectorization
+/// easier.
+static void applyCanonicalizationPatterns(MLIRContext *context, Operation *op) {
+  OwningRewritePatternList canonicalizationPatterns;
+  canonicalizationPatterns.insert<AffineMinCanonicalizationPattern>(context);
+  AffineApplyOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
+  AffineMinOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
+  SubViewOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
+  applyPatternsAndFoldGreedily(op, canonicalizationPatterns);
+}
+
 void LinalgTileAndFusePass::runOnOperation() {
   MLIRContext *context = &getContext();
   ModuleOp module = getOperation();
 
+  // Override options with command line values.
+  if (!tileSizes.empty())
+    options.tileSizes.assign(tileSizes.begin(), tileSizes.end());
+  if (!workgroupSize.empty())
+    options.workgroupSize.assign(workgroupSize.begin(), workgroupSize.end());
+  if (useWorkgroupMemory) options.useWorkgroupMemory = true;
+  if (useVectorization) options.useVectorization = true;
+
   LLVM_DEBUG(
       llvm::dbgs() << "--- IREE Linalg tile and fuse configuration ---\n";);
   for (FuncOp funcOp : module.getOps<FuncOp>()) {
@@ -397,58 +417,62 @@
     }
     Block &block = body.front();
     auto linalgOps = block.getOps<linalg::LinalgOp>();
-    if (linalgOps.empty()) return;
+    if (linalgOps.empty()) continue;
 
-    TileSizeCalculator tileSizeCalculator(funcOp);
-    if (tileSizes.empty()) {
-      // Get the tile sizes to use for the lowering.
-      SmallVector<int64_t, 3> tileSizes;
-      SmallVector<linalg::LinalgOp, 1> opsVec(linalgOps.begin(),
-                                              linalgOps.end());
-      if (failed(tileSizeCalculator.inferTileAndWorkgroupSize(opsVec)))
-        return signalPassFailure();
-    } else {
-      tileSizeCalculator.setTileSizes(tileSizes);
-      if (!workgroupSize.empty())
-        tileSizeCalculator.setWorkgroupSize(workgroupSize);
+    LaunchConfig launchConfig;
+    SmallVector<linalg::LinalgOp, 4> linalgOpsVec(linalgOps.begin(),
+                                                  linalgOps.end());
+    if (failed(launchConfig.init(options, linalgOpsVec))) {
+      funcOp.emitError("unable to find launch configuration");
+      return signalPassFailure();
     }
 
     LLVM_DEBUG({
       llvm::dbgs() << "@func " << funcOp.getName() << ": # workgroup sizes: [";
-      interleaveComma(tileSizeCalculator.getWorkgroupSize(), llvm::dbgs());
-      llvm::dbgs() << "]\ntile sizes: [";
-      interleaveComma(tileSizeCalculator.getTileSizes(), llvm::dbgs());
+      interleaveComma(launchConfig.getWorkgroupSize(), llvm::dbgs());
       llvm::dbgs() << "]\n";
+      for (auto op : linalgOps) {
+        llvm::dbgs() << "\t" << op.getOperation()->getName() << " : ";
+        TileSizesListType const &tileSizes = launchConfig.getTileSizes(op);
+        llvm::dbgs() << "{";
+        std::string sep = "";
+        for (auto &level : enumerate(tileSizes)) {
+          llvm::dbgs() << sep << level.index() << " : [";
+          sep = ", ";
+          interleaveComma(level.value(), llvm::dbgs());
+          llvm::dbgs() << "]";
+        }
+        llvm::dbgs() << "}\n";
+      }
     });
 
-    OwningRewritePatternList tilingPatterns;
-    tilingPatterns.insert<TileConvPoolPattern<linalg::ConvOp>,
-                          TileMatmulPattern<linalg::MatmulOp>,
-                          TileMatmulPattern<linalg::BatchMatmulOp>,
-                          TileConvPoolPattern<linalg::PoolingMaxOp>,
-                          TileConvPoolPattern<linalg::PoolingMinOp>,
-                          TileConvPoolPattern<linalg::PoolingSumOp>>(
-        context,
-        linalg::LinalgTilingOptions()
-            .setTileSizes(tileSizeCalculator.getTileSizes())
-            .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops),
-        tileSizeCalculator.getTileSizes(),
-        tileSizeCalculator.getWorkgroupSize());
-    applyPatternsAndFoldGreedily(funcOp, tilingPatterns);
+    OwningRewritePatternList firstLevelTilingPatterns;
+    populateTilingToWorkgroupPatterns(context, launchConfig,
+                                      firstLevelTilingPatterns);
+    applyPatternsAndFoldGreedily(funcOp, firstLevelTilingPatterns);
+    applyCanonicalizationPatterns(context, funcOp);
 
-    if (useWorkgroupMemory) {
+    if (options.useWorkgroupMemory) {
       // The promotion patterns are put separate from the tiling patterns to
       // make sure that the allocated scratchspace memory is constant sizes
       // which requires some folding to trigger.
       OwningRewritePatternList promotionPatterns;
-      promotionPatterns.insert<PromoteMatmulSubviewsPattern,
-                               PromoteConvolutionSubviewsPattern>(
-          context,
-          linalg::LinalgPromotionOptions()
-              .setAllocationDeallocationFns(allocateWorkgroupMemory,
-                                            deallocateWorkgroupMemory)
-              .setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory));
+      populatePromotionPatterns(context, promotionPatterns);
       applyPatternsAndFoldGreedily(funcOp, promotionPatterns);
+      applyCanonicalizationPatterns(context, funcOp);
+    }
+
+    if (options.useVectorization) {
+      OwningRewritePatternList secondLevelTilingPatterns;
+      populateTilingToSubgroupPatterns(context, launchConfig,
+                                       secondLevelTilingPatterns);
+      applyPatternsAndFoldGreedily(funcOp, secondLevelTilingPatterns);
+      applyCanonicalizationPatterns(context, funcOp);
+
+      OwningRewritePatternList vectorizationPatterns;
+      populateVectorizationPatterns(context, launchConfig,
+                                    vectorizationPatterns);
+      applyPatternsAndFoldGreedily(funcOp, vectorizationPatterns);
     }
   }
 }
@@ -458,10 +482,8 @@
 //===----------------------------------------------------------------------===//
 
 std::unique_ptr<OperationPass<ModuleOp>> createLinalgTileAndFusePass(
-    ArrayRef<int64_t> workgroupSize, ArrayRef<int64_t> tileSizes,
-    bool useWorkgroupMemory) {
-  return std::make_unique<LinalgTileAndFusePass>(workgroupSize, tileSizes,
-                                                 useWorkgroupMemory);
+    const SPIRVCodegenOptions &options) {
+  return std::make_unique<LinalgTileAndFusePass>(options);
 }
 
 static PassRegistration<LinalgTileAndFusePass> pass(
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
index cc90ccc..04044a7 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
@@ -22,7 +22,9 @@
 
 #include "iree/compiler/Conversion/HLOToHLO/Passes.h"
 #include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
+#include "iree/compiler/Conversion/LinalgToVector/Passes.h"
 #include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
+#include "llvm/Support/CommandLine.h"
 #include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h"
 #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
 #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
@@ -50,22 +52,21 @@
 namespace iree_compiler {
 
 namespace {
-/// Command line options for use with SPIR-V code-generation pass pipeline.
-struct SPIRVCodegenClOpts : public PassPipelineOptions<SPIRVCodegenClOpts> {
-  ListOption<int64_t> workgroupSize{
-      *this, "workgroup-size",
+
+/// Linalg to SPIR-V pass pipeline options. In theory this is a superset of all
+/// options in all passes in the pipeline. Adding those based on need, and they
+/// should be needed only for testing.
+struct LinalgToSPIRVPassPipelineOptions
+    : public PassPipelineOptions<LinalgToSPIRVPassPipelineOptions> {
+  Option<bool> useVectorization{
+      *this, "use-vectorization",
       llvm::cl::desc(
-          "Number of workgroups to dispatch for the SPIR-V module; at most "
-          "three integers standarding for the x, y, and z dimension; "
-          "additional arguments will be ignored (used only for testing)"),
-      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
-  ListOption<int64_t> tileSizes{
-      *this, "tile-sizes", llvm::cl::desc("Set tile sizes to use"),
-      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
-  Option<bool> useWorkgroupMemory{
-      *this, "use-workgroup-memory",
-      llvm::cl::desc(
-          "Enable use of workgroup memory in SPIR-V code generation pipeline"),
+          "Enable use of vectorization in SPIR-V code generation pipeline"),
+      llvm::cl::init(false)};
+  Option<bool> useVectorPass{
+      *this, "use-vector-pass",
+      llvm::cl::desc("Enable use of Linalg vectorization in SPIR-V code "
+                     "generation pipeline"),
       llvm::cl::init(false)};
 };
 
@@ -101,8 +102,10 @@
   //   with the second tile and fuse pass.
   //===--------------------------------------------------------------------===//
   pm.addPass(createSplitDispatchFunctionPass());
-  pm.addPass(createLinalgTileAndFusePass(
-      options.workgroupSize, options.tileSizes, options.useWorkgroupMemory));
+  pm.addPass(createLinalgTileAndFusePass(options));
+  if (options.useVectorPass) {
+    pm.addPass(createLoadStoreVectorizationPass());
+  }
   pm.addPass(createCanonicalizerPass());
 
   //===--------------------------------------------------------------------===//
@@ -252,30 +255,33 @@
   addLinalgToSPIRVPasses(pm, options);
 }
 
-static SPIRVCodegenOptions getSPIRVCodegenOptions(
-    const SPIRVCodegenClOpts &clOpts) {
-  SPIRVCodegenOptions options;
-  options.workgroupSize.assign(clOpts.workgroupSize.begin(),
-                               clOpts.workgroupSize.end());
-  options.tileSizes.assign(clOpts.tileSizes.begin(), clOpts.tileSizes.end());
-  options.useWorkgroupMemory = clOpts.useWorkgroupMemory;
-  return options;
+static SPIRVCodegenOptions initializeCodegenOptions(
+    const LinalgToSPIRVPassPipelineOptions &options) {
+  SPIRVCodegenOptions codegenOptions;
+  codegenOptions.useVectorization = options.useVectorization;
+  return codegenOptions;
 }
 
-static PassPipelineRegistration<SPIRVCodegenClOpts> linalgToSPIRVPipeline(
-    "iree-codegen-linalg-to-spirv-pipeline",
-    "Runs the progressive lowering pipeline from Linalg to SPIR-V",
-    [](OpPassManager &passManager, const SPIRVCodegenClOpts &options) {
-      addLinalgToSPIRVPasses(passManager, getSPIRVCodegenOptions(options));
-    });
+static PassPipelineRegistration<LinalgToSPIRVPassPipelineOptions>
+    linalgToSPIRVPipeline(
+        "iree-codegen-linalg-to-spirv-pipeline",
+        "Runs the progressive lowering pipeline from Linalg to SPIR-V",
+        [](OpPassManager &passManager,
+           const LinalgToSPIRVPassPipelineOptions &options) {
+          addLinalgToSPIRVPasses(passManager,
+                                 initializeCodegenOptions(options));
+        });
 
-static PassPipelineRegistration<SPIRVCodegenClOpts> hloToLinalgSPIRVPipeline(
-    "iree-codegen-hlo-to-spirv-pipeline",
-    "Runs the progressive lowering pipeline from XLA HLO to Linalg to SPIR-V",
-    [](OpPassManager &passManager, const SPIRVCodegenClOpts &options) {
-      buildSPIRVTransformPassPipeline(passManager,
-                                      getSPIRVCodegenOptions(options));
-    });
+static PassPipelineRegistration<LinalgToSPIRVPassPipelineOptions>
+    hloToLinalgSPIRVPipeline(
+        "iree-codegen-hlo-to-spirv-pipeline",
+        "Runs the progressive lowering pipeline from XLA HLO to Linalg to "
+        "SPIR-V",
+        [](OpPassManager &passManager,
+           const LinalgToSPIRVPassPipelineOptions &options) {
+          buildSPIRVTransformPassPipeline(passManager,
+                                          initializeCodegenOptions(options));
+        });
 
 }  // namespace iree_compiler
 }  // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
index 6e07cb0..80805e2 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
@@ -17,6 +17,7 @@
 
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassOptions.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
@@ -27,6 +28,8 @@
   SmallVector<int64_t, 3> workgroupSize = {};
   SmallVector<int64_t, 3> tileSizes = {};
   bool useWorkgroupMemory = false;
+  bool useVectorization = false;
+  bool useVectorPass = false;
 };
 
 /// Pass to initialize the function that computes the number of workgroups for
@@ -40,8 +43,7 @@
 /// it exists) and along "z" for the next loop (if it exists). The workgroup
 /// size is expected to be of size at-most 3.
 std::unique_ptr<OperationPass<ModuleOp>> createLinalgTileAndFusePass(
-    ArrayRef<int64_t> workGroupSize = {}, ArrayRef<int64_t> tileSizes = {},
-    bool useWorkgroupMem = false);
+    const SPIRVCodegenOptions &options);
 
 /// Pass to add the synchronizations and attributes needed to lower from PLoops
 /// to GPU dialect.
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
index 47cc37d..2c14351 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
@@ -107,6 +107,11 @@
   Value blockId = builder.create<gpu::BlockIdOp>(loc, indexType, attr);
   Value blockDim = builder.create<gpu::BlockDimOp>(loc, indexType, attr);
   Value threadId = builder.create<gpu::ThreadIdOp>(loc, indexType, attr);
+  // TODO(ravishankarm): Using affine_maps here would be beneficial, and we can
+  // do this because the blockDim is constant. But this would lead to an
+  // ordering issue cause it assumes that the workgroup size has already been
+  // set. If using affine_map can help, make sure that the workgroup size is set
+  // before.
   return {builder.create<AddIOp>(
               loc, builder.create<MulIOp>(loc, blockId, blockDim), threadId),
           builder.create<MulIOp>(loc, blockDim, gridDim)};
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Utils.h b/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
index 934728d..682583a 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
@@ -62,7 +62,6 @@
 SmallVector<linalg::ProcInfo, 2> getGPUProcessorIdsAndCounts(OpBuilder &builder,
                                                              Location loc,
                                                              unsigned numDims);
-
 }  // namespace iree_compiler
 }  // namespace mlir
 
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
index e06eae5..f6604bf 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
@@ -64,7 +64,6 @@
 //   CHECK-DAG:   %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
 //   CHECK-DAG:   %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
 //   CHECK-DAG:   %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
-//   CHECK-DAG:   %[[C0:.+]] = constant 0
 //   CHECK-DAG:   %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
 //   CHECK-DAG:   %[[NBLOCKSX:.+]] = "gpu.grid_dim"() {dimension = "x"}
 //   CHECK-DAG:   %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
@@ -78,9 +77,9 @@
 //       CHECK:   scf.parallel (%[[IV0:.+]], %[[IV1:.+]], %[[IV2:.+]]) = (%[[BIDZ]], %[[LBY]], %[[LBX]])
 //  CHECK-SAME:     step (%[[NBLOCKSZ]], %[[STEPY]], %[[STEPX]])
 //       CHECK:     %[[VIEW1:.+]] = subview %[[ARG1]]
-//  CHECK-SAME:       [%[[IV0]], %[[IV1]], %[[IV2]], %[[C0]]]
+//  CHECK-SAME:       [%[[IV0]], %[[IV1]], %[[IV2]], 0]
 //       CHECK:     %[[VIEW2:.+]] = subview %[[RET0]]
-//  CHECK-SAME:       [%[[IV0]], %[[IV1]], %[[IV2]], %[[C0]]]
+//  CHECK-SAME:       [%[[IV0]], %[[IV1]], %[[IV2]], 0]
 //       CHECK:     linalg.conv
 //  CHECK-SAME:       %[[ARG0]], %[[VIEW1]], %[[VIEW2]]
 //  CHECK-SAME:       "workgroup"
@@ -116,37 +115,38 @@
 }
 
 //   CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 8)>
+//   CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 16)>
 //       CHECK: func @matmul()
-//  CHECK-SAME:   local_size = dense<[8, 8, 1]>
+//  CHECK-SAME:   local_size = dense<[16, 8, 1]>
 //  CHECK-SAME:   vkspv.num_workgroups_fn = @[[NUM_WORKGROUPS_FN:.[a-zA-Z0-9_]+]]
 //   CHECK-DAG:   %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
 //   CHECK-DAG:   %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
 //   CHECK-DAG:   %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
-//   CHECK-DAG:   %[[C0:.+]] = constant 0
-//   CHECK-DAG:   %[[C4:.+]] = constant 4
 //   CHECK-DAG:   %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
 //   CHECK-DAG:   %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
 //   CHECK-NOT:   scf.parallel
-//       CHECK:   scf.for %[[IV:.+]] = %[[C0]] to %{{.+}} step %[[C4]]
-//       CHECK:     %[[LBY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-//       CHECK:     %[[VIEW0:.+]] = subview %[[ARG0]][%[[LBY]], %[[IV]]
-//       CHECK:     %[[LBX:.+]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
-//       CHECK:     %[[VIEW1:.+]] = subview %[[ARG1]][%[[IV]], %[[LBX]]]
-//       CHECK:     %[[LBY_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
-//       CHECK:     %[[LBX_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
-//       CHECK:     %[[VIEW2:.+]] = subview %[[RET0]][%[[LBY_2]], %[[LBX_2]]]
-//       CHECK:     linalg.matmul
-//  CHECK-SAME:       "workgroup_numprocs_ge_numiters"
-//  CHECK-SAME:       %[[VIEW0]], %[[VIEW1]], %[[VIEW2]]
+//   CHECK-NOT:   scf.for
+//       CHECK:   %[[LBY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+//       CHECK:   %[[VIEW0:.+]] = subview %[[ARG0]][%[[LBY]], 0]
+//       CHECK:   %[[LBX:.+]] = affine.apply #[[MAP3]]()[%[[BIDX]]]
+//       CHECK:   %[[VIEW1:.+]] = subview %[[ARG1]][0, %[[LBX]]]
+//       CHECK:   %[[LBY_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+//       CHECK:   %[[LBX_2:.+]] = affine.apply #[[MAP3]]()[%[[BIDX]]]
+//       CHECK:   %[[VIEW2:.+]] = subview %[[RET0]][%[[LBY_2]], %[[LBX_2]]]
+//       CHECK:   linalg.matmul
+//  CHECK-SAME:     "workgroup_numprocs_ge_numiters"
+//  CHECK-SAME:     %[[VIEW0]], %[[VIEW1]], %[[VIEW2]]
 //       CHECK: func @[[NUM_WORKGROUPS_FN]]
 //   CHECK-DAG:   %[[C8:.+]] = constant 8 : index
 //   CHECK-DAG:   %[[C7:.+]] = constant 7 : index
 //   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
 //   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:   %[[C16:.+]] = constant 16 : index
+//   CHECK-DAG:   %[[C15:.+]] = constant 15 : index
 //       CHECK:   %[[DIM0:.+]] = dim %{{.*}}, %[[C0]]
 //       CHECK:   %[[DIM1:.+]] = dim %{{.*}}, %[[C1]]
-//       CHECK:   %[[T0:.+]] = addi %[[DIM1]], %[[C7]]
-//       CHECK:   %[[T1:.+]] = divi_signed %[[T0]], %[[C8]]
+//       CHECK:   %[[T0:.+]] = addi %[[DIM1]], %[[C15]]
+//       CHECK:   %[[T1:.+]] = divi_signed %[[T0]], %[[C16]]
 //       CHECK:   %[[T2:.+]] = addi %[[DIM0]], %[[C7]]
 //       CHECK:   %[[T3:.+]] = divi_signed %[[T2]], %[[C8]]
 //       CHECK:   return %[[T1]], %[[T3]], %[[C1]]
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_subgroup.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_subgroup.mlir
new file mode 100644
index 0000000..8421f81
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_subgroup.mlir
@@ -0,0 +1,77 @@
+// RUN: iree-opt -split-input-file -iree-codegen-linalg-tile-and-fuse=use-vectorization %s | IreeFileCheck %s
+
+module attributes {
+  spv.target_env =
+    #spv.target_env<#spv.vce<v1.3,
+      [Shader, CooperativeMatrixNV],
+      [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>,
+      {max_compute_workgroup_invocations = 512 : i32,
+       max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+  func @matmul_static_shape()
+    attributes {vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
+    %arg0 = iree.placeholder for "interface buffer"
+      {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<128x64xf16>
+    %arg1 = iree.placeholder for "interface buffer"
+      {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<64x256xf16>
+    %ret0 = iree.placeholder for "interface buffer"
+      {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<128x256xf16>
+    linalg.matmul %arg0, %arg1, %ret0 :
+      (memref<128x64xf16>, memref<64x256xf16>, memref<128x256xf16>)
+    return
+  }
+  func @matmul_static_shape__num_workgroups__
+    (!shapex.ranked_shape<[128, 64]>, !shapex.ranked_shape<[64, 256]>,
+     !shapex.ranked_shape<[128, 256]>) -> (index, index, index)
+    attributes {sym_visibility = "private"}
+  hal.interface @legacy_io attributes {sym_visibility = "private"} {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+  }
+}
+
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 32)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>
+//      CHECK: func @matmul_static_shape
+//  CHECK-DAG:  %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
+//  CHECK-DAG:  %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
+//  CHECK-DAG:  %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
+//  CHECK-DAG:  %[[C0:.+]] = constant 0 : index
+//  CHECK-DAG:  %[[CST:.+]] = constant 0.0
+//  CHECK-DAG:  %[[C4:.+]] = constant 4 : index
+//      CHECK:  %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
+//      CHECK:  %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
+//      CHECK:  %[[BOFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+//      CHECK:  %[[SUBVIEW_LHS:.+]] = subview %[[ARG0]]
+// CHECK-SAME:    [%[[BOFFSET_Y]], 0] [32, 64]
+//      CHECK:  %[[BOFFSET_X:.+]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+//      CHECK:  %[[SUBVIEW_RHS:.+]] = subview %[[ARG1]]
+// CHECK-SAME:    [0, %[[BOFFSET_X]]] [64, 32]
+//      CHECK:  %[[BOFFSET_Y_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
+//      CHECK:  %[[BOFFSET_X_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
+//      CHECK:  %[[SUBVIEW_RESULT:.+]] = subview %[[RET0]]
+// CHECK-SAME:    [%[[BOFFSET_Y_2]], %[[BOFFSET_X_2]]] [32, 32]
+//      CHECK:  %[[SGID:.+]] = gpu.subgroup_id
+//      CHECK:  %[[SGID_Y:.+]] = divi_signed %[[SGID]], %[[C4]]
+//      CHECK:  scf.for %[[IV2:.+]] =
+//      CHECK:    %[[SGOFFSET_Y:.+]] = affine.apply #[[MAP3]]()[%[[SGID_Y]]]
+//      CHECK:    %[[SUBVIEW2_LHS:.+]] = subview %[[SUBVIEW_LHS]]
+// CHECK-SAME:      [%[[SGOFFSET_Y]], %[[IV2]]] [8, 16]
+//      CHECK:    %[[SGOFFSET_X:.+]] = affine.apply #[[MAP3]]()[%[[SGID]]]
+//      CHECK:    %[[SUBVIEW2_RHS:.+]] = subview %[[SUBVIEW_RHS]]
+// CHECK-SAME:      [%[[IV2]], %[[SGOFFSET_X]]] [16, 8]
+//      CHECK:    %[[SGOFFSET_Y_2:.+]] = affine.apply #[[MAP3]]()[%[[SGID_Y]]]
+//      CHECK:    %[[SGOFFSET_X_2:.+]] = affine.apply #[[MAP3]]()[%[[SGID]]]
+//      CHECK:    %[[SUBVIEW2_RESULT:.+]] = subview %[[SUBVIEW_RESULT]]
+// CHECK-SAME:      [%[[SGOFFSET_Y_2]], %[[SGOFFSET_X_2]]] [8, 8]
+//      CHECK:    %[[VTR_LHS:.+]] = vector.transfer_read %[[SUBVIEW2_LHS]]
+// CHECK-SAME:      [%[[C0]], %[[C0]]], %[[CST]] {masked = [false, false]}
+//      CHECK:    %[[VTR_RHS:.+]] = vector.transfer_read %[[SUBVIEW2_RHS]]
+// CHECK-SAME:      [%[[C0]], %[[C0]]], %[[CST]] {masked = [false, false]}
+//      CHECK:    %[[VTR_RESULT:.+]] = vector.transfer_read %[[SUBVIEW2_RESULT]]
+// CHECK-SAME:      [%[[C0]], %[[C0]]], %[[CST]] {masked = [false, false]}
+//      CHECK:    %[[VECTOR_CONTRACT:.+]] = vector.contract
+// CHECK-SAME:      iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-SAME:      vector<8x16xf16>, vector<16x8xf16> into vector<8x8xf16>
+//      CHECK:    vector.transfer_write %[[VECTOR_CONTRACT]], %[[SUBVIEW2_RESULT]]
+// CHECK-SAME:      masked = [false, false]
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir
new file mode 100644
index 0000000..ddd2530
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir
@@ -0,0 +1,37 @@
+// RUN: iree-opt -split-input-file -iree-codegen-linalg-to-spirv-pipeline=use-vectorization %s | IreeFileCheck %s
+
+module attributes {
+  spv.target_env =
+    #spv.target_env<#spv.vce<v1.3,
+      [Float16, Shader, CooperativeMatrixNV],
+      [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>,
+      {max_compute_workgroup_invocations = 512 : i32,
+       max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+  func @matmul_static_shape()
+    attributes {vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
+    %0 = iree.placeholder for "interface buffer"
+      {binding = @legacy_io::@arg0, operand_result_num = 0} : memref<128x64xf16>
+    %1 = iree.placeholder for "interface buffer"
+      {binding = @legacy_io::@arg1, operand_result_num = 1} : memref<64x256xf16>
+    %2 = iree.placeholder for "interface buffer"
+      {binding = @legacy_io::@ret0, operand_result_num = 2} : memref<128x256xf16>
+    linalg.matmul %0, %1, %2 :
+      (memref<128x64xf16>, memref<64x256xf16>, memref<128x256xf16>)
+    return
+  }
+  func @matmul_static_shape__num_workgroups__
+    (!shapex.ranked_shape<[128, 64]>, !shapex.ranked_shape<[64, 256]>,
+     !shapex.ranked_shape<[128, 256]>) -> (index, index, index)
+    attributes {sym_visibility = "private"}
+  hal.interface @legacy_io attributes {sym_visibility = "private"} {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+  }
+}
+// CHECK-LABEL: spv.func @matmul_static_shape
+//       CHECK:   spv.CooperativeMatrixLoadNV
+//       CHECK:   spv.CooperativeMatrixLoadNV
+//       CHECK:   spv.CooperativeMatrixLoadNV
+//       CHECK:   spv.CooperativeMatrixMulAddNV
+//       CHECK:   spv.CooperativeMatrixStoreNV
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
index 9ff21ba..5b1d440 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
@@ -5,8 +5,7 @@
     #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
                     {max_compute_workgroup_invocations = 128 : i32,
                      max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
-  func @matmul_tile()
-    attributes {signature = (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)} {
+  func @matmul_tile() attributes {vkspv.num_workgroups_fn = @matmul_tile__num_workgroups__} {
     %0 = iree.placeholder for "interace buffer"
       {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<?x?xf32>
     %1 = iree.placeholder for "interace buffer"
@@ -16,6 +15,10 @@
     linalg.matmul %0, %1, %2 : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
     return
   }
+  func @matmul_tile__num_workgroups__
+    (!shapex.ranked_shape<[?,?]>, !shapex.ranked_shape<[?,?]>,
+     !shapex.ranked_shape<[?,?]>) -> (index, index, index)
+    attributes {sym_visibility = "private"}
   hal.interface @legacy_io attributes {sym_visibility = "private"} {
     hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
     hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
@@ -30,9 +33,9 @@
 //       CHECK:     %[[ARG0SV:.+]] = subview %[[ARG0]]
 //       CHECK:     %[[ARG1SV:.+]] = subview %[[ARG1]]
 //       CHECK:     %[[RET0SV:.+]] = subview %[[RET0]]
-//       CHECK:     %[[ALLOC1:.+]] = alloc() : memref<8x4xf32, 3>
+//       CHECK:     %[[ALLOC1:.+]] = alloc() : memref<8x32xf32, 3>
 //       CHECK:     %[[SUBVIEW1:.+]] = subview %[[ALLOC1]]
-//       CHECK:     %[[ALLOC2:.+]] = alloc() : memref<4x8xf32, 3>
+//       CHECK:     %[[ALLOC2:.+]] = alloc() : memref<32x16xf32, 3>
 //       CHECK:     %[[SUBVIEW2:.+]] = subview %[[ALLOC2]]
 //       CHECK:     linalg.copy(%[[ARG0SV]], %[[SUBVIEW1]])
 //  CHECK-SAME:       "copy_to_workgroup_memory"
@@ -41,8 +44,8 @@
 //       CHECK:     linalg.matmul
 //  CHECK-SAME:       "workgroup_memory_numprocs_ge_numiters"
 //  CHECK-SAME:       %[[SUBVIEW1]], %[[SUBVIEW2]], %[[RET0SV]]
-//   CHECK-DAG:     dealloc %[[ALLOC1]] : memref<8x4xf32, 3>
-//   CHECK-DAG:     dealloc %[[ALLOC2]] : memref<4x8xf32, 3>
+//   CHECK-DAG:     dealloc %[[ALLOC1]] : memref<8x32xf32, 3>
+//   CHECK-DAG:     dealloc %[[ALLOC2]] : memref<32x16xf32, 3>
 
 // -----
 
@@ -51,8 +54,7 @@
     #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
                     {max_compute_workgroup_invocations = 128 : i32,
                      max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
-  func @conv_no_padding_tile()
-    attributes {signature = (tensor<3x4x3x2xf32>, tensor<?x?x?x3xf32>) -> (tensor<?x?x?x2xf32>)}  {
+  func @conv_no_padding_tile() {
     %0 = iree.placeholder for "interace buffer"
       {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<3x4x3x2xf32>
     %1 = iree.placeholder for "interace buffer"
@@ -63,6 +65,11 @@
       : memref<3x4x3x2xf32>, memref<?x?x?x3xf32>, memref<?x?x?x2xf32>
     return
   }
+  hal.interface @legacy_io attributes {sym_visibility = "private"} {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+  }
 }
 //       CHECK: func @conv_no_padding_tile()
 //   CHECK-DAG:   %[[ARG0:.+]] = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg0
diff --git a/iree/compiler/Conversion/LinalgToVector/BUILD b/iree/compiler/Conversion/LinalgToVector/BUILD
new file mode 100644
index 0000000..95676fb
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToVector/BUILD
@@ -0,0 +1,44 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+    default_visibility = ["//visibility:public"],
+    features = ["layering_check"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+cc_library(
+    name = "LinalgToVector",
+    srcs = [
+        "LoadStoreVectorization.cpp",
+    ],
+    hdrs = [
+        "Passes.h",
+    ],
+    deps = [
+        "//iree/compiler/Conversion/CodegenUtils",
+        "//iree/compiler/Dialect/IREE/IR",
+        "//iree/compiler/Dialect/Shape/IR",
+        "//iree/compiler/Dialect/Shape/Transforms",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:LinalgOps",
+        "@llvm-project//mlir:LinalgTransforms",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:StandardOps",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:Transforms",
+        "@llvm-project//mlir:VectorOps",
+    ],
+)
diff --git a/iree/compiler/Conversion/LinalgToVector/CMakeLists.txt b/iree/compiler/Conversion/LinalgToVector/CMakeLists.txt
new file mode 100644
index 0000000..1c7e227
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToVector/CMakeLists.txt
@@ -0,0 +1,39 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+  NAME
+    LinalgToVector
+  HDRS
+    "Passes.h"
+  SRCS
+    "LoadStoreVectorization.cpp"
+  DEPS
+    LLVMSupport
+    MLIRIR
+    MLIRLinalgOps
+    MLIRLinalgTransforms
+    MLIRPass
+    MLIRStandardOps
+    MLIRSupport
+    MLIRTransforms
+    MLIRVector
+    iree::compiler::Conversion::CodegenUtils
+    iree::compiler::Dialect::IREE::IR
+    iree::compiler::Dialect::Shape::IR
+    iree::compiler::Dialect::Shape::Transforms
+  PUBLIC
+)
diff --git a/iree/compiler/Conversion/LinalgToVector/LoadStoreVectorization.cpp b/iree/compiler/Conversion/LinalgToVector/LoadStoreVectorization.cpp
new file mode 100644
index 0000000..65c3da7
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToVector/LoadStoreVectorization.cpp
@@ -0,0 +1,328 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Conversion/LinalgToVector/Passes.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+/// Returns the bitwidth of a scalar or vector type.
+static Optional<unsigned> getBitWidth(Type type) {
+  if (type.isIntOrFloat()) {
+    return type.getIntOrFloatBitWidth();
+  } else if (type.isa<VectorType>()) {
+    auto vecType = type.cast<VectorType>();
+    auto elementType = vecType.getElementType();
+    return elementType.getIntOrFloatBitWidth() * vecType.getNumElements();
+  }
+  return {};
+}
+
+constexpr int kVectorizationSizeInBits = 128;
+constexpr int kVecSize = kVectorizationSizeInBits / (sizeof(float) * 8);
+
+/// Returns a VectorType in `kVectorizationSizeInBits` bits if `t` is a scalar.
+static VectorType getVecType(OpBuilder &builder, Type t) {
+  if (!t.isa<IntegerType, FloatType>()) return {};
+  if (t.getIntOrFloatBitWidth() != 32) return {};
+  Type newElemType = t.isa<IntegerType>() ? builder.getI32Type().cast<Type>()
+                                          : builder.getF32Type().cast<Type>();
+  return VectorType::get(kVecSize, newElemType);
+}
+
+/// Returns the memref of vector converted from `type`.
+static MemRefType getVectorizedMemRefType(OpBuilder &builder, MemRefType type) {
+  Type elemType = type.getElementType();
+  VectorType vecType = getVecType(builder, elemType);
+  if (!vecType) return {};
+  unsigned elemSize = elemType.getIntOrFloatBitWidth();
+  unsigned vecSize = kVectorizationSizeInBits / elemSize;
+  SmallVector<int64_t, 2> newShape(type.getShape().begin(),
+                                   type.getShape().end());
+  if (newShape.empty()) return {};
+  if (newShape.back() % vecSize != 0) return {};
+  newShape.back() = newShape.back() / vecSize;
+  return MemRefType::get(newShape, vecType, {}, type.getMemorySpace());
+}
+
+/// Returns a vectorized `val`, ie, the result type is a VectorType.
+static Value legalizeToVectorType(OpBuilder &builder, Value val) {
+  Type type = val.getType();
+  if (type.isa<VectorType>()) {
+    return val;
+  } else if (type.isIntOrFloat()) {
+    auto vecType = getVecType(builder, type);
+    if (!vecType) return nullptr;
+    // TODO(hanchung): Add a folder on vector::BroadcastOp so we don't need to
+    // create manually.
+    if (auto cst = val.getDefiningOp<ConstantOp>()) {
+      auto cstVecValue = DenseElementsAttr::get(vecType, cst.value());
+      return builder.create<ConstantOp>(val.getLoc(), vecType, cstVecValue)
+          .getResult();
+    }
+    return builder.create<vector::BroadcastOp>(val.getLoc(), vecType, val)
+        .getResult();
+  }
+  return nullptr;
+}
+
+/// Base class to vectorize std ops. If a generic op is vectorized, all the std
+/// ops in the region should be vectorized as well.
+///
+/// This base class handles the check on operands and vectorization for all the
+/// operands.
+///
+/// All derived classes implement a static apply method with the following
+/// signature:
+///
+/// ```c++
+/// LogicalResult apply(SrcOpTy op, ArrayRef<Value> args,
+///                     ConversionPatternRewriter& rewriter) const;
+/// ```
+template <typename DerivedTy, typename SrcOpTy>
+struct VectorizeOpBase : public OpConversionPattern<SrcOpTy> {
+  using OpConversionPattern<SrcOpTy>::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      SrcOpTy op, ArrayRef<Value> args,
+      ConversionPatternRewriter &rewriter) const override {
+    if (llvm::all_of(args, [](Value arg) {
+          return arg.getType().isIntOrIndexOrFloat();
+        })) {
+      return failure();
+    }
+    SmallVector<Value, 4> vecArgs;
+    for (Value arg : args) {
+      Value val = legalizeToVectorType(rewriter, arg);
+      if (!val) return failure();
+      vecArgs.push_back(val);
+    }
+    return static_cast<DerivedTy const *>(this)->apply(op, vecArgs, rewriter);
+  }
+};
+
+template <typename OpTy>
+struct VectorizeElementwiseOp
+    : public VectorizeOpBase<VectorizeElementwiseOp<OpTy>, OpTy> {
+  using VectorizeOpBase<VectorizeElementwiseOp<OpTy>, OpTy>::VectorizeOpBase;
+  LogicalResult apply(OpTy op, ArrayRef<Value> args,
+                      ConversionPatternRewriter &rewriter) const {
+    auto vecType = getVecType(rewriter, op.getResult().getType());
+    if (!vecType) return failure();
+    auto newOp = rewriter.create<OpTy>(op.getLoc(), vecType, args);
+    rewriter.replaceOp(op, newOp.getOperation()->getResults());
+    return success();
+  }
+};
+
+template <typename OpTy>
+struct VectorizeCmpOp : public VectorizeOpBase<VectorizeCmpOp<OpTy>, OpTy> {
+  using VectorizeOpBase<VectorizeCmpOp<OpTy>, OpTy>::VectorizeOpBase;
+  LogicalResult apply(OpTy op, ArrayRef<Value> args,
+                      ConversionPatternRewriter &rewriter) const {
+    auto newOp =
+        rewriter.create<OpTy>(op.getLoc(), op.predicate(), args[0], args[1]);
+    rewriter.replaceOp(op, newOp.getResult());
+    return success();
+  }
+};
+
+struct VectorizeSelectOp
+    : public VectorizeOpBase<VectorizeSelectOp, mlir::SelectOp> {
+  using VectorizeOpBase<VectorizeSelectOp, mlir::SelectOp>::VectorizeOpBase;
+  LogicalResult apply(mlir::SelectOp op, ArrayRef<Value> args,
+                      ConversionPatternRewriter &rewriter) const {
+    auto newOp =
+        rewriter.create<SelectOp>(op.getLoc(), args[0], args[1], args[2]);
+    rewriter.replaceOp(op, newOp.getResult());
+    return success();
+  }
+};
+
+struct VectorizeGenericOp : public OpConversionPattern<linalg::GenericOp> {
+  using OpConversionPattern<linalg::GenericOp>::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      linalg::GenericOp genericOp, ArrayRef<Value> args,
+      ConversionPatternRewriter &rewriter) const override {
+    if (llvm::any_of(genericOp.iterator_types(), [](Attribute attr) {
+          return attr.cast<StringAttr>().getValue() !=
+                 getParallelIteratorTypeName();
+        })) {
+      return failure();
+    }
+
+    // Do not vectorize if one of the operand is 0-D or one of the operand is
+    // not iterated on contiguous memory.
+    for (auto map : genericOp.getIndexingMaps()) {
+      if (map.getNumResults() == 0) return failure();
+      AffineDimExpr innerMostExpr =
+          map.getResults().back().dyn_cast<AffineDimExpr>();
+      if (!innerMostExpr ||
+          innerMostExpr.getPosition() != map.getNumDims() - 1) {
+        return failure();
+      }
+    }
+
+    SmallVector<IREE::PlaceholderOp, 4> operands;
+    SmallVector<MemRefType, 4> vecMemRefs;
+    for (auto operand : args) {
+      auto op = operand.getDefiningOp<IREE::PlaceholderOp>();
+      if (!op) return failure();
+      if (!op.getOperation()->hasOneUse()) return failure();
+      auto memrefType = op.getResult().getType().dyn_cast<MemRefType>();
+      if (!memrefType) return failure();
+      auto vecMemRef = getVectorizedMemRefType(rewriter, memrefType);
+      if (!vecMemRef) return failure();
+      operands.push_back(op);
+      vecMemRefs.push_back(vecMemRef);
+    }
+
+    SmallVector<Value, 4> newArgs;
+    for (auto it : llvm::zip(operands, vecMemRefs)) {
+      IREE::PlaceholderOp placeholder = std::get<0>(it);
+      MemRefType vecMemRef = std::get<1>(it);
+      auto arg = rewriter.create<IREE::PlaceholderOp>(placeholder.getLoc(),
+                                                      vecMemRef, ValueRange{},
+                                                      placeholder.getAttrs());
+      rewriter.replaceOp(placeholder, arg.getResult());
+      newArgs.push_back(arg.getResult());
+    }
+
+    auto newOp = rewriter.create<linalg::GenericOp>(
+        genericOp.getLoc(), genericOp.getResultTypes(), newArgs,
+        rewriter.getI64IntegerAttr(genericOp.getNumInputs()),
+        rewriter.getI64IntegerAttr(genericOp.getNumOutputs()),
+        genericOp.indexing_mapsAttr(), genericOp.iterator_types(),
+        /*doc=*/nullptr,
+        /*library_call=*/nullptr,
+        /*symbol_source=*/nullptr);
+
+    Region &newRegion = newOp.region();
+    rewriter.inlineRegionBefore(genericOp.getRegion(), newRegion,
+                                newRegion.end());
+    Block &newBlock = newOp.region().front();
+    TypeConverter::SignatureConversion signatureConverter(
+        newBlock.getNumArguments());
+    for (auto arg : llvm::enumerate(vecMemRefs)) {
+      signatureConverter.addInputs(arg.index(), arg.value().getElementType());
+    }
+    rewriter.applySignatureConversion(&newOp.region(), signatureConverter);
+    rewriter.replaceOp(genericOp, newOp.getResults());
+    return success();
+  }
+};
+
+struct LoadStoreVectorizationPass
+    : public PassWrapper<LoadStoreVectorizationPass, OperationPass<FuncOp>> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<linalg::LinalgDialect, vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    OwningRewritePatternList patterns;
+    // clang-format off
+    patterns.insert<
+        VectorizeGenericOp,
+        VectorizeCmpOp<CmpFOp>,
+        VectorizeCmpOp<CmpIOp>,
+        VectorizeSelectOp,
+        VectorizeElementwiseOp<AbsFOp>,
+        VectorizeElementwiseOp<AndOp>,
+        VectorizeElementwiseOp<OrOp>,
+        VectorizeElementwiseOp<XOrOp>,
+        VectorizeElementwiseOp<AddFOp>,
+        VectorizeElementwiseOp<AddIOp>,
+        VectorizeElementwiseOp<CeilFOp>,
+        VectorizeElementwiseOp<CosOp>,
+        VectorizeElementwiseOp<DivFOp>,
+        VectorizeElementwiseOp<ExpOp>,
+        VectorizeElementwiseOp<FPExtOp>,
+        VectorizeElementwiseOp<FPToSIOp>,
+        VectorizeElementwiseOp<FPTruncOp>,
+        VectorizeElementwiseOp<FloorFOp>,
+        VectorizeElementwiseOp<LogOp>,
+        VectorizeElementwiseOp<MulFOp>,
+        VectorizeElementwiseOp<MulIOp>,
+        VectorizeElementwiseOp<NegFOp>,
+        VectorizeElementwiseOp<RemFOp>,
+        VectorizeElementwiseOp<RsqrtOp>,
+        VectorizeElementwiseOp<SIToFPOp>,
+        VectorizeElementwiseOp<ShiftLeftOp>,
+        VectorizeElementwiseOp<SignExtendIOp>,
+        VectorizeElementwiseOp<SignedDivIOp>,
+        VectorizeElementwiseOp<SignedShiftRightOp>,
+        VectorizeElementwiseOp<SinOp>,
+        VectorizeElementwiseOp<SqrtOp>,
+        VectorizeElementwiseOp<SubFOp>,
+        VectorizeElementwiseOp<SubIOp>,
+        VectorizeElementwiseOp<TanhOp>,
+        VectorizeElementwiseOp<TruncateIOp>,
+        VectorizeElementwiseOp<UnsignedDivIOp>,
+        VectorizeElementwiseOp<UnsignedRemIOp>,
+        VectorizeElementwiseOp<UnsignedShiftRightOp>>(context);
+    // clang-format on
+
+    ConversionTarget target(*context);
+    // Mark vector dialect and plancholder op legal.
+    target.addLegalDialect<vector::VectorDialect>();
+    target.addLegalOp<IREE::PlaceholderOp>();
+
+    // If a generic op is vectorized, it is legal.
+    target.addDynamicallyLegalOp<linalg::GenericOp>([](linalg::GenericOp op) {
+      if (!op.hasBufferSemantics()) return false;
+      for (auto arg : op.getOperands()) {
+        if (arg.getType()
+                .cast<MemRefType>()
+                .getElementType()
+                .isSignlessIntOrFloat())
+          return false;
+      }
+      return true;
+    });
+
+    // Mark all standard ops legal if they are operating on vector types.
+    target.addDynamicallyLegalDialect<mlir::StandardOpsDialect>(
+        Optional<ConversionTarget::DynamicLegalityCallbackFn>(
+            [](Operation *op) {
+              auto isVectorType = [](Type t) { return t.isa<VectorType>(); };
+              return llvm::any_of(op->getOperandTypes(), isVectorType) ||
+                     llvm::any_of(op->getResultTypes(), isVectorType);
+            }));
+    if (failed(applyPartialConversion(getOperation(), target, patterns)))
+      return signalPassFailure();
+  }
+};
+}  // namespace
+
+std::unique_ptr<Pass> createLoadStoreVectorizationPass() {
+  return std::make_unique<LoadStoreVectorizationPass>();
+}
+
+static PassRegistration<LoadStoreVectorizationPass> pass(
+    "iree-codegen-vectorize-linalg-ops", "Vectorize Linalg operations");
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToVector/Passes.h b/iree/compiler/Conversion/LinalgToVector/Passes.h
new file mode 100644
index 0000000..98f07bd
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToVector/Passes.h
@@ -0,0 +1,29 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_CONVERSION_LINALGTOVECTOR_PASSES_H_
+#define IREE_COMPILER_CONVERSION_LINALGTOVECTOR_PASSES_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Creates a pass to vectorize Linalg operations.
+std::unique_ptr<Pass> createLoadStoreVectorizationPass();
+
+}  // namespace iree_compiler
+}  // namespace mlir
+
+#endif  // IREE_COMPILER_CONVERSION_LINALGTOVECTOR_PASSES_H_
diff --git a/iree/compiler/Conversion/LinalgToVector/test/BUILD b/iree/compiler/Conversion/LinalgToVector/test/BUILD
new file mode 100644
index 0000000..e124cf5
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToVector/test/BUILD
@@ -0,0 +1,32 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Tests for common transforms.
+
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+
+package(
+    default_visibility = ["//visibility:public"],
+    features = ["layering_check"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+iree_lit_test_suite(
+    name = "lit",
+    srcs = glob(["*.mlir"]),
+    data = [
+        "//iree/tools:IreeFileCheck",
+        "//iree/tools:iree-opt",
+    ],
+)
diff --git a/iree/compiler/Conversion/LinalgToVector/test/CMakeLists.txt b/iree/compiler/Conversion/LinalgToVector/test/CMakeLists.txt
new file mode 100644
index 0000000..fcc538b
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToVector/test/CMakeLists.txt
@@ -0,0 +1,26 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir)
+iree_lit_test_suite(
+  NAME
+    lit
+  SRCS
+    "${_GLOB_X_MLIR}"
+  DATA
+    iree::tools::IreeFileCheck
+    iree::tools::iree-opt
+)
diff --git a/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_ops.mlir b/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_ops.mlir
new file mode 100644
index 0000000..8b72941
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_ops.mlir
@@ -0,0 +1,115 @@
+// RUN: iree-opt -split-input-file -iree-codegen-vectorize-linalg-ops -canonicalize -cse %s | IreeFileCheck %s
+
+func @broadcast_add() {
+  %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4xf32>
+  %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x4xf32>
+  %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<3x4xf32>
+  linalg.generic {args_in = 2 : i64,
+                  args_out = 1 : i64,
+                  indexing_maps = [affine_map<(d0, d1) -> (d1)>,
+                                   affine_map<(d0, d1) -> (d0, d1)>,
+                                   affine_map<(d0, d1) -> (d0, d1)>],
+                  iterator_types = ["parallel", "parallel"]
+  } %0, %1, %2 {
+  ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):  // no predecessors
+    %3 = addf %arg0, %arg1 : f32
+    linalg.yield %3 : f32
+  }: memref<4xf32>, memref<3x4xf32>, memref<3x4xf32>
+  return
+}
+// CHECK-LABEL: func @broadcast_add
+//   CHECK-DAG: %[[BUF0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1xvector<4xf32>>
+//   CHECK-DAG: %[[BUF1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x1xvector<4xf32>>
+//   CHECK-DAG: %[[BUF2:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<3x1xvector<4xf32>>
+//       CHECK: linalg.generic
+//  CHECK-SAME:   %[[BUF0]], %[[BUF1]], %[[BUF2]]
+//       CHECK: ^bb0(%[[ARG0:.+]]: vector<4xf32>, %[[ARG1:.+]]: vector<4xf32>, %[[ARG2:.+]]: vector<4xf32>)
+//       CHECK:   %[[RES:.+]] = addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
+//       CHECK:   linalg.yield %[[RES]] : vector<4xf32>
+
+// -----
+
+func @log_plus_one() {
+  %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4xf32>
+  %c0 = constant 0 : index
+  %cst = constant 1.000000e+00 : f32
+  %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4xf32>
+  linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} %1, %0 {
+  ^bb0(%arg0: f32, %arg1: f32):  // no predecessors
+    %2 = addf %arg0, %cst : f32
+    %3 = log %2 : f32
+    linalg.yield %3 : f32
+  }: memref<4xf32>, memref<4xf32>
+  return
+}
+// CHECK-LABEL: func @log_plus_one
+//   CHECK-DAG: %[[BUF0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1xvector<4xf32>>
+//   CHECK-DAG: %[[BUF1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1xvector<4xf32>>
+//   CHECK-DAG: %[[CST:.+]] = constant dense<1.000000e+00> : vector<4xf32>
+//       CHECK: linalg.generic
+//  CHECK-SAME:   %[[BUF0]], %[[BUF1]]
+//       CHECK: ^bb0(%[[ARG0:.+]]: vector<4xf32>, %[[ARG1:.+]]: vector<4xf32>)
+//       CHECK:   %[[T1:.+]] = addf %[[ARG0]], %[[CST]] : vector<4xf32>
+//       CHECK:   %[[T2:.+]] = log %[[T1]] : vector<4xf32>
+//       CHECK:   linalg.yield %[[T2]] : vector<4xf32>
+
+// -----
+
+func @cmp_and_select() {
+  %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4xi32>
+  %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4xi32>
+  %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<4xi32>
+  linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} %1, %2, %0 {
+  ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):  // no predecessors
+    %3 = cmpi "sgt", %arg0, %arg1 : i32
+    %4 = select %3, %arg0, %arg1 : i32
+    linalg.yield %4 : i32
+  }: memref<4xi32>, memref<4xi32>, memref<4xi32>
+  return
+}
+// CHECK-LABEL: func @cmp_and_select
+//   CHECK-DAG: %[[BUF0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1xvector<4xi32>>
+//   CHECK-DAG: %[[BUF1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<1xvector<4xi32>>
+//   CHECK-DAG: %[[BUF2:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1xvector<4xi32>>
+//       CHECK: linalg.generic
+//  CHECK-SAME:   %[[BUF0]], %[[BUF1]], %[[BUF2]]
+//       CHECK: ^bb0(%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>)
+//       CHECK:   %[[T1:.+]] = cmpi "sgt", %[[ARG0]], %[[ARG1]] : vector<4xi32>
+//       CHECK:   %[[T2:.+]] = select %[[T1]], %[[ARG0]], %[[ARG1]] : vector<4xi1>, vector<4xi32>
+//       CHECK:   linalg.yield %[[T2]] : vector<4xi32>
+
+// -----
+
+func @not_contiguous() {
+  %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4x4xf32>
+  %c0 = constant 0 : index
+  %cst = constant 1.000000e+00 : f32
+  %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4x4xf32>
+  linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} %1, %0 {
+  ^bb0(%arg0: f32, %arg1: f32):  // no predecessors
+    %2 = addf %arg0, %cst : f32
+    linalg.yield %2 : f32
+  }: memref<4x4xf32>, memref<4x4xf32>
+  return
+}
+// CHECK-LABEL: func @not_contiguous
+//   CHECK-DAG: iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4x4xf32>
+//   CHECK-DAG: iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4x4xf32>
+
+// -----
+
+func @not_4s() {
+  %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4x3xf32>
+  %c0 = constant 0 : index
+  %cst = constant 1.000000e+00 : f32
+  %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4x3xf32>
+  linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} %1, %0 {
+  ^bb0(%arg0: f32, %arg1: f32):  // no predecessors
+    %2 = addf %arg0, %cst : f32
+    linalg.yield %2 : f32
+  }: memref<4x3xf32>, memref<4x3xf32>
+  return
+}
+// CHECK-LABEL: func @not_4s
+//   CHECK-DAG: iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4x3xf32>
+//   CHECK-DAG: iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4x3xf32>
diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h
index 2978c69..c06e2d8 100644
--- a/iree/compiler/Conversion/init_conversions.h
+++ b/iree/compiler/Conversion/init_conversions.h
@@ -19,6 +19,7 @@
 #include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
 #include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
 #include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
+#include "iree/compiler/Conversion/LinalgToVector/Passes.h"
 
 namespace mlir {
 namespace iree_compiler {
@@ -33,11 +34,19 @@
   createHLOToLinalgOnTensorsPass();
 }
 
+inline void registerLinalgToVectorPasses() {
+  static bool init_once = []() {
+    createLoadStoreVectorizationPass();
+    return true;
+  }();
+  (void)init_once;
+}
+
 inline void registerLinalgToSPIRVPasses() {
   static bool init_once = []() {
     // LinalgToSPIRV
     createConvertToGPUPass();
-    createLinalgTileAndFusePass();
+    createLinalgTileAndFusePass(SPIRVCodegenOptions());
     createSplitDispatchFunctionPass();
     createVectorToGPUPass();
     createMatMulTileAndVectorizeGPUPass();
diff --git a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp
index 2017d60..1226491 100644
--- a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp
@@ -221,10 +221,11 @@
 
     // Ordinals are fixed based on the precomputed schedule, so use
     // CommandBufferDispatchOp instead of CommandBufferDispatchSymbolOp.
+    int32_t entryPointOrdinal = it.index();
     builder.create<IREE::HAL::CommandBufferDispatchOp>(
         loc, commandBuffer, executable,
-        builder.getI32IntegerAttr(/*entryPointOrdinal=*/it.index()),
-        workgroupCount[0], workgroupCount[1], workgroupCount[2]);
+        builder.getI32IntegerAttr(entryPointOrdinal), workgroupCount[0],
+        workgroupCount[1], workgroupCount[2]);
     if (it.index() + 1 != spvEntryPointFns.size()) {
       recordFullExecutionBarrier(commandBuffer, loc, builder);
     }
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
index 4a9ac26..582c165 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -48,6 +48,12 @@
   // llvm::cl::OptionCategory halVulkanSPIRVOptionsCategory(
   //     "IREE Vulkan/SPIR-V backend options");
 
+  static llvm::cl::opt<bool> clUseVectorPass(
+      "iree-spirv-use-vector-pass",
+      llvm::cl::desc(
+          "Enable use of Linalg vectorization in SPIR-V code generation"),
+      llvm::cl::init(false));
+
   static llvm::cl::opt<bool> clUseWorkgroupMemory(
       "iree-spirv-use-workgroup-memory",
       llvm::cl::desc(
@@ -77,6 +83,7 @@
   targetOptions.codegenOptions.tileSizes.assign(clTileSizes.begin(),
                                                 clTileSizes.end());
   targetOptions.codegenOptions.useWorkgroupMemory = clUseWorkgroupMemory;
+  targetOptions.codegenOptions.useVectorPass = clUseVectorPass;
   targetOptions.vulkanTargetEnv = clVulkanTargetEnv;
   return targetOptions;
 }
diff --git a/iree/compiler/Dialect/VM/Target/C/CMakeLists.txt b/iree/compiler/Dialect/VM/Target/C/CMakeLists.txt
index ae197cc..0b0b4da 100644
--- a/iree/compiler/Dialect/VM/Target/C/CMakeLists.txt
+++ b/iree/compiler/Dialect/VM/Target/C/CMakeLists.txt
@@ -15,7 +15,7 @@
 
 if(${IREE_ENABLE_EMITC})
   iree_add_all_subdirs()
-  
+
   iree_cc_library(
     NAME
       C
diff --git a/iree/test/e2e/vulkan_specific/BUILD b/iree/test/e2e/vulkan_specific/BUILD
index 9686098..f4741ac 100644
--- a/iree/test/e2e/vulkan_specific/BUILD
+++ b/iree/test/e2e/vulkan_specific/BUILD
@@ -45,3 +45,15 @@
     driver = "vulkan",
     target_backend = "vulkan-spirv",
 )
+
+iree_check_single_backend_test_suite(
+    name = "check_vulkan-spirv_vulkan_vector",
+    srcs = [
+        "compare.mlir",
+        "log_plus_one.mlir",
+        "pw_add_multiwg.mlir",
+    ],
+    compiler_flags = ["-iree-spirv-use-vector-pass"],
+    driver = "vulkan",
+    target_backend = "vulkan-spirv",
+)
diff --git a/iree/test/e2e/vulkan_specific/CMakeLists.txt b/iree/test/e2e/vulkan_specific/CMakeLists.txt
index 4260b07..d5bd481 100644
--- a/iree/test/e2e/vulkan_specific/CMakeLists.txt
+++ b/iree/test/e2e/vulkan_specific/CMakeLists.txt
@@ -41,3 +41,18 @@
   COMPILER_FLAGS
     "-iree-spirv-use-workgroup-memory"
 )
+
+iree_check_single_backend_test_suite(
+  NAME
+    check_vulkan-spirv_vulkan_vector
+  SRCS
+    "compare.mlir"
+    "log_plus_one.mlir"
+    "pw_add_multiwg.mlir"
+  TARGET_BACKEND
+    vulkan-spirv
+  DRIVER
+    vulkan
+  COMPILER_FLAGS
+    "-iree-spirv-use-vector-pass"
+)
diff --git a/iree/test/e2e/vulkan_specific/compare.mlir b/iree/test/e2e/vulkan_specific/compare.mlir
new file mode 100644
index 0000000..099670a
--- /dev/null
+++ b/iree/test/e2e/vulkan_specific/compare.mlir
@@ -0,0 +1,164 @@
+func @compare_tensor() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+  %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+  check.expect_eq_const(%output, dense<[0, 1, 0, 1]> : tensor<4xi8>) : tensor<4xi8>
+  return
+}
+
+func @compare_scalar() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<1> : tensor<i32>
+  %rhs = iree.unfoldable_constant dense<5> : tensor<i32>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+  check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+  return
+}
+
+func @compare_i8() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<1> : tensor<i8>
+  %rhs = iree.unfoldable_constant dense<5> : tensor<i8>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<i8>, tensor<i8>) -> tensor<i1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+  check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+  return
+}
+
+func @compare_i16() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<1> : tensor<i16>
+  %rhs = iree.unfoldable_constant dense<5> : tensor<i16>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<i16>, tensor<i16>) -> tensor<i1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+  check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+  return
+}
+
+func @compare_i32() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<1> : tensor<i32>
+  %rhs = iree.unfoldable_constant dense<5> : tensor<i32>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+  check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+  return
+}
+
+func @compare_i64() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<1> : tensor<i64>
+  %rhs = iree.unfoldable_constant dense<5> : tensor<i64>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+  check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+  return
+}
+
+func @compare_f32() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<1.0> : tensor<f32>
+  %rhs = iree.unfoldable_constant dense<5.0> : tensor<f32>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+  check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+  return
+}
+
+func @compare_f64() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<1.0> : tensor<f64>
+  %rhs = iree.unfoldable_constant dense<5.0> : tensor<f64>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<f64>, tensor<f64>) -> tensor<i1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<i8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<i8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<i1>, tensor<i8>, tensor<i8>) -> tensor<i8>
+  check.expect_eq_const(%output, dense<0> : tensor<i8>) : tensor<i8>
+  return
+}
+
+func @compare_tensor_odd_length() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<[1, 2, 7]> : tensor<3xi32>
+  %rhs = iree.unfoldable_constant dense<[5, 2, 3]> : tensor<3xi32>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<3xi8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<3xi8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<3xi1>, tensor<3xi8>, tensor<3xi8>) -> tensor<3xi8>
+  check.expect_eq_const(%output, dense<[0, 1, 0]> : tensor<3xi8>) : tensor<3xi8>
+  return
+}
+
+func @compare_eq() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+  %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+  check.expect_eq_const(%output, dense<[0, 1, 0, 1]> : tensor<4xi8>) : tensor<4xi8>
+  return
+}
+
+func @compare_ne() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+  %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+  check.expect_eq_const(%output, dense<[1, 0, 1, 0]> : tensor<4xi8>) : tensor<4xi8>
+  return
+}
+
+func @compare_lt() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+  %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+  check.expect_eq_const(%output, dense<[1, 0, 0, 0]> : tensor<4xi8>) : tensor<4xi8>
+  return
+}
+
+func @compare_le() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+  %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+  check.expect_eq_const(%output, dense<[1, 1, 0, 1]> : tensor<4xi8>) : tensor<4xi8>
+  return
+}
+
+func @compare_gt() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+  %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+  check.expect_eq_const(%output, dense<[0, 0, 1, 0]> : tensor<4xi8>) : tensor<4xi8>
+  return
+}
+
+func @compare_ge() attributes { iree.module.export } {
+  %lhs = iree.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32>
+  %rhs = iree.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32>
+  %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %c0 = iree.unfoldable_constant dense<0> : tensor<4xi8>
+  %c1 = iree.unfoldable_constant dense<1> : tensor<4xi8>
+  %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+  check.expect_eq_const(%output, dense<[0, 1, 1, 1]> : tensor<4xi8>) : tensor<4xi8>
+  return
+}
diff --git a/iree/test/e2e/vulkan_specific/log_plus_one.mlir b/iree/test/e2e/vulkan_specific/log_plus_one.mlir
new file mode 100644
index 0000000..3bba4a9
--- /dev/null
+++ b/iree/test/e2e/vulkan_specific/log_plus_one.mlir
@@ -0,0 +1,6 @@
+func @log_plus_one() attributes { iree.module.export } {
+  %input = iree.unfoldable_constant dense<[0.0, 0.5, 1.0, 5.0]> : tensor<4xf32>
+  %result = "mhlo.log_plus_one"(%input) : (tensor<4xf32>) -> tensor<4xf32>
+  check.expect_almost_eq_const(%result, dense<[0.0, 0.4054651, 0.6931472, 1.7917595]> : tensor<4xf32>) : tensor<4xf32>
+  return
+}