[Codegen][GPU] Use arithmetic intensity to guide gemm size categorization - step 2 (#21691)

This is a follow-up to https://github.com/iree-org/iree/pull/21638 and
final step of heuristic implementation described in #21506.

Since both peakMemoryBandwidth and peakPerf are available, the heuristic
now use those to decide the compute/memory cutoff point, which further
help to derive the cutoff of small and large gemms.

This PR once merged will improve heuristic performance by 8% according
to performance geo-mean of 478 convolutions.

---------

Signed-off-by: jerryyin <zhuoryin@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
index d9fe58c..8617327 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
@@ -492,20 +492,14 @@
     return bestMNTileCountPerSubgroup;
   }
 
-  int64_t mSize = ShapedType::getNumElements(problem.mSizes);
-  int64_t nSize = ShapedType::getNumElements(problem.nSizes);
-  int64_t kSize = ShapedType::getNumElements(problem.kSizes);
-  float arithmeticIntensity =
-      (2.0f * mSize * nSize * kSize) /
-      static_cast<float>(mSize * nSize + nSize * kSize + mSize * kSize);
-
-  // TODO(jerryyin): compute arithmetic intensity bound based on the information
-  // from the target chip.
-  if (arithmeticIntensity <= 10.0f) {
-    LDBG() << "Arithmetic intensity is too low, " << arithmeticIntensity
-           << ", skipping adjustment of seeds for workgroup count.";
+  if (problem.gemmSize == GemmSize::NotSet ||
+      problem.gemmSize == GemmSize::SmallGemm) {
+    LDBG() << "Arithmetic intensity is too low, "
+           << "skipping adjustment of seeds for workgroup count.";
     return bestMNTileCountPerSubgroup;
   }
+  int64_t mSize = ShapedType::getNumElements(problem.mSizes);
+  int64_t nSize = ShapedType::getNumElements(problem.nSizes);
   auto computeWorkgroupCount = [&] {
     // Compute the number of workgroups needed to cover the problem size.
     // This number tends to be lower than actual workgroup count, since:
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h
index e45fca0..43d696c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h
@@ -9,6 +9,8 @@
 
 namespace mlir::iree_compiler {
 
+enum class GemmSize { NotSet, SmallGemm, MediumGemm, LargeGemm };
+
 /// Struct containing information about a matmul's shape and type.
 struct GPUMatmulShapeType {
   SmallVector<int64_t, 2> mSizes;
@@ -18,6 +20,7 @@
   Type aType;
   Type bType;
   Type cType;
+  GemmSize gemmSize = GemmSize::NotSet;
 
   GPUMatmulShapeType(int64_t m, int64_t n, int64_t k, Type a, Type b, Type c)
       : mSizes({m}), nSizes({n}), kSizes({k}), batchSizes({}), aType(a),
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
index 3ac1751..14d1615 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
@@ -26,6 +26,7 @@
 #include "llvm/Support/InterleavedRange.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
@@ -122,6 +123,119 @@
       workgroupSize, targetSubgroupSize, pipelineConfig);
 }
 
+static std::optional<ComputeBitwidths> getComputeBitwidthForType(Type type) {
+  return llvm::TypeSwitch<Type, std::optional<ComputeBitwidths>>(type)
+      .Case<FloatType>(
+          [](FloatType floatType) -> std::optional<ComputeBitwidths> {
+            switch (floatType.getIntOrFloatBitWidth()) {
+            case 64:
+              return ComputeBitwidths::FP64;
+            case 32:
+              return ComputeBitwidths::FP32;
+            case 16:
+              return ComputeBitwidths::FP16;
+            case 8:
+              return ComputeBitwidths::FP8;
+            case 6:
+              return ComputeBitwidths::FP6;
+            case 4:
+              return ComputeBitwidths::FP4;
+            default:
+              return std::nullopt;
+            }
+          })
+      .Case<IntegerType>(
+          [](IntegerType intType) -> std::optional<ComputeBitwidths> {
+            switch (intType.getWidth()) {
+            case 64:
+              return ComputeBitwidths::Int64;
+            case 32:
+              return ComputeBitwidths::Int32;
+            case 16:
+              return ComputeBitwidths::Int16;
+            case 8:
+              return ComputeBitwidths::Int8;
+            default:
+              return std::nullopt;
+            }
+          })
+      .Default([](Type) { return std::nullopt; });
+}
+
+namespace {
+struct GemmCutoff {
+  float smallGemmCutoff;
+  float largeGemmCutoff;
+};
+} // namespace
+
+/// Function to compute small and large gemm cutoffs for arithmetic intensity
+/// based on the target's peak performance and memory bandwidth.
+static GemmCutoff computeGemmCutoffsForAI(IREE::GPU::TargetAttr target,
+                                          Type computeType) {
+  float smallGemmCutoff = 1.0f;
+  float largeGemmCutoff = 1000.0f;
+  if (!target.getChip()) {
+    LDBG() << "Target chip is not specified, using default gemm cutoffs: "
+           << smallGemmCutoff << ", " << largeGemmCutoff;
+    return {smallGemmCutoff, largeGemmCutoff};
+  }
+
+  TargetChipAttr chip = target.getChip();
+  DictionaryAttr peakPerfTlopsAttr = chip.getPerfTflops();
+  llvm::DenseMap<ComputeBitwidths, float> peakPerfTflops;
+  for (NamedAttribute namedAttr : peakPerfTlopsAttr) {
+    StringRef bitwidthStr = namedAttr.getName().strref();
+    FloatAttr floatAttr = dyn_cast<FloatAttr>(namedAttr.getValue());
+    if (!floatAttr) {
+      continue;
+    }
+
+    std::optional<ComputeBitwidths> bitwidth =
+        symbolizeComputeBitwidths(bitwidthStr);
+    if (!bitwidth) {
+      continue;
+    }
+
+    peakPerfTflops[*bitwidth] = floatAttr.getValue().convertToFloat();
+  }
+
+  bool peakPerfTflopsFound = false;
+  auto computeBitwidth = getComputeBitwidthForType(computeType);
+  if (computeBitwidth) {
+    peakPerfTflopsFound = peakPerfTflops.contains(computeBitwidth.value());
+  }
+  bool memoryBandwidthFound = chip.getMemoryBandwidthTbps() != nullptr;
+  if (!peakPerfTflopsFound || !memoryBandwidthFound) {
+    LDBG() << "Target chip does not have peak performance or memory bandwidth "
+              "information, using default gemm cutoffs: "
+           << smallGemmCutoff << ", " << largeGemmCutoff;
+    return {smallGemmCutoff, largeGemmCutoff};
+  }
+
+  // TODO: Attempt to use number of elements loaded per second instead of
+  // Tbps and adopt it if the perf uplift transfer better between different
+  // data types.
+  FloatAttr memoryBandwidthTbpsAttr = chip.getMemoryBandwidthTbps();
+  float memoryBandwidthTbps =
+      memoryBandwidthTbpsAttr.getValue().convertToFloat();
+
+  float perfTflops = peakPerfTflops[computeBitwidth.value()];
+  float computeMemoryCutoff = perfTflops / memoryBandwidthTbps;
+  LDBG() << "Target chip peak performance: " << perfTflops << " TFlops for "
+         << stringifyComputeBitwidths(computeBitwidth.value())
+         << " bitwidth, memory bandwidth: " << memoryBandwidthTbps
+         << " Tbps, compute-memory cutoff: " << computeMemoryCutoff;
+
+  // The constants below are determined and generalized based on empirical data
+  // based on the approach in https://github.com/iree-org/iree/discussions/21506
+  smallGemmCutoff = 0.05f * computeMemoryCutoff;
+  largeGemmCutoff = 5.0f * computeMemoryCutoff;
+  LDBG() << "Target chip small gemm cutoff: " << smallGemmCutoff
+         << ", large gemm cutoff: " << largeGemmCutoff;
+  return {smallGemmCutoff, largeGemmCutoff};
+}
+
 /// Given a target and a matmul problem, try to find an MMA schedule for the
 /// problem based on the available mma intrinsics.
 static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget(
@@ -169,25 +283,43 @@
          "expected the same aType and bType.");
   int64_t inBitWidth = problem.aType.getIntOrFloatBitWidth();
 
+  GemmCutoff gemmCutoffs = computeGemmCutoffsForAI(target, problem.aType);
+
   // Note that the following heuristic seeds are just placeholder values.
   // We need to clean it up and make it adjusting to different targets.
   // See https://github.com/iree-org/iree/issues/16341 for details.
   int64_t mSize = ShapedType::getNumElements(problem.mSizes);
   int64_t nSize = ShapedType::getNumElements(problem.nSizes);
-  if (mSize * nSize <= 512 * 512) {
-    // For matmuls with small M*N size, we want to distribute M*N onto more
-    // workgroups to fill the GPU. Use a smaller bestMNTileCountPerSubgroup
-    // and a larger bestKTileCountPerSubgroup.
-    seeds = {/*bestSubgroupCountPerWorkgroup=*/4,
-             /*bestMNTileCountPerSubgroup=*/4,
-             /*bestKTileCountPerSubgroup=*/8,
-             /*bestKElementCountPerSubgroup*/ kCacheLineSizeBits / inBitWidth};
-  } else {
-    seeds = {/*bestSubgroupCountPerWorkgroup=*/4,
-             /*bestMNTileCountPerSubgroup=*/16,
+  int64_t kSize = ShapedType::getNumElements(problem.kSizes);
+  int64_t computeIntensity = (2 * mSize * nSize * kSize) /
+                             (mSize * nSize + nSize * kSize + mSize * kSize);
+
+  if (computeIntensity <= gemmCutoffs.smallGemmCutoff) {
+    // For matmuls with small arithmetic intensity, use small
+    // bestMNTileCountPerSubgroup and large bestKTileCountPerSubgroup.
+    problem.gemmSize = GemmSize::SmallGemm;
+    seeds = {/*bestSubgroupCountPerWorkgroup=*/2,
+             /*bestMNTileCountPerSubgroup=*/2,
              /*bestKTileCountPerSubgroup=*/4,
+             /*bestKElementCountPerSubgroup*/ kCacheLineSizeBits / inBitWidth};
+  } else if (computeIntensity >= gemmCutoffs.largeGemmCutoff) {
+    // For matmuls with large arithmetic intensity, use large
+    // bestMNTileCountPerSubgroup and small bestKTileCountPerSubgroup to
+    // amortize launch/memory costs and maximize throughput.
+    problem.gemmSize = GemmSize::LargeGemm;
+    seeds = {/*bestSubgroupCountPerWorkgroup=*/8,
+             /*bestMNTileCountPerSubgroup=*/8,
+             /*bestKTileCountPerSubgroup=*/2,
              /*bestKElementCountPerSubgroup*/ kCacheLineSizeBits / 2 /
                  inBitWidth};
+  } else {
+    // Choose balanced tile shapes. Empirically, medium-AI workloads can favor
+    // either small or large tiles depending on kernel details.
+    problem.gemmSize = GemmSize::MediumGemm;
+    seeds = {/*bestSubgroupCountPerWorkgroup=*/8,
+             /*bestMNTileCountPerSubgroup=*/4,
+             /*bestKTileCountPerSubgroup=*/4,
+             /*bestKElementCountPerSubgroup*/ kCacheLineSizeBits / inBitWidth};
   }
   int64_t maxSharedMemoryBytes = target.getWgp().getMaxWorkgroupMemoryBytes();
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 4012aa2..c33039b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -1060,9 +1060,10 @@
   if (intrinsics.empty())
     return failure();
 
-  // Note that the following heuristic seeds are just placeholder values.
-  // We need to clean it up and make it adjusting to different targets.
-  // See https://github.com/iree-org/iree/issues/16341 for details.
+  // TODO: Replace the below with algorithm described in
+  // https://github.com/iree-org/iree/discussions/21506.
+  // This is already implemented in KernelConfig.cpp in tileAndFuse pipeline
+  // and should be ported to here once its perf results are verified.
   GPUMMAHeuristicSeeds seeds{/*bestSubgroupCountPerWorkgroup=*/4,
                              /*bestMNTileCountPerSubgroup=*/8,
                              /*bestKTileCountPerSubgroup=*/2};
@@ -1287,9 +1288,10 @@
 
   GPUMMAHeuristicSeeds seeds;
 
-  // Note that the following heuristic seeds are just placeholder values.
-  // We need to clean it up and make it adjusting to different targets.
-  // See https://github.com/iree-org/iree/issues/16341 for details.
+  // TODO: Replace the below with algorithm described in
+  // https://github.com/iree-org/iree/discussions/21506.
+  // This is already implemented in KernelConfig.cpp in tileAndFuse pipeline
+  // and should be ported to here once its perf results are verified.
   if (problem.mSizes[0] * problem.nSizes[0] <= clGPUMatmulCThreshold) {
     // For matmuls with small M*N size, we want to distribute M*N onto more
     // workgroups to fill the GPU. Use a smaller bestMNTileCountPerSubgroup
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir
index 35eb91b..4c77bad 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir
@@ -23,7 +23,7 @@
 }
 
 // CHECK-LABEL: func.func @nhwc_conv_mfma
-//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
+//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [512, 1, 1] subgroup_size = 64
 //  CHECK-SAME:   #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false
 //  CHECK-SAME:   use_igemm_convolution = true
 
@@ -32,12 +32,12 @@
 //  CHECK-SAME:     promote_operands = [0, 1]
 
 //  GFX942-SAME:    reduction = [0, 0, 0, 0, 8]
-//  GFX942-SAME:    subgroup = [1, 2, 2, 1, 0]
-//  GFX942-SAME:    workgroup = [1, 2, 32, 64, 0]
+//  GFX942-SAME:    subgroup = [1, 4, 1, 1, 0]
+//  GFX942-SAME:    workgroup = [1, 4, 32, 64, 0]
 
 //  MI300X-SAME:    reduction = [0, 0, 0, 0, 8]
 //  MI300X-SAME:    subgroup = [1, 1, 1, 1, 0]
-//  MI300X-SAME:    workgroup = [1, 1, 16, 64, 0]}>
+//  MI300X-SAME:    workgroup = [1, 1, 32, 64, 0]}>
 
 // -----
 
@@ -57,7 +57,7 @@
 }
 
 // CHECK-LABEL: func.func @nchw_conv_mfma
-//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
+//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [512, 1, 1] subgroup_size = 64
 //  CHECK-SAME:   #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false
 //  CHECK-SAME:   use_igemm_convolution = true
 
@@ -66,12 +66,12 @@
 //  CHECK-SAME:     promote_operands = [0, 1]
 
 // GFX942-SAME:     reduction = [0, 0, 0, 0, 8]
-// GFX942-SAME:     subgroup = [1, 2, 2, 1, 0]
-// GFX942-SAME:     workgroup = [1, 64, 2, 32, 0]
+// GFX942-SAME:     subgroup = [1, 1, 4, 1, 0]
+// GFX942-SAME:     workgroup = [1, 64, 4, 32, 0]
 
 // MI300X-SAME:     reduction = [0, 0, 0, 0, 8]
 // MI300X-SAME:     subgroup = [1, 1, 1, 1, 0]
-// MI300X-SAME:     workgroup = [1, 32, 1, 32, 0]
+// MI300X-SAME:     workgroup = [1, 64, 1, 32, 0]
 
 // -----
 
@@ -91,7 +91,7 @@
 }
 
 // CHECK-LABEL: func.func @nhwc_conv_unaligned_mfma
-//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
+//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [512, 1, 1] subgroup_size = 64
 //  CHECK-SAME:   #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false
 //  CHECK-SAME:   use_igemm_convolution = true
 
@@ -101,14 +101,14 @@
 // GFX942-SAME:     padding = [2, 1, 32, 64, 32]
 // GFX942-SAME:     promote_operands = [0, 1, 2]
 // GFX942-SAME:     reduction = [0, 0, 0, 0, 8]
-// GFX942-SAME:     subgroup = [2, 1, 2, 1, 0]
+// GFX942-SAME:     subgroup = [2, 1, 1, 1, 0]
 // GFX942-SAME:     workgroup = [2, 1, 32, 64, 0]
 
-// MI300X-SAME:     padding = [1, 1, 16, 64, 32]
+// MI300X-SAME:     padding = [1, 1, 32, 64, 32]
 // MI300X-SAME:     promote_operands = [0, 1, 2]
 // MI300X-SAME:     reduction = [0, 0, 0, 0, 8]
 // MI300X-SAME:     subgroup = [1, 1, 1, 1, 0]
-// MI300X-SAME:     workgroup = [1, 1, 16, 64, 0]
+// MI300X-SAME:     workgroup = [1, 1, 32, 64, 0]
 
 // PAD-CONV-GFX942:     padding_conv = [2, 1, 32, 64, 0, 0, 0]
 
@@ -130,26 +130,26 @@
 }
 
 // CHECK-LABEL: func.func @nchw_conv_unaligned_mfma
-//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
+//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [512, 1, 1] subgroup_size = 64
 //  CHECK-SAME:   #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false
 //  CHECK-SAME:   use_igemm_convolution = true
 
 //       CHECK:   linalg.conv_2d_nchw_fchw {{.*}}lowering_config = #iree_gpu.lowering_config
 //  CHECK-SAME:     mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>
 
-// GFX942-SAME:     padding = [1, 64, 2, 32, 32]
+// GFX942-SAME:     padding = [1, 64, 4, 32, 32]
 // GFX942-SAME:     promote_operands = [0, 1, 2]
 // GFX942-SAME:     reduction = [0, 0, 0, 0, 8]
-// GFX942-SAME:     subgroup = [1, 2, 2, 1, 0]
-// GFX942-SAME:     workgroup = [1, 64, 2, 32, 0]
+// GFX942-SAME:     subgroup = [1, 1, 4, 1, 0]
+// GFX942-SAME:     workgroup = [1, 64, 4, 32, 0]
 
-// MI300X-SAME:     padding = [1, 32, 1, 32, 32]
+// MI300X-SAME:     padding = [1, 64, 1, 32, 32]
 // MI300X-SAME:     promote_operands = [0, 1, 2]
 // MI300X-SAME:     reduction = [0, 0, 0, 0, 8]
 // MI300X-SAME:     subgroup = [1, 1, 1, 1, 0]
-// MI300X-SAME:     workgroup = [1, 32, 1, 32, 0]
+// MI300X-SAME:     workgroup = [1, 64, 1, 32, 0]
 
-// PAD-CONV-GFX942:     padding_conv = [1, 64, 2, 32, 0, 0, 0]
+// PAD-CONV-GFX942:     padding_conv = [1, 64, 4, 32, 0, 0, 0]
 
 // -----
 
@@ -169,7 +169,7 @@
 }
 
 // CHECK-LABEL: func.func @conv_nhwc_fhwc_unaligned_channel
