[LLVMGPU] Switch LLVMGPUVectorDistribute to use iree_gpu.lowering_config (#18651)

Most of the changes in the patch are tests. Actual changes:

- LLVMGPUVectorDistribution pipeline now uses iree_gpu.lowering_config
- GPUApplyTilingLevel does not fuse tensor.pad for reduction tiling
level
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp
index 000fa81..4b029d0 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp
@@ -135,6 +135,14 @@
             bool isDestinationOperand)
         -> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> {
       Operation *owner = originalProducer.getOwner();
+      if (tilingLevel == IREE::GPU::TilingLevel::Reduction ||
+          tilingLevel == IREE::GPU::TilingLevel::Subgroup) {
+        // Do not fuse pad in reduction and subgroup tiling.
+        if (isa<tensor::PadOp>(owner)) {
+          return std::nullopt;
+        }
+      }
+
       bool yieldProducerReplacement = yieldReplacementsFor.contains(owner);
       bool shouldFuse = false;
       if (auto tilingOwner = dyn_cast<TilingInterface>(owner)) {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 224ef19..5599f47 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -115,6 +115,16 @@
   return target.getArch().starts_with("gfx");
 }
 
+static bool needsLoweringConfigPropagation(
+    IREE::Codegen::DispatchLoweringPassPipeline pipeline) {
+  using Pipeline = IREE::Codegen::DispatchLoweringPassPipeline;
+  // Pipelines that do not need propagation of lowering config.
+  Pipeline supportedPipelines[] = {Pipeline::LLVMGPUTileAndFuse,
+                                   Pipeline::LLVMGPUVectorDistribute,
+                                   Pipeline::LLVMGPUPadAndVectorDistribute};
+  return !llvm::is_contained(supportedPipelines, pipeline);
+}
+
 //====---------------------------------------------------------------------===//
 // Matmul Configuration Helpers
 //====---------------------------------------------------------------------===//
@@ -339,6 +349,7 @@
       schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount, 1};
 
   SmallVector<int64_t> workgroupTileSizes(op.getNumLoops(), 0);
+  SmallVector<int64_t> reductionTileSizes(op.getNumLoops(), 0);
   // Tile all batch dimensions with unit size.
   for (int64_t batch : convolutionDims->batch) {
     workgroupTileSizes[batch] = 1;
@@ -351,7 +362,7 @@
     workgroupTileSizes[oc] = 1;
   }
   for (int64_t ic : llvm::drop_end(convolutionDims->inputChannel)) {
-    workgroupTileSizes[ic] = 1;
+    reductionTileSizes[ic] = 1;
   }
   // Compute the M/N dimension tile size by multiply subgroup information.
   workgroupTileSizes[mDim] =
@@ -359,25 +370,32 @@
   workgroupTileSizes[nDim] =
       schedule->nWarpCount * schedule->nTileCount * schedule->nSize;
 
-  // Follow the LLVMGPU convention of keeping all of the tile sizes in one list.
-  workgroupTileSizes[kDim] = schedule->kTileCount * schedule->kSize;
+  reductionTileSizes[kDim] = schedule->kTileCount * schedule->kSize;
 
   // Tile all filter loop dimensions to 1.
   for (int64_t filterDim : convolutionDims->filterLoop) {
-    workgroupTileSizes[filterDim] = 1;
+    reductionTileSizes[filterDim] = 1;
   }
 
-  TileSizesListType tileSizes;
-  tileSizes.push_back(workgroupTileSizes);
+  MLIRContext *context = op.getContext();
+  Builder b(context);
+  SmallVector<NamedAttribute, 2> attrs;
+  attrs.emplace_back(StringAttr::get(context, "workgroup"),
+                     b.getI64ArrayAttr(workgroupTileSizes));
+  attrs.emplace_back(StringAttr::get(context, "reduction"),
+                     b.getI64ArrayAttr(reductionTileSizes));
+
+  auto configDict = DictionaryAttr::get(context, attrs);
+  auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);
 
   // Attach the MMA schedule as an attribute to the entry point export function
   // for later access in the pipeline.
-  MLIRContext *context = op.getContext();
+  SmallVector<NamedAttribute, 1> pipelineAttrs;
   auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get(
       context, target.getWgp().getMma()[schedule->index], schedule->mWarpCount,
       schedule->nWarpCount);
-  SmallVector<NamedAttribute, 1> attrs;
-  attrs.emplace_back(StringAttr::get(context, "mma_schedule"), scheduleAttr);
+  pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"),
+                             scheduleAttr);
 
   // Prefetch shared memory if requested.
   if (clLLVMGPUEnablePrefetch) {
@@ -385,17 +403,17 @@
         context, /*prefetchSharedMemory=*/true,
         /*no_reduce_shared_memory_bank_conflicts=*/false,
         /*reorder_workgroups_strategy=*/std::nullopt);
-    attrs.emplace_back(
+    pipelineAttrs.emplace_back(
         StringAttr::get(context,
                         IREE::GPU::GPUPipelineOptionsAttr::getDictKeyName()),
         pipelineOptions);
   }
 
-  auto configDict = DictionaryAttr::get(context, attrs);
+  auto pipelineConfig = DictionaryAttr::get(context, pipelineAttrs);
 
   return setOpConfigAndEntryPointFnTranslation(
-      entryPoint, op, tileSizes, CodeGenPipeline::LLVMGPUVectorDistribute,
-      workgroupSize, targetSubgroupSize, configDict);
+      entryPoint, op, loweringConfig, CodeGenPipeline::LLVMGPUVectorDistribute,
+      workgroupSize, targetSubgroupSize, pipelineConfig);
 }
 
 [[maybe_unused]] static void
@@ -573,6 +591,7 @@
       schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount, 1};
 
   SmallVector<int64_t> workgroupTileSizes(op.getNumLoops(), 0);
+  SmallVector<int64_t> reductionTileSizes(op.getNumLoops(), 0);
   // Tile all batch dimensions with unit size.
   for (int64_t batch : contractionDims->batch) {
     workgroupTileSizes[batch] = 1;
@@ -587,7 +606,7 @@
     workgroupTileSizes[n] = 1;
   }
   for (int64_t k : llvm::drop_end(contractionDims->k)) {
-    workgroupTileSizes[k] = 1;
+    reductionTileSizes[k] = 1;
   }
 
   // Compute the M/N dimension tile size by multiply subgroup information.
@@ -596,23 +615,32 @@
   workgroupTileSizes[nDim] =
       schedule->nWarpCount * schedule->nTileCount * schedule->nSize;
 
-  // Follow the LLVMGPU convention of keeping all of the tile sizes in one list.
-  workgroupTileSizes[kDim] = schedule->kTileCount * schedule->kSize;
+  reductionTileSizes[kDim] = schedule->kTileCount * schedule->kSize;
 
   LLVM_DEBUG(debugPrintContractionInfo("Workgroup tile sizes", op.getNumLoops(),
                                        *contractionDims, workgroupTileSizes));
+  LLVM_DEBUG(debugPrintContractionInfo("Reduction tile sizes", op.getNumLoops(),
+                                       *contractionDims, reductionTileSizes));
 
-  TileSizesListType tileSizes;
-  tileSizes.push_back(workgroupTileSizes);
+  MLIRContext *context = op.getContext();
+  Builder b(context);
+  SmallVector<NamedAttribute, 2> attrs;
+  attrs.emplace_back(StringAttr::get(context, "workgroup"),
+                     b.getI64ArrayAttr(workgroupTileSizes));
+  attrs.emplace_back(StringAttr::get(context, "reduction"),
+                     b.getI64ArrayAttr(reductionTileSizes));
+
+  auto configDict = DictionaryAttr::get(context, attrs);
+  auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);
 
   // Attach the MMA schedule as an attribute to the entry point export function
   // for later access in the pipeline.
-  MLIRContext *context = op.getContext();
+  SmallVector<NamedAttribute, 1> pipelineAttrs;
   auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get(
       context, target.getWgp().getMma()[schedule->index], schedule->mWarpCount,
       schedule->nWarpCount);
