[LLVMGPU] Use forall workgroup distribution in TileAndFuse pipeline (#18565)

This switches the TileAndFuse pipeline to use scf.forall distribution.
Using scf.forall distribution also requires some changes to the pass
ordering in the TileAndFuse pipeline, which is also handled by this PR:
1. The main difference is that PackToIntrinsics happens before workgroup
distribution. Otherwise, collapse_shape ops can end up at the end of the
workgroup forall, and an extra buffer is created.
2. Pack decomposition is now staged, with packs/unpacks at the function
boundaries being decomposed early before workgroup decomposition, and
the rest being decomposed after reduction tiling as before. This
prevents unpacks being fused into the workgroup forall and causing the
same problem as in (1).
3. `ConcretizeMmaShapes` now runs before workgroup tiling as well,
so the resulting collapse_shape on the multi_mma op result can be
propagated to the function boundary before any tiling. This is also to
avoid the same problem as in (1).

The lowering configs on the MMA path have also changed, since they now
need to account for inner tile sizes of packing.

depends on https://github.com/iree-org/iree/pull/18852

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
index 611a874..ca23b0c 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
@@ -245,10 +245,8 @@
   }
 
   // Compute the M/N dimension tile size by multiplying subgroup information.
-  workgroupTileSizes[mDim] =
-      schedule->mWarpCount * schedule->mTileCount * schedule->mSize;
-  workgroupTileSizes[nDim] =
-      schedule->nWarpCount * schedule->nTileCount * schedule->nSize;
+  workgroupTileSizes[mDim] = schedule->mWarpCount * schedule->mTileCount;
+  workgroupTileSizes[nDim] = schedule->nWarpCount * schedule->nTileCount;
 
   // Specify the subgroup tile sizes from the mma schedule. This is applied
   subgroupTileSizes[mDim] = schedule->mTileCount;
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index e8c3de8..76b1af3 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -19,6 +19,7 @@
 #include "iree/compiler/Codegen/Utils/GPUUtils.h"
 #include "iree/compiler/Codegen/Utils/MarkerUtils.h"
 #include "iree/compiler/Codegen/Utils/Utils.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
 #include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
 #include "iree/compiler/Dialect/Util/Transforms/Passes.h"
 #include "iree/compiler/Utils/PassUtils.h"
@@ -190,18 +191,23 @@
 }
 
 static void tileAndDistributeToWorkgroup(
-    OpPassManager &funcPassManager,
+    OpPassManager &funcPassManager, bool useForall,
     std::optional<ConvertToDestinationPassingStylePassOptions>
         convertToDpsOptions = ConvertToDestinationPassingStylePassOptions{}) {
-  funcPassManager.addPass(createTileAndDistributeToWorkgroupsPass(
-      kNumMaxParallelDims,
-      linalg::DistributionMethod::CyclicNumProcsEqNumIters));
-  funcPassManager.addPass(createCSEPass());
-
-  if (convertToDpsOptions) {
+  if (useForall) {
     funcPassManager.addPass(
-        createConvertToDestinationPassingStylePass(*convertToDpsOptions));
+        createTileAndDistributeToWorkgroupsUsingForallOpPass());
+  } else {
+    funcPassManager.addPass(createTileAndDistributeToWorkgroupsPass(
+        kNumMaxParallelDims,
+        linalg::DistributionMethod::CyclicNumProcsEqNumIters));
+    funcPassManager.addPass(createCSEPass());
+    if (convertToDpsOptions) {
+      funcPassManager.addPass(
+          createConvertToDestinationPassingStylePass(*convertToDpsOptions));
+    }
   }
+
   // TODO(#16421): Disable decomposition due to failure in bufferization.
   // funcPassManager.addPass(
   //     IREE::LinalgExt::createTileAndDecomposeAttentionPass());
@@ -212,7 +218,8 @@
 static void tileAndBufferize(OpPassManager &funcPassManager) {
   ConvertToDestinationPassingStylePassOptions options;
   options.useWARForCooperativeMatrixCodegen = true;
-  tileAndDistributeToWorkgroup(funcPassManager, options);
+  tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false,
+                               /*convertToDpsOptions=*/options);
   addBufferizePasses(funcPassManager);
 }
 
@@ -243,7 +250,7 @@
 //===---------------------------------------------------------------------===//
 
 void addGPUVectorizationPassPipeline(OpPassManager &funcPassManager) {
-  tileAndDistributeToWorkgroup(funcPassManager);
+  tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);
 
   funcPassManager.addPass(createCanonicalizerPass());
   funcPassManager.addPass(createCanonicalizerPass());
@@ -323,22 +330,45 @@
   funcPassManager.addPass(createCSEPass());
 }
 