-//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
+//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [512, 1, 1] subgroup_size = 64
 //  CHECK-SAME:   #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false
 //  CHECK-SAME:   use_igemm_convolution = true
 
@@ -179,13 +179,13 @@
 // GFX942-SAME:     padding = [1, 8, 32, 32, 32]
 // GFX942-SAME:     promote_operands = [0, 1, 2]
 // GFX942-SAME:     reduction = [0, 0, 0, 0, 2]
-// GFX942-SAME:     subgroup = [1, 8, 1, 1, 0]
+// GFX942-SAME:     subgroup = [1, 4, 1, 1, 0]
 // GFX942-SAME:     workgroup = [1, 8, 32, 32, 0]
 
 // MI300X-SAME:     padding = [1, 4, 32, 32, 32]
 // MI300X-SAME:     promote_operands = [0, 1, 2]
 // MI300X-SAME:     reduction = [0, 0, 0, 0, 2]
-// MI300X-SAME:     subgroup = [1, 4, 1, 1, 0]
+// MI300X-SAME:     subgroup = [1, 2, 1, 1, 0]
 // MI300X-SAME:     workgroup = [1, 4, 32, 32, 0]
 
 // PAD-CONV-GFX942:     padding_conv = [1, 8, 32, 32, 0, 0, 32]