-  SmallVector<NamedAttribute, 1> attrs;
-  attrs.emplace_back(StringAttr::get(context, "mma_schedule"), scheduleAttr);
+  pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"),
+                             scheduleAttr);
 
   // Prefetch shared memory if requested.
   if (clLLVMGPUEnablePrefetch) {
@@ -620,17 +648,17 @@
         context, /*prefetchSharedMemory=*/true,
         /*no_reduce_shared_memory_bank_conflicts=*/false,
         /*reorder_workgroups_strategy=*/std::nullopt);
-    attrs.emplace_back(
+    pipelineAttrs.emplace_back(
         StringAttr::get(context,
                         IREE::GPU::GPUPipelineOptionsAttr::getDictKeyName()),
         pipelineOptions);
   }
 
-  auto configDict = DictionaryAttr::get(context, attrs);
+  auto pipelineConfig = DictionaryAttr::get(context, pipelineAttrs);
 
-  return setOpConfigAndEntryPointFnTranslation(entryPoint, op, tileSizes,
-                                               pipeline, workgroupSize,
-                                               targetSubgroupSize, configDict);
+  return setOpConfigAndEntryPointFnTranslation(
+      entryPoint, op, loweringConfig, pipeline, workgroupSize,
+      targetSubgroupSize, pipelineConfig);
 }
 
 static LogicalResult
@@ -712,8 +740,6 @@
 
   LDBG("Attention Vector Distribution Config");
 
-  auto pipeline = CodeGenPipeline::LLVMGPUVectorDistribute;
-
   // Infer if Q, K and V are transposed to help generate better schedule.
   bool transposedQ =
       k1Dim != llvm::cast<AffineDimExpr>(op.getQueryMap().getResults().back())
@@ -765,6 +791,7 @@
       schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount, 1};
 
   SmallVector<int64_t> workgroupTileSizes(opInfo.getDomainRank(), 0);
+  SmallVector<int64_t> reductionTileSizes(op.getNumLoops(), 0);
   // Tile all batch dimensions with unit size.
   for (int64_t batch : opInfo.getBatchDims()) {
     workgroupTileSizes[batch] = 1;
@@ -780,7 +807,7 @@
     workgroupTileSizes[n] = 1;
   }
   for (int64_t k2 : llvm::drop_end(opInfo.getK2Dims())) {
-    workgroupTileSizes[k2] = 1;
+    reductionTileSizes[k2] = 1;
   }
 
   // Compute the M/N dimension tile size by multiply subgroup information.
@@ -789,29 +816,36 @@
   workgroupTileSizes[nDim] =
       schedule->nWarpCount * schedule->nTileCount * schedule->nSize;
 
-  // Follow the LLVMGPU convention of keeping all of the tile sizes in one list.
-  workgroupTileSizes[k2Dim] = schedule->kTileCount * schedule->kSize;
+  reductionTileSizes[k2Dim] = schedule->kTileCount * schedule->kSize;
 
-  TileSizesListType tileSizes;
-  tileSizes.push_back(workgroupTileSizes);
+  MLIRContext *context = op.getContext();
+  SmallVector<NamedAttribute, 2> attrs;
+  attrs.emplace_back(StringAttr::get(context, "workgroup"),
+                     b.getI64ArrayAttr(workgroupTileSizes));
+  attrs.emplace_back(StringAttr::get(context, "reduction"),
+                     b.getI64ArrayAttr(reductionTileSizes));
+
+  auto configDict = DictionaryAttr::get(context, attrs);
+  auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);
 
   // Attach the MMA schedule as an attribute to the entry point export function
   // for later access in the pipeline.
-  MLIRContext *context = op.getContext();
+  SmallVector<NamedAttribute, 1> pipelineAttrs;
   auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get(
       context, target.getWgp().getMma()[schedule->index], schedule->mWarpCount,
       schedule->nWarpCount);
-  SmallVector<NamedAttribute, 1> attrs;
-  attrs.emplace_back(StringAttr::get(context, "mma_schedule"), scheduleAttr);
-  auto configDict = DictionaryAttr::get(context, attrs);
+  pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"),
+                             scheduleAttr);
 
   // TODO: We do not turn prefetching on even when requested by the prefetching
   // flag because there is a shared memory allocation the two matmuls, which
   // the prefetching pass cannot understand.
 
-  return setOpConfigAndEntryPointFnTranslation(entryPoint, op, tileSizes,
-                                               pipeline, workgroupSize,
-                                               targetSubgroupSize, configDict);
+  auto pipelineConfig = DictionaryAttr::get(context, pipelineAttrs);
+
+  return setOpConfigAndEntryPointFnTranslation(
+      entryPoint, op, loweringConfig, CodeGenPipeline::LLVMGPUVectorDistribute,
+      workgroupSize, targetSubgroupSize, pipelineConfig);
 }
 
 static LogicalResult
