[vulkan] Update default RDNA GPU subgroup size to 32 (#18207)
Signed-off-by: nithinsubbiah <nithinsubbiah@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.h b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.h
index aa5d22b..6030dd9 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.h
@@ -98,8 +98,7 @@
// Distributes vector ops to all threads/warps in a GPU workgroup.
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createConvertVectorReductionToGPUPass(bool expandSubgroupReduction = true,
- bool pickLargestSubroupSize = false);
+createConvertVectorReductionToGPUPass(bool expandSubgroupReduction = true);
enum class ReorderWorkgroupsStrategy { None, Swizzle, Transpose };
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp
index 20506a0..b1a6560 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp
@@ -193,10 +193,8 @@
struct VectorReductionToGPUPass final
: impl::VectorReductionToGPUPassBase<VectorReductionToGPUPass> {
- VectorReductionToGPUPass(bool expandSubgroupReduction,
- bool pickLargestSubroupSize)
- : expandSubgroupReduction(expandSubgroupReduction),
- pickLargestSubroupSize(pickLargestSubroupSize) {}
+ VectorReductionToGPUPass(bool expandSubgroupReduction)
+ : expandSubgroupReduction(expandSubgroupReduction) {}
void runOnOperation() override {
FunctionOpInterface funcOp = getOperation();
@@ -258,8 +256,7 @@
// 4. Distribute transfer write operations and propagate vector
// distribution.
{
- std::optional<int> subgroupSize =
- getGPUSubgroupSize(funcOp, pickLargestSubroupSize);
+ std::optional<int> subgroupSize = getGPUSubgroupSize(funcOp);
if (!subgroupSize) {
funcOp->emitOpError("missing subgroup size");
return signalPassFailure();
@@ -316,16 +313,13 @@
private:
bool expandSubgroupReduction;
- bool pickLargestSubroupSize;
};
} // namespace
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createConvertVectorReductionToGPUPass(bool expandSubgroupReduction,
- bool pickLargestSubroupSize) {
- return std::make_unique<VectorReductionToGPUPass>(expandSubgroupReduction,
- pickLargestSubroupSize);
+createConvertVectorReductionToGPUPass(bool expandSubgroupReduction) {
+ return std::make_unique<VectorReductionToGPUPass>(expandSubgroupReduction);
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
index 4c453a4..c3514d6 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
@@ -374,17 +374,14 @@
return *llvm::max_element(getWgp().getSubgroupSizeChoices().asArrayRef());
}
// Returns the preferred subgroup size. If the target supports multiple
- // subgroup sizes, pickLargest controls whether to return the largest one.
+ // subgroup sizes, pick the smallest one.
//
// AMD RDNA GPUs supports multiple subgroup sizes and the preferred one
// differ given the API--HIP prefers 32 while Vulkan prefers 64.
- // TODO: We should be able to force Vulkan side to use 32 consistently
- // too with subgroup size control; it might have perf implications though.
- int getPreferredSubgroupSize(bool pickLargest=false) const {
- if (pickLargest) {
- return getMaxSubgroupSize();
- }
- return getMinSubgroupSize();
+ // We force Vulkan side to use 32 to be consistent with the HIP backend;
+ // might have implications on perf.
+ int getPreferredSubgroupSize() const {
+ return *llvm::min_element(getWgp().getSubgroupSizeChoices().asArrayRef());
}
// Hardware feature related APIs
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 7fce1e4..f71f616 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -889,7 +889,7 @@
// vector -> simt gpu + vector
funcPassManager.addPass(createConvertVectorReductionToGPUPass(
- /*expandSubgroupReduction=*/true, /*pickLargestSubgroupSize=*/false));
+ /*expandSubgroupReduction=*/true));
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
}
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp
index ee74582..8ae0148 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp
@@ -37,7 +37,7 @@
AMDCoopMatrixSoftwarePipelineStoreStage)))
return success();
- int subgroupSize = target.getPreferredSubgroupSize(/*pickLargest=*/true);
+ int subgroupSize = target.getPreferredSubgroupSize();
const std::array<int64_t, 2> workgroupXY = {subgroupSize / 2, 8};
std::array<int64_t, 3> threadMNK;
auto inputType =
@@ -67,7 +67,7 @@
LogicalResult setAMDCodeGenConfig(IREE::GPU::TargetAttr target,
Operation *rootOp) {
- int subgroupSize = target.getPreferredSubgroupSize(/*pickLargest=*/true);
+ int subgroupSize = target.getPreferredSubgroupSize();
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp)) {
if (isMatmulOrBatchMatmul(linalgOp))
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index a226f79..5683659 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -702,7 +702,7 @@
llvm::dbgs() << ")\n";
});
- int subgroupSize = target.getPreferredSubgroupSize(/*pickLargest=*/true);
+ int subgroupSize = target.getPreferredSubgroupSize();
const int maxBytes = target.getWgp().getMaxWorkgroupMemoryBytes();
// We want a 2-stage pipeline without multi-buffering if the depth is 0 to
@@ -908,7 +908,7 @@
// AMD RDNA architectures supports both wave32 and wave64 modes. Prefer to use
// wave32 mode for better performance.
- int64_t subgroupSize = target.getPreferredSubgroupSize(/*pickLargest=*/false);
+ int64_t subgroupSize = target.getPreferredSubgroupSize();
// Infer if lhs or rhs is transposed to help generate better schedule.
SmallVector<AffineMap> maps = op.getIndexingMapsArray();
@@ -999,7 +999,7 @@
static LogicalResult setFftOpConfig(IREE::GPU::TargetAttr target,
IREE::LinalgExt::FftOp op) {
LLVM_DEBUG(llvm::dbgs() << "trying to deduce config as fft...\n");
- int subgroupSize = target.getPreferredSubgroupSize(/*pickLargest=*/true);
+ int subgroupSize = target.getPreferredSubgroupSize();
auto pipeline = CodeGenPipeline::SPIRVBaseDistribute;
std::array<int64_t, 3> workgroupSize = {subgroupSize, 1, 1};
@@ -1121,7 +1121,7 @@
if (!foundSingleReductionOutput)
return failure();
- int subgroupSize = target.getPreferredSubgroupSize(/*pickLargest=*/true);
+ int subgroupSize = target.getPreferredSubgroupSize();
// Tile all the parallel dimension to 1.
SmallVector<unsigned> partitionedLoops =
@@ -1281,7 +1281,7 @@
funcOp, op, TileSizesListType{}, pipeline, workgroupSize);
}
- int subgroupSize = target.getPreferredSubgroupSize(/*pickLargest=*/true);
+ int subgroupSize = target.getPreferredSubgroupSize();
const unsigned loopDepth = partitionedLoops.back() + 1;
// Configurations we need to decide.
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index 23e4fe5..4b577c0 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -594,7 +594,7 @@
// Handle vector reduction operations specifically.
funcPassManager.addPass(createConvertVectorReductionToGPUPass(
- /*expandSubgroupReduction=*/false, /*pickLargestSubgroupSize=*/true));
+ /*expandSubgroupReduction=*/false));
// Perform normal vector unrolling and lowering transformations. This breaks
// vectors down to native machine size.
addSPIRVVectorLoweringPasses(funcPassManager);
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVConvertGPUTarget.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVConvertGPUTarget.cpp
index ee78950..8a6007d 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVConvertGPUTarget.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVConvertGPUTarget.cpp
@@ -199,10 +199,7 @@
spirv::ScopeAttr::get(context, spirv::Scope::Subgroup)));
}
- // This is mostly to match RDNA behavior on Vulkan--RDNA supports either 32 or
- // 64 as subgroup sizes; the default subgroup size is 64.
- const int preferredSubgroupSize =
- target.getPreferredSubgroupSize(/*pickLargest=*/true);
+ const int preferredSubgroupSize = target.getPreferredSubgroupSize();
return spirv::ResourceLimitsAttr::get(
context, wgp.getMaxWorkgroupMemoryBytes(),
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp
index 296bf8a..48943df 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp
@@ -185,8 +185,7 @@
SmallVector<int64_t> &workgroupSize = maybeWorkgroupSize.value();
int64_t totalThreads = workgroupSize[0] * workgroupSize[1] * workgroupSize[2];
- std::optional<int> subgroupSize =
- getGPUSubgroupSize(funcOp, /*pickLargest=*/true);
+ std::optional<int> subgroupSize = getGPUSubgroupSize(funcOp);
if (!subgroupSize) {
funcOp.emitError("failed to query subgroup size");
return signalPassFailure();
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp
index 0161a4a..23dd233 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp
@@ -366,8 +366,7 @@
// Then tile and distribute to subgroups.
{
- std::optional<int> subgroupSize =
- getGPUSubgroupSize(funcOp, /*pickLargest=*/true);
+ std::optional<int> subgroupSize = getGPUSubgroupSize(funcOp);
if (!subgroupSize) {
funcOp.emitError("failed to query subgroup size");
return signalPassFailure();
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Verifiers.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Verifiers.cpp
index 12ccd8f..eb92f44 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Verifiers.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Verifiers.cpp
@@ -56,8 +56,7 @@
LLVM_DEBUG(llvm::dbgs() << "target: " << target << "\n");
auto funcOp = op->getParentOfType<mlir::FunctionOpInterface>();
- std::optional<int> subgroupSize =
- getGPUSubgroupSize(funcOp, /*pickLargest=*/true);
+ std::optional<int> subgroupSize = getGPUSubgroupSize(funcOp);
if (!subgroupSize)
return funcOp->emitError("failed to query subgroup size");
const int maxThreads = target.getWgp().getMaxThreadCountPerWorkgroup();
@@ -169,8 +168,7 @@
LLVM_DEBUG(llvm::dbgs() << "target: " << target << "\n");
auto funcOp = op->getParentOfType<mlir::FunctionOpInterface>();
- std::optional<int> subgroupSize =
- getGPUSubgroupSize(funcOp, /*pickLargest=*/true);
+ std::optional<int> subgroupSize = getGPUSubgroupSize(funcOp);
if (!subgroupSize)
return funcOp->emitError("failed to query subgroup size");
const int maxThreads = target.getWgp().getMaxThreadCountPerWorkgroup();
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_conv.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_conv.mlir
index 6f408be..2fa2c4d 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_conv.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_conv.mlir
@@ -31,8 +31,8 @@
return
}
-// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 64, 64], [1, 1, 8, 8], [0, 0, 0, 0, 1, 1, 8], [0, 1, 0, 0]{{\]}}>
-// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVBaseVectorize workgroup_size = [8, 8, 1]>
+// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 4, 4, 64], [1, 2, 2, 8], [0, 0, 0, 0, 1, 1, 8], [0, 1, 0, 0]{{\]}}>
+// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVBaseVectorize workgroup_size = [8, 2, 2]>
// CHECK: func.func @nhwc_conv_pointwise_2x64x64x320()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: linalg.conv_2d_nhwc_hwcf
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matmul.mlir
index c0d891b..b986068 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matmul.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matmul.mlir
@@ -20,8 +20,8 @@
return
}
-// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 512, 8, 16]{{\]}}>
-// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [2, 64, 1], {pipeline_depth = 1 : i64, store_stage = 1 : i64}>
+// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 256, 8, 32]{{\]}}>
+// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [2, 32, 1], {pipeline_depth = 1 : i64, store_stage = 1 : i64}>
// CHECK: func.func @batch_matmul_f32_16x4096x40x4096()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: linalg.batch_matmul
@@ -53,7 +53,7 @@
}
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 128, 32]{{\]}}>
-// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [16, 16, 1], {pipeline_depth = 2 : i64, store_stage = 0 : i64}>
+// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [16, 8, 1], {pipeline_depth = 2 : i64, store_stage = 0 : i64}>
// CHECK: func.func @matmul_f16_64x640x320()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: linalg.matmul
@@ -82,8 +82,8 @@
return
}
-// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 256, 16, 32]{{\]}}>
-// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [4, 32, 1], {pipeline_depth = 1 : i64, store_stage = 1 : i64}>
+// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 128, 16, 32]{{\]}}>
+// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [4, 16, 1], {pipeline_depth = 2 : i64, store_stage = 0 : i64}>
// CHECK: func.func @batch_matmul_f32_16x4096x40x4096()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: linalg.batch_matmul
@@ -120,8 +120,8 @@
return
}
-// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 64, 256, 32]{{\]}}>
-// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [32, 8, 1], {pipeline_depth = 1 : i64, store_stage = 1 : i64}>
+// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 64, 128, 32]{{\]}}>
+// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [16, 8, 1], {pipeline_depth = 1 : i64, store_stage = 1 : i64}>
// CHECK: func.func @batch_matmul_f16_1x4096x4096x512()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: linalg.batch_matmul
@@ -184,8 +184,8 @@
return
}
-// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 128, 1, 16]{{\]}}>
-// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [32, 8, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
+// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 1, 16]{{\]}}>
+// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [16, 8, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
// CHECK: func.func @matmul_multi_reduce_i4xf32xf32()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: linalg.generic
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matmul_cooperative_ops.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matmul_cooperative_ops.mlir
index c0ac97c..c07fe95 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matmul_cooperative_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matmul_cooperative_ops.mlir
@@ -190,6 +190,6 @@
return
}
-// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [32, 8, 1], {pipeline_depth = 1 : i64, store_stage = 1 : i64}>
+// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize workgroup_size = [16, 8, 1], {pipeline_depth = 1 : i64, store_stage = 1 : i64}>
// CHECK-LABEL: func.func @matmul_256x1024x8
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_gpu_target.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_gpu_target.mlir
index eeaa8fe..eca6f4a 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_gpu_target.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_gpu_target.mlir
@@ -31,7 +31,7 @@
// CHECK-SAME: AMD,
// CHECK-SAME: #spirv.resource_limits<max_compute_shared_memory_size = 65536,
// CHECK-SAME: max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024 : i32, 1024 : i32, 1024 : i32],
-// CHECK-SAME: subgroup_size = 64, min_subgroup_size = 32, max_subgroup_size = 64,
+// CHECK-SAME: min_subgroup_size = 32, max_subgroup_size = 64,
// CHECK-SAME: cooperative_matrix_properties_khr = [
// CHECK-SAME: #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, acc_sat = false, scope = <Subgroup>>,
// CHECK-SAME: #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>
diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
index dd3ca2a..61ea846 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
@@ -998,15 +998,14 @@
return getCLGPUTarget(op->getContext());
}
-std::optional<int> getGPUSubgroupSize(mlir::FunctionOpInterface func,
- bool pickLargest) {
+std::optional<int> getGPUSubgroupSize(mlir::FunctionOpInterface func) {
// First try to see if there is a subgroup size chosen in the CodeGen pipeline
// configuration.
if (std::optional<int64_t> subgroupSize = getSubgroupSize(func))
return subgroupSize.value();
// Then try to find the subgroup size from the target description.
if (IREE::GPU::TargetAttr target = getGPUTargetAttr(func))
- return target.getPreferredSubgroupSize(pickLargest);
+ return target.getPreferredSubgroupSize();
return std::nullopt;
}
diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
index 943d6a3..b34209a 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
@@ -158,8 +158,7 @@
/// Returns the GPU subgroup size chosen for the current CodeGen pipeline if
/// exists; otherwise returns the subgroup size from the GPU target description.
/// Returns std::nullopt if none found.
-std::optional<int> getGPUSubgroupSize(mlir::FunctionOpInterface func,
- bool pickLargest);
+std::optional<int> getGPUSubgroupSize(mlir::FunctionOpInterface func);
} // namespace mlir::iree_compiler
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index b4d371d..6cb2702 100644
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -388,7 +388,7 @@
elif compilation_info_id == CompilationInfoId.SPIRVCooperativeMatrixVectorize:
tile_workgroup_size_pairs = [
TileWorkgroupSizePair(
- [[64, 64], [16, 64], [0, 0, 16], [16, 16, 16]], [64, 4, 1]
+ [[64, 128], [32, 64], [0, 0, 32], [16, 16, 16]], [64, 2, 1]
)
]
elif compilation_info_id == CompilationInfoId.SPIRVVectorizeNVIDIA: