[vulkan] Request 8-/16-bit integer/floating-point features (#14848)

When the Vulkan implementation supports 8-/16-bit integer or
floating-point features, we can just request them. This helps to address
validation errors regarding them.
diff --git a/runtime/src/iree/hal/drivers/vulkan/extensibility_util.cc b/runtime/src/iree/hal/drivers/vulkan/extensibility_util.cc
index bf464c0..e23e91b 100644
--- a/runtime/src/iree/hal/drivers/vulkan/extensibility_util.cc
+++ b/runtime/src/iree/hal/drivers/vulkan/extensibility_util.cc
@@ -216,6 +216,12 @@
     } else if (strcmp(extension_name,
                       VK_KHR_BUFFER_DEVICE_ADDRESS_EXTENSION_NAME) == 0) {
       extensions.buffer_device_address = true;
+    } else if (strcmp(extension_name, VK_KHR_8BIT_STORAGE_EXTENSION_NAME) ==
+               0) {
+      extensions.shader_8bit_storage = true;
+    } else if (strcmp(extension_name,
+                      VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME) == 0) {
+      extensions.shader_float16_int8 = true;
     }
   }
   return extensions;
diff --git a/runtime/src/iree/hal/drivers/vulkan/extensibility_util.h b/runtime/src/iree/hal/drivers/vulkan/extensibility_util.h
index b404f32..789dd71 100644
--- a/runtime/src/iree/hal/drivers/vulkan/extensibility_util.h
+++ b/runtime/src/iree/hal/drivers/vulkan/extensibility_util.h
@@ -86,6 +86,10 @@
   bool external_memory_host : 1;
   // VK_KHR_buffer_device_address is enabled.
   bool buffer_device_address : 1;
+  // VK_KHR_8bit_storage is enabled.
+  bool shader_8bit_storage : 1;
+  // VK_KHR_shader_float16_int8 is enabled.
+  bool shader_float16_int8 : 1;
 } iree_hal_vulkan_device_extensions_t;
 
 // Returns a bitfield with all of the provided extension names.
diff --git a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
index d93b259..2a596be 100644
--- a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
+++ b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
@@ -242,6 +242,19 @@
   ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL,
           VK_EXT_SUBGROUP_SIZE_CONTROL_EXTENSION_NAME);
 
+  // VK_KHR_8bit_storage:
+  // This extension allows use of 8-bit types in uniform and storage buffers,
+  // and push constant blocks. It's promoted to core since Vulkan 1.2.
+  ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL,
+          VK_KHR_8BIT_STORAGE_EXTENSION_NAME);
+
+  // VK_KHR_shader_float16_int8:
+  // This extension allows use of 16-bit floating-point types and 8-bit integer
+  // types in shaders for arithmetic operations. It's promoted to core since
+  // Vulkan 1.2.
+  ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL,
+          VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME);
+
   //===--------------------------------------------------------------------===//
   // Optional debugging features
   //===--------------------------------------------------------------------===//
@@ -919,13 +932,45 @@
   VkPhysicalDeviceFeatures2 available_features2;
   memset(&available_features2, 0, sizeof(available_features2));
   available_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
+
+  // + Buffer device address features.
   VkPhysicalDeviceBufferDeviceAddressFeatures
       available_buffer_device_address_features;
   memset(&available_buffer_device_address_features, 0,
          sizeof(available_buffer_device_address_features));
   available_buffer_device_address_features.sType =
       VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_BUFFER_DEVICE_ADDRESS_FEATURES;
+  available_buffer_device_address_features.pNext = available_features2.pNext;
   available_features2.pNext = &available_buffer_device_address_features;
+
+  // + Shader 16 bit storage features.
+  VkPhysicalDevice16BitStorageFeatures available_16bit_storage_features;
+  memset(&available_16bit_storage_features, 0,
+         sizeof(available_16bit_storage_features));
+  available_16bit_storage_features.sType =
+      VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES;
+  available_16bit_storage_features.pNext = available_features2.pNext;
+  available_features2.pNext = &available_16bit_storage_features;
+
+  // + Shader 8 bit storage features.
+  VkPhysicalDevice8BitStorageFeatures available_8bit_storage_features;
+  memset(&available_8bit_storage_features, 0,
+         sizeof(available_8bit_storage_features));
+  available_8bit_storage_features.sType =
+      VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES;
+  available_8bit_storage_features.pNext = available_features2.pNext;
+  available_features2.pNext = &available_8bit_storage_features;
+
+  // + Shader float16 and int8 features.
+  VkPhysicalDeviceShaderFloat16Int8Features
+      available_shader_float16_int8_features;
+  memset(&available_shader_float16_int8_features, 0,
+         sizeof(available_shader_float16_int8_features));
+  available_shader_float16_int8_features.sType =
+      VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES;
+  available_shader_float16_int8_features.pNext = available_features2.pNext;
+  available_features2.pNext = &available_shader_float16_int8_features;
+
   instance_syms->vkGetPhysicalDeviceFeatures2(physical_device,
                                               &available_features2);
   const VkPhysicalDeviceFeatures* available_features =
@@ -950,6 +995,9 @@
   if (available_features->shaderInt64) {
     enabled_features2.features.shaderInt64 = VK_TRUE;
   }
+  if (available_features->shaderInt16) {
+    enabled_features2.features.shaderInt16 = VK_TRUE;
+  }
 
   iree_hal_vulkan_features_t enabled_features = 0;
 
@@ -1030,6 +1078,18 @@
     subgroup_control_features.subgroupSizeControl = VK_TRUE;
   }
 
+  // Enable all available 16- or 8-bit integer/floating-point features.
+  available_16bit_storage_features.pNext = enabled_features2.pNext;
+  enabled_features2.pNext = &available_16bit_storage_features;
+  if (enabled_device_extensions.shader_8bit_storage) {
+    available_8bit_storage_features.pNext = enabled_features2.pNext;
+    enabled_features2.pNext = &available_8bit_storage_features;
+  }
+  if (enabled_device_extensions.shader_float16_int8) {
+    available_shader_float16_int8_features.pNext = enabled_features2.pNext;
+    enabled_features2.pNext = &available_shader_float16_int8_features;
+  }
+
   auto logical_device = new VkDeviceHandle(
       instance_syms, physical_device, enabled_features,
       enabled_device_extensions,