@@ -2108,10 +2142,9 @@
   SmallVector<Operation *> computeOps = getComputeOps(funcOp);
   if (IREE::Codegen::TranslationInfoAttr translationInfo =
           getTranslationInfo(funcOp)) {
-    // Currently ROCDL requires propagation of user lowering configs for
-    // all pipelines except TileAndFuse.
-    if (translationInfo.getDispatchLoweringPassPipeline() !=
-        IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTileAndFuse) {
+    // Currently some ROCDL requires propagation of user lowering configs.
+    if (needsLoweringConfigPropagation(
+            translationInfo.getDispatchLoweringPassPipeline())) {
       for (auto op : computeOps) {
         if (getLoweringConfig(op)) {
           propagateLoweringConfig(op, computeOps);
@@ -2165,10 +2198,9 @@
 
   if (IREE::Codegen::TranslationInfoAttr translationInfo =
           getTranslationInfo(funcOp)) {
-    // Currently ROCDL requires propagation of user lowering configs for
-    // all pipelines except TileAndFuse.
-    if (translationInfo.getDispatchLoweringPassPipeline() ==
-        IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTileAndFuse) {
+    // Currently some ROCDL requires propagation of user lowering configs.
+    if (!needsLoweringConfigPropagation(
+            translationInfo.getDispatchLoweringPassPipeline())) {
       return success();
     }
   }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index a2320c9..c676abd 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -816,12 +816,14 @@
     funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass(option));
   }
 
-  // Problem specific (reduction) tiling.
+  // Tile to reduction loops.
   {
-    GPUTensorTileToSerialLoopsPassOptions tensorTileToSerialLoopsPassOptions;
-    tensorTileToSerialLoopsPassOptions.coalesceLoops = true;
-    funcPassManager.addPass(createGPUTensorTileToSerialLoopsPass(
-        tensorTileToSerialLoopsPassOptions));
+    GPUApplyTilingLevelPassOptions options;
+    options.tilingLevel = IREE::GPU::TilingLevel::Reduction;
+    funcPassManager.addPass(createGPUApplyTilingLevelPass(options));
+    funcPassManager.addPass(affine::createLoopCoalescingPass());
+    funcPassManager.addPass(createCanonicalizerPass());
+    funcPassManager.addPass(createCSEPass());
   }
 
   if (usePadToModelSharedMemcpy) {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel
index 5789cef..843befa 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel
@@ -20,7 +20,8 @@
         [
             "annotate_kernel_for_translation.mlir",
             "config_tile_and_fuse.mlir",
-            "config_vector_distribute.mlir",
+            "config_vector_distribute_gfx1100.mlir",
+            "config_vector_distribute_gfx940.mlir",
             "config_user_vector_distribute.mlir",
             "lowering_scalar_dispatch.mlir",
             "pipeline_tile_and_fuse.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt
index e843564..b9b45b0 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt
@@ -17,7 +17,8 @@
     "annotate_kernel_for_translation.mlir"
     "config_tile_and_fuse.mlir"
     "config_user_vector_distribute.mlir"
-    "config_vector_distribute.mlir"
+    "config_vector_distribute_gfx1100.mlir"
+    "config_vector_distribute_gfx940.mlir"
     "lowering_scalar_dispatch.mlir"
     "pipeline_tile_and_fuse.mlir"
     "pipeline_vector_distribute_gfx1100.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_user_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_user_vector_distribute.mlir
index e5a7bf1..44eeb64 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_user_vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_user_vector_distribute.mlir
@@ -16,6 +16,7 @@
 // OPT-IN:       #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64
 // OPT-IN-SAME:    gpu_pipeline_options = #iree_gpu.pipeline_options<no_reduce_shared_memory_bank_conflicts = true>
 // OPT-IN-SAME:    mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
+#config = #iree_gpu.lowering_config<{workgroup = [128, 128, 0], reduction = [0, 0, 32]}>
 #pipeline_layout = #hal.pipeline.layout<bindings = [
   #hal.pipeline.binding<storage_buffer>,
   #hal.pipeline.binding<storage_buffer>,
@@ -58,11 +59,12 @@
         %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2048, 1280], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2048x1280xf16>> -> tensor<2048x1280xf16>
         %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [10240, 1280], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<10240x1280xf16>> -> tensor<10240x1280xf16>
         %5 = tensor.empty() : tensor<2048x10240xf32>
-        %6 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[128, 128, 32]]>} ins(%cst : f16) outs(%5 : tensor<2048x10240xf32>) -> tensor<2048x10240xf32>
+        %6 = linalg.fill  ins(%cst : f16) outs(%5 : tensor<2048x10240xf32>) -> tensor<2048x10240xf32>
         %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
                                               affine_map<(d0, d1, d2) -> (d1, d2)>,
                                               affine_map<(d0, d1, d2) -> (d0, d1)>],
-                             iterator_types = ["parallel", "parallel", "reduction"]}
+                             iterator_types = ["parallel", "parallel", "reduction"],
+                             lowering_config = #config}
           ins(%3, %4 : tensor<2048x1280xf16>, tensor<10240x1280xf16>)
           outs(%6 : tensor<2048x10240xf32>) {
         ^bb0(%in: f16, %in_0: f16, %out: f32):
@@ -90,6 +92,7 @@
 // OPT-IN:       #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64
 // OPT-IN-SAME:    gpu_pipeline_options = #iree_gpu.pipeline_options<reorder_workgroups_strategy = Transpose>
 // OPT-IN-SAME:    mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
+#config = #iree_gpu.lowering_config<{workgroup = [128, 128, 0], reduction = [0, 0, 32]}>
 #pipeline_layout = #hal.pipeline.layout<bindings = [
   #hal.pipeline.binding<storage_buffer>,
   #hal.pipeline.binding<storage_buffer>,
@@ -133,11 +136,12 @@
         %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2048, 1280], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2048x1280xf16>> -> tensor<2048x1280xf16>
         %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [10240, 1280], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<10240x1280xf16>> -> tensor<10240x1280xf16>
         %5 = tensor.empty() : tensor<2048x10240xf32>
-        %6 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[128, 128, 32]]>} ins(%cst : f16) outs(%5 : tensor<2048x10240xf32>) -> tensor<2048x10240xf32>
+        %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2048x10240xf32>) -> tensor<2048x10240xf32>
         %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
                                               affine_map<(d0, d1, d2) -> (d1, d2)>,
                                               affine_map<(d0, d1, d2) -> (d0, d1)>],
-                             iterator_types = ["parallel", "parallel", "reduction"]}
+                             iterator_types = ["parallel", "parallel", "reduction"],
+                             lowering_config = #config}
           ins(%3, %4 : tensor<2048x1280xf16>, tensor<10240x1280xf16>)
           outs(%6 : tensor<2048x10240xf32>) {
         ^bb0(%in: f16, %in_0: f16, %out: f32):
@@ -160,6 +164,7 @@
 // OPT-OUT:       #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64
 // OPT-OUT-SAME:    gpu_pipeline_options = #iree_gpu.pipeline_options<reorder_workgroups_strategy = None>
 // OPT-OUT-SAME:    mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
+#config = #iree_gpu.lowering_config<{workgroup = [128, 128, 0], reduction = [0, 0, 32]}>
 #pipeline_layout = #hal.pipeline.layout<bindings = [
   #hal.pipeline.binding<storage_buffer>,
   #hal.pipeline.binding<storage_buffer>,
@@ -192,11 +197,12 @@
         %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2048, 1280], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2048x1280xf16>> -> tensor<2048x1280xf16>
         %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [10240, 1280], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<10240x1280xf16>> -> tensor<10240x1280xf16>
         %5 = tensor.empty() : tensor<2048x10240xf32>
-        %6 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[128, 128, 32]]>} ins(%cst : f16) outs(%5 : tensor<2048x10240xf32>) -> tensor<2048x10240xf32>
+        %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2048x10240xf32>) -> tensor<2048x10240xf32>
         %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
                                               affine_map<(d0, d1, d2) -> (d1, d2)>,
                                               affine_map<(d0, d1, d2) -> (d0, d1)>],
-                             iterator_types = ["parallel", "parallel", "reduction"]}
+                             iterator_types = ["parallel", "parallel", "reduction"],
+                             lowering_config = #config}
           ins(%3, %4 : tensor<2048x1280xf16>, tensor<10240x1280xf16>)
           outs(%6 : tensor<2048x10240xf32>) {
         ^bb0(%in: f16, %in_0: f16, %out: f32):
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx1100.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx1100.mlir
new file mode 100644
index 0000000..9d45ea0
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx1100.mlir
@@ -0,0 +1,36 @@
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx1100 --iree-codegen-llvmgpu-use-vector-distribution \
+// RUN:   --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s --check-prefix=WMMA
+
+// TODO: This test is still using the legacy LLVMGPU kernel config. This needs
+// to be migrated to the rocdl heuristics, but for now is just physically
+// located here.
+
+// WMMA:      #iree_codegen.translation_info<LLVMGPUVectorDistribute
+// WMMA-SAME: mma_schedule = #iree_gpu.mma_schedule
+// WMMA-SAME:   intrinsic = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>
+// WMMA-SAME:   subgroup_m_count = 2, subgroup_n_count = 2
+
+#pipeline_layout = #hal.pipeline.layout<bindings = [
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>
+]>
+func.func @wmma_matmul_1024x1024x1024() {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1024x1024xf16>>
+  %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1024x1024xf16>>
+  %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<1024x1024xf32>>
+  %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1024, 1024], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1024x1024xf16>> -> tensor<1024x1024xf16>
+  %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1024, 1024], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1024x1024xf16>> -> tensor<1024x1024xf16>
+  %5 = tensor.empty() : tensor<1024x1024xf32>
+  %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
+  %7 = linalg.matmul ins(%3, %4 : tensor<1024x1024xf16>, tensor<1024x1024xf16>) outs(%6 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
+  flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [1024, 1024], strides = [1, 1] : tensor<1024x1024xf32> -> !flow.dispatch.tensor<writeonly:tensor<1024x1024xf32>>
+  return
+}
+
+// WMMA-LABEL: func.func @wmma_matmul_1024x1024x1024()
+// WMMA: linalg.matmul {{.*}}lowering_config = #iree_gpu.lowering_config
+// WMMA-SAME:                           reduction =  [0, 0, 64]
+// WMMA-SAME:                           workgroup =  [64, 128, 0]
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx940.mlir
similarity index 87%
rename from compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir
rename to compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx940.mlir
index 227ae1e..c7f5b44 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx940.mlir
@@ -1,13 +1,10 @@
 // RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx940 --iree-codegen-llvmgpu-use-vector-distribution \
 // RUN:   --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s
-// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx1100 --iree-codegen-llvmgpu-use-vector-distribution \
-// RUN:   --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s --check-prefix=WMMA
 
 // TODO: This test is still using the legacy LLVMGPU kernel config. This needs
 // to be migrated to the rocdl heuristics, but for now is just physically
 // located here.
 