+/// Control function for decomposing pack and unpack ops. Returns true if the
+/// op is a PackOp with a DispatchTensorLoadOp producer, or an UnPackOp with
+/// only DispatchTensorStoreOp consumers.
+LogicalResult isAtBoundary(Operation *op) {
+  if (isa<tensor::PackOp>(op)) {
+    if (isa_and_nonnull<IREE::Flow::DispatchTensorLoadOp>(
+            op->getOperand(0).getDefiningOp())) {
+      return success();
+    }
+  } else if (isa<tensor::UnPackOp>(op)) {
+    if (llvm::all_of(op->getUsers(), [](Operation *user) {
+          return isa<IREE::Flow::DispatchTensorStoreOp>(user);
+        })) {
+      return success();
+    }
+  }
+  return failure();
+}
+
 void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
                                    const GPUPipelineOptions &pipelineOptions) {
-  tileAndDistributeToWorkgroup(funcPassManager,
-                               /*convertToDpsOptions=*/std::nullopt);
-
   // Step 1. Promote matmul operands and pack to intrinsic shapes.
   funcPassManager.addPass(createGPUPromoteMatmulOperandsPass());
   funcPassManager.addPass(IREE::GPU::createPackToIntrinsicsPass());
+  // Decompose packs and unpacks that are at the function boundary.
+  funcPassManager.addPass(createDecomposeBoundaryPackUnPackOpsPass());
 
-  // Step 1.5. Expand result shapes of MultiMmaOps before reduction tiling.
+  // Step 1.5. Expand result shapes of MultiMmaOps before tiling, and
+  // propagate reshapes to the function boundary.
   {
     IREE::GPU::ConcretizeMmaShapesPassOptions options;
     options.concretizeInputs = false;
     options.concretizeResult = true;
     funcPassManager.addPass(IREE::GPU::createConcretizeMmaShapesPass());
   }
+  funcPassManager.addPass(createPropagateReshapesByExpansionPass());
+
+  tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/true,
+                               /*convertToDpsOptions=*/std::nullopt);
 
   // Step 2. Tile and fuse tileable ops to reduction loops.
   {
@@ -468,7 +498,7 @@
 //===---------------------------------------------------------------------===//
 
 void addGPUWinogradVectorizePassPipeline(OpPassManager &funcPassManager) {
-  tileAndDistributeToWorkgroup(funcPassManager);
+  tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);
 
   funcPassManager.addPass(createCanonicalizerPass());
   funcPassManager.addPass(createCanonicalizerPass());
@@ -505,7 +535,7 @@
 
 void addGPUMatmulSimtPassPipeline(OpPassManager &funcPassManager,
                                   const GPUPipelineOptions &options) {
-  tileAndDistributeToWorkgroup(funcPassManager);
+  tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);
 
   funcPassManager.addPass(createCanonicalizerPass());
   funcPassManager.addPass(createCanonicalizerPass());
@@ -709,7 +739,7 @@
 
 void addGPUTransposePassPipeline(OpPassManager &funcPassManager,
                                  const GPUPipelineOptions &options) {
-  tileAndDistributeToWorkgroup(funcPassManager);
+  tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);
 
   funcPassManager.addPass(createCanonicalizerPass());
   funcPassManager.addPass(createCanonicalizerPass());
@@ -814,7 +844,7 @@
 void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager,
                                         const GPUPipelineOptions &options,
                                         bool usePadToModelSharedMemcpy) {
-  tileAndDistributeToWorkgroup(funcPassManager);
+  tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);
 
   ReorderWorkgroupsStrategy reorderStrategy =
       getReorderWorkgroupsStrategy(options.reorderStrategy);
@@ -914,7 +944,7 @@
 }
 
 void addGPUWarpReductionPassPipeline(OpPassManager &funcPassManager) {
-  tileAndDistributeToWorkgroup(funcPassManager);
+  tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);
   funcPassManager.addPass(createRematerializeParallelOpsPass());
   funcPassManager.addPass(createCanonicalizerPass());
   funcPassManager.addPass(createGPUTileReductionPass());
@@ -958,7 +988,7 @@
 }
 
 void addGPUPackUnPackPasses(OpPassManager &funcPassManager) {
-  tileAndDistributeToWorkgroup(funcPassManager);
+  tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);
   funcPassManager.addPass(createCanonicalizerPass());
   funcPassManager.addPass(createCSEPass());
 
@@ -994,7 +1024,8 @@
                                const GPUPipelineOptions &options) {
   ConvertToDestinationPassingStylePassOptions dpsOptions;
   dpsOptions.useWARForCooperativeMatrixCodegen = true;
-  tileAndDistributeToWorkgroup(funcPassManager, dpsOptions);
+  tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false,
+                               /*convertToDpsOptions=*/dpsOptions);
   if (options.enableUkernels) {
     funcPassManager.addPass(createGPULowerToUKernelsPass());
   }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
index 53952e9..b98e85a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
@@ -38,7 +38,7 @@
 //  CHECK-SAME:     promote_operands = [0, 1]
 //  CHECK-SAME:     reduction = [0, 0, 0, 0, 4]
 //  CHECK-SAME:     subgroup = [0, 0, 4, 1, 0]
-//  CHECK-SAME:     workgroup = [1, 1, 64, 64, 0]
+//  CHECK-SAME:     workgroup = [1, 1, 4, 4, 0]
 
 // -----
 
@@ -63,7 +63,7 @@
 //  CHECK-SAME:     promote_operands = [0, 1]
 //  CHECK-SAME:     reduction = [0, 0, 2]
 //  CHECK-SAME:     subgroup = [4, 4, 0]
-//  CHECK-SAME:     workgroup = [128, 128, 0]
+//  CHECK-SAME:     workgroup = [8, 8, 0]
 
 // -----
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
index 0dc8b0f..912acf3 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
@@ -50,18 +50,20 @@
 //   CHECK-DAG:   %[[B2:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2)
 //   CHECK-DAG:   memref.alloc() : memref<64x8xf16, #gpu.address_space<workgroup>>
 //   CHECK-DAG:   memref.alloc() : memref<64x8xf16, #gpu.address_space<workgroup>>
