[vulkan][spirv] Plumb through support for KHR Integer Dot Product (#12424)
This implements the Vulkan-side of the integer dot product extension:
https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_KHR_shader_integer_dot_product.html.
For now, only enable it for the `Valhall` target.
Also simplify how we append extensions (NFC).
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.td b/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.td
index ab6f357..793bc4f 100644
--- a/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.td
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.td
@@ -104,6 +104,13 @@
OptionalParameter<"::mlir::UnitAttr">:$shaderFloat16,
OptionalParameter<"::mlir::UnitAttr">:$shaderInt8,
+ // VK_KHR_shader_integer_dot_product features.
+ //
+ // This corresponds to the `VkPhysicalDeviceShaderIntegerDotProductFeatures`
+ // structure:
+ // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR.html
+ OptionalParameter<"::mlir::UnitAttr">:$shaderIntegerDotProduct,
+
// VK_KHR_variable_pointers features.
// This corresponds to the `VkPhysicalDeviceVariablePointersFeatures`
// structure:
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanBase.td b/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanBase.td
index c5f2bf6..0de0435 100644
--- a/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanBase.td
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanBase.td
@@ -84,18 +84,19 @@
def VK_KHR_16bit_storage : I32EnumAttrCase<"VK_KHR_16bit_storage", 0>;
def VK_KHR_8bit_storage : I32EnumAttrCase<"VK_KHR_8bit_storage", 1>;
def VK_KHR_shader_float16_int8 : I32EnumAttrCase<"VK_KHR_shader_float16_int8", 2>;
-def VK_KHR_spirv_1_4 : I32EnumAttrCase<"VK_KHR_spirv_1_4", 3>;
-def VK_KHR_storage_buffer_storage_class : I32EnumAttrCase<"VK_KHR_storage_buffer_storage_class", 4>;
-def VK_KHR_variable_pointers: I32EnumAttrCase<"VK_KHR_variable_pointers", 5>;
-def VK_EXT_subgroup_size_control : I32EnumAttrCase<"VK_EXT_subgroup_size_control", 6>;
-def VK_NV_cooperative_matrix : I32EnumAttrCase<"VK_NV_cooperative_matrix", 7>;
+def VK_KHR_shader_integer_dot_product : I32EnumAttrCase<"VK_KHR_shader_integer_dot_product", 3>;
+def VK_KHR_spirv_1_4 : I32EnumAttrCase<"VK_KHR_spirv_1_4", 4>;
+def VK_KHR_storage_buffer_storage_class : I32EnumAttrCase<"VK_KHR_storage_buffer_storage_class", 5>;
+def VK_KHR_variable_pointers: I32EnumAttrCase<"VK_KHR_variable_pointers", 6>;
+def VK_EXT_subgroup_size_control : I32EnumAttrCase<"VK_EXT_subgroup_size_control", 7>;
+def VK_NV_cooperative_matrix : I32EnumAttrCase<"VK_NV_cooperative_matrix", 8>;
def VK_ExtensionAttr :
VK_I32EnumAttr<"Extension", "supported Vulkan extension", "extension", [
VK_KHR_16bit_storage, VK_KHR_8bit_storage, VK_KHR_shader_float16_int8,
- VK_KHR_spirv_1_4, VK_KHR_storage_buffer_storage_class,
- VK_KHR_variable_pointers, VK_EXT_subgroup_size_control,
- VK_NV_cooperative_matrix
+ VK_KHR_shader_integer_dot_product, VK_KHR_spirv_1_4,
+ VK_KHR_storage_buffer_storage_class, VK_KHR_variable_pointers,
+ VK_EXT_subgroup_size_control, VK_NV_cooperative_matrix
]>;
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetEnvironment.cpp b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetEnvironment.cpp
index 4335026..0e33d12 100644
--- a/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetEnvironment.cpp
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetEnvironment.cpp
@@ -59,6 +59,9 @@
case Extension::VK_KHR_shader_float16_int8:
// This extension allows using certain SPIR-V capabilities.
break;
+ case Extension::VK_KHR_shader_integer_dot_product:
+ extensions.push_back(spirv::Extension::SPV_KHR_integer_dot_product);
+ break;
case Extension::VK_KHR_spirv_1_4:
// This extension only affects SPIR-V version.
break;
@@ -141,6 +144,14 @@
if (vkCapabilities.getVariablePointersStorageBuffer()) {
capabilities.push_back(spirv::Capability::VariablePointersStorageBuffer);
}
+ if (vkCapabilities.getShaderIntegerDotProduct()) {
+ capabilities.push_back(spirv::Capability::DotProduct);
+ capabilities.push_back(spirv::Capability::DotProductInputAll);
+ capabilities.push_back(spirv::Capability::DotProductInput4x8BitPacked);
+ if (vkCapabilities.getShaderInt8()) {
+ capabilities.push_back(spirv::Capability::DotProductInput4x8Bit);
+ }
+ }
if (ArrayAttr attr = vkCapabilities.getCooperativeMatrixPropertiesNV()) {
if (!attr.empty()) {
capabilities.push_back(spirv::Capability::CooperativeMatrixNV);
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetTriple.cpp b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetTriple.cpp
index 2184cd1..c7e495f 100644
--- a/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetTriple.cpp
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/TargetTriple.cpp
@@ -6,8 +6,7 @@
#include "iree/compiler/Dialect/Vulkan/Utils/TargetTriple.h"
-#include <array>
-
+#include "iree/compiler/Dialect/Vulkan/IR/VulkanTypes.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
@@ -117,38 +116,39 @@
switch (triple.getArch()) {
case TargetTripleArch::Apple_M1: {
// Example: https://vulkan.gpuinfo.org/displayreport.php?id=14673
- const std::array<Extension, 5> list = {
+ const Extension list[] = {
Extension::VK_KHR_16bit_storage,
Extension::VK_KHR_8bit_storage,
Extension::VK_KHR_shader_float16_int8,
Extension::VK_KHR_storage_buffer_storage_class,
Extension::VK_KHR_variable_pointers,
};
- return extensions.append(list.begin(), list.end());
+ return append_range(extensions, list);
}
case TargetTripleArch::ARM_Valhall: {
// Example: https://vulkan.gpuinfo.org/displayreport.php?id=10312
- const std::array<Extension, 6> list = {
+ const Extension list[] = {
Extension::VK_KHR_16bit_storage,
Extension::VK_KHR_8bit_storage,
Extension::VK_KHR_shader_float16_int8,
+ Extension::VK_KHR_shader_integer_dot_product,
Extension::VK_KHR_spirv_1_4,
Extension::VK_KHR_storage_buffer_storage_class,
Extension::VK_KHR_variable_pointers,
};
- return extensions.append(list.begin(), list.end());
+ return append_range(extensions, list);
}
case TargetTripleArch::QC_Adreno: {
// Example: https://vulkan.gpuinfo.org/displayreport.php?id=10983 (11)
// Example: https://vulkan.gpuinfo.org/displayreport.php?id=16312 (12)
- const std::array<Extension, 5> list = {
+ const Extension list[] = {
Extension::VK_KHR_16bit_storage,
Extension::VK_KHR_shader_float16_int8,
Extension::VK_KHR_spirv_1_4,
Extension::VK_KHR_storage_buffer_storage_class,
Extension::VK_KHR_variable_pointers,
};
- extensions.append(list.begin(), list.end());
+ append_range(extensions, list);
if (triple.getOS() == TargetTripleOS::Android31) {
extensions.push_back(Extension::VK_KHR_8bit_storage);
}
@@ -169,11 +169,11 @@
if (triple.getArch() == TargetTripleArch::Unknown) {
// The following extensions have 90%+ device coverage from
// https://vulkan.gpuinfo.org/listextensions.php.
- const std::array<Extension, 2> list = {
+ const Extension list[] = {
Extension::VK_KHR_storage_buffer_storage_class,
Extension::VK_KHR_variable_pointers,
};
- return extensions.append(list.begin(), list.end());
+ return append_range(extensions, list);
}
// Desktop GPUs typically support all extensions we care.
@@ -215,6 +215,8 @@
bool shaderFloat16 = false, shaderFloat64 = false;
bool shaderInt8 = false, shaderInt16 = false, shaderInt64 = false;
+ bool shaderIntegerDotProduct = false;
+
bool storageBuffer16BitAccess = false, storagePushConstant16 = false;
bool uniformAndStorageBuffer16BitAccess = false;
bool storageBuffer8BitAccess = false, storagePushConstant8 = false;
@@ -315,6 +317,8 @@
shaderFloat16 = shaderInt8 = shaderInt16 = true;
+ shaderIntegerDotProduct = true;
+
storageBuffer16BitAccess = storagePushConstant16 = true;
uniformAndStorageBuffer16BitAccess = true;
storageBuffer8BitAccess = true, storagePushConstant8 = true;
@@ -426,6 +430,7 @@
getBoolAttr(storageBuffer8BitAccess), getBoolAttr(storagePushConstant8),
getBoolAttr(uniformAndStorageBuffer8BitAccess),
getBoolAttr(shaderFloat16), getBoolAttr(shaderInt8),
+ getBoolAttr(shaderIntegerDotProduct),
getBoolAttr(variablePointersStorageBuffer), getBoolAttr(variablePointers),
builder.getArrayAttr(coopmatCases));
}
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/Utils/test/target_env_conversion.mlir b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/test/target_env_conversion.mlir
index 0943a1d..c273ae2 100644
--- a/compiler/src/iree/compiler/Dialect/Vulkan/Utils/test/target_env_conversion.mlir
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/test/target_env_conversion.mlir
@@ -20,8 +20,8 @@
// ADRENO-SAME: api=Vulkan, Qualcomm:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64], subgroup_size = 64, cooperative_matrix_properties_nv = []>>
// VALHALL: #spirv.target_env<#spirv.vce<v1.4,
-// VALHALL-SAME: [Shader, Float16, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer],
-// VALHALL-SAME: [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+// VALHALL-SAME: [Shader, Float16, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer, DotProduct, DotProductInputAll, DotProductInput4x8BitPacked, DotProductInput4x8Bit],
+// VALHALL-SAME: [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_integer_dot_product, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
// VALHALL-SAME: api=Vulkan, ARM:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16, cooperative_matrix_properties_nv = []>>
// TURING: #spirv.target_env<#spirv.vce<v1.6,