-// CHECK:      #[[$TILE_SIZES:.+]] = #iree_codegen.lowering_config<tile_sizes =  {{\[}}[1, 1, 64, 64, 128]{{\]}}
 // CHECK:      #iree_codegen.translation_info<LLVMGPUVectorDistribute
 // CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule
 // CHECK-SAME:   intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
@@ -42,11 +39,12 @@
 }
 
 // CHECK-LABEL: func.func @expanded_matmul_transpose_b()
-// CHECK: linalg.generic {{.*}}lowering_config = #[[$TILE_SIZES]]
+// CHECK: linalg.generic {{.*}}lowering_config =  #iree_gpu.lowering_config
+// CHECK-SAME:                           reduction =  [0, 0, 0, 0, 128]
+// CHECK-SAME:                           workgroup =  [1, 1, 64, 64, 0]
 
 // -----
 
-// CHECK:      #[[$TILE_SIZES:.+]] = #iree_codegen.lowering_config<tile_sizes =  {{\[}}[1, 1, 64, 128, 1, 1, 32]{{\]}}
 // CHECK:      #iree_codegen.translation_info<LLVMGPUVectorDistribute
 // CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule
 // CHECK-SAME:   intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
@@ -73,7 +71,9 @@
 }
 
 // CHECK-LABEL: func.func @conv_nhwc()
-// CHECK: linalg.conv_2d_nhwc_hwcf {{.*}} lowering_config = #[[$TILE_SIZES]]
+// CHECK: linalg.conv_2d_nhwc_hwcf {{.*}} lowering_config = #iree_gpu.lowering_config
+// CHECK-SAME:                           reduction =  [0, 0, 0, 0, 1, 1, 32]
+// CHECK-SAME:                           workgroup =  [1, 1, 64, 128, 0, 0, 0]
 
 // -----
 
@@ -111,7 +111,6 @@
 
 // -----
 
-// CHECK:      #[[$TILE_SIZES:.+]] = #iree_codegen.lowering_config<tile_sizes =  {{\[}}[64, 128, 64]{{\]}}
 // CHECK:      #iree_codegen.translation_info<LLVMGPUVectorDistribute
 // CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule
 // CHECK-SAME:   intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
@@ -138,11 +137,12 @@
 }
 
 // CHECK-LABEL: func.func @mfma_matmul_1024x1024x1024()
-// CHECK: linalg.matmul {{.*}}lowering_config = #[[$TILE_SIZES]]
+// CHECK: linalg.matmul {{.*}}lowering_config = #iree_gpu.lowering_config
+// CHECK-SAME:                           reduction =  [0, 0, 64]
+// CHECK-SAME:                           workgroup =  [64, 128, 0]
 
 // -----
 
-// CHECK:      #[[$TILE_SIZES:.+]] = #iree_codegen.lowering_config<tile_sizes =  {{\[}}[1, 1, 1, 32, 32, 1, 1, 1, 32]{{\]}}
 // CHECK:      #iree_codegen.translation_info<LLVMGPUVectorDistribute
 // CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule
 // CHECK-SAME:   intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
@@ -153,7 +153,8 @@
   #hal.pipeline.binding<storage_buffer>,
   #hal.pipeline.binding<storage_buffer>
 ]>
-#config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 1, 32, 0, 1, 1, 1, 0]]>
+#config = #iree_gpu.lowering_config<{reduction = [0, 0, 0, 0, 0, 1, 1, 1, 32],
+                                     workgroup = [1, 1, 1, 32, 32, 0, 0, 0, 0]}>
 #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d8)>
 #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d4, d8)>
 #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
@@ -187,42 +188,12 @@
 }
 
 // CHECK-LABEL: func.func @conv_nchwc()
-// CHECK: linalg.generic {{.*}}lowering_config = #[[$TILE_SIZES]]
+// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
+// CHECK-SAME:                           reduction =  [0, 0, 0, 0, 0, 1, 1, 1, 32]
+// CHECK-SAME:                           workgroup =  [1, 1, 1, 32, 32, 0, 0, 0, 0]
 
 // -----
 
-// WMMA:      #[[$TILE_SIZES:.+]] = #iree_codegen.lowering_config<tile_sizes =  {{\[}}[64, 128, 64]{{\]}}
-// WMMA:      #iree_codegen.translation_info<LLVMGPUVectorDistribute
-// WMMA-SAME: mma_schedule = #iree_gpu.mma_schedule
-// WMMA-SAME:   intrinsic = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>
-// WMMA-SAME:   subgroup_m_count = 2, subgroup_n_count = 2
-
-#pipeline_layout = #hal.pipeline.layout<bindings = [
-  #hal.pipeline.binding<storage_buffer>,
-  #hal.pipeline.binding<storage_buffer>,
-  #hal.pipeline.binding<storage_buffer>
-]>
-func.func @wmma_matmul_1024x1024x1024() {
-  %cst = arith.constant 0.000000e+00 : f32
-  %c0 = arith.constant 0 : index
-  %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1024x1024xf16>>
-  %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1024x1024xf16>>
-  %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<1024x1024xf32>>
-  %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1024, 1024], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1024x1024xf16>> -> tensor<1024x1024xf16>
-  %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1024, 1024], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1024x1024xf16>> -> tensor<1024x1024xf16>
-  %5 = tensor.empty() : tensor<1024x1024xf32>
-  %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
-  %7 = linalg.matmul ins(%3, %4 : tensor<1024x1024xf16>, tensor<1024x1024xf16>) outs(%6 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
-  flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [1024, 1024], strides = [1, 1] : tensor<1024x1024xf32> -> !flow.dispatch.tensor<writeonly:tensor<1024x1024xf32>>
-  return
-}
-
-// WMMA-LABEL: func.func @wmma_matmul_1024x1024x1024()
-// WMMA: linalg.matmul {{.*}}lowering_config = #[[$TILE_SIZES]]
-
-// -----
-
-// CHECK:      #[[$TILE_SIZES:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 16, 16, 16]{{\]}}
 // CHECK:      #iree_codegen.translation_info<LLVMGPUPadAndVectorDistribute
 // CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule
 // CHECK-SAME:   intrinsic =  #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
@@ -249,11 +220,12 @@
 }
 // CHECK-LABEL: func.func @unaligned_mk_batch_matmul()
 // CHECK:         linalg.batch_matmul
-// CHECK-SAME:      lowering_config = #[[$TILE_SIZES]]
+// CHECK-SAME:      lowering_config = #iree_gpu.lowering_config
+// CHECK-SAME:                           reduction =  [0, 0, 0, 16]
+// CHECK-SAME:                           workgroup =  [1, 16, 16, 0]
 
 // -----
 
-// CHECK:      #[[$TILE_SIZES:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 16, 128, 128]{{\]}}
 // CHECK:      #iree_codegen.translation_info<LLVMGPUPadAndVectorDistribute
 // CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule
 // CHECK-SAME:   intrinsic =  #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
@@ -280,7 +252,9 @@
 }
 // CHECK-LABEL: func.func @unaligned_m_batch_matmul_64x72x1280x1280()
 // CHECK:         linalg.batch_matmul
-// CHECK-SAME:      lowering_config = #[[$TILE_SIZES]]
+// CHECK-SAME:      lowering_config = #iree_gpu.lowering_config
+// CHECK-SAME:                           reduction =  [0, 0, 0, 128]
+// CHECK-SAME:                           workgroup =  [1, 16, 128, 0]
 
 // -----
 
@@ -342,7 +316,6 @@
 
 // -----
 
-// CHECK:       #[[$TILE_SIZES:.+]] = #iree_codegen.lowering_config<tile_sizes =  {{\[}}[1, 64, 0, 64, 64]{{\]}}
 // CHECK:       #iree_codegen.translation_info<LLVMGPUVectorDistribute
 // CHECK-SAME:  subgroup_m_count = 2, subgroup_n_count = 1
 // CHECK-NOT:   prefetch_shared_memory = true