-//       CHECK:   %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c1280 step %c4 {{.*}} -> (vector<8x4xf32>)
-//       CHECK:     gpu.barrier
-//   CHECK-DAG:     %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<2xf16>
-//   CHECK-DAG:     vector.transfer_write %[[LHS_RD]], %[[LHS_ALLOC:[A-Za-z0-9]+]]
-//   CHECK-DAG:     %[[RHS_RD:.+]] = vector.transfer_read %[[B1]]{{.*}} vector<2xf16>
-//   CHECK-DAG:     vector.transfer_write %[[RHS_RD]], %[[RHS_ALLOC:[A-Za-z0-9]+]]
-//       CHECK:     gpu.barrier
-//   CHECK-DAG:     %[[LHS_MM:.+]] = vector.transfer_read %[[LHS_ALLOC]]{{.*}} vector<8x4xf16>
-//   CHECK-DAG:     %[[RHS_MM:.+]] = vector.transfer_read %[[RHS_ALLOC]]{{.*}} vector<4x4xf16>
-//       CHECK:     %[[MM:.+]] = vector.contract {{.*}} %[[LHS_MM]], %[[RHS_MM]]
-//       CHECK:     scf.yield %[[MM]]
-//       CHECK:   vector.transfer_write %[[LOOP]], %[[B2]]
+//       CHECK:   scf.forall ({{.*}}) in (32, 160) {
+//       CHECK:     %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c1280 step %c4 {{.*}} -> (vector<8x4xf32>)
+//       CHECK:       gpu.barrier
+//   CHECK-DAG:       %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<2xf16>
+//   CHECK-DAG:       vector.transfer_write %[[LHS_RD]], %[[LHS_ALLOC:[A-Za-z0-9]+]]
+//   CHECK-DAG:       %[[RHS_RD:.+]] = vector.transfer_read %[[B1]]{{.*}} vector<2xf16>
+//   CHECK-DAG:       vector.transfer_write %[[RHS_RD]], %[[RHS_ALLOC:[A-Za-z0-9]+]]
+//       CHECK:       gpu.barrier
+//   CHECK-DAG:       %[[LHS_MM:.+]] = vector.transfer_read %[[LHS_ALLOC]]{{.*}} vector<8x4xf16>
+//   CHECK-DAG:       %[[RHS_MM:.+]] = vector.transfer_read %[[RHS_ALLOC]]{{.*}} vector<4x4xf16>
+//       CHECK:       %[[MM:.+]] = vector.contract {{.*}} %[[LHS_MM]], %[[RHS_MM]]
+//       CHECK:       scf.yield %[[MM]]
+//       CHECK:     vector.transfer_write %[[LOOP]], %[[B2]]
+//       CHECK:   } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
 
 // -----
 
@@ -71,7 +73,7 @@
   #hal.pipeline.binding<storage_buffer>
 ]>
 #config = #iree_gpu.lowering_config<{
-  workgroup = [64, 64, 0],
+  workgroup = [4, 4, 0],
   reduction = [0, 0, 2],
   subgroup = [2, 2],
   mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
@@ -112,21 +114,23 @@
 //   CHECK-DAG:   %[[B2:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2)
 //   CHECK-DAG:   memref.alloc() : memref<64x36xf16, #gpu.address_space<workgroup>>
 //   CHECK-DAG:   memref.alloc() : memref<64x36xf16, #gpu.address_space<workgroup>>
-//       CHECK:   %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c80 step %c2 {{.*}} -> (vector<2x2x4x1xf32>)
-//       CHECK:     gpu.barrier
-//   CHECK-DAG:     %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<8xf16>
-//   CHECK-DAG:     vector.transfer_write %[[LHS_RD]]
-//   CHECK-DAG:     %[[RHS_RD:.+]] = vector.transfer_read %[[B1]]{{.*}} vector<8xf16>
-//   CHECK-DAG:     vector.transfer_write %[[RHS_RD]]
-//       CHECK:     gpu.barrier
-//   CHECK-DAG:     vector.transfer_read {{.*}} #gpu.address_space<workgroup>>, vector<2x1x2x4xf16>
-//   CHECK-DAG:     vector.transfer_read {{.*}} #gpu.address_space<workgroup>>, vector<2x1x2x4xf16>
-//   CHECK-DAG:     vector.transpose %{{.*}}, [0, 2, 1, 3] : vector<2x1x2x4xf16>
-//   CHECK-DAG:     vector.transpose %{{.*}}, [0, 2, 1, 3] : vector<2x1x2x4xf16>
-// CHECK-COUNT-4:   amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32
-//       CHECK:     scf.yield
-//       CHECK:   %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 2, 1, 3] : vector<2x2x4x1xf32> to vector<2x4x2x1xf32>
-//       CHECK:   vector.transfer_write %[[LOOP_T]], %[[B2]]
+//       CHECK:   scf.forall ({{.*}}) in (32, 160) {
+//       CHECK:     %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c80 step %c2 {{.*}} -> (vector<2x2x4x1xf32>)
+//       CHECK:       gpu.barrier
+//   CHECK-DAG:       %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<8xf16>
+//   CHECK-DAG:       vector.transfer_write %[[LHS_RD]]
+//   CHECK-DAG:       %[[RHS_RD:.+]] = vector.transfer_read %[[B1]]{{.*}} vector<8xf16>
+//   CHECK-DAG:       vector.transfer_write %[[RHS_RD]]
+//       CHECK:       gpu.barrier
+//   CHECK-DAG:       vector.transfer_read {{.*}} #gpu.address_space<workgroup>>, vector<2x1x2x4xf16>
+//   CHECK-DAG:       vector.transfer_read {{.*}} #gpu.address_space<workgroup>>, vector<2x1x2x4xf16>
+//   CHECK-DAG:       vector.transpose %{{.*}}, [0, 2, 1, 3] : vector<2x1x2x4xf16>
+//   CHECK-DAG:       vector.transpose %{{.*}}, [0, 2, 1, 3] : vector<2x1x2x4xf16>
+// CHECK-COUNT-4:     amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32
+//       CHECK:       scf.yield
+//       CHECK:     %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 2, 1, 3] : vector<2x2x4x1xf32> to vector<2x4x2x1xf32>
+//       CHECK:     vector.transfer_write %[[LOOP_T]], %[[B2]]
+//       CHECK:   } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
 
 // -----
 
