[vulkan] Request 8-/16-bit integer/floating-point features (#14848)
When the Vulkan implementation supports 8-/16-bit integer or
floating-point features, we can just request them. This helps to address
validation errors regarding them.
diff --git a/runtime/src/iree/hal/drivers/vulkan/extensibility_util.cc b/runtime/src/iree/hal/drivers/vulkan/extensibility_util.cc
index bf464c0..e23e91b 100644
--- a/runtime/src/iree/hal/drivers/vulkan/extensibility_util.cc
+++ b/runtime/src/iree/hal/drivers/vulkan/extensibility_util.cc
@@ -216,6 +216,12 @@
} else if (strcmp(extension_name,
VK_KHR_BUFFER_DEVICE_ADDRESS_EXTENSION_NAME) == 0) {
extensions.buffer_device_address = true;
+ } else if (strcmp(extension_name, VK_KHR_8BIT_STORAGE_EXTENSION_NAME) ==
+ 0) {
+ extensions.shader_8bit_storage = true;
+ } else if (strcmp(extension_name,
+ VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME) == 0) {
+ extensions.shader_float16_int8 = true;
}
}
return extensions;
diff --git a/runtime/src/iree/hal/drivers/vulkan/extensibility_util.h b/runtime/src/iree/hal/drivers/vulkan/extensibility_util.h
index b404f32..789dd71 100644
--- a/runtime/src/iree/hal/drivers/vulkan/extensibility_util.h
+++ b/runtime/src/iree/hal/drivers/vulkan/extensibility_util.h
@@ -86,6 +86,10 @@
bool external_memory_host : 1;
// VK_KHR_buffer_device_address is enabled.
bool buffer_device_address : 1;
+ // VK_KHR_8bit_storage is enabled.
+ bool shader_8bit_storage : 1;
+ // VK_KHR_shader_float16_int8 is enabled.
+ bool shader_float16_int8 : 1;
} iree_hal_vulkan_device_extensions_t;
// Returns a bitfield with all of the provided extension names.
diff --git a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
index d93b259..2a596be 100644
--- a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
+++ b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
@@ -242,6 +242,19 @@
ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL,
VK_EXT_SUBGROUP_SIZE_CONTROL_EXTENSION_NAME);
+ // VK_KHR_8bit_storage:
+ // This extension allows use of 8-bit types in uniform and storage buffers,
+ // and push constant blocks. It's promoted to core since Vulkan 1.2.
+ ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL,
+ VK_KHR_8BIT_STORAGE_EXTENSION_NAME);
+
+ // VK_KHR_shader_float16_int8:
+ // This extension allows use of 16-bit floating-point types and 8-bit integer
+ // types in shaders for arithmetic operations. It's promoted to core since
+ // Vulkan 1.2.
+ ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL,
+ VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME);
+
//===--------------------------------------------------------------------===//
// Optional debugging features
//===--------------------------------------------------------------------===//
@@ -919,13 +932,45 @@
VkPhysicalDeviceFeatures2 available_features2;
memset(&available_features2, 0, sizeof(available_features2));
available_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
+
+ // + Buffer device address features.
VkPhysicalDeviceBufferDeviceAddressFeatures
available_buffer_device_address_features;
memset(&available_buffer_device_address_features, 0,
sizeof(available_buffer_device_address_features));
available_buffer_device_address_features.sType =
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_BUFFER_DEVICE_ADDRESS_FEATURES;
+ available_buffer_device_address_features.pNext = available_features2.pNext;
available_features2.pNext = &available_buffer_device_address_features;
+
+ // + Shader 16 bit storage features.
+ VkPhysicalDevice16BitStorageFeatures available_16bit_storage_features;
+ memset(&available_16bit_storage_features, 0,
+ sizeof(available_16bit_storage_features));
+ available_16bit_storage_features.sType =
+ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES;
+ available_16bit_storage_features.pNext = available_features2.pNext;
+ available_features2.pNext = &available_16bit_storage_features;
+
+ // + Shader 8 bit storage features.
+ VkPhysicalDevice8BitStorageFeatures available_8bit_storage_features;
+ memset(&available_8bit_storage_features, 0,
+ sizeof(available_8bit_storage_features));
+ available_8bit_storage_features.sType =
+ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES;
+ available_8bit_storage_features.pNext = available_features2.pNext;
+ available_features2.pNext = &available_8bit_storage_features;
+
+ // + Shader float16 and int8 features.
+ VkPhysicalDeviceShaderFloat16Int8Features
+ available_shader_float16_int8_features;
+ memset(&available_shader_float16_int8_features, 0,
+ sizeof(available_shader_float16_int8_features));
+ available_shader_float16_int8_features.sType =
+ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES;
+ available_shader_float16_int8_features.pNext = available_features2.pNext;
+ available_features2.pNext = &available_shader_float16_int8_features;
+
instance_syms->vkGetPhysicalDeviceFeatures2(physical_device,
&available_features2);
const VkPhysicalDeviceFeatures* available_features =
@@ -950,6 +995,9 @@
if (available_features->shaderInt64) {
enabled_features2.features.shaderInt64 = VK_TRUE;
}
+ if (available_features->shaderInt16) {
+ enabled_features2.features.shaderInt16 = VK_TRUE;
+ }
iree_hal_vulkan_features_t enabled_features = 0;
@@ -1030,6 +1078,18 @@
subgroup_control_features.subgroupSizeControl = VK_TRUE;
}
+ // Enable all available 16- or 8-bit integer/floating-point features.
+ available_16bit_storage_features.pNext = enabled_features2.pNext;
+ enabled_features2.pNext = &available_16bit_storage_features;
+ if (enabled_device_extensions.shader_8bit_storage) {
+ available_8bit_storage_features.pNext = enabled_features2.pNext;
+ enabled_features2.pNext = &available_8bit_storage_features;
+ }
+ if (enabled_device_extensions.shader_float16_int8) {
+ available_shader_float16_int8_features.pNext = enabled_features2.pNext;
+ enabled_features2.pNext = &available_shader_float16_int8_features;
+ }
+
auto logical_device = new VkDeviceHandle(
instance_syms, physical_device, enabled_features,
enabled_device_extensions,