Plumb through support for controlling subgroup size in CodeGen (#11388)
This commit plumbs through support for controlling subgroup size (i.e.,
VK_EXT_subgroup_size_control for Vulkan) in CodeGen:
- Added a new subgroup size attribute to HAL ExecutableExportOp, similar
to workgroup size there.
- Extended various helper functions to optionally set subgroup size.
- Defined VK_EXT_subgroup_size_control and updated various target
triples regarding min/max subgroup size.
This commit just thread support through CodeGen flow; it's not wired up
into runtime to use yet, which will happen later.
diff --git a/compiler/src/iree/compiler/Codegen/Common/UserConfig.cpp b/compiler/src/iree/compiler/Codegen/Common/UserConfig.cpp
index 7406981..16cfcaf 100644
--- a/compiler/src/iree/compiler/Codegen/Common/UserConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/UserConfig.cpp
@@ -23,7 +23,10 @@
if (failed(setTranslationInfo(entryPointFn, info))) return failure();
SmallVector<int64_t> workgroupSize = compilationInfo.getWorkgroupSizeVals();
- if (failed(setWorkgroupSize(entryPointFn, workgroupSize))) return failure();
+ llvm::Optional<int64_t> subgroupSize = compilationInfo.getSubgroupSize();
+ if (failed(setDispatchConfig(entryPointFn, workgroupSize, subgroupSize))) {
+ return failure();
+ }
setLoweringConfig(computeOp, compilationInfo.getLoweringConfig());
eraseCompilationInfo(computeOp);
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.cpp b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.cpp
index 30f3d16..ff8d6bd 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.cpp
@@ -190,15 +190,17 @@
CompilationInfoAttr CompilationInfoAttr::get(
MLIRContext *context, LoweringConfigAttr configAttr,
- TranslationInfoAttr translationInfo, ArrayRef<int64_t> workgroupSize) {
+ TranslationInfoAttr translationInfo, ArrayRef<int64_t> workgroupSize,
+ llvm::Optional<int64_t> subgroupSize) {
ArrayAttr workgroupSizeAttr = getI64IntegerArrayAttr(context, workgroupSize);
- return get(context, configAttr, translationInfo, workgroupSizeAttr);
+ return get(context, configAttr, translationInfo, workgroupSizeAttr,
+ subgroupSize);
}
LogicalResult CompilationInfoAttr::verify(
function_ref<InFlightDiagnostic()> emitError,
LoweringConfigAttr loweringConfig, TranslationInfoAttr translationInfo,
- ArrayAttr workgroupSize) {
+ ArrayAttr workgroupSize, llvm::Optional<int64_t> subgroupSize) {
if (!loweringConfig) {
return emitError() << "missing lowering config";
}
@@ -263,13 +265,27 @@
return {};
}
-LogicalResult setWorkgroupSize(func::FuncOp entryPoint,
- ArrayRef<int64_t> workgroupSize) {
+llvm::Optional<int64_t> getSubgroupSize(
+ IREE::HAL::ExecutableExportOp exportOp) {
+ if (IntegerAttr attr = exportOp.getSubgroupSizeAttr()) {
+ return attr.getValue().getSExtValue();
+ }
+ return {};
+}
+
+LogicalResult setDispatchConfig(func::FuncOp entryPoint,
+ ArrayRef<int64_t> workgroupSize,
+ llvm::Optional<int64_t> subgroupSize) {
FailureOr<IREE::HAL::ExecutableExportOp> exportOp = getEntryPoint(entryPoint);
if (failed(exportOp)) return failure();
- if (workgroupSize.empty()) return success();
- auto attr = getIndexIntegerArrayAttr(exportOp->getContext(), workgroupSize);
- exportOp->setWorkgroupSizeAttr(attr);
+ MLIRContext *context = exportOp->getContext();
+ if (!workgroupSize.empty()) {
+ auto attr = getIndexIntegerArrayAttr(context, workgroupSize);
+ exportOp->setWorkgroupSizeAttr(attr);
+ }
+ if (subgroupSize) {
+ exportOp->setSubgroupSizeAttr(Builder(context).getIndexAttr(*subgroupSize));
+ }
return success();
}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.h b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.h
index 9bf110c..c08efbd 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.h
@@ -64,11 +64,15 @@
/// Returns the workgroup size specified on the `exportOp`.
SmallVector<int64_t> getWorkgroupSize(IREE::HAL::ExecutableExportOp exportOp);
-/// Sets and overwrites the dispatch workgroup size for the given entry point
-/// function. Returns failure if the given entry point is not exported via
+/// Returns the subgroup size specified on the `exportOp`.
+llvm::Optional<int64_t> getSubgroupSize(IREE::HAL::ExecutableExportOp exportOp);
+
+/// Sets and overwrites the dispatch workgroup/subgroup size for the given entry
+/// point function. Returns failure if the given entry point is not exported via
/// hal.executable.export.
-LogicalResult setWorkgroupSize(func::FuncOp entryPoint,
- ArrayRef<int64_t> workgroupSize);
+LogicalResult setDispatchConfig(func::FuncOp entryPoint,
+ ArrayRef<int64_t> workgroupSize,
+ llvm::Optional<int64_t> subgroupSize);
/// Sets and overwites the translate executable info for the given entry point.
/// Returns failure if the given entry point is not exported via
@@ -133,12 +137,15 @@
inline LogicalResult setOpConfigAndEntryPointFnTranslation(
func::FuncOp entryPointFn, Operation *op, TileSizesListTypeRef tileSizes,
IREE::Codegen::DispatchLoweringPassPipeline passPipeline,
- ArrayRef<int64_t> workgroupSize = {}, unsigned softwarePipelineDepth = 0,
+ ArrayRef<int64_t> workgroupSize = {},
+ llvm::Optional<int64_t> subgroupSize = {},
+ unsigned softwarePipelineDepth = 0,
unsigned softwarePipelineStoreStage = 1) {
MLIRContext *context = entryPointFn.getContext();
auto config = IREE::Codegen::LoweringConfigAttr::get(context, tileSizes);
setLoweringConfig(op, config);
- if (failed(setWorkgroupSize(entryPointFn, workgroupSize))) return failure();
+ if (failed(setDispatchConfig(entryPointFn, workgroupSize, subgroupSize)))
+ return failure();
return setTranslationInfo(entryPointFn, passPipeline, softwarePipelineDepth,
softwarePipelineStoreStage);
}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.td b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.td
index 51b291b..03d7fed 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.td
@@ -218,19 +218,24 @@
AttrParameter<"LoweringConfigAttr", "">:$loweringConfig,
AttrParameter<"TranslationInfoAttr", "">:$translationInfo,
DefaultValuedParameter<"ArrayAttr", "ArrayAttr::get($_ctxt, {})",
- "The workgroup size to use during translation.">:$workgroupSize
+ "The workgroup size to use during translation.">:$workgroupSize,
+ OptionalParameter<"llvm::Optional<int64_t>",
+ "The subgroup size to use during translation.">:$subgroupSize
);
let assemblyFormat = [{
`<` `lowering_config` `=` $loweringConfig `,` `translation_info` `=` $translationInfo
- (`,` `workgroup_size` `=` $workgroupSize^)? `>`
+ (`,` `workgroup_size` `=` $workgroupSize^)?
+ (`,` `subgroup_size` `=` $subgroupSize^)? `>`
}];
// The builder is externally for auto-tuner to generate the attributes.
let builders = [
AttrBuilder<(ins "LoweringConfigAttr":$configAttr,
"TranslationInfoAttr":$translationInfo,
- "ArrayRef<int64_t>":$workgroupSize)>,
+ "ArrayRef<int64_t>":$workgroupSize,
+ "llvm::Optional<int64_t>":$subgroupSize
+ )>,
];
let extraClassDeclaration = [{
SmallVector<int64_t> getWorkgroupSizeVals();
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/test/lowering_config_attr.mlir b/compiler/src/iree/compiler/Codegen/Dialect/test/lowering_config_attr.mlir
index 8d9b696..ff9a552 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/test/lowering_config_attr.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/test/lowering_config_attr.mlir
@@ -40,3 +40,17 @@
}
}
// CHECK: #compilation = #iree_codegen.compilation_info<lowering_config = <tile_sizes = []>, translation_info = <CPUDefault>>
+
+// -----
+
+module {
+ func.func @test() attributes {
+ compilation_info = #iree_codegen.compilation_info<
+ lowering_config = <tile_sizes = []>,
+ translation_info = <CPUDefault>,
+ workgroup_size = [16, 4, 1],
+ subgroup_size = 32>} {
+ return
+ }
+}
+// CHECK: #compilation = #iree_codegen.compilation_info<lowering_config = <tile_sizes = []>, translation_info = <CPUDefault>, workgroup_size = [16, 4, 1], subgroup_size = 32>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 6ab3c9c..6fc9a33 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -171,9 +171,9 @@
tileSizes.emplace_back(
std::move(workgroupTileSizes)); // Workgroup level.
- return setOpConfigAndEntryPointFnTranslation(entryPoint, op, tileSizes,
- pipeline, workgroupSize,
- softwarePipelineDepth);
+ return setOpConfigAndEntryPointFnTranslation(
+ entryPoint, op, tileSizes, pipeline, workgroupSize,
+ /*subgroupSize=*/llvm::None, softwarePipelineDepth);
};
// Infer the MxN size of the matmul based on operands and indexing maps.
auto lhsShape =
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
index f852273..547fc66 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
@@ -335,11 +335,14 @@
"expected workgroup_size attribute to be set for SPIR-V lowering");
return signalPassFailure();
}
+ Optional<int64_t> subgroupSize = getSubgroupSize(exportOp);
auto workgroupSize32 = llvm::to_vector<4>(llvm::map_range(
workgroupSize, [](int64_t v) { return static_cast<int32_t>(v); }));
+ Optional<int> subgroupSize32;
+ if (subgroupSize) subgroupSize32 = *subgroupSize;
funcOp->setAttr(
spirv::getEntryPointABIAttrName(),
- spirv::getEntryPointABIAttr(context, workgroupSize32, llvm::None));
+ spirv::getEntryPointABIAttr(context, workgroupSize32, subgroupSize32));
}
spirv::TargetEnvAttr targetAttr = getSPIRVTargetEnvAttr(moduleOp);
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index 1661520..bae4e49 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -666,7 +666,7 @@
return setOpConfigAndEntryPointFnTranslation(
op->getParentOfType<func::FuncOp>(), op, tileSizes,
CodeGenPipeline::SPIRVMatmulPromoteVectorize, workgroupSize,
- pipelineDepth, storeStage);
+ /*subgroupSize=*/llvm::None, pipelineDepth, storeStage);
}
return setOpConfigAndEntryPointFnTranslation(
@@ -834,9 +834,9 @@
return v.getType().cast<ShapedType>().getElementType();
};
- spirv::ResourceLimitsAttr resourceLimits = targetEnv.getResourceLimits();
+ spirv::ResourceLimitsAttr limits = targetEnv.getResourceLimits();
Optional<CooperativeMatrixSize> coopMatSize = getCooperativeMatrixSize(
- resourceLimits, numSubgroupsPerWorkgroup, numMNTilesPerSubgroup,
+ limits, numSubgroupsPerWorkgroup, numMNTilesPerSubgroup,
getElementType(lhs), getElementType(rhs), getElementType(init), dimM,
dimN, dimK);
if (!coopMatSize) return success();
@@ -845,7 +845,7 @@
SPIRVCooperativeMatrixVectorize;
std::array<int64_t, 3> workgroupSize{
- coopMatSize->nWarpCount * resourceLimits.getSubgroupSize(),
+ coopMatSize->nWarpCount * limits.getSubgroupSize(),
coopMatSize->mWarpCount, 1};
SmallVector<int64_t> vectorSizes(kIndex + 1, 0);
@@ -880,9 +880,11 @@
tileSizes.push_back(reductionTileSizes);
tileSizes.push_back(vectorSizes);
+ Optional<int64_t> subgroupSize = limits.getSubgroupSize();
+
return setOpConfigAndEntryPointFnTranslation(
op->getParentOfType<func::FuncOp>(), op, tileSizes, pipeline,
- workgroupSize);
+ workgroupSize, subgroupSize);
}
} // namespace detail
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matmul_cooperative_ops.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matmul_cooperative_ops.mlir
index 04da9fd..49ba3f7 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matmul_cooperative_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_amd_matmul_cooperative_ops.mlir
@@ -71,6 +71,7 @@
// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 128], [32, 64], [0, 0, 32], [16, 16, 16]{{\]}}>
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize>
//CHECK-LABEL: hal.executable.export public @matmul_256x1024x128_div_add
+// CHECK-SAME: subgroup_size = 64 : index
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK-SAME: workgroup_size = [128 : index, 2 : index, 1 : index]
// CHECK: func.func @matmul_256x1024x128_div_add()
@@ -139,6 +140,7 @@
// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 64, 128], [1, 32, 64], [0, 0, 0, 32], [1, 16, 16, 16]{{\]}}>
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize>
//CHECK-LABEL: hal.executable.export public @batch_matmul_16x128x256x512_div
+// CHECK-SAME: subgroup_size = 64 : index
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK-SAME: workgroup_size = [128 : index, 2 : index, 1 : index]
// CHECK: func.func @batch_matmul_16x128x256x512_div()
@@ -205,6 +207,7 @@
// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 64, 128], [1, 32, 64], [0, 0, 0, 32], [1, 16, 16, 16]{{\]}}>
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize>
//CHECK-LABEL: hal.executable.export public @generic_batch_matmul_32x8x512x64
+// CHECK-SAME: subgroup_size = 64 : index
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK-SAME: workgroup_size = [128 : index, 2 : index, 1 : index]
// CHECK: func.func @generic_batch_matmul_32x8x512x64()
@@ -265,6 +268,7 @@
// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 64, 128], [1, 32, 64], [0, 0, 0, 16], [1, 16, 16, 16]{{\]}}>
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize>
//CHECK-LABEL: hal.executable.export public @batch_matmul_16x1024x1024x80
+// CHECK-SAME: subgroup_size = 64 : index
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK-SAME: workgroup_size = [128 : index, 2 : index, 1 : index]
// CHECK: func.func @batch_matmul_16x1024x1024x80()
@@ -325,4 +329,6 @@
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize pipeline_depth = 1>
// CHECK-LABEL: hal.executable.export public @matmul_256x1024x8
+// CHECK-NOT: subgroup_size =
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
+// CHECK-NOT: subgroup_size =
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul_cooperative_ops.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul_cooperative_ops.mlir
index 97441b8..16a0962 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul_cooperative_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul_cooperative_ops.mlir
@@ -77,6 +77,7 @@
// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64], [32, 32], [0, 0, 32], [16, 16, 16]{{\]}}>
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize>
//CHECK-LABEL: hal.executable.export public @matmul_256x1024x128_div_add
+// CHECK-SAME: subgroup_size = 32 : index
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK-SAME: workgroup_size = [64 : index, 2 : index, 1 : index]
// CHECK: func.func @matmul_256x1024x128_div_add()
@@ -151,6 +152,7 @@
// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 64, 64], [1, 32, 32], [0, 0, 0, 32], [1, 16, 16, 16]{{\]}}>
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize>
//CHECK-LABEL: hal.executable.export public @batch_matmul_16x128x256x512_div
+// CHECK-SAME: subgroup_size = 32 : index
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK-SAME: workgroup_size = [64 : index, 2 : index, 1 : index]
// CHECK: func.func @batch_matmul_16x128x256x512_div()
@@ -223,6 +225,7 @@
// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 64, 64], [1, 32, 32], [0, 0, 0, 32], [1, 16, 16, 16]{{\]}}>
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize>
//CHECK-LABEL: hal.executable.export public @generic_batch_matmul_32x8x512x64
+// CHECK-SAME: subgroup_size = 32 : index
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK-SAME: workgroup_size = [64 : index, 2 : index, 1 : index]
// CHECK: func.func @generic_batch_matmul_32x8x512x64()
@@ -289,6 +292,7 @@
// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 64, 64], [1, 32, 32], [0, 0, 0, 16], [1, 16, 16, 16]{{\]}}>
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVCooperativeMatrixVectorize>
//CHECK-LABEL: hal.executable.export public @batch_matmul_16x1024x1024x80
+// CHECK-SAME: subgroup_size = 32 : index
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK-SAME: workgroup_size = [64 : index, 2 : index, 1 : index]
// CHECK: func.func @batch_matmul_16x1024x1024x80()
@@ -355,4 +359,6 @@
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVMatmulPromoteVectorize pipeline_depth = 1>
// CHECK-LABEL: hal.executable.export public @matmul_256x1024x8
+// CHECK-NOT: subgroup_size =
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
+// CHECK-NOT: subgroup_size =
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_user.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_user.mlir
index acafa68..c73503d 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_user.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_user.mlir
@@ -3,7 +3,7 @@
#compilation = #iree_codegen.compilation_info<
lowering_config = <tile_sizes = [[128, 256], [16, 16]]>,
translation_info = <SPIRVBaseVectorize>,
- workgroup_size = [16, 8, 1]>
+ workgroup_size = [16, 8, 1], subgroup_size = 64>
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
@@ -47,6 +47,7 @@
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[128, 256], [16, 16]{{\]}}>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVBaseVectorize>
// CHECK: hal.executable.export public @matmul_128x1024x256
+// CHECK-SAME: subgroup_size = 64 : index
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK-SAME: workgroup_size = [16 : index, 8 : index, 1 : index]
// CHECK: linalg.matmul
diff --git a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp
index 3ebe73e..31f885d 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp
@@ -235,7 +235,9 @@
linkedTargetBuilder.create<IREE::HAL::ExecutableExportOp>(
exportOp.getLoc(), exportOp.getSymNameAttr(),
builder.getIndexAttr(nextEntryPointOrdinal++),
- exportOp.getLayout(), ArrayAttr{}, IntegerAttr{});
+ exportOp.getLayout(), /*workgroup_size=*/ArrayAttr{},
+ /*subgroup_size=*/IntegerAttr{},
+ /*workgroup_local_memory=*/IntegerAttr{});
newExportOp->setDialectAttrs(exportOp->getDialectAttrs());
// Add to replacement table for fixing up dispatch calls referencing
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td
index 55eee98..c9b65f4 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td
@@ -412,6 +412,8 @@
let constBuilderCall = "$_builder.getIndexArrayAttr($0)";
}
+def HAL_SubgroupSizeAttr : Util_IndexAttrBase<"size_t">;
+
// A bitmask defining which queues an operation is allowed to execute on.
// The selection is wrapped to the total number of available queues, so 0b0101
// would enable queues 0 and 2 if there were four queues or queue 0 if there
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
index 953b3e8..3a9012b 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -1589,6 +1589,7 @@
OptionalAttr<HAL_OrdinalAttr>:$ordinal,
HAL_PipelineLayoutAttr:$layout,
OptionalAttr<HAL_WorkgroupSizeAttr>:$workgroup_size,
+ OptionalAttr<HAL_SubgroupSizeAttr>:$subgroup_size,
OptionalAttr<IndexAttr>:$workgroup_local_memory
);
@@ -1600,10 +1601,11 @@
"::mlir::IntegerAttr":$ordinal,
"IREE::HAL::PipelineLayoutAttr":$layout,
"::mlir::ArrayAttr":$workgroup_size,
+ "::mlir::IntegerAttr":$subgroup_size,
"::mlir::IntegerAttr":$workgroup_local_memory
), [{
build($_builder, $_state, nullptr, sym_name, ordinal, layout,
- workgroup_size, workgroup_local_memory);
+ workgroup_size, subgroup_size, workgroup_local_memory);
}]>,
];
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir
index f834c88..31da799 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir
@@ -34,6 +34,7 @@
// CHECK: hal.executable.variant public @backend, target = #executable_target_format
hal.executable.variant @backend, target = #executable_target_format {
// CHECK-DAG: hal.executable.export public @entry0 ordinal(0) layout(#pipeline_layout) attributes {
+ // CHECK-SAME: subgroup_size = 64 : index
// CHECK-SAME: workgroup_size = [4 : index, 1 : index, 1 : index]
hal.executable.export @entry0 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
@@ -41,6 +42,7 @@
#hal.descriptor_set.binding<1, storage_buffer>
]>
]>) attributes {
+ subgroup_size = 64 : index,
workgroup_size = [4 : index, 1 : index, 1 : index]
} {
^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index):
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
index e52067d..3c9e9d9 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
@@ -259,7 +259,8 @@
exportOp.getLoc(),
targetBuilder.getStringAttr(exportOp.getFunctionRef()),
targetBuilder.getIndexAttr(ordinal), layoutAttr, ArrayAttr{},
- IntegerAttr{});
+ /*subgroup_size=*/IntegerAttr{},
+ /*workgroup_local_memory=*/IntegerAttr{});
// Clone the workgroup count calculation function.
if (!exportOp.getWorkgroupCount().empty()) {
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.td b/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.td
index c6360cb..ab6f357 100644
--- a/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.td
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.td
@@ -73,6 +73,13 @@
"::mlir::iree_compiler::IREE::Vulkan::SubgroupFeatureAttr":$subgroupFeatures,
"int":$subgroupSize,
+ // VK_EXT_subgroup_size_control features.
+ //
+ // This corresponds to the `VkPhysicalDeviceSubgroupSizeControlProperties` structure:
+ // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkPhysicalDeviceSubgroupSizeControlPropertiesEXT.html
+ OptionalParameter<"::llvm::Optional<int>">:$minSubgroupSize,
+ OptionalParameter<"::llvm::Optional<int>">:$maxSubgroupSize,
+
// VK_KHR_16bit_storage features.
//
// This corresponds to the `VkPhysicalDevice16BitStorageFeatures` structure:
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanBase.td b/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanBase.td
index 844a258..d29b66c 100644
--- a/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanBase.td
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanBase.td
@@ -90,13 +90,15 @@
def VK_KHR_spirv_1_4 : I32EnumAttrCase<"VK_KHR_spirv_1_4", 3>;
def VK_KHR_storage_buffer_storage_class : I32EnumAttrCase<"VK_KHR_storage_buffer_storage_class", 4>;
def VK_KHR_variable_pointers: I32EnumAttrCase<"VK_KHR_variable_pointers", 5>;
-def VK_NV_cooperative_matrix : I32EnumAttrCase<"VK_NV_cooperative_matrix", 6>;
+def VK_EXT_subgroup_size_control : I32EnumAttrCase<"VK_EXT_subgroup_size_control", 6>;
+def VK_NV_cooperative_matrix : I32EnumAttrCase<"VK_NV_cooperative_matrix", 7>;
def VK_ExtensionAttr :
VK_EnumAttr<"Extension", "supported Vulkan extension", [
VK_KHR_16bit_storage, VK_KHR_8bit_storage, VK_KHR_shader_float16_int8,
VK_KHR_spirv_1_4, VK_KHR_storage_buffer_storage_class,
- VK_KHR_variable_pointers, VK_NV_cooperative_matrix
+ VK_KHR_variable_pointers, VK_EXT_subgroup_size_control,
+ VK_NV_cooperative_matrix
]>;
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetEnvironment.cpp b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetEnvironment.cpp
index c65078f..4335026 100644
--- a/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetEnvironment.cpp
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetEnvironment.cpp
@@ -69,6 +69,9 @@
case Extension::VK_KHR_variable_pointers:
extensions.push_back(spirv::Extension::SPV_KHR_variable_pointers);
break;
+ case Extension::VK_EXT_subgroup_size_control:
+ // This extension allows specifying min/max subgroup size.
+ break;
case Extension::VK_NV_cooperative_matrix:
extensions.push_back(spirv::Extension::SPV_NV_cooperative_matrix);
break;
@@ -171,7 +174,7 @@
context, vkCapabilities.getMaxComputeSharedMemorySize(),
vkCapabilities.getMaxComputeWorkGroupInvocations(),
builder.getI64ArrayAttr(sizes), vkCapabilities.getSubgroupSize(),
- /*min_subgroup_size=*/llvm::None, /*max_subgroup_size=*/llvm::None,
+ vkCapabilities.getMinSubgroupSize(), vkCapabilities.getMaxSubgroupSize(),
ArrayAttr::get(context, spvAttrs));
}
} // anonymous namespace
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetTriple.cpp b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetTriple.cpp
index 3b129e6..4a7ca9b 100644
--- a/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetTriple.cpp
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetTriple.cpp
@@ -154,10 +154,6 @@
}
return;
}
- case TargetTripleArch::AMD_RDNAv3: {
- extensions.push_back(Extension::VK_NV_cooperative_matrix);
- break;
- }
default:
break;
}
@@ -181,16 +177,18 @@
}
// Desktop GPUs typically support all extensions we care.
- const std::array<Extension, 6> desktop = {
+ const std::array<Extension, 7> desktop = {
Extension::VK_KHR_16bit_storage,
Extension::VK_KHR_8bit_storage,
Extension::VK_KHR_shader_float16_int8,
Extension::VK_KHR_spirv_1_4,
Extension::VK_KHR_storage_buffer_storage_class,
- Extension::VK_KHR_variable_pointers};
+ Extension::VK_KHR_variable_pointers,
+ Extension::VK_EXT_subgroup_size_control};
extensions.append(desktop.begin(), desktop.end());
- if (getVendor(triple) == spirv::Vendor::NVIDIA) {
+ if (getVendor(triple) == spirv::Vendor::NVIDIA ||
+ triple.getArch() == TargetTripleArch::AMD_RDNAv3) {
extensions.push_back(Extension::VK_NV_cooperative_matrix);
}
}
@@ -212,6 +210,7 @@
int subgroupSize = 32;
SubgroupFeature subgroupFeatures = SubgroupFeature::Basic;
+ Optional<int> minSubgroupSize, maxSubgroupSize;
bool shaderFloat16 = false, shaderFloat64 = false;
bool shaderInt8 = false, shaderInt16 = false, shaderInt64 = false;
@@ -228,6 +227,15 @@
Builder builder(context);
switch (triple.getArch()) {
+ case TargetTripleArch::AMD_RDNAv3: {
+ auto f16t = builder.getF16Type();
+ auto scope = ScopeNVAttr::get(context, ScopeNV::Subgroup);
+ coopmatCases.push_back(CooperativeMatrixPropertiesNVAttr::get(
+ context,
+ /*mSize=*/16, /*nSize=*/16, /*kSize=*/16, /*aType=*/f16t,
+ /*bType=*/f16t, /*cType=*/f16t, /*resultType=*/f16t, scope));
+ }
+ LLVM_FALLTHROUGH;
case TargetTripleArch::AMD_RDNAv1:
case TargetTripleArch::AMD_RDNAv2:
// Example: https://vulkan.gpuinfo.org/displayreport.php?id=10906
@@ -235,7 +243,7 @@
maxComputeWorkGroupInvocations = 1024;
maxComputeWorkGroupSize = {1024, 1024, 1024};
- subgroupSize = 64;
+ subgroupSize = 64, minSubgroupSize = 32, maxSubgroupSize = 64;
subgroupFeatures = SubgroupFeature::Basic | SubgroupFeature::Vote |
SubgroupFeature::Arithmetic | SubgroupFeature::Ballot |
SubgroupFeature::Shuffle |
@@ -252,34 +260,6 @@
variablePointers = variablePointersStorageBuffer = true;
break;
- case TargetTripleArch::AMD_RDNAv3: {
- maxComputeSharedMemorySize = 65536;
- maxComputeWorkGroupInvocations = 1024;
- maxComputeWorkGroupSize = {1024, 1024, 1024};
-
- subgroupSize = 64;
- subgroupFeatures = SubgroupFeature::Basic | SubgroupFeature::Vote |
- SubgroupFeature::Arithmetic | SubgroupFeature::Ballot |
- SubgroupFeature::Shuffle |
- SubgroupFeature::ShuffleRelative |
- SubgroupFeature::Clustered | SubgroupFeature::Quad;
-
- shaderFloat16 = shaderFloat64 = true;
- shaderInt8 = shaderInt16 = shaderInt64 = true;
-
- storageBuffer16BitAccess = storagePushConstant16 = true;
- uniformAndStorageBuffer16BitAccess = true;
- storageBuffer8BitAccess = true, storagePushConstant8 = true;
- uniformAndStorageBuffer8BitAccess = true;
-
- variablePointers = variablePointersStorageBuffer = true;
- auto f16t = builder.getF16Type();
- auto scope = ScopeNVAttr::get(context, ScopeNV::Subgroup);
- coopmatCases.push_back(CooperativeMatrixPropertiesNVAttr::get(
- context,
- /*mSize=*/16, /*nSize=*/16, /*kSize=*/16, /*aType=*/f16t,
- /*bType=*/f16t, /*cType=*/f16t, /*resultType=*/f16t, scope));
- } break;
case TargetTripleArch::Apple_M1:
// Example: https://vulkan.gpuinfo.org/displayreport.php?id=14673
maxComputeSharedMemorySize = 32768;
@@ -349,7 +329,7 @@
maxComputeWorkGroupInvocations = 1024;
maxComputeWorkGroupSize = {1024, 1024, 64};
- subgroupSize = 32;
+ subgroupSize = 32, minSubgroupSize = 32, maxSubgroupSize = 32;
subgroupFeatures = SubgroupFeature::Basic | SubgroupFeature::Vote |
SubgroupFeature::Arithmetic | SubgroupFeature::Ballot |
SubgroupFeature::Shuffle |
@@ -428,7 +408,8 @@
getBoolAttr(shaderFloat64), getBoolAttr(shaderInt16),
getBoolAttr(shaderInt64),
SubgroupFeatureAttr::get(context, subgroupFeatures), subgroupSize,
- getBoolAttr(storageBuffer16BitAccess), getBoolAttr(storagePushConstant16),
+ minSubgroupSize, maxSubgroupSize, getBoolAttr(storageBuffer16BitAccess),
+ getBoolAttr(storagePushConstant16),
getBoolAttr(uniformAndStorageBuffer16BitAccess),
getBoolAttr(storageBuffer8BitAccess), getBoolAttr(storagePushConstant8),
getBoolAttr(uniformAndStorageBuffer8BitAccess),
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/Utils/test/target_env_conversion.mlir b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/test/target_env_conversion.mlir
index 472fa8c..523d548 100644
--- a/compiler/src/iree/compiler/Dialect/Vulkan/Utils/test/target_env_conversion.mlir
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/test/target_env_conversion.mlir
@@ -1,22 +1,48 @@
// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-transformation-pipeline{serialize-executables=false})' --iree-hal-target-backends=vulkan-spirv %s | FileCheck %s --check-prefix=DEFAULT
// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-transformation-pipeline{serialize-executables=false})' --iree-hal-target-backends=vulkan-spirv --iree-vulkan-target-triple=adreno-a650-android30 %s | FileCheck %s --check-prefix=ADRENO
-// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-transformation-pipeline{serialize-executables=false})' --iree-hal-target-backends=vulkan-spirv --iree-vulkan-target-triple=valhall-unknown-android31 %s | FileCheck %s --check-prefix=MALI
-// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-transformation-pipeline{serialize-executables=false})' --iree-hal-target-backends=vulkan-spirv --iree-vulkan-target-triple=turing-t4-linux %s | FileCheck %s --check-prefix=TURINGT4
-// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-transformation-pipeline{serialize-executables=false})' --iree-hal-target-backends=vulkan-spirv --iree-vulkan-target-triple=rdna1-5700xt-windows %s | FileCheck %s --check-prefix=AMD5700XT
-// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-transformation-pipeline{serialize-executables=false})' --iree-hal-target-backends=vulkan-spirv --iree-vulkan-target-triple=rdna3-6900xtx-windows %s | FileCheck %s --check-prefix=AMD6900XTX
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-transformation-pipeline{serialize-executables=false})' --iree-hal-target-backends=vulkan-spirv --iree-vulkan-target-triple=valhall-unknown-android31 %s | FileCheck %s --check-prefix=VALHALL
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-transformation-pipeline{serialize-executables=false})' --iree-hal-target-backends=vulkan-spirv --iree-vulkan-target-triple=turing-t4-linux %s | FileCheck %s --check-prefix=TURING
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-transformation-pipeline{serialize-executables=false})' --iree-hal-target-backends=vulkan-spirv --iree-vulkan-target-triple=rdna1-5700xt-windows %s | FileCheck %s --check-prefix=RDNA1
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-transformation-pipeline{serialize-executables=false})' --iree-hal-target-backends=vulkan-spirv --iree-vulkan-target-triple=rdna3-6900xtx-windows %s | FileCheck %s --check-prefix=RDNA3
// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-transformation-pipeline{serialize-executables=false})' --iree-hal-target-backends=vulkan-spirv --iree-vulkan-target-triple=m1-moltenvk-macos %s | FileCheck %s --check-prefix=M1
// TODO(antiagainst): Passing in lenghty strings as command-line options is not
// optimal. We should consider creating a dedicated test pass to pick up
// #vk.target_env in input assembly and convert them.
-// DEFAULT: #spirv.target_env<#spirv.vce<v1.3, [Shader, GroupNonUniform], [SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, api=Vulkan, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64], subgroup_size = 64, cooperative_matrix_properties_nv = []>>
-// ADRENO: #spirv.target_env<#spirv.vce<v1.4, [Shader, Float16, Int16, Int8, StorageBuffer16BitAccess, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, api=Vulkan, Qualcomm:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64], subgroup_size = 64, cooperative_matrix_properties_nv = []>>
-// MALI: #spirv.target_env<#spirv.vce<v1.4, [Shader, Float16, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, api=Vulkan, ARM:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16, cooperative_matrix_properties_nv = []>>
-// TURINGT4: #spirv.target_env<#spirv.vce<v1.6, [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer, CooperativeMatrixNV], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers, SPV_NV_cooperative_matrix]>, api=Vulkan, NVIDIA:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 49152, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64], cooperative_matrix_properties_nv = [#spirv.coop_matrix_props<m_size = 8, n_size = 8, k_size = 32, a_type = i8, b_type = i8, c_type = i32, result_type = i32, scope = <Subgroup>>, #spirv.coop_matrix_props<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, scope = <Subgroup>>, #spirv.coop_matrix_props<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, scope = <Subgroup>>]>>
-// AMD5700XT: #spirv.target_env<#spirv.vce<v1.6, [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, api=Vulkan, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, cooperative_matrix_properties_nv = []>>
-// AMD6900XTX: #spirv.target_env<#spirv.vce<v1.6, [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer, CooperativeMatrixNV], [SPV_NV_cooperative_matrix, SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, api=Vulkan, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, cooperative_matrix_properties_nv = [#spirv.coop_matrix_props<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, scope = <Subgroup>>]>>
-// M1: #spirv.target_env<#spirv.vce<v1.3, [Shader, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, api=Vulkan, Apple:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], cooperative_matrix_properties_nv = []>>
+// DEFAULT: #spirv.target_env<#spirv.vce<v1.3,
+// DEFAULT-SAME: [Shader, GroupNonUniform], [SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+// DEFAULT-SAME: api=Vulkan, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64], subgroup_size = 64, cooperative_matrix_properties_nv = []>>
+
+// ADRENO: #spirv.target_env<#spirv.vce<v1.4,
+// ADRENO-SAME: [Shader, Float16, Int16, Int8, StorageBuffer16BitAccess, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer],
+// ADRENO-SAME: [SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+// ADRENO-SAME: api=Vulkan, Qualcomm:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64], subgroup_size = 64, cooperative_matrix_properties_nv = []>>
+
+// VALHALL: #spirv.target_env<#spirv.vce<v1.4,
+// VALHALL-SAME: [Shader, Float16, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer],
+// VALHALL-SAME: [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+// VALHALL-SAME: api=Vulkan, ARM:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16, cooperative_matrix_properties_nv = []>>
+
+// TURING: #spirv.target_env<#spirv.vce<v1.6,
+// TURING-SAME: [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer, CooperativeMatrixNV],
+// TURING-SAME: [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers, SPV_NV_cooperative_matrix]>,
+// TURING-SAME: api=Vulkan, NVIDIA:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 49152, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64], min_subgroup_size = 32, max_subgroup_size = 32, cooperative_matrix_properties_nv = [#spirv.coop_matrix_props<m_size = 8, n_size = 8, k_size = 32, a_type = i8, b_type = i8, c_type = i32, result_type = i32, scope = <Subgroup>>, #spirv.coop_matrix_props<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, scope = <Subgroup>>, #spirv.coop_matrix_props<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, scope = <Subgroup>>]>>
+
+// RDNA1: #spirv.target_env<#spirv.vce<v1.6,
+// RDNA1-SAME: [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer],
+// RDNA1-SAME: [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+// RDNA1-SAME: api=Vulkan, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, min_subgroup_size = 32, max_subgroup_size = 64, cooperative_matrix_properties_nv = []>>
+
+// RDNA3: #spirv.target_env<#spirv.vce<v1.6,
+// RDNA3-SAME: [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer, CooperativeMatrixNV],
+// RDNA3-SAME: [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers, SPV_NV_cooperative_matrix]>,
+// RDNA3-SAME: api=Vulkan, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, min_subgroup_size = 32, max_subgroup_size = 64, cooperative_matrix_properties_nv = [#spirv.coop_matrix_props<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, scope = <Subgroup>>]>>
+
+// M1: #spirv.target_env<#spirv.vce<v1.3,
+// M1-SAME: [Shader, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer],
+// M1-SAME: [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+// M1-SAME: api=Vulkan, Apple:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], cooperative_matrix_properties_nv = []>>
stream.executable public @reduce_dispatch {