[spirv] Add a pipeline to use workgroup memory (#8425)

This commit adds a pipeline for using workgroup memory in
SPIR-V CodeGen. It basically moves existing LLVM GPU
code into Common/ directory and use them. This pipeline
is meant to support desktop/server class GPUs where using
shared memory can give nice performance gains. So along
the way, added configuration stubs for NVIDIA/AMD GPUs.
They are not tuned yet, just to get started.
diff --git a/iree/compiler/Codegen/Common/BUILD b/iree/compiler/Codegen/Common/BUILD
index d681159..f4783b4 100644
--- a/iree/compiler/Codegen/Common/BUILD
+++ b/iree/compiler/Codegen/Common/BUILD
@@ -43,6 +43,8 @@
         "FoldAffineMinInDistributedLoops.cpp",
         "FoldTensorExtractOpPass.cpp",
         "ForOpCanonicalizationPass.cpp",
+        "GPUDistributeSharedMemoryCopy.cpp",
+        "GPUPipelining.cpp",
         "IREEComprehensiveBufferizePass.cpp",
         "InsertDistributionInfoPass.cpp",
         "LinalgBufferizePass.cpp",
diff --git a/iree/compiler/Codegen/Common/CMakeLists.txt b/iree/compiler/Codegen/Common/CMakeLists.txt
index c613a6b..80ef1e8 100644
--- a/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -36,6 +36,8 @@
     "FoldAffineMinInDistributedLoops.cpp"
     "FoldTensorExtractOpPass.cpp"
     "ForOpCanonicalizationPass.cpp"
+    "GPUDistributeSharedMemoryCopy.cpp"
+    "GPUPipelining.cpp"
     "IREEComprehensiveBufferizePass.cpp"
     "InsertDistributionInfoPass.cpp"
     "LinalgBufferizePass.cpp"
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistributeSharedMemoryCopy.cpp b/iree/compiler/Codegen/Common/GPUDistributeSharedMemoryCopy.cpp
similarity index 91%
rename from iree/compiler/Codegen/LLVMGPU/LLVMGPUDistributeSharedMemoryCopy.cpp
rename to iree/compiler/Codegen/Common/GPUDistributeSharedMemoryCopy.cpp
index c0702f5..0499715 100644
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPUDistributeSharedMemoryCopy.cpp
+++ b/iree/compiler/Codegen/Common/GPUDistributeSharedMemoryCopy.cpp
@@ -8,12 +8,12 @@
 #include <numeric>
 
 #include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
-#include "iree/compiler/Codegen/LLVMGPU/LLVMGPUUtils.h"
 #include "iree/compiler/Codegen/PassDetail.h"
 #include "iree/compiler/Codegen/Passes.h"
 #include "iree/compiler/Codegen/Transforms/Transforms.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
 #include "iree/compiler/Codegen/Utils/MarkerUtils.h"
-#include "mlir/Dialect/GPU/Passes.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/MLIRContext.h"
@@ -195,9 +195,9 @@
 
 namespace {
 
-class LLVMGPUDistributeSharedMemoryCopyPass
-    : public LLVMGPUDistributeSharedMemoryCopyBase<
-          LLVMGPUDistributeSharedMemoryCopyPass> {
+class GPUDistributeSharedMemoryCopyPass
+    : public GPUDistributeSharedMemoryCopyBase<
+          GPUDistributeSharedMemoryCopyPass> {
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<vector::VectorDialect, scf::SCFDialect>();
   }
@@ -220,20 +220,8 @@
         copiesToWorkgroupMem, [flatWorkgroupSize](linalg::GenericOp copyOp) {
           auto shape =
               copyOp.getOperand(0).getType().cast<MemRefType>().getShape();
-          // Verify that each dimension of the shape can be distributed on the
-          // threads
-          int64_t threadsAvailable = flatWorkgroupSize;
-          for (auto &dim : llvm::enumerate(llvm::reverse(shape))) {
-            int64_t numElementPerThread =
-                dim.index() == 0 ? targetVectorSize : 1;
-            int64_t numThreads = dim.value() / numElementPerThread;
-            if (numThreads == 0) return false;
-            numThreads = std::min(numThreads, threadsAvailable);
-            if (threadsAvailable % numThreads != 0) return false;
-            threadsAvailable = threadsAvailable / numThreads;
-            if (threadsAvailable == 1) break;
-          }
-          return threadsAvailable == 1;
+          return canPerformVectorAccessUsingAllThreads(shape, flatWorkgroupSize,
+                                                       targetVectorSize);
         });
     if (isAligned) {
       // Step 1. Vectorize the shared memory copy.
@@ -292,8 +280,8 @@
 }  // namespace
 
 std::unique_ptr<OperationPass<func::FuncOp>>
-createLLVMGPUDistributeSharedMemoryCopy() {
-  return std::make_unique<LLVMGPUDistributeSharedMemoryCopyPass>();
+createGPUDistributeSharedMemoryCopy() {
+  return std::make_unique<GPUDistributeSharedMemoryCopyPass>();
 }
 
 }  // namespace iree_compiler
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPUPipelining.cpp b/iree/compiler/Codegen/Common/GPUPipelining.cpp
similarity index 94%
rename from iree/compiler/Codegen/LLVMGPU/LLVMGPUPipelining.cpp
rename to iree/compiler/Codegen/Common/GPUPipelining.cpp
index 8e6355c..1ff2865 100644
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPUPipelining.cpp
+++ b/iree/compiler/Codegen/Common/GPUPipelining.cpp
@@ -8,7 +8,6 @@
 #include "iree/compiler/Codegen/Passes.h"
 #include "iree/compiler/Codegen/Utils/Utils.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
-#include "mlir/Dialect/GPU/Passes.h"
 #include "mlir/Dialect/SCF/Transforms.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -83,9 +82,8 @@
 }
 
 namespace {
-struct LLVMGPUPipeliningPass
-    : public LLVMGPUPipeliningBase<LLVMGPUPipeliningPass> {
-  LLVMGPUPipeliningPass(unsigned depth) : depth(depth) {}
+struct GPUPipeliningPass : public GPUPipeliningBase<GPUPipeliningPass> {
+  GPUPipeliningPass(unsigned depth) : depth(depth) {}
   void runOnOperation() override {
     auto funcOp = getOperation();
     MLIRContext* context = &getContext();
@@ -155,9 +153,9 @@
 };
 }  // namespace
 
-std::unique_ptr<OperationPass<func::FuncOp>> createLLVMGPUPipeliningPass(
+std::unique_ptr<OperationPass<func::FuncOp>> createGPUPipeliningPass(
     unsigned depth) {
-  return std::make_unique<LLVMGPUPipeliningPass>(depth);
+  return std::make_unique<GPUPipeliningPass>(depth);
 }
 
 }  // namespace iree_compiler
diff --git a/iree/compiler/Codegen/Common/test/BUILD b/iree/compiler/Codegen/Common/test/BUILD
index d6abde2..b78d822 100644
--- a/iree/compiler/Codegen/Common/test/BUILD
+++ b/iree/compiler/Codegen/Common/test/BUILD
@@ -24,6 +24,7 @@
             "canonicalize_interface_load_store.mlir",
             "convert_to_destination_passing_style.mlir",
             "dead_alloc.mlir",
+            "distribute_gpu_shared_memory.mlir",
             "flatten_memref_subspan.mlir",
             "fold_affine_min_in_distributed_loops.mlir",
             "fold_tensor_extract_op.mlir",
diff --git a/iree/compiler/Codegen/Common/test/CMakeLists.txt b/iree/compiler/Codegen/Common/test/CMakeLists.txt
index f45b163..7e9b53b 100644
--- a/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -19,6 +19,7 @@
     "canonicalize_interface_load_store.mlir"
     "convert_to_destination_passing_style.mlir"
     "dead_alloc.mlir"
+    "distribute_gpu_shared_memory.mlir"
     "flatten_memref_subspan.mlir"
     "fold_affine_min_in_distributed_loops.mlir"
     "fold_tensor_extract_op.mlir"
diff --git a/iree/compiler/Codegen/LLVMGPU/test/distribute_wg_copy.mlir b/iree/compiler/Codegen/Common/test/distribute_gpu_shared_memory.mlir
similarity index 91%
rename from iree/compiler/Codegen/LLVMGPU/test/distribute_wg_copy.mlir
rename to iree/compiler/Codegen/Common/test/distribute_gpu_shared_memory.mlir
index dde4b51..51cbc0f 100644
--- a/iree/compiler/Codegen/LLVMGPU/test/distribute_wg_copy.mlir
+++ b/iree/compiler/Codegen/Common/test/distribute_gpu_shared_memory.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(func.func(iree-llvmgpu-distribute-shared-memory-copy))))' -cse %s | FileCheck %s
+// RUN: iree-opt -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(func.func(iree-gpu-distribute-shared-memory-copy))))' -cse %s | FileCheck %s
 
 // CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 32 + s0 floordiv 4)>
 // CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16)>
@@ -49,10 +49,10 @@
     //     CHECK: vector.transfer_write %[[R0]], %{{.*}}[%[[Y0]], %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, 3>
     //     CHECK: vector.transfer_write %[[R1]], %{{.*}}[%[[Y1]], %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, 3>
 
-        linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], 
+        linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
         iterator_types = ["parallel", "parallel"]}
-          ins(%m0 : memref<64x16xf32>) 
-          outs(%sm0 : memref<64x16xf32, 3>) 
+          ins(%m0 : memref<64x16xf32>)
+          outs(%sm0 : memref<64x16xf32, 3>)
           attrs= {__internal_linalg_transform__ = "copy_to_workgroup_memory"} {
           ^bb0(%arg4: f32, %s: f32):  // no predecessors
             linalg.yield %arg4 : f32
@@ -65,10 +65,10 @@
     //     CHECK: vector.transfer_write %[[R2]], %{{.*}}[%[[Y1]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<256x4xf32, 3>
     //     CHECK: vector.transfer_write %[[R3]], %{{.*}}[%[[Y2]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<256x4xf32, 3>
 
-        linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], 
+        linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
         iterator_types = ["parallel", "parallel"]}
-          ins(%m1 : memref<256x4xf32>) 
-          outs(%sm1 : memref<256x4xf32, 3>) 
+          ins(%m1 : memref<256x4xf32>)
+          outs(%sm1 : memref<256x4xf32, 3>)
           attrs= {__internal_linalg_transform__ = "copy_to_workgroup_memory"} {
           ^bb0(%arg4: f32, %s: f32):  // no predecessors
             linalg.yield %arg4 : f32
@@ -82,10 +82,10 @@
     //     CHECK: vector.transfer_write %[[R5]], %{{.*}}[%c1, %15] {in_bounds = [true, true]} : vector<1x4xf32>, memref<3x512xf32, 3>
     //     CHECK: vector.transfer_write %[[R6]], %{{.*}}[%c2, %15] {in_bounds = [true, true]} : vector<1x4xf32>, memref<3x512xf32, 3>
 
-        linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], 
+        linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
         iterator_types = ["parallel", "parallel"]}