@@ -136,7 +140,7 @@
   #hal.pipeline.binding<storage_buffer>
 ]>
 #config = #iree_gpu.lowering_config<{
-  workgroup = [1, 64, 64, 0],
+  workgroup = [1, 4, 4, 0],
   reduction = [0, 0, 0, 2],
   subgroup = [1, 2, 2],
   mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
@@ -154,11 +158,11 @@
         %cst = arith.constant 0.000000e+00 : f32
         %c0 = arith.constant 0 : index
         %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x34x34x1280xf16>>
-        %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<3x3x1280x1280xf16>>
-        %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x16x16x1280xf32>>
+        %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<11520x1280xf16>>
+        %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x256x1280xf32>>
         %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [2, 34, 34, 1280], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x34x34x1280xf16>> -> tensor<2x34x34x1280xf16>
-        %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [3, 3, 1280, 1280], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<3x3x1280x1280xf16>> -> tensor<3x3x1280x1280xf16>
-        %5 = tensor.empty() : tensor<2x16x16x1280xf32>
+        %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [11520, 1280], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<11520x1280xf16>> -> tensor<11520x1280xf16>
+        %5 = tensor.empty() : tensor<2x256x1280xf32>
         %6 = tensor.empty() : tensor<2x256x11520xf16>
         %7 = iree_linalg_ext.im2col
             strides = [2, 2] dilations = [1, 1] kernel_size = [3, 3]
@@ -166,15 +170,13 @@
             batch_pos = [0] m_pos = [1, 2] k_pos = [3]
           ins(%3 : tensor<2x34x34x1280xf16>)
           outs(%6 : tensor<2x256x11520xf16>) -> tensor<2x256x11520xf16>
-        %collapsed = tensor.collapse_shape %4 [[0, 1, 2], [3]] : tensor<3x3x1280x1280xf16> into tensor<11520x1280xf16>
-        %collapsed_0 = tensor.collapse_shape %5 [[0], [1, 2], [3]] : tensor<2x16x16x1280xf32> into tensor<2x256x1280xf32>
-        %8 = linalg.fill ins(%cst : f32) outs(%collapsed_0 : tensor<2x256x1280xf32>) -> tensor<2x256x1280xf32>
+        %8 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2x256x1280xf32>) -> tensor<2x256x1280xf32>
         %9 = linalg.generic {
           indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
                            affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
                            affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>],
           iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
-          ins(%7, %collapsed : tensor<2x256x11520xf16>, tensor<11520x1280xf16>)
+          ins(%7, %4 : tensor<2x256x11520xf16>, tensor<11520x1280xf16>)
           outs(%8 : tensor<2x256x1280xf32>) attrs =  {lowering_config = #config} {
         ^bb0(%in: f16, %in_1: f16, %out: f32):
           %10 = arith.extf %in : f16 to f32
@@ -183,8 +185,7 @@
           %13 = arith.addf %12, %out : f32
           linalg.yield %13 : f32
         } -> tensor<2x256x1280xf32>
-        %expanded = tensor.expand_shape %9 [[0], [1, 2], [3]] output_shape [2, 16, 16, 1280] : tensor<2x256x1280xf32> into tensor<2x16x16x1280xf32>
-        flow.dispatch.tensor.store %expanded, %2, offsets = [0, 0, 0, 0], sizes = [2, 16, 16, 1280], strides = [1, 1, 1, 1] : tensor<2x16x16x1280xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x16x16x1280xf32>>
+        flow.dispatch.tensor.store %9, %2, offsets = [0, 0, 0], sizes = [2, 256, 1280], strides = [1, 1, 1] : tensor<2x256x1280xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x256x1280xf32>>
         return
       }
     }
@@ -200,22 +201,24 @@
 //     CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
 //     CHECK-DAG:   %[[C720:.+]] = arith.constant 720 : index
 //     CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