@@ -240,25 +240,25 @@
 }
 
 // CHECK-LABEL: func.func @group_conv_unaligned
-//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
+//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [512, 1, 1] subgroup_size = 64
 //  CHECK-SAME:   #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false
 //  CHECK-SAME:   use_igemm_convolution = true
 
 //       CHECK:   linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
 // GFX942-SAME:     mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16>
-// GFX942-SAME:     padding = [1, 32, 1, 64, 32]
+// GFX942-SAME:     padding = [1, 32, 1, 64, 64]
 // GFX942-SAME:     promote_operands = [0, 1, 2]
-// GFX942-SAME:     reduction = [0, 0, 0, 0, 2]
-// GFX942-SAME:     subgroup = [1, 2, 0, 1, 0]
+// GFX942-SAME:     reduction = [0, 0, 0, 0, 4]
+// GFX942-SAME:     subgroup = [1, 1, 0, 1, 0]
 // GFX942-SAME:     workgroup = [1, 32, 1, 64, 0]
 
-// MI300X-SAME:     padding = [1, 32, 1, 32, 32]
+// MI300X-SAME:     padding = [1, 32, 1, 64, 64]
 // MI300X-SAME:     promote_operands = [0, 1, 2]
-// MI300X-SAME:     reduction = [0, 0, 0, 0, 2]
+// MI300X-SAME:     reduction = [0, 0, 0, 0, 4]
 // MI300X-SAME:     subgroup = [1, 1, 0, 1, 0]