-          ins(%m2 : memref<3x512xf32>) 
-          outs(%sm2 : memref<3x512xf32, 3>) 
+          ins(%m2 : memref<3x512xf32>)
+          outs(%sm2 : memref<3x512xf32, 3>)
           attrs= {__internal_linalg_transform__ = "copy_to_workgroup_memory"} {
           ^bb0(%arg4: f32, %s: f32):  // no predecessors
             linalg.yield %arg4 : f32
diff --git a/iree/compiler/Codegen/Dialect/LoweringConfig.td b/iree/compiler/Codegen/Dialect/LoweringConfig.td
index 9293a89..e4108de 100644
--- a/iree/compiler/Codegen/Dialect/LoweringConfig.td
+++ b/iree/compiler/Codegen/Dialect/LoweringConfig.td
@@ -42,6 +42,8 @@
     : I32EnumAttrCase<"SPIRVVectorize", 12>;
 def SPIRV_VectorizeToCooperativeOps
     : I32EnumAttrCase<"SPIRVVectorizeToCooperativeOps", 13>;
+def SPIRV_VectorizeWithWorkgroupMemory
+    : I32EnumAttrCase<"SPIRVVectorizeWithWorkgroupMemory", 14>;
 
 def None
     : I32EnumAttrCase<"None", 0xff>;
@@ -56,7 +58,8 @@
      CPU_BufferOpsTileAndVectorize, Linalg_TransformInterpCodegen,
      LLVMGPU_SimpleDistribute, LLVMGPU_Vectorize, LLVMGPU_MatmulSimt,
      LLVMGPU_MatmulTensorCore, SPIRV_Distribute, SPIRV_Vectorize,
-     SPIRV_VectorizeToCooperativeOps, None]> {
+     SPIRV_VectorizeToCooperativeOps, SPIRV_VectorizeWithWorkgroupMemory,
+     None]> {
   let cppNamespace = "::mlir::iree_compiler::IREE::Codegen";
   // Don't generate a C++ class! We want to use the AttrDef
   let genSpecializedAttr = 0;
diff --git a/iree/compiler/Codegen/LLVMGPU/BUILD b/iree/compiler/Codegen/LLVMGPU/BUILD
index 2bb5615..038d634 100644
--- a/iree/compiler/Codegen/LLVMGPU/BUILD
+++ b/iree/compiler/Codegen/LLVMGPU/BUILD
@@ -17,14 +17,11 @@
         "ConvertToNVVM.cpp",
         "ConvertToROCDL.cpp",
         "KernelConfig.cpp",
-        "LLVMGPUDistributeSharedMemoryCopy.cpp",
         "LLVMGPULowerExecutableTarget.cpp",
         "LLVMGPUMultiBuffering.cpp",
-        "LLVMGPUPipelining.cpp",
         "LLVMGPUReduceBankConflicts.cpp",
         "LLVMGPUTensorCoreVectorization.cpp",
         "LLVMGPUTileAndDistribute.cpp",
-        "LLVMGPUUtils.cpp",
         "LLVMGPUVectorLowering.cpp",
         "LLVMGPUVectorToGPU.cpp",
         "LLVMGPUVectorization.cpp",
@@ -34,7 +31,6 @@
     hdrs = [
         "ConvertToLLVM.h",
         "KernelConfig.h",
-        "LLVMGPUUtils.h",
     ],
     deps = [
         "//iree/compiler/Codegen:PassHeaders",
diff --git a/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index 586059d..2abeefd 100644
--- a/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -16,20 +16,16 @@
   HDRS
     "ConvertToLLVM.h"
     "KernelConfig.h"
-    "LLVMGPUUtils.h"
   SRCS
     "ConvertToLLVM.cpp"
     "ConvertToNVVM.cpp"
     "ConvertToROCDL.cpp"
     "KernelConfig.cpp"
-    "LLVMGPUDistributeSharedMemoryCopy.cpp"
     "LLVMGPULowerExecutableTarget.cpp"
     "LLVMGPUMultiBuffering.cpp"
-    "LLVMGPUPipelining.cpp"
     "LLVMGPUReduceBankConflicts.cpp"
     "LLVMGPUTensorCoreVectorization.cpp"
     "LLVMGPUTileAndDistribute.cpp"
-    "LLVMGPUUtils.cpp"
     "LLVMGPUVectorLowering.cpp"
     "LLVMGPUVectorToGPU.cpp"
     "LLVMGPUVectorization.cpp"
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp b/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
index 828069c..8f0c24f 100644
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
@@ -8,10 +8,10 @@
 #include "iree-dialects/Dialect/LinalgExt/Passes/Transforms.h"
 #include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
 #include "iree/compiler/Codegen/LLVMGPU/KernelConfig.h"
-#include "iree/compiler/Codegen/LLVMGPU/LLVMGPUUtils.h"
 #include "iree/compiler/Codegen/PassDetail.h"
 #include "iree/compiler/Codegen/Passes.h"
 #include "iree/compiler/Codegen/Transforms/Transforms.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
 #include "iree/compiler/Codegen/Utils/MarkerUtils.h"
 #include "iree/compiler/Dialect/Util/IR/UtilOps.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
@@ -178,38 +178,6 @@
   return success();
 }
 
-static Optional<Value> allocateWorkgroupMemory(
-    OpBuilder &b, memref::SubViewOp subview,
-    ArrayRef<Value> boundingSubViewSize, DataLayout &layout) {
-  OpBuilder::InsertionGuard guard(b);
-  func::FuncOp funcOp = subview->getParentOfType<func::FuncOp>();
-  if (!funcOp) {
-    subview.emitError("expected op to be within std.func");
-    return llvm::None;
-  }
-
-  // The bounding subview size is expected to be constant. This specified the
-  // shape of the allocation.
-  SmallVector<int64_t, 2> shape = llvm::to_vector<2>(
-      llvm::map_range(boundingSubViewSize, [](Value v) -> int64_t {
-        APInt value;
-        if (matchPattern(v, m_ConstantInt(&value))) return value.getSExtValue();
-        return -1;
-      }));
-  if (llvm::any_of(shape, [](int64_t v) { return v == -1; })) return {};
-  MemRefType allocType =
-      MemRefType::get(shape, subview.getType().getElementType(), {},
-                      gpu::GPUDialect::getWorkgroupAddressSpace());
-  b.setInsertionPoint(&funcOp.front(), funcOp.front().begin());
-  Value buffer = b.create<memref::AllocOp>(funcOp.getLoc(), allocType);
-  return buffer;
-}
-
-static LogicalResult deallocateWorkgroupMemory(OpBuilder &b, Value buffer) {
-  // Nothing to do.
-  return success();
-}
-
 static void populatePromotionPatterns(MLIRContext *context,
                                       RewritePatternSet &patterns) {
   patterns.insert<linalg::LinalgPromotionPattern<linalg::MatmulOp>,
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPUUtils.cpp b/iree/compiler/Codegen/LLVMGPU/LLVMGPUUtils.cpp
deleted file mode 100644
index c7ab1aa..0000000
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPUUtils.cpp
+++ /dev/null
@@ -1,69 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Codegen/LLVMGPU/LLVMGPUUtils.h"
-
-#include "mlir/Dialect/GPU/Passes.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-llvm::SmallVector<mlir::linalg::ProcInfo, 2> getGPUThreadIdsAndCounts(
-    mlir::OpBuilder &builder, mlir::Location loc, unsigned numDims,
-    llvm::ArrayRef<int64_t> workgroupSize) {
-  assert(numDims <= kNumGPUDims);
-  llvm::SmallVector<mlir::linalg::ProcInfo, 2> procInfo(numDims);
-  std::array<gpu::Dimension, kNumGPUDims> dimAttr{
-      gpu::Dimension::x, gpu::Dimension::y, gpu::Dimension::z};
-  mlir::Type indexType = builder.getIndexType();
-  for (unsigned i = 0; i < numDims; ++i) {
-    procInfo[numDims - 1 - i] = {
-        builder.create<mlir::gpu::ThreadIdOp>(loc, indexType, dimAttr[i]),
-        builder.create<mlir::arith::ConstantOp>(
-            loc, builder.getIndexAttr(workgroupSize[i]))};
-  }
-  return procInfo;
-}
-
-llvm::SmallVector<mlir::linalg::ProcInfo, 2> getSubgroupIdsAndCounts(
-    mlir::OpBuilder &builder, mlir::Location loc, unsigned numDims,
-    llvm::ArrayRef<int64_t> numSubgroups) {
-  assert(numDims <= kNumGPUDims);
-  llvm::SmallVector<mlir::linalg::ProcInfo, 2> procInfo(numDims);
-  std::array<gpu::Dimension, kNumGPUDims> dimAttr{
-      gpu::Dimension::x, gpu::Dimension::y, gpu::Dimension::z};
-  mlir::Type indexType = builder.getIndexType();
-  for (unsigned i = 0; i < numDims; ++i) {
-    mlir::Value subgroupId =
-        builder.create<mlir::gpu::ThreadIdOp>(loc, indexType, dimAttr[i]);
-    if (i == 0) {
-      mlir::AffineExpr d0 = builder.getAffineDimExpr(0);
-      subgroupId = mlir::makeComposedAffineApply(
-          builder, loc, d0.floorDiv(builder.getAffineConstantExpr(kWarpSize)),
-          {subgroupId});
-    }
-    procInfo[numDims - 1 - i] = {
-        subgroupId, builder.create<mlir::arith::ConstantOp>(
-                        loc, builder.getIndexAttr(numSubgroups[i]))};
-  }
-  return procInfo;
-}
-
-std::array<int64_t, 3> getWorkgroupSize(mlir::func::FuncOp funcOp) {
-  std::array<int64_t, 3> workgroupSize;
-  auto entryPointOp = mlir::iree_compiler::getEntryPoint(funcOp);
-  llvm::Optional<mlir::ArrayAttr> workgroupSizeAttr =
-      entryPointOp.workgroup_size();
-  assert(workgroupSizeAttr.hasValue());
-  for (auto it : llvm::enumerate(workgroupSizeAttr.getValue())) {
-    workgroupSize[it.index()] =
-        it.value().cast<mlir::IntegerAttr>().getValue().getZExtValue();
-  }
-  return workgroupSize;
-}
-
-}  // namespace iree_compiler
-}  // namespace mlir
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPUUtils.h b/iree/compiler/Codegen/LLVMGPU/LLVMGPUUtils.h
deleted file mode 100644
index aaca6bd..0000000
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPUUtils.h
+++ /dev/null
@@ -1,37 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_COMPILER_CODEGEN_LLVMGPU_LLVMGPUUTILS_H_
-#define IREE_COMPILER_CODEGEN_LLVMGPU_LLVMGPUUTILS_H_
-
-#include "iree/compiler/Codegen/Transforms/Transforms.h"
-#include "iree/compiler/Codegen/Utils/Utils.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-static constexpr int32_t kNumGPUDims = 3;
-static constexpr int32_t kWarpSize = 32;
-
-llvm::SmallVector<mlir::linalg::ProcInfo, 2> getGPUThreadIdsAndCounts(
-    mlir::OpBuilder &builder, mlir::Location loc, unsigned numDims,
-    llvm::ArrayRef<int64_t> workgroupSize);
-
-/// Compute subgroup ID. CUDA doesn't have a subgroupId equivalent so we are are
-/// computing the subgroup ID based on the threadID.
-/// When tiling to warp we assume each warp is full and we pick a workgroup
-/// size so that `workgroupSize.x % warpSize == 0`. This is why we can have
-/// warpId = { threadId.x / warpSize, threadId.y, threadId.z }.
-llvm::SmallVector<mlir::linalg::ProcInfo, 2> getSubgroupIdsAndCounts(
-    mlir::OpBuilder &builder, mlir::Location loc, unsigned numDims,
-    llvm::ArrayRef<int64_t> numSubgroups);
-
-/// return the workgroup size associated to the funcOp entry point.
-std::array<int64_t, 3> getWorkgroupSize(mlir::func::FuncOp funcOp);
-
-}  // namespace iree_compiler
-}  // namespace mlir
-#endif  // IREE_COMPILER_CODEGEN_LLVMGPU_LLVMGPUUTILS_H_
diff --git a/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index a70bc33..a1224ed 100644
--- a/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -88,7 +88,7 @@
   // Distribute linalg onto threads within the workgroup.
   pm.addNestedPass<func::FuncOp>(createLLVMGPUTileAndDistribute());
   pm.addNestedPass<func::FuncOp>(createMemrefCopyToLinalgPass());
-  pm.addNestedPass<func::FuncOp>(createLLVMGPUDistributeSharedMemoryCopy());
+  pm.addNestedPass<func::FuncOp>(createGPUDistributeSharedMemoryCopy());
   pm.addPass(createCanonicalizerPass());
   pm.addPass(createCSEPass());
 
@@ -106,7 +106,7 @@
   pm.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass());
 
   // Pipeline memory operations.