-//         CHECK:   %[[LOOP:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[C720]] step %[[C2]] {{.*}} -> (vector<1x2x2x4x1xf32>)
-//         CHECK:     gpu.barrier
-//     CHECK-DAG:     %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<8xf16>
-//     CHECK-DAG:     vector.transfer_write %[[LHS_RD]]
-//     CHECK-DAG:     %[[RHS_RD:.+]] = vector.transfer_read %[[B1]]{{.*}} vector<8xf16>
-//     CHECK-DAG:     vector.transfer_write %[[RHS_RD]]
-//         CHECK:     gpu.barrier
-//     CHECK-DAG:     %[[LHS_MM0:.+]] = vector.transfer_read {{.*}} vector<2x1x2x4xf16>
-//     CHECK-DAG:     %[[LHS_MM1:.+]] = vector.broadcast {{.*}} vector<2x1x2x4xf16> to vector<1x2x1x2x4xf16>
-//     CHECK-DAG:     %[[RHS_MM:.+]] = vector.transfer_read {{.*}} vector<2x4x2x1xf16>
-//     CHECK-DAG:     vector.transpose %[[LHS_MM1]], [0, 1, 3, 2, 4] : vector<1x2x1x2x4xf16> to vector<1x2x2x1x4xf16>
-//     CHECK-DAG:     vector.transpose %[[RHS_MM]], [0, 2, 3, 1] : vector<2x4x2x1xf16> to vector<2x2x1x4xf16>
-// CHECK-COUNT-4:     amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32
-//         CHECK:   %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 1, 3, 2, 4] : vector<1x2x2x4x1xf32> to vector<1x2x4x2x1xf32>
-//         CHECK:   %[[EXTRACT:.+]] = vector.extract %[[LOOP_T]][0] : vector<2x4x2x1xf32> from vector<1x2x4x2x1xf32>
-//         CHECK:   vector.transfer_write %[[EXTRACT]], %[[B2]]
+//         CHECK:   scf.forall ({{.*}}) in (2, 4, 20) {
+//         CHECK:     %[[LOOP:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[C720]] step %[[C2]] {{.*}} -> (vector<1x2x2x4x1xf32>)
+//         CHECK:       gpu.barrier
+//     CHECK-DAG:       %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<8xf16>
+//     CHECK-DAG:       vector.transfer_write %[[LHS_RD]]
+//     CHECK-DAG:       %[[RHS_RD:.+]] = vector.transfer_read %[[B1]]{{.*}} vector<8xf16>
+//     CHECK-DAG:       vector.transfer_write %[[RHS_RD]]
+//         CHECK:       gpu.barrier
+//     CHECK-DAG:       %[[LHS_MM0:.+]] = vector.transfer_read {{.*}} vector<2x1x2x4xf16>
+//     CHECK-DAG:       %[[LHS_MM1:.+]] = vector.broadcast {{.*}} vector<2x1x2x4xf16> to vector<1x2x1x2x4xf16>
+//     CHECK-DAG:       %[[RHS_MM:.+]] = vector.transfer_read {{.*}} vector<2x4x2x1xf16>
+//     CHECK-DAG:       vector.transpose %[[LHS_MM1]], [0, 1, 3, 2, 4] : vector<1x2x1x2x4xf16> to vector<1x2x2x1x4xf16>
+//     CHECK-DAG:       vector.transpose %[[RHS_MM]], [0, 2, 3, 1] : vector<2x4x2x1xf16> to vector<2x2x1x4xf16>
+// CHECK-COUNT-4:       amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32
+//         CHECK:     %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 1, 3, 2, 4] : vector<1x2x2x4x1xf32> to vector<1x2x4x2x1xf32>
+//         CHECK:     %[[EXTRACT:.+]] = vector.extract %[[LOOP_T]][0] : vector<2x4x2x1xf32> from vector<1x2x4x2x1xf32>
+//         CHECK:     vector.transfer_write %[[EXTRACT]], %[[B2]]
+//         CHECK:   } {mapping = [#iree_codegen.workgroup_mapping<z>, #iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
 
 // -----
 
@@ -225,7 +228,7 @@
   #hal.pipeline.binding<storage_buffer>
 ]>
 #config = #iree_gpu.lowering_config<{
-  workgroup = [1, 4, 16, 256, 0],
+  workgroup = [1, 4, 16, 16, 0],
   reduction = [0, 0, 0, 0, 2],
   subgroup = [1, 4, 1, 4, 0],
   mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
@@ -287,6 +290,7 @@
 //      CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
 //      CHECK-DAG:   %[[C720:.+]] = arith.constant 720 : index
 //      CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//         CHECK:   scf.forall ({{.*}}) in (2, 4, 1, 5) {
 //          CHECK:   %[[LOOP:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[C720]] step %[[C2]] {{.*}} -> (vector<1x4x1x4x4x1xf32>)
 //          CHECK:     gpu.barrier
 //      CHECK-DAG:     %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<8xf16>
@@ -303,6 +307,7 @@
 //          CHECK:   %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 1, 2, 4, 3, 5] : vector<1x4x1x4x4x1xf32> to vector<1x4x1x4x4x1xf32>
 //          CHECK:   %[[EXTRACT:.+]] = vector.extract %[[LOOP_T]][0] : vector<4x1x4x4x1xf32> from vector<1x4x1x4x4x1xf32>
 //          CHECK:   vector.transfer_write %[[EXTRACT]], %[[B2]]
+//         CHECK:   } {mapping = [#iree_codegen.workgroup_mapping<z:1>, #iree_codegen.workgroup_mapping<z>, #iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
 
 // -----
 