-// MI300X-SAME:     workgroup = [1, 32, 1, 32, 0]
+// MI300X-SAME:     workgroup = [1, 32, 1, 64, 0]
 
-// PAD-CONV-GFX942:     padding_conv = [1, 32, 1, 64, 0, 0, 32]
+// PAD-CONV-GFX942:     padding_conv = [1, 32, 1, 64, 0, 0, 64]
 
 // -----
 
@@ -280,22 +280,22 @@
 }
 
 // CHECK-LABEL: func.func @conv_nhwc_filter_5x1_unaligned
-//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
+//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [512, 1, 1] subgroup_size = 64
 //  CHECK-SAME:   #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false
 //  CHECK-SAME:   use_igemm_convolution = true
 
 //       CHECK:   linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
 // GFX942-SAME:     mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_BF16>
-// GFX942-SAME:     padding = [2, 2, 32, 64, 32]
+// GFX942-SAME:     padding = [2, 2, 32, 64, 64]
 // GFX942-SAME:     promote_operands = [0, 1, 2]
-// GFX942-SAME:     reduction = [0, 0, 0, 0, 2]
-// GFX942-SAME:     subgroup = [2, 2, 2, 1, 0]
+// GFX942-SAME:     reduction = [0, 0, 0, 0, 4]
+// GFX942-SAME:     subgroup = [2, 2, 1, 1, 0]
 // GFX942-SAME:     workgroup = [2, 2, 32, 64, 0]
 