@@ -376,9 +349,12 @@
   return
 }
 
+// CHECK:                #iree_gpu.lowering_config
+// CHECK-SAME:                           reduction =  [0, 0, 0, 64, 0]
+// CHECK-SAME:                           workgroup =  [1, 64, 0, 0, 64]
+
 // -----
 
-// CHECK:       #[[$TILE_SIZES:.+]] = #iree_codegen.lowering_config<tile_sizes =  {{\[}}[32, 0, 16, 32]{{\]}}
 // CHECK:       #iree_codegen.translation_info<LLVMGPUVectorDistribute
 // CHECK-SAME:  subgroup_m_count = 2, subgroup_n_count = 1
 // CHECK-NOT:   prefetch_shared_memory = true
@@ -414,3 +390,7 @@
   flow.dispatch.tensor.store %8, %3, offsets = [0, 0], sizes = [1024, 512], strides = [1, 1] : tensor<1024x512xf16> -> !flow.dispatch.tensor<writeonly:tensor<1024x512xf16>>
   return
 }
+
+// CHECK:                #iree_gpu.lowering_config
+// CHECK-SAME:                           reduction =  [0, 0, 16, 0]
+// CHECK-SAME:                           workgroup =  [32, 0, 0, 32]
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx1100.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx1100.mlir
index 505c7b6..58bfb00 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx1100.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx1100.mlir
@@ -3,7 +3,7 @@
 // RUN:   --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-llvmgpu-lower-executable-target)))))" \
 // RUN:   %s | FileCheck %s
 
-#config = #iree_codegen.lowering_config<tile_sizes = [[64, 64, 128]]>
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [64, 2, 1] subgroup_size = 32, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2>}>
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
@@ -28,7 +28,7 @@
       %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf16>> -> tensor<256x256xf16>
       %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf16>> -> tensor<256x256xf16>
       %5 = tensor.empty() : tensor<256x256xf32>