@@ -312,7 +317,7 @@
   #hal.pipeline.binding<storage_buffer>
 ]>
 #config = #iree_gpu.lowering_config<{
-  workgroup = [64, 64, 0],
+  workgroup = [4, 4, 0],
   reduction = [0, 0, 2],
   subgroup = [2, 2],
   mma_kind = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>,
@@ -353,21 +358,23 @@
 //   CHECK-DAG:   %[[B2:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2)
 //   CHECK-DAG:   memref.alloc() : memref<64x36xf16, #gpu.address_space<workgroup>>
 //   CHECK-DAG:   memref.alloc() : memref<64x36xf16, #gpu.address_space<workgroup>>
-//       CHECK:   %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c80 step %c2 {{.*}} -> (vector<2x2x8x1x1xf32>)
-//       CHECK:     gpu.barrier
-//   CHECK-DAG:     vector.transfer_read %[[B0]]{{.*}} vector<8xf16>
-//   CHECK-DAG:     vector.transfer_read %[[B0]]{{.*}} vector<8xf16>
-//   CHECK-DAG:     vector.transfer_read %[[B1]]{{.*}} vector<8xf16>
-//   CHECK-DAG:     vector.transfer_read %[[B1]]{{.*}} vector<8xf16>
-//       CHECK:     gpu.barrier
-//   CHECK-DAG:     vector.transfer_read {{.*}} vector<2x1x2x16xf16>
-//   CHECK-DAG:     vector.transfer_read {{.*}} vector<2x1x2x16xf16>
-//   CHECK-DAG:     vector.transpose %{{.*}}, [0, 2, 1, 3] : vector<2x1x2x16xf16>
-//   CHECK-DAG:     vector.transpose %{{.*}}, [0, 2, 1, 3] : vector<2x1x2x16xf16>
-// CHECK-COUNT-8:   amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<8xf32>
-//       CHECK:     scf.yield
-//       CHECK:   %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 2, 3, 1, 4] : vector<2x2x8x1x1xf32> to vector<2x8x1x2x1xf32>
-//       CHECK:   vector.transfer_write %[[LOOP_T]], %[[B2]]
+//       CHECK:   scf.forall ({{.*}}) in (32, 160) {
+//       CHECK:     %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c80 step %c2 {{.*}} -> (vector<2x2x8x1x1xf32>)
+//       CHECK:       gpu.barrier
+//   CHECK-DAG:       vector.transfer_read %[[B0]]{{.*}} vector<8xf16>
+//   CHECK-DAG:       vector.transfer_read %[[B0]]{{.*}} vector<8xf16>
+//   CHECK-DAG:       vector.transfer_read %[[B1]]{{.*}} vector<8xf16>
+//   CHECK-DAG:       vector.transfer_read %[[B1]]{{.*}} vector<8xf16>
+//       CHECK:       gpu.barrier
+//   CHECK-DAG:       vector.transfer_read {{.*}} vector<2x1x2x16xf16>
+//   CHECK-DAG:       vector.transfer_read {{.*}} vector<2x1x2x16xf16>
+//   CHECK-DAG:       vector.transpose %{{.*}}, [0, 2, 1, 3] : vector<2x1x2x16xf16>
+//   CHECK-DAG:       vector.transpose %{{.*}}, [0, 2, 1, 3] : vector<2x1x2x16xf16>
+// CHECK-COUNT-8:     amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<8xf32>
+//       CHECK:       scf.yield
+//       CHECK:     %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 2, 3, 1, 4] : vector<2x2x8x1x1xf32> to vector<2x8x1x2x1xf32>
+//       CHECK:     vector.transfer_write %[[LOOP_T]], %[[B2]]
+//       CHECK:   } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
 
 // -----
 
@@ -377,7 +384,7 @@
   #hal.pipeline.binding<storage_buffer>
 ]>
 #config = #iree_gpu.lowering_config<{
-  workgroup = [64, 64, 0],
+  workgroup = [4, 4, 0],
   reduction = [0, 0, 2],
   subgroup = [2, 2],
   mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>,
@@ -419,9 +426,11 @@
 // CHECK-LABEL: func @matmul_transpose_b_mfma_16x16x4
 //   CHECK-DAG:   memref.alloc() : memref<64x10xf32, #gpu.address_space<workgroup>>
 //   CHECK-DAG:   memref.alloc() : memref<64x10xf32, #gpu.address_space<workgroup>>
-//       CHECK:   scf.for %{{.*}} = %c0 to %c320 step %c2 {{.*}} -> (vector<2x2x4x1xf32>)
-// CHECK-COUNT-8:   amdgpu.mfma {{.*}}blocks = 1 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32
-//       CHECK:     scf.yield
+//       CHECK:   scf.forall ({{.*}}) in (32, 160) {
+//       CHECK:     scf.for %{{.*}} = %c0 to %c320 step %c2 {{.*}} -> (vector<2x2x4x1xf32>)
+// CHECK-COUNT-8:     amdgpu.mfma {{.*}}blocks = 1 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32
+//       CHECK:       scf.yield
+//       CHECK:   } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
 
 // -----
 
@@ -431,7 +440,7 @@
   #hal.pipeline.binding<storage_buffer>
 ]>
 #config = #iree_gpu.lowering_config<{
-  workgroup = [64, 64, 0],
+  workgroup = [4, 4, 0],
   reduction = [0, 0, 2],
   subgroup = [2, 2],
   mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ>,