-// MI300X-SAME:     padding = [1, 1, 32, 64, 32]
+// MI300X-SAME:     padding = [1, 1, 32, 64, 64]
 // MI300X-SAME:     promote_operands = [0, 1, 2]
-// MI300X-SAME:     reduction = [0, 0, 0, 0, 2]
-// MI300X-SAME:     subgroup = [1, 1, 2, 1, 0]
+// MI300X-SAME:     reduction = [0, 0, 0, 0, 4]
+// MI300X-SAME:     subgroup = [1, 1, 1, 1, 0]
 // MI300X-SAME:     workgroup = [1, 1, 32, 64, 0]
 
 // PAD-CONV-GFX942:     padding_conv = [2, 2, 32, 64, 0, 0]
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
index 4a4a14f..5101daf 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
@@ -37,7 +37,7 @@
 }
 
 // CHECK-LABEL: func.func @expanded_matmul_transpose_b
-//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
+//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [512, 1, 1] subgroup_size = 64
 //  CHECK-SAME:   #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>
 
 // Verify that the fill does not have the lowering config propagated to it.
@@ -47,8 +47,8 @@
 //  CHECK-SAME:     mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
 //  CHECK-SAME:     promote_operands = [0, 1]
 //  CHECK-SAME:     reduction = [0, 0, 0, 0, 4]