-  pm.addNestedPass<func::FuncOp>(createLLVMGPUPipeliningPass());
+  pm.addNestedPass<func::FuncOp>(createGPUPipeliningPass());
 }
 
 void addGPUMatmulTensorCorePassPipeline(OpPassManager &pm) {
@@ -118,7 +118,7 @@
   if (pipelineDepth > 1)
     pm.addNestedPass<func::FuncOp>(createLLVMGPUMultiBuffering(pipelineDepth));
   pm.addNestedPass<func::FuncOp>(createMemrefCopyToLinalgPass());
-  pm.addNestedPass<func::FuncOp>(createLLVMGPUDistributeSharedMemoryCopy());
+  pm.addNestedPass<func::FuncOp>(createGPUDistributeSharedMemoryCopy());
   pm.addPass(createCanonicalizerPass());
   pm.addPass(createCSEPass());
 
@@ -142,7 +142,7 @@
   pm.addPass(createCSEPass());
 
   // Pipeline memory operations.
-  pm.addNestedPass<func::FuncOp>(createLLVMGPUPipeliningPass(pipelineDepth));
+  pm.addNestedPass<func::FuncOp>(createGPUPipeliningPass(pipelineDepth));
 }
 
 void addGPUSimpleDistributePassPipeline(OpPassManager &pm) {
diff --git a/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp b/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp
index c4854ea..e832517 100644
--- a/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp
@@ -4,9 +4,9 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-#include "iree/compiler/Codegen/LLVMGPU/LLVMGPUUtils.h"
 #include "iree/compiler/Codegen/PassDetail.h"
 #include "iree/compiler/Codegen/Passes.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
 #include "mlir/Dialect/Linalg/Passes.h"
 
 namespace mlir {
diff --git a/iree/compiler/Codegen/LLVMGPU/test/BUILD b/iree/compiler/Codegen/LLVMGPU/test/BUILD
index 5011132..36d1d95 100644
--- a/iree/compiler/Codegen/LLVMGPU/test/BUILD
+++ b/iree/compiler/Codegen/LLVMGPU/test/BUILD
@@ -22,7 +22,6 @@
             "convert_to_nvvm.mlir",
             "convert_to_rocdl.mlir",
             "distribute_to_thread.mlir",
-            "distribute_wg_copy.mlir",
             "gpu_set_num_workgroups.mlir",
             "nvvm_pipeline_test.mlir",
             "reduce_bank_conflicts.mlir",
diff --git a/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index 7069e4f..547de8a 100644
--- a/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -17,7 +17,6 @@
     "convert_to_nvvm.mlir"
     "convert_to_rocdl.mlir"
     "distribute_to_thread.mlir"
-    "distribute_wg_copy.mlir"
     "gpu_set_num_workgroups.mlir"
     "illegal_configuration.mlir"
     "legalize.mlir"
diff --git a/iree/compiler/Codegen/Passes.h b/iree/compiler/Codegen/Passes.h
index a89bbb6..b84ec6b 100644
--- a/iree/compiler/Codegen/Passes.h
+++ b/iree/compiler/Codegen/Passes.h
@@ -153,6 +153,15 @@
 /// Creates a pass to convert memref.copy to linalg op.
 std::unique_ptr<OperationPass<func::FuncOp>> createMemrefCopyToLinalgPass();
 
+/// Convert GPU shared memory copies to distributed
+/// transfer_read/transfer_write.
+std::unique_ptr<OperationPass<func::FuncOp>>
+createGPUDistributeSharedMemoryCopy();
+
+/// Apply software pipelining.
+std::unique_ptr<OperationPass<func::FuncOp>> createGPUPipeliningPass(
+    unsigned depth = 1);
+
 /// Converts vector ops to gpu dialect.
 std::unique_ptr<OperationPass<func::FuncOp>> createWorkGroupSwizzle(
     unsigned swizzleLogTile = 0);
@@ -352,14 +361,6 @@
 /// Lower vector ops before convertion to LLVM.
 std::unique_ptr<OperationPass<func::FuncOp>> createLLVMGPUVectorLoweringPass();
 
-/// Convert shared memory copies to distributed transfer_read/transfer_write.
-std::unique_ptr<OperationPass<func::FuncOp>>
-createLLVMGPUDistributeSharedMemoryCopy();
-
-/// Apply software pipelining.
-std::unique_ptr<OperationPass<func::FuncOp>> createLLVMGPUPipeliningPass(
-    unsigned depth = 1);
-
 /// Apply multi-buffering transformation.
 std::unique_ptr<OperationPass<func::FuncOp>> createLLVMGPUMultiBuffering(
     unsigned numBuffers = 5);
@@ -391,6 +392,12 @@
 /// performs distribution to threads with vectorization.
 void addSPIRVTileAndVectorizeToCooperativeOpsPassPipeline(OpPassManager &pm);
 
+/// Pass pipeline to lower IREE HAL executables with workgroup tiled and
+/// distributed Linalg ops to SPIR-V scalar and vector code. Additionally
+/// performs distribution to threads with vectorization and promotion to use
+/// workgroup memory.
+void addSPIRVTileAndVectorizeWithWorkgroupMemoryPassPipeline(OpPassManager &pm);
+
 /// Pass to perform the final conversion to SPIR-V dialect.
 ///
 /// This pass converts remaining interface ops into SPIR-V global variables,
@@ -411,6 +418,10 @@
 /// Pass to tile and distribute Linalg ops with buffer semantics to invocations.
 std::unique_ptr<OperationPass<func::FuncOp>> createSPIRVTileAndDistributePass();
 
+/// Pass to promote Linalg ops with buffer semantics to use workgroup memory and
+/// then tile to invocations.
+std::unique_ptr<OperationPass<func::FuncOp>> createSPIRVTileAndPromotePass();
+
 /// Pass to tile Linalg ops with buffer semantics to subgroups and vectorize to
 /// vector ops suitable for lowering to SPIR-V cooperative ops.
 std::unique_ptr<OperationPass<func::FuncOp>>
diff --git a/iree/compiler/Codegen/Passes.td b/iree/compiler/Codegen/Passes.td
index d4f23b6..05e8fcb 100644
--- a/iree/compiler/Codegen/Passes.td
+++ b/iree/compiler/Codegen/Passes.td
@@ -163,6 +163,17 @@
       "mlir::iree_compiler::createMemrefCopyToLinalgPass()";
 }
 
+def GPUDistributeSharedMemoryCopy :
+    Pass<"iree-gpu-distribute-shared-memory-copy", "func::FuncOp"> {
+  let summary = "Pass to distribute shared memory copies to threads.";
+  let constructor = "mlir::iree_compiler::createGPUDistributeSharedMemoryCopy()";
+}
+
+def GPUPipelining : Pass<"iree-gpu-pipelining", "func::FuncOp"> {
+  let summary = "Pass to do software pipelining.";
+  let constructor = "mlir::iree_compiler::createGPUPipeliningPass()";
+}
+
 def WorkGroupSwizzle :
     Pass<"iree-workgroup-swizzle", "func::FuncOp"> {
   let summary = "swizzle the workgroup ids for better cache reuse";
@@ -293,18 +304,6 @@
   let constructor = "mlir::iree_compiler::createLLVMGPUVectorLoweringPass()";
 }
 
-def LLVMGPUDistributeSharedMemoryCopy :
-    Pass<"iree-llvmgpu-distribute-shared-memory-copy", "func::FuncOp"> {
-  let summary = "Pass to distribute shared memory copies to threads.";
-  let constructor = "mlir::iree_compiler::createLLVMGPUDistributeSharedMemoryCopy()";
-}
-
-def LLVMGPUPipelining :
-    Pass<"iree-llvmgpu-pipelining", "func::FuncOp"> {
-  let summary = "Pass to do software pipelining.";
-  let constructor = "mlir::iree_compiler::createLLVMGPUPipeliningPass()";
-}
-
 def LLVMGPUMultiBuffering :
     Pass<"iree-llvmgpu-multi-buffering", "func::FuncOp"> {
   let summary = "Pass to do multi buffering.";
@@ -366,6 +365,13 @@
     "mlir::iree_compiler::createSPIRVTileAndVectorizeToCooperativeOpsPass()";
 }
 
+def SPIRVTileAndPromote : Pass<"iree-spirv-tile-and-promote", "func::FuncOp"> {
+  let summary = "Promote tiled Linalg ops with buffer semantics to use "
+                "workgroup memory and then tile to invocations";
+  let constructor =
+    "mlir::iree_compiler::createSPIRVTileAndPromotePass()";
+}
+
 def SPIRVVectorize : Pass<"iree-spirv-vectorize", "func::FuncOp"> {
   let summary = "Vectorize Linalg ops with buffer semantics";
   let constructor = "mlir::iree_compiler::createSPIRVVectorizePass()";
diff --git a/iree/compiler/Codegen/SPIRV/AMDConfig.cpp b/iree/compiler/Codegen/SPIRV/AMDConfig.cpp
new file mode 100644
index 0000000..4fb689a
--- /dev/null
+++ b/iree/compiler/Codegen/SPIRV/AMDConfig.cpp
@@ -0,0 +1,52 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+//===- AMDConfig.h - AMD CodeGen Configurations ---------------------------===//
+//
+// This file contains CodeGen configurations for AMD GPUs.
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
+#include "iree/compiler/Codegen/SPIRV/KernelConfig.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/IR/BuiltinOps.h"
+
+#define DEBUG_TYPE "iree-spirv-amd-config"
+
+namespace mlir {
+namespace iree_compiler {
+namespace detail {
+
+// RDNA architecture:
+// https://gpuopen.com/wp-content/uploads/2019/08/RDNA_Architecture_public.pdf
+//
+// Workgroup Processor (WGP) is the block for workgroups in RDNA; it has its own
+// instruction/constant cache, L0 cache x2, Local Data Share (LDS, a.k.a. shared
+// memory), SALU x4, SIMD32 x4.
+//
+// * 1024 registers per SIMD32
+// * 128KB LDS per WGP
+// * Max 20 waves per SIMD32
+// * Max 64KB LDS per workgroup
+
+LogicalResult setAMDCodeGenConfig(const spirv::TargetEnv &targetEnv,
+                                  Operation *rootOp) {
+  int64_t subgroupSize = targetEnv.getResourceLimits().subgroup_size().getInt();
+  if (auto matmulOp = dyn_cast<linalg::MatmulOp>(rootOp)) {
+    std::array<int64_t, 2> workgroupXY = {subgroupSize / 2, 8};
+    std::array<int64_t, 3> threadMNK = {8, 4, 32};
+    return setMatmulOpConfig(matmulOp, subgroupSize, workgroupXY, threadMNK,
+                             /*useWorkgroupMemory=*/true);
+  }
+  return success();
+}
+
+}  // namespace detail
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp b/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp
index d4511cd..bfbd2f4 100644
--- a/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp
+++ b/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp
@@ -29,10 +29,10 @@
                                      Operation *rootOp) {
   int64_t subgroupSize = targetEnv.getResourceLimits().subgroup_size().getInt();
   return TypeSwitch<Operation *, LogicalResult>(rootOp)
-      .Case<linalg::BatchMatmulOp, linalg::MatmulOp>([](auto op) {
-        std::array<int64_t, 2> workgroupXY = {32, 2};
+      .Case<linalg::BatchMatmulOp, linalg::MatmulOp>([subgroupSize](auto op) {
+        std::array<int64_t, 2> workgroupXY = {subgroupSize / 2, 2};
         std::array<int64_t, 3> threadMNK = {16, 4, 4};
-        return setMatmulOpConfig(op, workgroupXY, threadMNK);
+        return setMatmulOpConfig(op, subgroupSize, workgroupXY, threadMNK);
       })
       .Case<linalg::Conv2DNhwcHwcfOp>([subgroupSize](auto op) {
         return setConvOpConfig(op, subgroupSize,
diff --git a/iree/compiler/Codegen/SPIRV/BUILD b/iree/compiler/Codegen/SPIRV/BUILD
index 12115b0..e2543c5 100644
--- a/iree/compiler/Codegen/SPIRV/BUILD
+++ b/iree/compiler/Codegen/SPIRV/BUILD
@@ -13,6 +13,7 @@
 cc_library(
     name = "SPIRV",
     srcs = [
+        "AMDConfig.cpp",
         "AdrenoConfig.cpp",
         "ConvertToSPIRVPass.cpp",
         "KernelConfig.cpp",
@@ -25,6 +26,7 @@
         "SPIRVLowerExecutableTargetPass.cpp",
         "SPIRVTile.cpp",
         "SPIRVTileAndDistribute.cpp",
+        "SPIRVTileAndPromote.cpp",
         "SPIRVTileAndVectorizeToCooperativeOps.cpp",
         "SPIRVVectorToCooperativeOps.cpp",
         "SPIRVVectorize.cpp",
diff --git a/iree/compiler/Codegen/SPIRV/CMakeLists.txt b/iree/compiler/Codegen/SPIRV/CMakeLists.txt
index cea4c3b..e222de5 100644
--- a/iree/compiler/Codegen/SPIRV/CMakeLists.txt
+++ b/iree/compiler/Codegen/SPIRV/CMakeLists.txt
@@ -18,6 +18,7 @@
     "MemorySpace.h"
     "Utils.h"
   SRCS
+    "AMDConfig.cpp"
     "AdrenoConfig.cpp"
     "ConvertToSPIRVPass.cpp"
     "KernelConfig.cpp"
@@ -30,6 +31,7 @@
     "SPIRVLowerExecutableTargetPass.cpp"
     "SPIRVTile.cpp"
     "SPIRVTileAndDistribute.cpp"
+    "SPIRVTileAndPromote.cpp"
     "SPIRVTileAndVectorizeToCooperativeOps.cpp"
     "SPIRVVectorToCooperativeOps.cpp"
     "SPIRVVectorize.cpp"
diff --git a/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index d8e6e2c..6dd8f80 100644
--- a/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -6,6 +6,9 @@
 
 #include "iree/compiler/Codegen/SPIRV/KernelConfig.h"
 
+#include <functional>
+#include <numeric>
+
 #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
 #include "iree/compiler/Codegen/SPIRV/Utils.h"
@@ -169,9 +172,10 @@
 
 namespace detail {
 
-LogicalResult setMatmulOpConfig(linalg::LinalgOp op,
+LogicalResult setMatmulOpConfig(linalg::LinalgOp op, int64_t subgroupSize,
                                 std::array<int64_t, 2> bestWorkgroupSizeXY,
-                                std::array<int64_t, 3> bestThreadTileSizeMNK) {
+                                std::array<int64_t, 3> bestThreadTileSizeMNK,
+                                bool useWorkgroupMemory) {
   auto lhsType = op.inputs()[0].getType().cast<ShapedType>();
   auto elementBits = lhsType.getElementType().getIntOrFloatBitWidth();
   if (elementBits != 16 && elementBits != 32) return success();
@@ -261,11 +265,20 @@
   }
   if (reductionTileSizes[2 + isBM] == 0) return success();
 
-  auto pipeline = IREE::Codegen::DispatchLoweringPassPipeline::SPIRVVectorize;
+  auto totalThreads =
+      std::accumulate(workgroupSize.begin(), workgroupSize.end(), 1,
+                      std::multiplies<int64_t>());
+  auto pipeline =
+      (useWorkgroupMemory && totalThreads > subgroupSize)
+          ? IREE::Codegen::DispatchLoweringPassPipeline::
+                SPIRVVectorizeWithWorkgroupMemory
+          : IREE::Codegen::DispatchLoweringPassPipeline::SPIRVVectorize;
+
   TileSizesListType tileSizes;
   tileSizes.push_back(workgroupTileSizes);
   tileSizes.push_back(invocationTileSizes);
   tileSizes.push_back(reductionTileSizes);
+
   return setOpConfigAndEntryPointFnTranslation(
       op->getParentOfType<func::FuncOp>(), op, tileSizes, pipeline,
       workgroupSize);
@@ -512,6 +525,9 @@
   // First try to find a proper CodeGen configuration to tile and vectorize for
   // the current target architecture.
   switch (targetEnv.getVendorID()) {
+    case spirv::Vendor::AMD:
+      result = detail::setAMDCodeGenConfig(targetEnv, rootOp);
+      break;
     case spirv::Vendor::ARM:
       result = detail::setMaliCodeGenConfig(targetEnv, rootOp);
       break;
@@ -534,10 +550,12 @@
   spirv::ResourceLimitsAttr limits = targetEnv.getResourceLimits();
   return TypeSwitch<Operation *, LogicalResult>(rootOp)
       .Case<linalg::BatchMatmulOp, linalg::MatmulOp>([limits](auto op) {
-        // Try to tile and vectorize first.
+        // Try to tile and vectorize first. It's common to see 32 threads
+        // per subgroup for GPUs.
         std::array<int64_t, 2> workgroupXY = {32, 2};
         std::array<int64_t, 3> threadMNK = {8, 8, 4};
-        auto result = detail::setMatmulOpConfig(op, workgroupXY, threadMNK);
+        auto result = detail::setMatmulOpConfig(op, /*subgroupSize=*/32,
+                                                workgroupXY, threadMNK);
         if (failed(result)) return result;
         if (getLoweringConfig(op)) return result;
 
diff --git a/iree/compiler/Codegen/SPIRV/KernelConfig.h b/iree/compiler/Codegen/SPIRV/KernelConfig.h
index a11c0aa..3d27354 100644
--- a/iree/compiler/Codegen/SPIRV/KernelConfig.h
+++ b/iree/compiler/Codegen/SPIRV/KernelConfig.h
@@ -35,9 +35,10 @@
 
 /// Sets CodeGen configurations via attributes to the given matmul `linalgOp`
 /// with the given best workgroup size and tile size hints.
-LogicalResult setMatmulOpConfig(linalg::LinalgOp linalgOp,
+LogicalResult setMatmulOpConfig(linalg::LinalgOp linalgOp, int64_t subgroupSize,
                                 std::array<int64_t, 2> bestWorkgroupSizeXY,
-                                std::array<int64_t, 3> bestThreadTileSizeMNK);
+                                std::array<int64_t, 3> bestThreadTileSizeMNK,
+                                bool useWorkgroupMemory = false);
 
 /// Sets CodeGen configuration for GPUs from a specific vendor.
 ///
@@ -51,6 +52,8 @@
 
 LogicalResult setAdrenoCodeGenConfig(const spirv::TargetEnv &targetEnv,
                                      Operation *rootOp);
+LogicalResult setAMDCodeGenConfig(const spirv::TargetEnv &targetEnv,
+                                  Operation *rootOp);
 LogicalResult setMaliCodeGenConfig(const spirv::TargetEnv &targetEnv,
                                    Operation *rootOp);
 LogicalResult setNVIDIACodeGenConfig(const spirv::TargetEnv &targetEnv,
diff --git a/iree/compiler/Codegen/SPIRV/MaliConfig.cpp b/iree/compiler/Codegen/SPIRV/MaliConfig.cpp
index 9411ec5..caf65c9 100644
--- a/iree/compiler/Codegen/SPIRV/MaliConfig.cpp
+++ b/iree/compiler/Codegen/SPIRV/MaliConfig.cpp
@@ -29,8 +29,8 @@
                                    Operation *rootOp) {
   int64_t subgroupSize = targetEnv.getResourceLimits().subgroup_size().getInt();
   return TypeSwitch<Operation *, LogicalResult>(rootOp)
-      .Case<linalg::BatchMatmulOp, linalg::MatmulOp>([](auto op) {
-        std::array<int64_t, 2> workgroupXY = {8, 2};
+      .Case<linalg::BatchMatmulOp, linalg::MatmulOp>([subgroupSize](auto op) {
+        std::array<int64_t, 2> workgroupXY = {subgroupSize / 2, 2};
         std::array<int64_t, 3> threadMNK;
         auto inputType = op.inputs()[0].getType().template cast<ShapedType>();
         if (inputType.getElementType().isF16()) {
@@ -38,7 +38,7 @@
         } else {
           threadMNK = {6, 4, 4};
         }
-        return setMatmulOpConfig(op, workgroupXY, threadMNK);
+        return setMatmulOpConfig(op, subgroupSize, workgroupXY, threadMNK);
       })
       .Case<linalg::Conv2DNhwcHwcfOp>([subgroupSize](auto op) {
         bool hasPaddedInput =
diff --git a/iree/compiler/Codegen/SPIRV/NVIDIAConfig.cpp b/iree/compiler/Codegen/SPIRV/NVIDIAConfig.cpp
index c4b121a..bd7aeb4 100644
--- a/iree/compiler/Codegen/SPIRV/NVIDIAConfig.cpp
+++ b/iree/compiler/Codegen/SPIRV/NVIDIAConfig.cpp
@@ -109,10 +109,48 @@
       workgroupSize);
 }
 
+// Volta architecture:
+// https://docs.nvidia.com/cuda/volta-tuning-guide/index.html#sm-occupancy
+//
+// * 64K 32-bit registers per SM
+// * 96KB shared memory per SM
+// * Max 32 thread blocks per SM
+// * Max 64 concurrent warps per SM
+// * Max 255 registers per thread
+
+// Turing architecture:
+// https://docs.nvidia.com/cuda/turing-tuning-guide/index.html#sm-occupancy
+//
+// * 64K 32-bit registers per SM
+// * 64KB shared memory per SM
+// * Max 16 thread blocks per SM
+// * Max 32 concurrent warps per SM
+// * Max 255 registers per thread
+
+// Ampere architecture:
+// https://docs.nvidia.com/cuda/ampere-tuning-guide/index.html#sm-occupancy
+//
+// * 64K 32-bit registers per SM
+// * 164KB/96KB shared memory for compute capability 8.0/8.6
+// * Max 32/16 thread blocks per SM for compute capability 8.0/8.6
+// * Max 64 concurrent warps per SM
+// * Max 255 registers per thread
+
+// Note that the above numbers are from CUDA docs; for Vulkan the drivder can
+// expose slightly different numbers, e.g., max shared memory size is smaller.
+
 LogicalResult setNVIDIACodeGenConfig(const spirv::TargetEnv &targetEnv,
                                      Operation *rootOp) {
+  int64_t subgroupSize = targetEnv.getResourceLimits().subgroup_size().getInt();
   if (auto matmulOp = dyn_cast<linalg::MatmulOp>(rootOp)) {
-    return setOpConfig(targetEnv, matmulOp);
+    // First try to see if we can use tensor cores.
+    if (failed(setOpConfig(targetEnv, matmulOp))) return failure();
+    if (getLoweringConfig(rootOp)) return success();
+
+    std::array<int64_t, 2> workgroupXY = {subgroupSize, 8};
+    std::array<int64_t, 3> threadMNK = {16, 4, 32};
+    return setMatmulOpConfig(matmulOp, subgroupSize, workgroupXY, threadMNK,
+                             /*useWorkgroupMemory=*/true);
   }
   return success();
 }
diff --git a/iree/compiler/Codegen/SPIRV/Passes.cpp b/iree/compiler/Codegen/SPIRV/Passes.cpp
index d8e5c53..ec83565 100644
--- a/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -201,6 +201,37 @@
   pm.addNestedPass<func::FuncOp>(createSPIRVVectorToCooperativeOpsPass());
 }
 
+void addSPIRVTileAndVectorizeWithWorkgroupMemoryPassPipeline(
+    OpPassManager &pm) {
+  pm.addNestedPass<func::FuncOp>(createInsertDistributionInfoPass());
+  pm.addNestedPass<func::FuncOp>(createTileAndDistributeToWorkgroupsPass());
+  pm.addPass(createCanonicalizerPass());
+  pm.addPass(createCSEPass());
+
+  addLinalgBufferizePasses(pm, gpuAllocationFunction);
+
+  pm.addPass(createCanonicalizerPass());
+  pm.addPass(createCSEPass());
+
+  // Tile and distribute to GPU invocations.
+  pm.addNestedPass<func::FuncOp>(createSPIRVTileAndPromotePass());
+  pm.addNestedPass<func::FuncOp>(createMemrefCopyToLinalgPass());
+  pm.addNestedPass<func::FuncOp>(createGPUDistributeSharedMemoryCopy());
+  pm.addPass(createCanonicalizerPass());
+  pm.addPass(createCSEPass());
+
+  pm.addNestedPass<func::FuncOp>(createRemoveSingleIterationLoopPass());
+
+  pm.addNestedPass<func::FuncOp>(createSPIRVVectorizePass());
+  pm.addPass(createCanonicalizerPass());
+  pm.addPass(createCSEPass());
+  pm.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass());
+
+  // pm.addNestedPass<func::FuncOp>(createGPUPipeliningPass());
+
+  addLoopMaterializationPasses(pm);
+}
+
 void addSPIRVTileAndDistributePassPipeline(OpPassManager &pm) {
   pm.addNestedPass<func::FuncOp>(createInsertDistributionInfoPass());
   pm.addNestedPass<func::FuncOp>(createTileAndDistributeToWorkgroupsPass());
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp b/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
index c65e1e3..dd06392 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
@@ -107,6 +107,10 @@
           SPIRVVectorizeToCooperativeOps:
         addSPIRVTileAndVectorizeToCooperativeOpsPassPipeline(nestedModulePM);
         break;
+      case IREE::Codegen::DispatchLoweringPassPipeline::
+          SPIRVVectorizeWithWorkgroupMemory:
+        addSPIRVTileAndVectorizeWithWorkgroupMemoryPassPipeline(nestedModulePM);
+        break;
       default:
         variantOp.emitOpError("Unsupported pipeline on GPU target.");
         return signalPassFailure();
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp b/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp
index 5bbbb2e..256114a 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp
@@ -15,6 +15,7 @@
 #include "iree/compiler/Codegen/PassDetail.h"
 #include "iree/compiler/Codegen/Passes.h"
 #include "iree/compiler/Codegen/SPIRV/Utils.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
 #include "iree/compiler/Codegen/Utils/MarkerUtils.h"
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
 #include "llvm/Support/Debug.h"
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp b/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp
new file mode 100644
index 0000000..d409a81
--- /dev/null
+++ b/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp
@@ -0,0 +1,284 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+//===- SPIRVTileAndPromote.cpp --------------------------------------------===//
+//
+// This pass tiles promote Linalg ops with buffer semantics to use workgroup
+// memory and then tiles to invocations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree/compiler/Codegen/PassDetail.h"
+#include "iree/compiler/Codegen/Passes.h"
+#include "iree/compiler/Codegen/SPIRV/Utils.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
+#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/Transforms.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-spirv-tile-and-promote"
+
+namespace mlir {
+namespace iree_compiler {
+
+//====---------------------------------------------------------------------===//
+// Reduction tiling patterns
+//====---------------------------------------------------------------------===//
+
+static void populateTilingReductionPatterns(
+    RewritePatternSet &patterns, linalg::LinalgTransformationFilter filter) {
+  auto getTileSizeFn = [&](OpBuilder &builder, Operation *op) {
+    return getTileSizes(builder, op, 2);
+  };
+
+  auto tilingOptions = linalg::LinalgTilingOptions()
+                           .setLoopType(linalg::LinalgTilingLoopType::Loops)
+                           .setTileSizeComputationFunction(getTileSizeFn);
+  linalg::TilingPatterns<linalg::BatchMatmulOp, linalg::MatmulOp>::insert(
+      patterns, tilingOptions, filter);
+}
+
+//===----------------------------------------------------------------------===//
+// Invocation tiling patterns
+//===----------------------------------------------------------------------===//
+
+static void populateTilingToInvocationPatterns(
+    RewritePatternSet &patterns, linalg::LinalgTransformationFilter filter) {
+  linalg::TileSizeComputationFunction getTileSizeFn = [&](OpBuilder &builder,
+                                                          Operation *op) {
+    return getTileSizes(builder, op, 1);
+  };
+
+  auto getThreadProcInfoFn = [](OpBuilder &builder, Location loc,
+                                ArrayRef<Range> parallelLoopRanges) {
+    return getGPUProcessorIdsAndCounts<gpu::ThreadIdOp, gpu::BlockDimOp>(
+        builder, loc, parallelLoopRanges.size());
+  };
+  linalg::LinalgLoopDistributionOptions distributionOptions;
+  distributionOptions.procInfo = getThreadProcInfoFn;
+  distributionOptions.distributionMethod = {
+      {linalg::DistributionMethod::Cyclic, linalg::DistributionMethod::Cyclic,
+       linalg::DistributionMethod::Cyclic}};
+
+  auto tilingOptions = linalg::LinalgTilingOptions()
+                           .setLoopType(linalg::LinalgTilingLoopType::Loops)
+                           .setTileSizeComputationFunction(getTileSizeFn)
+                           .setDistributionOptions(distributionOptions);
+
+  linalg::TilingPatterns<linalg::BatchMatmulOp, linalg::FillOp,
+                         linalg::GenericOp,
+                         linalg::MatmulOp>::insert(patterns, tilingOptions,
+                                                   filter);
+}
+
+//===----------------------------------------------------------------------===//
+// Promotion patterns
+//===----------------------------------------------------------------------===//
+
+static const char promoteLHSMarker[] = "promote_lhs";
+static const char promoteRHSMarker[] = "promote_rhs";
+static const char promoteBothMarker[] = "promote_lhs_and_rhs";
+
+LogicalResult copyToWorkgroupMemory(OpBuilder &builder, Value src, Value dst) {
+  Operation *copyOp = builder.create<memref::CopyOp>(src.getLoc(), src, dst);
+  setMarker(copyOp, getCopyToWorkgroupMemoryMarker());
+  return success();
+}
+
+static void populatePromotionPatterns(RewritePatternSet &patterns,
+                                      StringAttr replaceMarker) {
+  MLIRContext *context = patterns.getContext();
+  auto baseOptions =
+      linalg::LinalgPromotionOptions()
+          .setAllocationDeallocationFns(allocateWorkgroupMemory,
+                                        deallocateWorkgroupMemory)
+          .setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory)
+          .setUseFullTileBuffers({false, false});
+  auto promoteLHSOptions = baseOptions.setOperandsToPromote({0});
+  auto promoteRHSOptions = baseOptions.setOperandsToPromote({1});
+  auto promoteBothOptions = baseOptions.setOperandsToPromote({0, 1});
+
+  linalg::LinalgTransformationFilter promoteLHSFilter(
+      {StringAttr::get(context, promoteLHSMarker)}, replaceMarker);
+  linalg::LinalgTransformationFilter promoteRHSFilter(
+      {StringAttr::get(context, promoteRHSMarker)}, replaceMarker);
+  linalg::LinalgTransformationFilter promoteBothFilter(
+      {StringAttr::get(context, promoteBothMarker)}, replaceMarker);
+
+  patterns.insert<linalg::LinalgPromotionPattern<linalg::MatmulOp>,
+                  linalg::LinalgPromotionPattern<linalg::BatchMatmulOp>>(
+      patterns.getContext(), promoteLHSOptions, promoteLHSFilter);
+  patterns.insert<linalg::LinalgPromotionPattern<linalg::MatmulOp>,
+                  linalg::LinalgPromotionPattern<linalg::BatchMatmulOp>>(
+      patterns.getContext(), promoteRHSOptions, promoteRHSFilter);
+  patterns.insert<linalg::LinalgPromotionPattern<linalg::MatmulOp>,
+                  linalg::LinalgPromotionPattern<linalg::BatchMatmulOp>>(
+      patterns.getContext(), promoteBothOptions, promoteBothFilter);
+}
+
+//===----------------------------------------------------------------------===//
+// Pass
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct SPIRVTileAndPromotePass final
+    : public SPIRVTileAndPromoteBase<SPIRVTileAndPromotePass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<gpu::GPUDialect>();
+  }
+
+  void runOnOperation() override;
+};
+
+}  // namespace
+
+void SPIRVTileAndPromotePass::runOnOperation() {
+  MLIRContext *context = &getContext();
+  func::FuncOp funcOp = getOperation();
+  auto entryPointOp = getEntryPoint(funcOp);
+  if (!entryPointOp) return;
+
+  {  // Tile reduction dimensions.
+    RewritePatternSet tilingPatterns(context);
+    linalg::LinalgTransformationFilter filter(
+        ArrayRef<StringAttr>(),
+        StringAttr::get(context, getWorkgroupKTiledMarker()));
+    populateTilingReductionPatterns(tilingPatterns, filter);
+    if (failed(
+            applyPatternsAndFoldGreedily(funcOp, std::move(tilingPatterns)))) {
+      funcOp.emitOpError() << "failed tiling reduction";
+      return signalPassFailure();
+    }
+
+    RewritePatternSet patterns =
+        linalg::getLinalgTilingCanonicalizationPatterns(context);
+    scf::populateSCFForLoopCanonicalizationPatterns(patterns);
+    if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+      funcOp.emitOpError() << "failed canonicalization after tiling reduction";
+      return signalPassFailure();
+    }
+  }
+
+  LLVM_DEBUG({
+    llvm::dbgs() << "--- After tiling reduction dimensions ---\n";
+    funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+    llvm::dbgs() << "\n\n";
+  });
+
+  auto workgroupSize = llvm::to_vector<4>(llvm::map_range(
+      entryPointOp.workgroup_size().getValue(),
+      [&](Attribute attr) { return attr.cast<IntegerAttr>().getInt(); }));
+  int64_t flatWorkgroupSize =
+      workgroupSize[0] * workgroupSize[1] * workgroupSize[2];
+  auto subgroupSize =
+      getSPIRVTargetEnvAttr(funcOp).getResourceLimits().subgroup_size();
+
+  funcOp.walk([&](Operation *op) {
+    if (isa<linalg::FillOp, linalg::GenericOp>(op)) {
+      op->setAttr(linalg::LinalgTransforms::kLinalgTransformMarker,
+                  StringAttr::get(context, getWorkgroupMemoryMarker()));
+    } else if (isa<linalg::BatchMatmulOp, linalg::MatmulOp>(op)) {
+      auto lhsShape = op->getOperand(0).getType().cast<ShapedType>().getShape();
+      auto rhsShape = op->getOperand(1).getType().cast<ShapedType>().getShape();
+      bool canPromoteLHS =
+          canPerformVectorAccessUsingAllThreads(lhsShape, flatWorkgroupSize, 4);
+      bool canPromoteRHS =
+          canPerformVectorAccessUsingAllThreads(rhsShape, flatWorkgroupSize, 4);
+      StringAttr promoteMarker =
+          StringAttr::get(context, getWorkgroupMemoryMarker());
+      if (canPromoteLHS && canPromoteRHS) {
+        promoteMarker = StringAttr::get(context, promoteBothMarker);
+      } else if (canPromoteLHS) {
+        promoteMarker = StringAttr::get(context, promoteLHSMarker);
+      } else if (canPromoteRHS) {
+        promoteMarker = StringAttr::get(context, promoteRHSMarker);
+      }
+      op->setAttr(linalg::LinalgTransforms::kLinalgTransformMarker,
+                  promoteMarker);
+    }
+    return WalkResult::advance();
+  });
+
+  // Only promote to workgroup size if there are multiple warps.
+  if (flatWorkgroupSize > subgroupSize.getInt()) {
+    RewritePatternSet promotionPatterns(&getContext());
+    auto replaceMarker = StringAttr::get(context, getWorkgroupMemoryMarker());
+    populatePromotionPatterns(promotionPatterns, replaceMarker);
+    if (failed(applyPatternsAndFoldGreedily(funcOp,
+                                            std::move(promotionPatterns)))) {
+      return signalPassFailure();
+    }
+
+    // Insert barriers before and after copies to workgroup memory and skip
+    // insert barriers between back to back copy to workgroup memory.
+    OpBuilder builder(&getContext());
+    funcOp.walk([&builder](memref::CopyOp copyOp) {
+      if (hasMarker(copyOp, getCopyToWorkgroupMemoryMarker())) {
+        Operation *prevOp = copyOp->getPrevNode();
+        if (!prevOp || !hasMarker(prevOp, getCopyToWorkgroupMemoryMarker())) {
+          builder.setInsertionPoint(copyOp);
+          builder.create<gpu::BarrierOp>(copyOp.getLoc());
+        }
+        Operation *nextOp = copyOp->getNextNode();
+        if (!nextOp || !hasMarker(nextOp, getCopyToWorkgroupMemoryMarker())) {
+          builder.setInsertionPointAfter(copyOp);
+          builder.create<gpu::BarrierOp>(copyOp.getLoc());
+        }
+      }
+    });
+  }
+
+  LLVM_DEBUG({
+    llvm::dbgs() << "--- After promotion ---\n";
+    funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+    llvm::dbgs() << "\n\n";
+  });
+
+  {  // Tile and distribute to invocations.
+    RewritePatternSet tilingPatterns(&getContext());
+    linalg::LinalgTransformationFilter filter(
+        {StringAttr::get(context, getWorkgroupMemoryMarker())}, llvm::None);
+    populateTilingToInvocationPatterns(tilingPatterns, filter);
+    if (failed(
+            applyPatternsAndFoldGreedily(funcOp, std::move(tilingPatterns)))) {
+      funcOp.emitOpError() << "failed tiling and distributing to invocations";
+      return signalPassFailure();
+    }
+
+    RewritePatternSet patterns =
+        linalg::getLinalgTilingCanonicalizationPatterns(context);
+    populateFoldAffineMinInDistributedLoopsPatterns(patterns);
+    if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+      // TODO(#4759): This does not converge after the max number of iterations.
+      // It indicates that some pattern upstream is generating ops even when the
+      // pattern failed to match. Not related to correctness, but would be good
+      // to figure out and fix.
+      // return signalPassFailure();
+    }
+  }
+
+  LLVM_DEBUG({
+    llvm::dbgs() << "--- After tiling to invocations ---\n";
+    funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+    llvm::dbgs() << "\n\n";
+  });
+}
+
+std::unique_ptr<OperationPass<func::FuncOp>> createSPIRVTileAndPromotePass() {
+  return std::make_unique<SPIRVTileAndPromotePass>();
+}
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp b/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
index 4e0bc80..fd2ae6b 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
@@ -225,6 +225,7 @@
     // tensors and hoist them out of loop nests. So after it we have
     // loop-carried vectors, not loop-carried tensors anymore.
     linalg::hoistRedundantVectorTransfersOnTensor(funcOp);
+    linalg::hoistRedundantVectorTransfers(funcOp);
 
     LLVM_DEBUG({
       llvm::dbgs() << "--- After hoisting vector transfers ---\n";
diff --git a/iree/compiler/Codegen/SPIRV/Utils.cpp b/iree/compiler/Codegen/SPIRV/Utils.cpp
index 7b91b13..b3c1dc4 100644
--- a/iree/compiler/Codegen/SPIRV/Utils.cpp
+++ b/iree/compiler/Codegen/SPIRV/Utils.cpp
@@ -12,6 +12,7 @@
 
 #include "iree/compiler/Codegen/SPIRV/Utils.h"
 
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
 #include "iree/compiler/Dialect/HAL/IR/HALOps.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
diff --git a/iree/compiler/Codegen/SPIRV/Utils.h b/iree/compiler/Codegen/SPIRV/Utils.h
index c213d56..edcfe26 100644
--- a/iree/compiler/Codegen/SPIRV/Utils.h
+++ b/iree/compiler/Codegen/SPIRV/Utils.h
@@ -20,8 +20,6 @@
 namespace mlir {
 namespace iree_compiler {
 
-static constexpr int kNumGPUDims = 3;
-
 /// Given an operation, return the `spv.target_env` attribute.
 spirv::TargetEnvAttr getSPIRVTargetEnvAttr(Operation *op);
 
diff --git a/iree/compiler/Codegen/SPIRV/test/BUILD b/iree/compiler/Codegen/SPIRV/test/BUILD
index 2fa72eb..7016e86 100644
--- a/iree/compiler/Codegen/SPIRV/test/BUILD
+++ b/iree/compiler/Codegen/SPIRV/test/BUILD
@@ -32,10 +32,12 @@
             "create_fast_slow_path.mlir",
             "distribute_to_invocations.mlir",
             "pipeline_matmul_cooperative_ops.mlir",
+            "pipeline_matmul_promotion.mlir",
             "pipeline_matmul_vectorization.mlir",
             "tile_and_distribute.mlir",
             "tile_and_distribute_scatter.mlir",
             "tile_and_distribute_sort.mlir",
+            "tile_and_promote_matmul.mlir",
             "tile_and_vectorize_batch_matmul.mlir",
             "tile_and_vectorize_conv.mlir",
             "tile_and_vectorize_matmul.mlir",
@@ -47,10 +49,6 @@
             "vectorize_tensor_pad.mlir",
         ],
         include = ["*.mlir"],
-        # TODO(b/203528778) reenable
-        exclude = [
-            "promote_workgroup_memory.mlir",
-        ],
     ),
     tools = [
         "//iree/tools:iree-opt",
diff --git a/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt b/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
index 2714f0b..9ee9691 100644
--- a/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
+++ b/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
@@ -27,10 +27,12 @@
     "create_fast_slow_path.mlir"
     "distribute_to_invocations.mlir"
     "pipeline_matmul_cooperative_ops.mlir"
+    "pipeline_matmul_promotion.mlir"
     "pipeline_matmul_vectorization.mlir"
     "tile_and_distribute.mlir"
     "tile_and_distribute_scatter.mlir"
     "tile_and_distribute_sort.mlir"
+    "tile_and_promote_matmul.mlir"
     "tile_and_vectorize_batch_matmul.mlir"
     "tile_and_vectorize_conv.mlir"
     "tile_and_vectorize_matmul.mlir"
diff --git a/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_promotion.mlir b/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_promotion.mlir
new file mode 100644
index 0000000..bf27da5
--- /dev/null
+++ b/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_promotion.mlir
@@ -0,0 +1,62 @@
+// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-codegen-linalg-to-spirv-pipeline))' %s | FileCheck %s
+
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>,
+    #hal.descriptor_set.binding<3, storage_buffer>
+  ]>
+]>
+#map = affine_map<(d0, d1) -> (d0, d1)>
+
+hal.executable @matmul_128x256x64 {
+  hal.executable.variant public @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", {
+    spv.target_env = #spv.target_env<#spv.vce<v1.5, [Shader], []>, NVIDIA:DiscreteGPU,
+      {max_compute_shared_memory_size = 49152 : i32,
+       max_compute_workgroup_invocations = 1024 : i32,
+       max_compute_workgroup_size = dense<[65535, 65535, 65535]> : vector<3xi32>,
+       subgroup_size = 32 : i32}>}> {
+    hal.executable.entry_point public @matmul_128x256x64 ordinal(0) layout(#executable_layout)
+    builtin.module {
+      func.func @matmul_128x256x64() {
+        %cst = arith.constant 0.000000e+00 : f32
+        %c0 = arith.constant 0 : index
+        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x64xf32>
+        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:64x256xf32>
+        %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:128x256xf32>
+        %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:128x256xf32>
+        %4 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [128, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x64xf32> -> tensor<128x64xf32>
+        %5 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [64, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:64x256xf32> -> tensor<64x256xf32>
+        %6 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [128, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x256xf32> -> tensor<128x256xf32>
+        %7 = linalg.init_tensor [128, 256] : tensor<128x256xf32>
+        %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<128x256xf32>) -> tensor<128x256xf32>
+        %9 = linalg.matmul ins(%4, %5 : tensor<128x64xf32>, tensor<64x256xf32>) outs(%8 : tensor<128x256xf32>) -> tensor<128x256xf32>
+        %10 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
+                ins(%9, %6 : tensor<128x256xf32>, tensor<128x256xf32>) outs(%7 : tensor<128x256xf32>) {
+        ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
+          %11 = arith.divf %arg0, %arg1 : f32
+          linalg.yield %11 : f32
+        } -> tensor<128x256xf32>
+        flow.dispatch.tensor.store %10, %3, offsets = [0, 0], sizes = [128, 256], strides = [1, 1] : tensor<128x256xf32> -> !flow.dispatch.tensor<writeonly:128x256xf32>
+        return
+      }
+    }
+  }
+}
+
+// CHECK: spv.GlobalVariable @{{.+}} : !spv.ptr<!spv.struct<(!spv.array<1024 x vector<4xf32>, stride=16>)>, Workgroup>
+// CHECK: spv.GlobalVariable @{{.+}} : !spv.ptr<!spv.struct<(!spv.array<1024 x vector<4xf32>, stride=16>)>, Workgroup>
+
+// CHECK-LABEL: spv.func @matmul_128x256x64
+
+//           CHECK: spv.mlir.loop
+//           CHECK:   spv.ControlBarrier Workgroup, Workgroup, "AcquireRelease|WorkgroupMemory"
+//   CHECK-COUNT-4:   spv.Load "StorageBuffer" %{{.+}} : vector<4xf32>
+//   CHECK-COUNT-4:   spv.Store "Workgroup" %{{.+}}, %{{.+}} : vector<4xf32>
+//           CHECK:   spv.ControlBarrier Workgroup, Workgroup, "AcquireRelease|WorkgroupMemory"
+// CHECK-COUNT-512:   spv.GLSL.Fma %{{.+}}, %{{.+}}, %{{.+}} : vector<4xf32>
+//           CHECK:   spv.mlir.merge
+//  CHECK-COUNT-16:   spv.Load "StorageBuffer" %{{.+}} : vector<4xf32>
+//  CHECK-COUNT-16:   spv.FDiv %{{.+}}, %{{.+}} : vector<4xf32>
+//  CHECK-COUNT-16:   spv.Store "StorageBuffer" %{{.+}}, %{{.+}} : vector<4xf32>
diff --git a/iree/compiler/Codegen/SPIRV/test/promote_workgroup_memory.mlir b/iree/compiler/Codegen/SPIRV/test/promote_workgroup_memory.mlir
deleted file mode 100644
index 252ceb3..0000000
--- a/iree/compiler/Codegen/SPIRV/test/promote_workgroup_memory.mlir
+++ /dev/null
@@ -1,135 +0,0 @@
-// TODO(antiagainst): Fix promotion to workgroup and enable the test.
-// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile-and-distribute,iree-spirv-vectorize,canonicalize,cse))))' | FileCheck %s
-
-hal.executable private @matmul_promote_workgroup_memory  {
-  hal.interface @io {
-    hal.interface.binding @s0b0_ro_external, set=0, binding=0, type="StorageBuffer"
-    hal.interface.binding @s0b1_ro_external, set=0, binding=1, type="StorageBuffer"
-    hal.interface.binding @s0b2_xw_external, set=0, binding=2, type="StorageBuffer"
-  }
-  hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb"> {
-    hal.executable.entry_point @matmul_promote_workgroup_memory interface(@io) {
-      workgroup_size = [16: index, 8: index, 1: index]
-    }
-    builtin.module {
-      func.func @matmul_promote_workgroup_memory() {
-        %c32 = arith.constant 32 : index
-        %c50 = arith.constant 50 : index
-        %c0 = arith.constant 0 : index
-        %0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : memref<25x50xf32>
-        %1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : memref<50x75xf32>
-        %2 = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : memref<25x75xf32>
-        %3 = hal.interface.workgroup.id[0] : index
-        %4 = hal.interface.workgroup.id[1] : index
-        scf.for %arg0 = %c0 to %c50 step %c32 {
-          %5 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%4]
-          %6 = affine.min affine_map<()[s0] -> (8, s0 * -8 + 25)>()[%4]
-          %7 = affine.min affine_map<(d0) -> (32, -d0 + 50)>(%arg0)
-          %8 = memref.subview %0[%5, %arg0] [%6, %7] [1, 1] : memref<25x50xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 50 + s0 + d1)>>
-          %9 = affine.min affine_map<(d0) -> (32, -d0 + 50)>(%arg0)
-          %10 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%3]
-          %11 = affine.min affine_map<()[s0] -> (16, s0 * -16 + 75)>()[%3]
-          %12 = memref.subview %1[%arg0, %10] [%9, %11] [1, 1] : memref<50x75xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 75 + s0 + d1)>>
-          %13 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%4]
-          %14 = affine.min affine_map<()[s0] -> (8, s0 * -8 + 25)>()[%4]
-          %15 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%3]
-          %16 = affine.min affine_map<()[s0] -> (16, s0 * -16 + 75)>()[%3]
-          %17 = memref.subview %2[%13, %15] [%14, %16] [1, 1] : memref<25x75xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 75 + s0 + d1)>>
-          linalg.matmul {__internal_linalg_transform__ = "workgroup", lowering_config = {tileSizes = [[8, 16, 32], [], [1, 1, 0]]}}
-            ins(%8, %12 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 50 + s0 + d1)>>, memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 75 + s0 + d1)>>)
-            outs(%17 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 75 + s0 + d1)>>)
-        }
-        return
-      }
-      hal.interface private @io  {
-        hal.interface.binding @s0b0_ro_external, set=0, binding=0, type="StorageBuffer"
-        hal.interface.binding @s0b1_ro_external, set=0, binding=1, type="StorageBuffer"
-        hal.interface.binding @s0b2_xw_external, set=0, binding=2, type="StorageBuffer"
-      }
-    }
-  }
-}
-
-// CHECK-LABEL: func @matmul_promote_workgroup_memory()
-//   CHECK-DAG:   %[[ARG0:.+]] = hal.interface.binding.subspan @io::@s0b0_ro_external
-//   CHECK-DAG:   %[[ARG1:.+]] = hal.interface.binding.subspan @io::@s0b1_ro_external
-//   CHECK-DAG:   %[[RET0:.+]] = hal.interface.binding.subspan @io::@s0b2_xw_external
-//   CHECK-DAG:   %[[ALLOC1:.+]] = memref.alloc() : memref<8x32xf32, 3>
-//   CHECK-DAG:   %[[ALLOC2:.+]] = memref.alloc() : memref<32x16xf32, 3>
-//       CHECK:   scf.for
-//       CHECK:     %[[ARG0SV:.+]] = memref.subview %[[ARG0]]
-//       CHECK:     %[[ARG1SV:.+]] = memref.subview %[[ARG1]]
-//       CHECK:     %[[RET0SV:.+]] = memref.subview %[[RET0]]
-//       CHECK:     %[[SUBVIEW1:.+]] = memref.subview %[[ALLOC1]]
-//       CHECK:     %[[SUBVIEW2:.+]] = memref.subview %[[ALLOC2]]
-//       CHECK:     linalg.generic(%[[ARG0SV]], %[[SUBVIEW1]])
-//  CHECK-SAME:       "copy_to_workgroup_memory"
-//       CHECK:     linalg.generic(%[[ARG1SV]], %[[SUBVIEW2]])
-//  CHECK-SAME:       "copy_to_workgroup_memory"
-//       CHECK:     scf.for
-//       CHECK:       scf.for
-//   CHECK-DAG:         memref.subview %[[SUBVIEW1]]
-//   CHECK-DAG:         memref.subview %[[SUBVIEW2]]
-//   CHECK-DAG:         memref.subview %[[RET0SV]]
-
-// -----
-
-hal.executable private @conv_promote_workgroup_memory  {
-  hal.interface @io {
-    hal.interface.binding @s0b0_ro_external, set=0, binding=0, type="StorageBuffer"
-    hal.interface.binding @s0b1_ro_external, set=0, binding=1, type="StorageBuffer"
-    hal.interface.binding @s0b2_xw_external, set=0, binding=2, type="StorageBuffer"
-  }
-  hal.executable.variant @vulkan, target = <"vulkan-spirv", "vulkan-spirv-fb"> {
-    hal.executable.entry_point @conv_promote_workgroup_memory interface(@io) {
-      workgroup_size = [32: index, 4: index, 1: index]
-    }
-    builtin.module {
-      func.func @conv_promote_workgroup_memory() {
-        %c0 = arith.constant 0 : index
-        %0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : memref<3x4x6x14xf32>
-        %1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : memref<2x15x14x6xf32>
-        %2 = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : memref<2x13x11x14xf32>
-        %3 = hal.interface.workgroup.id[0] : index
-        %4 = hal.interface.workgroup.id[1] : index
-        %5 = hal.interface.workgroup.id[2] : index
-        %6 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%4]
-        %7 = affine.min affine_map<()[s0] -> (6, s0 * -4 + 15)>()[%4]
-        %8 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%3]
-        %9 = affine.min affine_map<()[s0] -> (35, s0 * -32 + 14)>()[%3]
-        %10 = memref.subview %1[%5, %6, %8, 0] [1, %7, %9, 6] [1, 1, 1, 1] : memref<2x15x14x6xf32> to memref<1x?x?x6xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1260 + s0 + d1 * 84 + d2 * 6 + d3)>>
-        %11 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%4]
-        %12 = affine.min affine_map<()[s0] -> (4, s0 * -4 + 13)>()[%4]
-        %13 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%3]
-        %14 = affine.min affine_map<()[s0] -> (32, s0 * -32 + 11)>()[%3]
-        %15 = memref.subview %2[%5, %11, %13, 0] [1, %12, %14, 14] [1, 1, 1, 1] : memref<2x13x11x14xf32> to memref<1x?x?x14xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2002 + s0 + d1 * 154 + d2 * 14 + d3)>>
-        linalg.conv_2d_nhwc_hwcf {__internal_linalg_transform__ = "workgroup", lowering_config = {tileSizes = [[0, 1, 4, 32], [], [0, 1, 1, 1]]}, dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
-          ins(%10, %0 : memref<1x?x?x6xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1260 + s0 + d1 * 84 + d2 * 6 + d3)>>, memref<3x4x6x14xf32>)
-          outs(%15 : memref<1x?x?x14xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2002 + s0 + d1 * 154 + d2 * 14 + d3)>>)
-        return
-      }
-      hal.interface private @io  {
-        hal.interface.binding @s0b0_ro_external, set=0, binding=0, type="StorageBuffer"
-        hal.interface.binding @s0b1_ro_external, set=0, binding=1, type="StorageBuffer"
-        hal.interface.binding @s0b2_xw_external, set=0, binding=2, type="StorageBuffer"
-      }
-    }
-  }
-}
-
-// CHECK-LABEL: func @conv_promote_workgroup_memory()
-//   CHECK-DAG:   %[[ARG0:.+]] = hal.interface.binding.subspan @io::@s0b0_ro_external
-//   CHECK-DAG:   %[[ARG1:.+]] = hal.interface.binding.subspan @io::@s0b1_ro_external
-//   CHECK-DAG:   %[[RET0:.+]] = hal.interface.binding.subspan @io::@s0b2_xw_external
-//   CHECK-DAG:   %[[ALLOC1:.+]] = memref.alloc() : memref<1x6x35x6xf32, 3>
-//       CHECK:   %[[ARG1SV:.+]] = memref.subview %[[ARG1]]
-//       CHECK:   %[[RET0SV:.+]] = memref.subview %[[RET0]]
-//       CHECK:   %[[SUBVIEW1:.+]] = memref.subview %[[ALLOC1]]
-//       CHECK:   linalg.generic(%[[ARG1SV]], %[[SUBVIEW1]])
-//  CHECK-SAME:      "copy_to_workgroup_memory"
-//       CHECK:   scf.for
-//       CHECK:     scf.for
-//       CHECK:       scf.for
-//   CHECK-DAG:         memref.subview %[[SUBVIEW1]]
-//   CHECK-DAG:         memref.subview %[[ARG0]]
-//   CHECK-DAG:         memref.subview %[[RET0SV]]
diff --git a/iree/compiler/Codegen/SPIRV/test/tile_and_promote_matmul.mlir b/iree/compiler/Codegen/SPIRV/test/tile_and_promote_matmul.mlir
new file mode 100644
index 0000000..548c498
--- /dev/null
+++ b/iree/compiler/Codegen/SPIRV/test/tile_and_promote_matmul.mlir
@@ -0,0 +1,130 @@
+// RUN: iree-opt -split-input-file -mlir-print-local-scope -pass-pipeline='hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile-and-promote))))' -cse %s | FileCheck %s
+
+#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>,
+    #hal.descriptor_set.binding<3, storage_buffer>
+  ]>
+]>
+#config = #iree_codegen.lowering_config<tile_sizes = [[128, 128], [16, 4], [0, 0, 32]]>
+
+hal.executable @matmul_256x1024x128 {
+  hal.executable.variant public @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", {
+    spv.target_env = #spv.target_env<#spv.vce<v1.5, [Shader], []>, NVIDIA:DiscreteGPU,
+      {max_compute_shared_memory_size = 49152 : i32,
+       max_compute_workgroup_invocations = 1024 : i32,
+       max_compute_workgroup_size = dense<[2147483647, 65535, 65535]> : vector<3xi32>,
+       subgroup_size = 32 : i32}>}> {
+    hal.executable.entry_point public @matmul_256x1024x128 ordinal(0) layout(#executable_layout) {
+      translation_info = #iree_codegen.translation_info<SPIRVVectorizeWithWorkgroupMemory, workload_per_wg = [128, 128]>,
+      workgroup_size = [32 : index, 8 : index, 1 : index]
+    }
+    builtin.module {
+      func.func @matmul_256x1024x128() {
+        %c1024 = arith.constant 1024 : index
+        %c256 = arith.constant 256 : index
+        %c0 = arith.constant 0 : index
+        %cst = arith.constant 0.000000e+00 : f32
+        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<256x128xf32>
+        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<128x1024xf32>
+        %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : memref<256x1024xf32>
+        %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(64) : memref<256x1024xf32>
+        %workgroup_id_x = hal.interface.workgroup.id[0] : index
+        %workgroup_count_x = hal.interface.workgroup.count[0] : index
+        %workgroup_id_y = hal.interface.workgroup.id[1] : index
+        %workgroup_count_y = hal.interface.workgroup.count[1] : index
+        %4 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_id_y]
+        %5 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_count_y]
+        scf.for %arg0 = %4 to %c256 step %5 {
+          %6 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_id_x]
+          %7 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%workgroup_count_x]
+          scf.for %arg1 = %6 to %c1024 step %7 {
+            %8 = memref.subview %2[%arg0, %arg1] [128, 128] [1, 1] : memref<256x1024xf32> to memref<128x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
+            %9 = memref.subview %0[%arg0, 0] [128, 128] [1, 1] : memref<256x128xf32> to memref<128x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>>
+            %10 = memref.subview %1[0, %arg1] [128, 128] [1, 1] : memref<128x1024xf32> to memref<128x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
+            %11 = memref.subview %3[%arg0, %arg1] [128, 128] [1, 1] : memref<256x1024xf32> to memref<128x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
+            linalg.fill {lowering_config = #config}
+              ins(%cst : f32) outs(%11 : memref<128x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>)
+            linalg.matmul {lowering_config = #config}
+              ins(%9, %10 : memref<128x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>>, memref<128x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>)
+              outs(%11 : memref<128x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>)
+            linalg.generic {
+              indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+              iterator_types = ["parallel", "parallel"]}
+              ins(%11, %8 : memref<128x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>, memref<128x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>)
+              outs(%11 : memref<128x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>)
+              attrs =  {lowering_config = #config} {
+            ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
+              %12 = arith.divf %arg2, %arg3 : f32
+              linalg.yield %12 : f32
+            }
+          }
+        }
+        return
+      }
+    }
+  }
+}
+
+// CHECK-LABEL: func @matmul_256x1024x128()
+
+//  CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
+//  CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
+
+//  CHECK-DAG: %[[MEM_A:.+]] = memref.alloc() : memref<128x32xf32, 3>
+//  CHECK-DAG: %[[MEM_B:.+]] = memref.alloc() : memref<32x128xf32, 3>
+
+//  CHECK-DAG: %[[BUFFER_A:.+]] = hal.interface.binding.subspan set(0) binding(0) {{.+}} : memref<256x128xf32>
+//  CHECK-DAG: %[[BUFFER_B:.+]] = hal.interface.binding.subspan set(0) binding(1) {{.+}} : memref<128x1024xf32>
+//  CHECK-DAG: %[[BUFFER_C:.+]] = hal.interface.binding.subspan set(0) binding(3) {{.+}} : memref<256x1024xf32>
+//  CHECK-DAG: %[[BUFFER_D:.+]] = hal.interface.binding.subspan set(0) binding(2) {{.+}} : memref<256x1024xf32>
+
+//      CHECK: scf.for
+//      CHECK:   scf.for
+//      CHECK:     %[[D:.+]] = memref.subview %[[BUFFER_D]]
+//      CHECK:     %[[A:.+]] = memref.subview %[[BUFFER_A]]
+//      CHECK:     %[[B:.+]] = memref.subview %[[BUFFER_B]]
+//      CHECK:     %[[C:.+]] = memref.subview %[[BUFFER_C]]
+//      CHECK:     %[[T_ID_X:.+]] = gpu.thread_id  x
+//      CHECK:     %[[T_DIM_X:.+]] = gpu.block_dim  x
+//      CHECK:     %[[T_ID_Y:.+]] = gpu.thread_id  y
+//      CHECK:     %[[T_DIM_Y:.+]] = gpu.block_dim  y
+//      CHECK:     %[[T_OFFSET_Y:.+]] = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%[[T_ID_Y]]]
+//      CHECK:     %[[T_SIZE_Y:.+]] = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%[[T_DIM_Y]]]
+
+//      CHECK:     scf.for %[[T_IV_Y:.+]] =
+//      CHECK:       scf.for %[[T_IV_X:.+]] =
+//      CHECK:         %[[VIEW_C:.+]] = memref.subview %[[C]][%[[T_IV_Y]], %[[T_IV_X]]] [16, 4] [1, 1]
+//      CHECK:         linalg.fill
+// CHECK-SAME:           outs(%[[VIEW_C]]
+
+//      CHECK:     scf.for %[[T_IV_Y:.+]] = %[[C0]] to %[[C128]] step %[[C32]] {
+//      CHECK:       %[[VIEW_A:.+]] = memref.subview %[[A]][0, %[[T_IV_Y]]] [128, 32]
+//      CHECK:       %[[VIEW_B:.+]] = memref.subview %[[B]][%[[T_IV_Y]], 0] [32, 128]
+
+//      CHECK:       gpu.barrier
+//      CHECK:       memref.copy %[[VIEW_A]], %[[MEM_A]]
+// CHECK-SAME:           __internal_linalg_transform__ = "copy_to_workgroup_memory"
+//      CHECK:       memref.copy %[[VIEW_B]], %[[MEM_B]]
+// CHECK-SAME:           __internal_linalg_transform__ = "copy_to_workgroup_memory"
+//      CHECK:       gpu.barrier
+
+//      CHECK:       scf.for %[[T_IV_Y:.+]] =
+//      CHECK:         scf.for %[[T_IV_X:.+]] =
+//      CHECK:           %[[VIEW_A:.+]] = memref.subview %[[MEM_A]][%[[T_IV_Y]], 0] [16, 32]
+//      CHECK:           %[[VIEW_B:.+]] = memref.subview %[[MEM_B]][0, %[[T_IV_X]]] [32, 4]
+//      CHECK:           %[[VIEW_C:.+]] = memref.subview %[[C]][%[[T_IV_Y]], %[[T_IV_X]]] [16, 4]
+//      CHECK:           linalg.matmul
+// CHECK-SAME:             ins(%[[VIEW_A]], %[[VIEW_B]]
+// CHECK-SAME:             outs(%[[VIEW_C]]
+
+//      CHECK:     scf.for %[[T_IV_Y:.+]] =
+//      CHECK:       scf.for %[[T_IV_X:.+]] =
+//      CHECK:         %[[VIEW_C:.+]] = memref.subview %[[C]][%[[T_IV_Y]], %[[T_IV_X]]]
+//      CHECK:         %[[VIEW_D:.+]] = memref.subview %[[D]][%[[T_IV_Y]], %[[T_IV_X]]]
+//      CHECK:         linalg.generic
+// CHECK-SAME:           ins(%[[VIEW_C]], %[[VIEW_D]]
+// CHECK-SAME:           outs(%[[VIEW_C]]
diff --git a/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir b/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
index 83067fb..0279fcd 100644
--- a/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
@@ -55,10 +55,10 @@
 //   CHECK-DAG:   %[[C128:.+]] = arith.constant 128 : index
 
 //       CHECK:   scf.for %[[IV_Y:.+]] = %[[C0]] to %[[C2]] step %[[C1]]
+//       CHECK:     %[[LHS_TILE:.+]] = tensor.extract_slice %{{.+}}[%[[IV_Y]], 0] [1, 4]
+//       CHECK:     %[[LHS_VECTOR:.+]] = vector.transfer_read %[[LHS_TILE]][%[[C0]], %[[C0]]], %[[PAD]]
 //       CHECK:     scf.for %[[IV_X:.+]] = %[[C0]] to %[[C128]] step %[[C4]] iter_args(%[[ACC_TILE:.+]] =
-//       CHECK:       %[[LHS_TILE:.+]] = tensor.extract_slice %{{.+}}[%[[IV_Y]], 0] [1, 4]
 //       CHECK:       %[[RHS_TILE:.+]] = tensor.extract_slice %{{.+}}[0, %[[IV_X]]] [4, 4]
-//       CHECK:       %[[LHS_VECTOR:.+]] = vector.transfer_read %[[LHS_TILE]][%[[C0]], %[[C0]]], %[[PAD]]
 //       CHECK:       %[[RHS_0_VECTOR:.+]] = vector.transfer_read %[[RHS_TILE]][%[[C0]], %[[C0]]], %[[PAD]]
 //       CHECK:       %[[RHS_1_VECTOR:.+]] = vector.transfer_read %[[RHS_TILE]][%[[C1]], %[[C0]]], %[[PAD]]
 //       CHECK:       %[[RHS_2_VECTOR:.+]] = vector.transfer_read %[[RHS_TILE]][%[[C2]], %[[C0]]], %[[PAD]]