@@ -473,9 +482,11 @@
 // CHECK-LABEL: func @matmul_transpose_b_mfma_16x16x32_f8
 //   CHECK-DAG:   memref.alloc() : memref<64x72xf8E4M3FNUZ, #gpu.address_space<workgroup>>
 //   CHECK-DAG:   memref.alloc() : memref<64x72xf8E4M3FNUZ, #gpu.address_space<workgroup>>
-//       CHECK:   scf.for %{{.*}} = %c0 to %c40 step %c2 {{.*}} -> (vector<2x2x4x1xf32>)
-// CHECK-COUNT-8:   amdgpu.mfma {{.*}}blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32
-//       CHECK:     scf.yield
+//       CHECK:   scf.forall ({{.*}}) in (32, 160) {
+//       CHECK:     scf.for %{{.*}} = %c0 to %c40 step %c2 {{.*}} -> (vector<2x2x4x1xf32>)
+// CHECK-COUNT-8:     amdgpu.mfma {{.*}}blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32
+//       CHECK:       scf.yield
+//       CHECK:   } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
 
 // -----
 
@@ -485,7 +496,7 @@
   #hal.pipeline.binding<storage_buffer>
 ]>
 #config = #iree_gpu.lowering_config<{
-  workgroup = [64, 64, 0],
+  workgroup = [2, 2, 0],
   reduction = [0, 0, 2],
   subgroup = [1, 1],
   mma_kind = #iree_gpu.mma_layout<MFMA_I32_32x32x16_I8>,
@@ -527,9 +538,11 @@
 // CHECK-LABEL: func @matmul_transpose_b_mfma_32x32x16_i8
 //   CHECK-DAG:   memref.alloc() : memref<64x40xi8, #gpu.address_space<workgroup>>
 //   CHECK-DAG:   memref.alloc() : memref<64x40xi8, #gpu.address_space<workgroup>>
-//       CHECK:   scf.for %{{.*}} = %c0 to %c80 step %c2 {{.*}} -> (vector<1x1x4x4x1xi32>)
-// CHECK-COUNT-2:   amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32
-//       CHECK:     scf.yield
+//       CHECK:   scf.forall ({{.*}}) in (32, 160) {
+//       CHECK:     scf.for %{{.*}} = %c0 to %c80 step %c2 {{.*}} -> (vector<1x1x4x4x1xi32>)
+// CHECK-COUNT-2:     amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32
+//       CHECK:       scf.yield
+//       CHECK:   } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
 
 // -----
 
@@ -539,7 +552,7 @@
   #hal.pipeline.binding<storage_buffer>
 ]>
 #config = #iree_gpu.lowering_config<{
-  workgroup = [64, 64, 0],
+  workgroup = [4, 4, 0],
   reduction = [0, 0, 2],
   subgroup = [2, 2],
   mma_kind = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F16>,
@@ -581,9 +594,11 @@
 // CHECK-LABEL: func @matmul_transpose_b_wmma_f16_16x16x16_f16
 //   CHECK-DAG:   memref.alloc() : memref<64x36xf16, #gpu.address_space<workgroup>>
 //   CHECK-DAG:   memref.alloc() : memref<64x36xf16, #gpu.address_space<workgroup>>
-//       CHECK:   scf.for %{{.*}} = %c0 to %c80 step %c2 {{.*}} -> (vector<2x2x16x1x1xf16>)
-// CHECK-COUNT-8:   amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<16xf16>
-//       CHECK:     scf.yield
+//       CHECK:   scf.forall ({{.*}}) in (32, 160) {
+//       CHECK:     scf.for %{{.*}} = %c0 to %c80 step %c2 {{.*}} -> (vector<2x2x16x1x1xf16>)
+// CHECK-COUNT-8:     amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<16xf16>
+//       CHECK:       scf.yield
+//       CHECK:   } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
 
 // -----
 
@@ -639,12 +654,14 @@
 // the producer's (convolution's) distributed scf.forall loop.
 // CHECK-LABEL: func @conv_nchw_fused
 //       CHECK:   %[[ALLOCA:.+]] = memref.alloca() : memref<1x1x1x1xf32, #gpu.address_space<private>>
-//       CHECK:   scf.for %{{.*}} = %c0 to %c64 step %c1
-//       CHECK:     linalg.conv_2d_nchw_fchw
-//  CHECK-SAME:       outs(%[[ALLOCA]] : memref<1x1x1x1xf32, #gpu.address_space<private>>)
-//       CHECK:   arith.addf
-//       CHECK:   arith.cmpf
-//       CHECK:   arith.select
+//       CHECK:   scf.forall ({{.*}}) in (64, 14, 7) {
+//       CHECK:     scf.for %{{.*}} = %c0 to %c64 step %c1
+//       CHECK:       linalg.conv_2d_nchw_fchw
+//  CHECK-SAME:         outs(%[[ALLOCA]] : memref<1x1x1x1xf32, #gpu.address_space<private>>)
+//       CHECK:     arith.addf
+//       CHECK:     arith.cmpf
+//       CHECK:     arith.select
+//       CHECK:   } {mapping = [#iree_codegen.workgroup_mapping<z>, #iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
 
 // -----
 