-//  CHECK-SAME:     subgroup = [1, 1, 4, 1, 0]
-//  CHECK-SAME:     workgroup = [1, 1, 64, 64, 0]
+//  CHECK-SAME:     subgroup = [1, 2, 2, 1, 0]
+//  CHECK-SAME:     workgroup = [1, 2, 64, 64, 0]
 
 //        LATE:  LLVMGPUVectorDistribute
 
@@ -77,7 +77,7 @@
 }
 
 // CHECK-LABEL: func.func @multi_dim_mma_schedule
-//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
+//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [512, 1, 1] subgroup_size = 64
 //  CHECK-SAME:   #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>
 
 //       CHECK:   linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
@@ -85,7 +85,7 @@
 //  CHECK-SAME:     promote_operands = [0, 1]
 //  CHECK-SAME:     reduction = [0, 0, 0, 0, 4, 1]
 //  CHECK-SAME:     subgroup = [2, 2, 1, 1, 0, 0]
-//  CHECK-SAME:     workgroup = [2, 2, 32, 32, 0, 0]
+//  CHECK-SAME:     workgroup = [2, 4, 32, 32, 0, 0]
 
 //        LATE:  LLVMGPUVectorDistribute
 
@@ -140,7 +140,7 @@
 }
 
 // CHECK-LABEL: func.func @mfma_matmul_1024x1024x1024