-      %6 = linalg.fill {lowering_config = #config} ins(%cst : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32>
+      %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32>
       %7 = linalg.matmul {lowering_config = #config} ins(%3, %4 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%6 : tensor<256x256xf32>) -> tensor<256x256xf32>
       flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf32> -> !flow.dispatch.tensor<writeonly:tensor<256x256xf32>>
       return
@@ -50,7 +50,7 @@
 
 // -----
 
-#config = #iree_codegen.lowering_config<tile_sizes = [[64, 64, 128]]>
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [64, 2, 1] subgroup_size = 32, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<WMMA_F16_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2>}>
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
@@ -75,7 +75,7 @@
       %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf16>> -> tensor<256x256xf16>
       %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf16>> -> tensor<256x256xf16>
       %5 = tensor.empty() : tensor<256x256xf16>
-      %6 = linalg.fill {lowering_config = #config} ins(%cst : f16) outs(%5 : tensor<256x256xf16>) -> tensor<256x256xf16>
+      %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<256x256xf16>) -> tensor<256x256xf16>
       %7 = linalg.matmul {lowering_config = #config} ins(%3, %4 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%6 : tensor<256x256xf16>) -> tensor<256x256xf16>
       flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf16> -> !flow.dispatch.tensor<writeonly:tensor<256x256xf16>>
       return
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
index 6ebf8ab..6ba9ba8 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
@@ -8,7 +8,7 @@
 // RUN:   --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-llvmgpu-lower-executable-target)))))" \
 // RUN:   %s | FileCheck %s --check-prefix=MEMORY
 
-#config = #iree_codegen.lowering_config<tile_sizes = [[64, 64, 128]]>
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2>}>
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
@@ -33,7 +33,7 @@
       %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf16>> -> tensor<256x256xf16>
       %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf16>> -> tensor<256x256xf16>
       %5 = tensor.empty() : tensor<256x256xf32>
-      %6 = linalg.fill {lowering_config = #config} ins(%cst : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32>
+      %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32>
       %7 = linalg.matmul {lowering_config = #config} ins(%3, %4 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%6 : tensor<256x256xf32>) -> tensor<256x256xf32>
       flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf32> -> !flow.dispatch.tensor<writeonly:tensor<256x256xf32>>
       return
@@ -54,7 +54,7 @@
 
 // -----
 
-#config = #iree_codegen.lowering_config<tile_sizes = [[64, 64, 128]]>
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 128]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2>}>
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
@@ -79,7 +79,7 @@
       %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf16>> -> tensor<256x256xf16>
       %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf16>> -> tensor<256x256xf16>
       %5 = tensor.empty() : tensor<256x256xf16>
-      %6 = linalg.fill {lowering_config = #config} ins(%cst : f16) outs(%5 : tensor<256x256xf16>) -> tensor<256x256xf16>
+      %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<256x256xf16>) -> tensor<256x256xf16>
       %7 = linalg.matmul {lowering_config = #config} ins(%3, %4 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%6 : tensor<256x256xf16>) -> tensor<256x256xf16>
       flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf16> -> !flow.dispatch.tensor<writeonly:tensor<256x256xf16>>
       return
@@ -98,7 +98,7 @@
 
 // -----
 
-#config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 64, 64, 128]]>
+#config = #iree_gpu.lowering_config<{workgroup = [1, 1, 64, 64, 0], reduction = [0, 0, 0, 0, 128]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 1, subgroup_n_count = 4>}>
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
@@ -129,7 +129,7 @@
           : !flow.dispatch.tensor<readonly:tensor<10x64x2048xf16>> -> tensor<10x64x2048xf16>
 
         %5 = tensor.empty() : tensor<2x10x64x64xf16>
-        %6 = linalg.fill {lowering_config = #config} ins(%cst : f16) outs(%5 : tensor<2x10x64x64xf16>) -> tensor<2x10x64x64xf16>
+        %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2x10x64x64xf16>) -> tensor<2x10x64x64xf16>
         %7 = linalg.generic {
           indexing_maps = [
             affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>,
@@ -156,7 +156,7 @@
 //          CHECK: func @expanded_matmul_transpose_b
 // This has more than 2 iteartions. So we have prefetching enabled for this case. Due to
 // prefetching, we have one iteration peeled of so upper bound is 2048 - 128 = 1920.
-//          CHECK:   scf.for {{.*}} = %c0 to %c15 step %c1 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<4x1x1x1x4x1xf16>)
+//          CHECK:   scf.for {{.*}} = %c0 to %c1920 step %c128 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<4x1x1x1x4x1xf16>)
 //          CHECK:     arith.extf %[[ARG]] {{.*}} : vector<4x1x1x1x4x1xf16> to vector<4x1x1x1x4x1xf32>
 // CHECK-COUNT-32:     amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
 //          CHECK:     %[[TRUNC:.+]] = arith.truncf %{{.*}} : vector<4x1x1x1x4x1xf32> to vector<4x1x1x1x4x1xf16>
@@ -168,7 +168,7 @@
 
 // Basic f8, f8 -> f32 matmul.
 
-#config = #iree_codegen.lowering_config<tile_sizes = [[64, 64, 256]]>
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 256]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ>, subgroup_m_count = 2, subgroup_n_count = 2>}>
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
@@ -193,7 +193,7 @@
       %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf8E4M3FNUZ>> -> tensor<256x256xf8E4M3FNUZ>
       %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xf8E4M3FNUZ>> -> tensor<256x256xf8E4M3FNUZ>
       %5 = tensor.empty() : tensor<256x256xf32>
-      %6 = linalg.fill {lowering_config = #config} ins(%cst : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32>
+      %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32>
       %7 = linalg.matmul {lowering_config = #config} ins(%3, %4 : tensor<256x256xf8E4M3FNUZ>, tensor<256x256xf8E4M3FNUZ>) outs(%6 : tensor<256x256xf32>) -> tensor<256x256xf32>
       flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf32> -> !flow.dispatch.tensor<writeonly:tensor<256x256xf32>>
       return
@@ -214,7 +214,7 @@
 
 // Basic i8, i8 -> i32 matmul.
 
-#config = #iree_codegen.lowering_config<tile_sizes = [[64, 64, 256]]>
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 256]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>, subgroup_m_count = 2, subgroup_n_count = 2>}>
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
@@ -239,7 +239,7 @@
       %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xi8>> -> tensor<256x256xi8>
       %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xi8>> -> tensor<256x256xi8>
       %5 = tensor.empty() : tensor<256x256xi32>
-      %6 = linalg.fill {lowering_config = #config} ins(%cst : i32) outs(%5 : tensor<256x256xi32>) -> tensor<256x256xi32>
+      %6 = linalg.fill ins(%cst : i32) outs(%5 : tensor<256x256xi32>) -> tensor<256x256xi32>
       %7 = linalg.matmul {lowering_config = #config} ins(%3, %4 : tensor<256x256xi8>, tensor<256x256xi8>) outs(%6 : tensor<256x256xi32>) -> tensor<256x256xi32>
       flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xi32> -> !flow.dispatch.tensor<writeonly:tensor<256x256xi32>>
       return
@@ -260,7 +260,7 @@
 
 // Basic i8, i8 -> i32 matmul_transpose_b.
 
-#config = #iree_codegen.lowering_config<tile_sizes = [[64, 64, 256]]>
+#config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 256]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>, subgroup_m_count = 2, subgroup_n_count = 2>}>
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
@@ -285,7 +285,7 @@
       %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xi8>> -> tensor<256x256xi8>
       %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<256x256xi8>> -> tensor<256x256xi8>
       %5 = tensor.empty() : tensor<256x256xi32>
-      %6 = linalg.fill {lowering_config = #config} ins(%cst : i32) outs(%5 : tensor<256x256xi32>) -> tensor<256x256xi32>
+      %6 = linalg.fill ins(%cst : i32) outs(%5 : tensor<256x256xi32>) -> tensor<256x256xi32>
       %7 = linalg.matmul_transpose_b {lowering_config = #config} ins(%3, %4 : tensor<256x256xi8>, tensor<256x256xi8>) outs(%6 : tensor<256x256xi32>) -> tensor<256x256xi32>
       flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xi32> -> !flow.dispatch.tensor<writeonly:tensor<256x256xi32>>
       return
@@ -304,7 +304,7 @@
 
 // -----
 
-#config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 64, 128, 1, 1, 32]]>
+#config = #iree_gpu.lowering_config<{workgroup = [1, 1, 64, 128, 0, 0, 0], reduction = [0, 0, 0, 0, 1, 1, 32]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2>}>
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
@@ -329,7 +329,7 @@
         %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [2, 258, 514, 768], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x258x514x768xf16>> -> tensor<2x258x514x768xf16>
         %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [3, 3, 768, 256], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<3x3x768x256xf16>> -> tensor<3x3x768x256xf16>
         %5 = tensor.empty() : tensor<2x256x512x256xf32>
-        %6 = linalg.fill {lowering_config = #config} ins(%cst : f32) outs(%5 : tensor<2x256x512x256xf32>) -> tensor<2x256x512x256xf32>
+        %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2x256x512x256xf32>) -> tensor<2x256x512x256xf32>
         %7 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>, lowering_config = #config} ins(%3, %4 : tensor<2x258x514x768xf16>, tensor<3x3x768x256xf16>) outs(%6 : tensor<2x256x512x256xf32>) -> tensor<2x256x512x256xf32>
         flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [2, 256, 512, 256], strides = [1, 1, 1, 1] : tensor<2x256x512x256xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x256x512x256xf32>>
         return
@@ -347,7 +347,7 @@
 
 // -----
 
-#config = #iree_codegen.lowering_config<tile_sizes = [[1, 64, 1, 64, 128]]>
+#config = #iree_gpu.lowering_config<{workgroup = [1, 64, 1, 64, 0], reduction = [0, 0, 0, 0, 128]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2>}>
 
 #pipeline_layout = #hal.pipeline.layout<constants = 2, bindings = [
@@ -380,7 +380,7 @@
         %7 = flow.dispatch.tensor.load %4, offsets = [0, 0, 0], sizes = [2, 1024, 1280], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x1024x1280xf16>> -> tensor<2x1024x1280xf16>
         %8 = flow.dispatch.tensor.load %5, offsets = [0, 0, 0], sizes = [20, 64, 1280], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x64x1280xf16>> -> tensor<20x64x1280xf16>
         %9 = tensor.empty() : tensor<2x1024x20x64xf16>
-        %10 = linalg.fill {lowering_config = #config} ins(%cst : f16) outs(%9 : tensor<2x1024x20x64xf16>) -> tensor<2x1024x20x64xf16>
+        %10 = linalg.fill ins(%cst : f16) outs(%9 : tensor<2x1024x20x64xf16>) -> tensor<2x1024x20x64xf16>
         %11 = linalg.generic {
           indexing_maps = [#map, #map1, #map2],
           iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"],
@@ -403,7 +403,7 @@
 //    CHECK-LABEL: func.func @generic_2x1024x20x64x1280_f16
 // This has more than 2 iteartions. So we have prefetching enabled for this case. Due to
 // prefetching, we have one iteration peeled of so upper bound is 1280 - 128 = 1152.
-//          CHECK:   scf.for {{.*}} = %c0 to %c9 step %c1 iter_args({{.*}}) -> (vector<2x2x1x1x4x1xf16>)
+//          CHECK:   scf.for {{.*}} = %c0 to %c1152 step %c128 iter_args({{.*}}) -> (vector<2x2x1x1x4x1xf16>)
 // Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times
 // along the K dimension. So in total 32 mfma ops.
 // CHECK-COUNT-32:     amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
@@ -413,7 +413,7 @@
 
 // -----
 
-#config = #iree_codegen.lowering_config<tile_sizes = [[1, 16, 16, 16]]>
+#config = #iree_gpu.lowering_config<{workgroup = [1, 16, 16, 0], reduction = [0, 0, 0, 16]}>
 #translation = #iree_codegen.translation_info<LLVMGPUPadAndVectorDistribute workgroup_size = [64, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 1, subgroup_n_count = 1>}>
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
@@ -438,7 +438,7 @@
       %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [64, 968, 1281], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<64x968x1281xf16>> -> tensor<64x968x1281xf16>
       %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [64, 1281, 1281], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<64x1281x1281xf16>> -> tensor<64x1281x1281xf16>
       %5 = tensor.empty() : tensor<64x968x1281xf16>
-      %6 = linalg.fill {lowering_config = #config} ins(%cst : f16) outs(%5 : tensor<64x968x1281xf16>) -> tensor<64x968x1281xf16>
+      %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<64x968x1281xf16>) -> tensor<64x968x1281xf16>
       %7 = linalg.batch_matmul {lowering_config = #config} ins(%3, %4 : tensor<64x968x1281xf16>, tensor<64x1281x1281xf16>) outs(%6 : tensor<64x968x1281xf16>) -> tensor<64x968x1281xf16>
       flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0], sizes = [64, 968, 1281], strides = [1, 1, 1] : tensor<64x968x1281xf16> -> !flow.dispatch.tensor<writeonly:tensor<64x968x1281xf16>>
       return
@@ -462,7 +462,7 @@
 // CHECK:         %[[RHS_LOAD:.+]] = vector.transfer_read %[[RHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]}
 // CHECK:         vector.transfer_write %[[LHS_LOAD]], %[[LHS_SHARED]]
 // CHECK:         vector.transfer_write %[[RHS_LOAD]], %[[RHS_SHARED]]
-// CHECK:         %[[RES:.+]] scf.for {{.*}} = %c0 to %c80 step %c1 iter_args({{.*}}) -> (vector<1x1x1x1x1x1x1x4x1xf16>)
+// CHECK:         %[[RES:.+]] scf.for {{.*}} = %c0 to %c1265 step %c16 iter_args({{.*}}) -> (vector<1x1x1x1x1x1x1x4x1xf16>)
 // CHECK-DAG:       %[[LHS_GLOBAL_SUB:.+]] = memref.subview %[[LHS_GLOBAL]]
 // CHECK-DAG:       %[[RHS_GLOBAL_SUB:.+]] = memref.subview %[[RHS_GLOBAL]]
 // CHECK:           %[[LHS_LOAD:.+]] = vector.transfer_read %[[LHS_GLOBAL_SUB]]
@@ -492,7 +492,7 @@
 // NOTE: This test is not exhaustive of all possible ways the above condition is breaking,
 //       but rather is an example of a matmul shape from a model that broke our compilation heuristic.
 
-#config = #iree_codegen.lowering_config<tile_sizes = [[1, 16, 128, 128]]>
+#config = #iree_gpu.lowering_config<{workgroup = [1, 16, 128, 0], reduction = [0, 0, 0, 128]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false>, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 1, subgroup_n_count = 4>}>
 
 #pipeline_layout = #hal.pipeline.layout<constants = 3, bindings = [
@@ -521,7 +521,7 @@
         %9 = flow.dispatch.tensor.load %6, offsets = [0, 0, 0], sizes = [2, 160, 1536], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x160x1536xf16>> -> tensor<2x160x1536xf16>
         %10 = flow.dispatch.tensor.load %7, offsets = [0, 0, 0], sizes = [2, 1536, 1536], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x1536x1536xf16>> -> tensor<2x1536x1536xf16>
         %11 = tensor.empty() : tensor<2x160x1536xf16>
-        %12 = linalg.fill {lowering_config = #config} ins(%cst : f16) outs(%11 : tensor<2x160x1536xf16>) -> tensor<2x160x1536xf16>
+        %12 = linalg.fill ins(%cst : f16) outs(%11 : tensor<2x160x1536xf16>) -> tensor<2x160x1536xf16>
         %13 = linalg.batch_matmul {lowering_config = #config} ins(%9, %10 : tensor<2x160x1536xf16>, tensor<2x1536x1536xf16>) outs(%12 : tensor<2x160x1536xf16>) -> tensor<2x160x1536xf16>
         flow.dispatch.tensor.store %13, %8, offsets = [0, 0, 0], sizes = [2, 160, 1536], strides = [1, 1, 1] : tensor<2x160x1536xf16> -> !flow.dispatch.tensor<writeonly:tensor<2x160x1536xf16>>
         return
@@ -536,14 +536,14 @@
 // CHECK-DAG:     %[[RHS_SHARED_SUB:.+]] =  memref.subview %[[RHS_SHARED]][0, 0] [128, 128] [1, 1]
 // CHECK-DAG:     %[[LHS_SHARED:.+]] = memref.alloc() : memref<16x132xf16, #gpu.address_space<workgroup>>
 // CHECK-DAG:     %[[LHS_SHARED_SUB:.+]] =  memref.subview %[[LHS_SHARED]][0, 0] [16, 128] [1, 1]
-// CHECK:   scf.for {{.*}} = %c0 to %c11 step %c1 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<1x2x1x1x4x1xf16>)
+// CHECK:   scf.for {{.*}} = %c0 to %c1408 step %c128 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<1x2x1x1x4x1xf16>)
 // CHECK-COUNT-16:     amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
 // CHECK:     scf.yield
 // CHECK-COUNT-16:   amdgpu.mfma
 
 // -----
 
-#config = #iree_codegen.lowering_config<tile_sizes = [[1, 64, 0, 64, 64]]>
+#config = #iree_gpu.lowering_config<{workgroup = [1, 64, 0, 0, 64], reduction = [0, 0, 0, 64, 0]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [64, 2, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 1>}>
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
@@ -609,7 +609,7 @@
 
 // -----
 
-#config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 64, 0, 64, 64]]>
+#config = #iree_gpu.lowering_config<{workgroup = [1, 1, 64, 0, 0, 64], reduction = [0, 0, 0, 0, 64, 0]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [64, 2, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 1>}>
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
@@ -639,7 +639,7 @@
         %7 = tensor.empty() : tensor<64x4608x24x128xf16>
         %8 = tensor.empty() : tensor<24x64x4608x128xf16>
         %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) -> tensor<24x64x4608x128xf16>
-        %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d0, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], lowering_config = #config} ins(%9 : tensor<24x64x4608x128xf16>) outs(%7 : tensor<64x4608x24x128xf16>) {
+        %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d0, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%9 : tensor<24x64x4608x128xf16>) outs(%7 : tensor<64x4608x24x128xf16>) {
         ^bb0(%in: f16, %out: f16):
           linalg.yield %in : f16
         } -> tensor<64x4608x24x128xf16>
@@ -651,7 +651,7 @@
 }
 
 // CHECK-LABEL: func.func @attention_multiple_m_transpose()
-// CHECK: scf.for %{{.*}} = %c0 to %c72 step %c1
+// CHECK: scf.for %{{.*}} = %c0 to %c4608 step %c64
 // CHECK-SAME: -> (vector<2x1x1xf32>, vector<2x1x1xf32>, vector<2x8x1x1x1x4xf32>)
 // CHECK-COUNT-96:  amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
 // CHECK: scf.yield
@@ -664,7 +664,7 @@
 
 // -----
 
-#config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 128, 0, 32, 64]]>
+#config = #iree_gpu.lowering_config<{workgroup = [1, 1, 128, 0, 0, 64], reduction = [0, 0, 0, 0, 32, 0]}>
 #translation = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [64, 4, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>, subgroup_m_count = 4, subgroup_n_count = 1>}>
 
 #pipeline_layout = #hal.pipeline.layout<bindings = [
@@ -694,7 +694,7 @@
         %7 = tensor.empty() : tensor<64x4608x24x128xf16>
         %8 = tensor.empty() : tensor<24x64x4608x128xf16>
         %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) -> tensor<24x64x4608x128xf16>
-        %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d0, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], lowering_config = #config} ins(%9 : tensor<24x64x4608x128xf16>) outs(%7 : tensor<64x4608x24x128xf16>) {
+        %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d0, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%9 : tensor<24x64x4608x128xf16>) outs(%7 : tensor<64x4608x24x128xf16>) {
         ^bb0(%in: f16, %out: f16):
           linalg.yield %in : f16
         } -> tensor<64x4608x24x128xf16>
@@ -706,7 +706,7 @@
 }
 
 // CHECK-LABEL: func.func @attention_mfma_32x32x8()
-// CHECK: scf.for %{{.*}} = %c0 to %c144 step %c1
+// CHECK: scf.for %{{.*}} = %c0 to %c4608 step %c32
 // CHECK-SAME: -> (vector<1x1x1xf32>, vector<1x1x1xf32>, vector<1x4x1x4x1x4xf32>)
 // CHECK-COUNT-32:  amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<16xf32>
 // CHECK: scf.yield
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index c1e8907..c6a2994 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -106,20 +106,80 @@
 # Describes how to construct compilation info for the testcase.
 @dataclasses.dataclass
 class CompilationInfo:
-    # Lowering Config
-    tile_sizes: typing.List[typing.List[int]]
-    # Translation Info
-    dispatch_lowering_pass_pipeline: str
-    software_pipeline_depth: int
-    mma_schedule: typing.Optional[MMASchedule]
     # Compilation info
     workgroup_size: typing.List[int]
-    subgroup_size: Optional[int] = None
+    subgroup_size: Optional[int]
+    # Translation info
+    dispatch_lowering_pass_pipeline: str
 
     # Prints the workgroup size
     def workgroup_size_str(self):
         return "workgroup_size = [" + ", ".join(map(str, self.workgroup_size)) + "]"
 
+    def get_compilation_info_attr(self) -> str:
+        ...
+
+
+@dataclasses.dataclass
+class IREEGPUCompilationInfo(CompilationInfo):
+    # Lowering Config
+    workgroup_tile: list[int]
+    reduction_tile: list[int]
+    # Translation Info
+    mma_schedule: Optional[MMASchedule]
+
+    def get_compilation_info_attr(self) -> str:
+        requested_pipeline = self.dispatch_lowering_pass_pipeline
+        compiler_pipeline = requested_pipeline
+
+        mma_schedule = ""
+        if self.mma_schedule is not None:
+            mma_schedule = "{}".format(self.mma_schedule)
+        subgroup_size_str = ""
+        if self.subgroup_size is not None:
+            subgroup_size_str = f"subgroup_size = {self.subgroup_size}"
+
+        return (
+            "#iree_codegen.compilation_info<\n"
+            f"  lowering_config = #iree_gpu.lowering_config<{{"
+            f"  workgroup = {self.workgroup_tile}, "
+            f"  reduction = {self.reduction_tile} }}>,\n"
+            f"  translation_info = <{compiler_pipeline} {self.workgroup_size_str()}\n"
+            f"  {subgroup_size_str},\n"
+            f"  {{ {mma_schedule} }}>>\n"
+        )
+
+
+@dataclasses.dataclass
+class LegacyCompilationInfo(CompilationInfo):
+    # Lowering Config
+    tile_sizes: typing.List[typing.List[int]]
+    # Translation Info
+    software_pipeline_depth: int
+
+    def get_compilation_info_attr(self) -> str:
+        requested_pipeline = self.dispatch_lowering_pass_pipeline
+        compiler_pipeline = requested_pipeline
+        if requested_pipeline == "SPIRVVectorizeMali":
+            compiler_pipeline = "SPIRVBaseVectorize"
+        elif requested_pipeline == "SPIRVCooperativeMatrixVectorize":
+            compiler_pipeline = "SPIRVCooperativeMatrixVectorize"
+        elif requested_pipeline == "SPIRVVectorizeNVIDIA":
+            # TODO: change to test SPIRVMatmulPromoteVectorize too
+            compiler_pipeline = "SPIRVBaseVectorize"
+
+        subgroup_size_str = ""
+        if self.subgroup_size is not None:
+            subgroup_size_str = f"subgroup_size = {self.subgroup_size}"
+
+        return (
+            "#iree_codegen.compilation_info<\n"
+            f"  lowering_config = #iree_codegen.lowering_config<tile_sizes = {self.tile_sizes}>,\n"
+            f"  translation_info = <{compiler_pipeline} {self.workgroup_size_str()}\n"
+            f"  {subgroup_size_str},\n"
+            f"  {{ pipeline_depth = {self.software_pipeline_depth}, store_stage = 1}}>>"
+        )
+
 
 # Returns the list of TestShape's to use for the collection of shapes
 # identified by shapes_id.
@@ -356,14 +416,15 @@
         else:
             raise NotImplementedError("unhandled intrinsic case")
 
-        workgroup_tile = [[wg_tile_m, wg_tile_n, wg_tile_k]]
+        workgroup_tile = [wg_tile_m, wg_tile_n, 0]
+        reduction_tile = [0, 0, wg_tile_k]
         workgroup_size = [schedule.n_count * subgroup_size, schedule.m_count, 1]
         infos.append(
-            CompilationInfo(
-                tile_sizes=workgroup_tile,
+            IREEGPUCompilationInfo(
+                workgroup_tile=workgroup_tile,
+                reduction_tile=reduction_tile,
                 dispatch_lowering_pass_pipeline="LLVMGPUVectorDistribute",
                 workgroup_size=workgroup_size,
-                software_pipeline_depth=0,
                 mma_schedule=schedule,
                 subgroup_size=subgroup_size,
             )
@@ -385,6 +446,7 @@
         return get_rocm_test_compilation_infos(compilation_info_id, lhs_rhs_type)
 
     software_pipeline_depth = 0
+    tile_workgroup_size_pairs = []
     if compilation_info_id == CompilationInfoId.LLVMGPUMatmulSimt:
         tile_workgroup_size_pairs = [
             TileWorkgroupSizePair([[32, 128, 32]], [32, 8, 1]),
@@ -438,12 +500,12 @@
     compilation_infos = []
     for tile_workgroup_size_pair in tile_workgroup_size_pairs:
         compilation_infos.append(
-            CompilationInfo(
+            LegacyCompilationInfo(
                 tile_sizes=tile_workgroup_size_pair.tile_size,
                 dispatch_lowering_pass_pipeline=compilation_info_id.value,
                 workgroup_size=tile_workgroup_size_pair.workgroup_size,
+                subgroup_size=None,
                 software_pipeline_depth=software_pipeline_depth,
-                mma_schedule=None,
             )
         )
     return compilation_infos
@@ -496,7 +558,7 @@
 def get_castback_from_arg_op(target_type: MatrixElemTypeId):
     if target_type == MatrixElemTypeId.F8E4M3FNUZ:
         return "arith.truncf"
-    return ValueError(f"Unhandled castback type of {t}")
+    return ValueError(f"Unhandled castback type of {target_type}")
 
 
 # Describes the fully resolved shape dimensions of all 3 input matrices,
@@ -559,13 +621,7 @@
 
     info = ""
     if compilation_info:
-        tile_sizes = list(itertools.chain(*compilation_info.tile_sizes))
-        tile_workgroup_key = (
-            "_".join([str(a) for a in tile_sizes])
-            + "_"
-            + "_".join([str(a) for a in compilation_info.workgroup_size])
-        )
-        info = f"_for_{compilation_info.dispatch_lowering_pass_pipeline}_{tile_workgroup_key}"
+        info = f"_for_{compilation_info.dispatch_lowering_pass_pipeline}"
 
     matmul_kind = "matmul_accumulate" if accumulate else "matmul"
     return (
@@ -592,7 +648,7 @@
     shape: TestShape,
     transpose_rhs: bool,
     dynamicity: Dynamicity,
-    compilation_info: typing.Optional[CompilationInfo] = None,
+    compilation_info: Optional[CompilationInfo] = None,
 ):
     shapes = generate_shapes(shape, transpose_rhs, dynamicity)
     func_name = generate_function_name(
@@ -619,32 +675,7 @@
     func_definition = ""
     compilation_info_attr = ""
     if compilation_info:
-        requested_pipeline = compilation_info.dispatch_lowering_pass_pipeline
-        compiler_pipeline = requested_pipeline
-        if requested_pipeline == "SPIRVVectorizeMali":
-            compiler_pipeline = "SPIRVBaseVectorize"
-        elif requested_pipeline == "SPIRVCooperativeMatrixVectorize":
-            compiler_pipeline = "SPIRVCooperativeMatrixVectorize"
-        elif requested_pipeline == "SPIRVVectorizeNVIDIA":
-            # TODO: change to test SPIRVMatmulPromoteVectorize too
-            compiler_pipeline = "SPIRVBaseVectorize"
-
-        mma_schedule = ""
-        if compilation_info.mma_schedule is not None:
-            mma_schedule = ", {}".format(compilation_info.mma_schedule)
-        subgroup_size_str = ""
-        if compilation_info.subgroup_size is not None:
-            subgroup_size_str = f"subgroup_size = {compilation_info.subgroup_size}"
-
-        compilation_info_string = (
-            f"#compilation{generate_function.compilation_index} = "
-            "#iree_codegen.compilation_info<\n"
-            f"  lowering_config = #iree_codegen.lowering_config<tile_sizes = {compilation_info.tile_sizes}>,\n"
-            f"  translation_info = <{compiler_pipeline} {compilation_info.workgroup_size_str()}\n"
-            f"  {subgroup_size_str},\n"
-            f"  {{ pipeline_depth = {compilation_info.software_pipeline_depth}, "
-            f"  store_stage = 1{mma_schedule} }}>>\n"
-        )
+        compilation_info_string = f"#compilation{generate_function.compilation_index} = {compilation_info.get_compilation_info_attr()}"
         compilation_info_attr = (
             f"{{compilation_info = #compilation{generate_function.compilation_index}}} "
         )