@@ -715,11 +732,13 @@
 //       CHECK:   %[[LINID0:.+]] = affine.apply #[[$MAP]]()[%[[IDX]], %[[IDY]], %[[IDZ]]]
 //       CHECK:   %[[IDS:.+]]:2 = affine.delinearize_index %[[LINID0:.+]] into (%c4, %c8) : index, index
 //       CHECK:   %[[LINID1:.+]] = affine.apply #[[$MAP1]]()[%[[IDS]]#0, %[[IDS]]#1]
-//       CHECK:   scf.for %{{.*}} = %c0 to %c256 step %c4 {{.*}} -> (vector<1x4xf32>)
-//       CHECK:     scf.for %{{.*}} = %[[LINID1]] to %c4 step %c32
-//       CHECK:       %[[READ:.+]] = vector.transfer_read {{.*}} : memref<128x256xf32, {{.*}}storage_buffer>>, vector<4xf32>
-//       CHECK:       vector.transfer_write %[[READ]], %{{.*}} : vector<4xf32>, memref<4x6xf32, #gpu.address_space<workgroup>>
-//       CHECK:     vector.contract
+//       CHECK:   scf.forall ({{.*}}) in (32, 98) {
+//       CHECK:     scf.for %{{.*}} = %c0 to %c256 step %c4 {{.*}} -> (vector<1x4xf32>)
+//       CHECK:       scf.for %{{.*}} = %[[LINID1]] to %c4 step %c32
+//       CHECK:         %[[READ:.+]] = vector.transfer_read {{.*}} : memref<128x256xf32, {{.*}}storage_buffer>>, vector<4xf32>
+//       CHECK:         vector.transfer_write %[[READ]], %{{.*}} : vector<4xf32>, memref<4x6xf32, #gpu.address_space<workgroup>>
+//       CHECK:       vector.contract
+//       CHECK:   } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
 
 // -----
 
@@ -736,7 +755,7 @@
   mma_kind = #iree_gpu.mma_layout<WMMA_I32_16x16x16_I8>,
   reduction = [0, 0, 4],
   subgroup = [2, 4, 0],
-  workgroup = [64, 128, 0],
+  workgroup = [4, 8, 0],
   promote_operands = [0, 1]
 }>
 
@@ -1012,7 +1031,6 @@
 //   CHECK-DAG:   %[[RHS_ALLOC:.+]] = memref.alloc() : memref<4x130xf32, #gpu.address_space<workgroup>>
 //       CHECK:   %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c1000 step %c4 {{.*}} -> (vector<1x4xf32>)
 //       CHECK:     gpu.barrier
-
 //       CHECK:     scf.for %{{.*}} = %{{.*}} to %c1 step %c32
 //       CHECK:       %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<4xf32>
 //  CHECK-NEXT:       vector.transfer_write %[[LHS_RD]], %[[LHS_ALLOC]]
@@ -1069,6 +1087,7 @@
 
 // Verify that the write does not get hoisted out of the single threaded
 // for loop.
-//       CHECK:     vector.transfer_write %{{.*}}, %[[B2]]{{.*}} memref<10x1xf32, #hal.descriptor_type<storage_buffer>>
-//  CHECK-NEXT:   }
+//       CHECK:       vector.transfer_write %{{.*}}, %[[B2]]{{.*}} memref<10x1xf32, #hal.descriptor_type<storage_buffer>>
+//  CHECK-NEXT:     }
+//  CHECK-NEXT:   } {mapping = [#iree_codegen.workgroup_mapping<x>]}
 //  CHECK-NEXT:   return
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir
index 5cc0b70..d57d163 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir
@@ -248,11 +248,11 @@
 //       CHECK:         %[[LHS_VIEW:.+]] = memref.subview %[[LHS_ALLOC]][%[[IV_Z]], %[[IV_Y]], 0] [1, 16, 32]
 //       CHECK:         scf.for %[[IV_X:.+]] = %[[OFFSET_X]] to %[[C32]] step %[[C32]] {
 //       CHECK:           %[[RHS_VIEW:.+]] = memref.subview %[[RHS_ALLOC]][%[[IV_Z]], 0, %[[IV_X]]] [1, 32, 16]
-//     CHECK-DAG:         %[[READ0:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C0]], %[[C0]]]
-//     CHECK-DAG:         %[[READ1:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C0]], %[[C16]]]
-//     CHECK-DAG:         %[[READ2:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C0]], %[[C0]], %[[C0]]]
-//     CHECK-DAG:         %[[READ3:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C0]], %[[C16]], %[[C0]]]
-//     CHECK-DAG:         %[[READ4:.+]] = vector.transfer_read %{{.+}}[%[[C0]], %[[C0]], %[[C0]]]
+//   CHECK-DAG:           %[[READ0:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C0]], %[[C0]]]
+//   CHECK-DAG:           %[[READ1:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C0]], %[[C16]]]
+//   CHECK-DAG:           %[[READ2:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C0]], %[[C0]], %[[C0]]]
+//   CHECK-DAG:           %[[READ3:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C0]], %[[C16]], %[[C0]]]
+//   CHECK-DAG:           %[[READ4:.+]] = vector.transfer_read %{{.+}}[%[[C0]], %[[C0]], %[[C0]]]
 //       CHECK:           %[[CT0:.+]] = vector.contract
 //  CHECK-SAME:             %[[READ0]], %[[READ2]], %[[READ4]] : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
 //       CHECK:           %[[CT1:.+]] = vector.contract