-//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
+//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [512, 1, 1] subgroup_size = 64
 //  CHECK-SAME:   #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>
 
 // Verify that the fill does not have the lowering config propagated to it.
@@ -149,9 +149,9 @@
 //       CHECK:   linalg.matmul {{.*}}lowering_config = #iree_gpu.lowering_config
 //  CHECK-SAME:     mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
 //  CHECK-SAME:     promote_operands = [0, 1]
-//  CHECK-SAME:     reduction = [0, 0, 2]
-//  CHECK-SAME:     subgroup = [4, 4, 0]
-//  CHECK-SAME:     workgroup = [128, 128, 0]
+//  CHECK-SAME:     reduction = [0, 0, 4]
+//  CHECK-SAME:     subgroup = [2, 2, 0]
+//  CHECK-SAME:     workgroup = [128, 64, 0]
 
 //        LATE:  LLVMGPUVectorDistribute
 
@@ -386,12 +386,12 @@
 }
 
 // CHECK-LABEL: func.func @aligned_dynamic_matmul_with_two_reduce_dim
-// CHECK-SAME:  {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [128, 1, 1] subgroup_size = 64
+// CHECK-SAME:  {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
 // CHECK:       linalg.generic
 // CHECK-SAME:  {lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>
 // CHECK-SAME:  promote_operands = [0, 1]
 // CHECK-SAME:  reduction = [0, 1, 0, 4],
-// CHECK-SAME:  subgroup = [2, 0, 1, 0],
+// CHECK-SAME:  subgroup = [1, 0, 1, 0],
 // CHECK-SAME:  workgroup = [64, 0, 16, 0]}
 
 // -----
@@ -439,13 +439,13 @@
 // schedule with nTileSize of 16 while in reality it should be 8.
 
 // LATE-LABEL: func.func @unaligned_to_intrinsic_batched_matmul_tiling_check
-// LATE-SAME:    #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
+// LATE-SAME:    #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [512, 1, 1] subgroup_size = 64
 // LATE-SAME:    {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>}
 //      LATE:    linalg.batch_matmul {{.*}}lowering_config = #iree_gpu.lowering_config
 //  LATE-SAME:     padding = [1, 16, 512, 4]
 //  LATE-SAME:     promote_operands = [0, 1, 2]
 //  LATE-SAME:     reduction = [0, 0, 0, 1]