diff --git a/iree/compiler/Codegen/Utils/BUILD b/iree/compiler/Codegen/Utils/BUILD
index 19bddad..e7510ff 100644
--- a/iree/compiler/Codegen/Utils/BUILD
+++ b/iree/compiler/Codegen/Utils/BUILD
@@ -15,10 +15,12 @@
 cc_library(
     name = "Utils",
     srcs = [
+        "GPUUtils.cpp",
         "MarkerUtils.cpp",
         "Utils.cpp",
     ],
     hdrs = [
+        "GPUUtils.h",
         "MarkerUtils.h",
         "Utils.h",
     ],
@@ -29,6 +31,7 @@
         "//iree/compiler/Dialect/HAL/IR",
         "//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
         "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:GPUDialect",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LinalgOps",
         "@llvm-project//mlir:LinalgTransforms",
diff --git a/iree/compiler/Codegen/Utils/CMakeLists.txt b/iree/compiler/Codegen/Utils/CMakeLists.txt
index de6f31c..ce45f4e 100644
--- a/iree/compiler/Codegen/Utils/CMakeLists.txt
+++ b/iree/compiler/Codegen/Utils/CMakeLists.txt
@@ -14,14 +14,17 @@
   NAME
     Utils
   HDRS
+    "GPUUtils.h"
     "MarkerUtils.h"
     "Utils.h"
   SRCS
+    "GPUUtils.cpp"
     "MarkerUtils.cpp"
     "Utils.cpp"
   DEPS
     IREELinalgExtDialect
     LLVMSupport
+    MLIRGPUOps
     MLIRIR
     MLIRLinalg
     MLIRLinalgTransforms
diff --git a/iree/compiler/Codegen/Utils/GPUUtils.cpp b/iree/compiler/Codegen/Utils/GPUUtils.cpp
new file mode 100644
index 0000000..09efafd
--- /dev/null
+++ b/iree/compiler/Codegen/Utils/GPUUtils.cpp
@@ -0,0 +1,118 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
+
+#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/IR/Matchers.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+llvm::SmallVector<mlir::linalg::ProcInfo, 2> getGPUThreadIdsAndCounts(
+    mlir::OpBuilder &builder, mlir::Location loc, unsigned numDims,
+    llvm::ArrayRef<int64_t> workgroupSize) {
+  assert(numDims <= kNumGPUDims);
+  llvm::SmallVector<mlir::linalg::ProcInfo, 2> procInfo(numDims);
+  std::array<gpu::Dimension, kNumGPUDims> dimAttr{
+      gpu::Dimension::x, gpu::Dimension::y, gpu::Dimension::z};
+  mlir::Type indexType = builder.getIndexType();
+  for (unsigned i = 0; i < numDims; ++i) {
+    procInfo[numDims - 1 - i] = {
+        builder.create<mlir::gpu::ThreadIdOp>(loc, indexType, dimAttr[i]),
+        builder.create<mlir::arith::ConstantOp>(
+            loc, builder.getIndexAttr(workgroupSize[i]))};
+  }
+  return procInfo;
+}
+
+llvm::SmallVector<mlir::linalg::ProcInfo, 2> getSubgroupIdsAndCounts(
+    mlir::OpBuilder &builder, mlir::Location loc, unsigned numDims,
+    llvm::ArrayRef<int64_t> numSubgroups) {
+  assert(numDims <= kNumGPUDims);
+  llvm::SmallVector<mlir::linalg::ProcInfo, 2> procInfo(numDims);
+  std::array<gpu::Dimension, kNumGPUDims> dimAttr{
+      gpu::Dimension::x, gpu::Dimension::y, gpu::Dimension::z};
+  mlir::Type indexType = builder.getIndexType();
+  for (unsigned i = 0; i < numDims; ++i) {
+    mlir::Value subgroupId =
+        builder.create<mlir::gpu::ThreadIdOp>(loc, indexType, dimAttr[i]);
+    if (i == 0) {
+      mlir::AffineExpr d0 = builder.getAffineDimExpr(0);
+      subgroupId = mlir::makeComposedAffineApply(
+          builder, loc, d0.floorDiv(builder.getAffineConstantExpr(kWarpSize)),
+          {subgroupId});
+    }
+    procInfo[numDims - 1 - i] = {
+        subgroupId, builder.create<mlir::arith::ConstantOp>(
+                        loc, builder.getIndexAttr(numSubgroups[i]))};
+  }
+  return procInfo;
+}
+
+std::array<int64_t, 3> getWorkgroupSize(mlir::func::FuncOp funcOp) {
+  std::array<int64_t, 3> workgroupSize;
+  auto entryPointOp = mlir::iree_compiler::getEntryPoint(funcOp);
+  llvm::Optional<mlir::ArrayAttr> workgroupSizeAttr =
+      entryPointOp.workgroup_size();
+  assert(workgroupSizeAttr.hasValue());
+  for (auto it : llvm::enumerate(workgroupSizeAttr.getValue())) {
+    workgroupSize[it.index()] =
+        it.value().cast<mlir::IntegerAttr>().getValue().getZExtValue();
+  }
+  return workgroupSize;
+}
+
+bool canPerformVectorAccessUsingAllThreads(ArrayRef<int64_t> shape,
+                                           int64_t threadCount,
+                                           int64_t vectorSize) {
+  // Verify that each dimension of the shape can be distributed on the
+  // threads
+  int64_t threadsAvailable = threadCount;
+  for (auto &dim : llvm::enumerate(llvm::reverse(shape))) {
+    int64_t numElementPerThread = dim.index() == 0 ? vectorSize : 1;
+    int64_t numThreads = dim.value() / numElementPerThread;
+    if (numThreads == 0) return false;
+    numThreads = std::min(numThreads, threadsAvailable);
+    if (threadsAvailable % numThreads != 0) return false;
+    threadsAvailable = threadsAvailable / numThreads;
+    if (threadsAvailable == 1) break;
+  }
+  return threadsAvailable == 1;
+}
+
+Optional<Value> allocateWorkgroupMemory(OpBuilder &builder,
+                                        memref::SubViewOp subview,
+                                        ArrayRef<Value> sizeBounds,
+                                        DataLayout &) {
+  OpBuilder::InsertionGuard guard(builder);
+
+  func::FuncOp funcOp = subview->getParentOfType<func::FuncOp>();
+  if (!funcOp) return llvm::None;
+
+  // The subview size bounds are expected to be constant; they specify the shape
+  // of the allocation.
+  SmallVector<int64_t, 2> shape;
+  for (Value bound : sizeBounds) {
+    APInt value;
+    if (!matchPattern(bound, m_ConstantInt(&value))) return llvm::None;
+    shape.push_back(value.getSExtValue());
+  }
+
+  builder.setInsertionPoint(&funcOp.front(), funcOp.front().begin());
+  auto type = MemRefType::get(shape, subview.getType().getElementType(), {},
+                              gpu::GPUDialect::getWorkgroupAddressSpace());
+  Value buffer = builder.create<memref::AllocOp>(funcOp.getLoc(), type);
+  return buffer;
+}
+
+LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value /*buffer*/) {
+  return success();
+}
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Codegen/Utils/GPUUtils.h b/iree/compiler/Codegen/Utils/GPUUtils.h
new file mode 100644
index 0000000..49593c3
--- /dev/null
+++ b/iree/compiler/Codegen/Utils/GPUUtils.h
@@ -0,0 +1,53 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_CODEGEN_UTILS_GPUUTILS_H_
+#define IREE_COMPILER_CODEGEN_UTILS_GPUUTILS_H_
+
+#include "iree/compiler/Codegen/Utils/Utils.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+static constexpr int32_t kNumGPUDims = 3;
+static constexpr int32_t kWarpSize = 32;
+
+llvm::SmallVector<linalg::ProcInfo, 2> getGPUThreadIdsAndCounts(
+    OpBuilder &builder, Location loc, unsigned numDims,
+    llvm::ArrayRef<int64_t> workgroupSize);
+
+/// Compute subgroup ID. CUDA doesn't have a subgroupId equivalent so we are are
+/// computing the subgroup ID based on the threadID.
+/// When tiling to warp we assume each warp is full and we pick a workgroup
+/// size so that `workgroupSize.x % warpSize == 0`. This is why we can have
+/// warpId = { threadId.x / warpSize, threadId.y, threadId.z }.
+llvm::SmallVector<linalg::ProcInfo, 2> getSubgroupIdsAndCounts(
+    OpBuilder &builder, Location loc, unsigned numDims,
+    llvm::ArrayRef<int64_t> numSubgroups);
+
+/// Return the workgroup size associated to the funcOp entry point.
+std::array<int64_t, 3> getWorkgroupSize(func::FuncOp funcOp);
+
+/// Return true if we can use all threads to perform vectorized load/store of
+/// the given `shape`.
+bool canPerformVectorAccessUsingAllThreads(ArrayRef<int64_t> shape,
+                                           int64_t threadCount,
+                                           int64_t vectorSize);
+
+/// Allocate GPU workgroup memory matching the given `subview`. If there are
+/// dynamic dimensions, the bounds are in `sizeBounds`.
+Optional<Value> allocateWorkgroupMemory(OpBuilder &builder,
+                                        memref::SubViewOp subview,
+                                        ArrayRef<Value> sizeBounds,
+                                        DataLayout &);
+
+/// Deallocate GPU workgroup memory behind `buffer`.
+LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value buffer);
+
+}  // namespace iree_compiler
+}  // namespace mlir
+
+#endif  // IREE_COMPILER_CODEGEN_UTILS_GPUUTILS_H_