-//  LATE-SAME:     subgroup = [0, 1, 8, 0]
+//  LATE-SAME:     subgroup = [0, 1, 4, 0]
 //  LATE-SAME:     workgroup = [1, 16, 512, 0]
 
 // -----
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir
index ae150b4..120661d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir
@@ -25,14 +25,14 @@
 }
 
 // CHECK-LABEL: func.func @scaled_matmul
-//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
+//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [512, 1, 1] subgroup_size = 64
 //  CHECK-SAME:   #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = true, use_igemm_convolution = false>
 //       CHECK:   linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
 //  CHECK-SAME:     mma_kind = #iree_gpu.scaled_mma_layout<intrinsic = MFMA_SCALE_F32_16x16x128_B32, lhs_elem_type = f4E2M1FN, rhs_elem_type = f4E2M1FN, acc_elem_type = f32>
 //  CHECK-SAME:     promote_operands = [0, 1]
 //  CHECK-SAME:     reduction = [0, 0, 8, 1]
-//  CHECK-SAME:     subgroup = [4, 4, 0, 0]
-//  CHECK-SAME:     workgroup = [128, 128, 0, 0]
+//  CHECK-SAME:     subgroup = [2, 2, 0, 0]
+//  CHECK-SAME:     workgroup = [128, 64, 0, 0]
 
 // -----
 
@@ -58,14 +58,14 @@
 }
 
 // CHECK-LABEL: func.func @scaled_matmul_with_batch
-//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
+//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [512, 1, 1] subgroup_size = 64
 //  CHECK-SAME:   #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = true, use_igemm_convolution = false>
 //       CHECK:   linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
 //  CHECK-SAME:     mma_kind = #iree_gpu.scaled_mma_layout<intrinsic = MFMA_SCALE_F32_16x16x128_B32, lhs_elem_type = f4E2M1FN, rhs_elem_type = f4E2M1FN, acc_elem_type = f32>
 //  CHECK-SAME:     promote_operands = [0, 1]
 //  CHECK-SAME:     reduction = [0, 0, 0, 8, 1]
-//  CHECK-SAME:     subgroup = [0, 4, 4, 0, 0]
-//  CHECK-SAME:     workgroup = [1, 128, 128, 0, 0]
+//  CHECK-SAME:     subgroup = [0, 2, 2, 0, 0]
+//  CHECK-SAME:     workgroup = [1, 128, 64, 0, 0]
 
 // -----
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matmul.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matmul.mlir
index ca7a600..3c0e602 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matmul.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matmul.mlir
@@ -69,11 +69,11 @@
 !TC = tensor<4096x4096xf32>
 !DTC = !iree_tensor_ext.dispatch.tensor<readwrite:tensor<4096x4096xf32>>
 //      CHECK:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse
-// CHECK-SAME:   workgroup_size = [256, 1, 1] subgroup_size = 64
+// CHECK-SAME:   workgroup_size = [128, 1, 1] subgroup_size = 64
 // CHECK-SAME:   {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>}>
 func.func @matmul_4096_1_4096(%arg0: !TA, %arg1: !TB, %arg2: !TC, %arg3: !DTC) {
   //      CHECK: #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>,
-  // CHECK-SAME: padding = [64, 128, 4], promote_operands = [0, 1, 2], reduction = [0, 0, 1], subgroup = [2, 4, 0], workgroup = [64, 128, 0]}
+  // CHECK-SAME: padding = [32, 32, 4], promote_operands = [0, 1, 2], reduction = [0, 0, 1], subgroup = [1, 2, 0], workgroup = [32, 32, 0]}
   %0 = linalg.matmul ins(%arg0, %arg1 : !TA, !TB) outs(%arg2 : !TC) -> !TC
   iree_tensor_ext.dispatch.tensor.store %0, %arg3, offsets = [0, 0], sizes = [4096, 4096], strides = [1, 1] : !TC -> !DTC
   return