[Vulkan][SPIRV] Introduce `address` vulkan device property (#16282)
This is so that we can check
`VkPhysicalDeviceBufferDeviceAddressFeatures` and decide if the
`PhysicalStorageBufferAddresses` SPIR-V capability is available.
Issue: https://github.com/openxla/iree/pull/14977
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMaterializeExecutableConditions.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMaterializeExecutableConditions.cpp
index 885686e..aac9bad 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMaterializeExecutableConditions.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMaterializeExecutableConditions.cpp
@@ -15,6 +15,9 @@
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+
+#include <cstdint>
namespace mlir::iree_compiler {
@@ -54,14 +57,18 @@
// ("coopmatrix.<input-element-type>.<output-element-type>.<m>x<n>x<k>")
// * 0b01: coopmatrix.f16.f16.16x16x16
uint32_t coopMatrix;
+ // Physical storage buffer address bitfield
+ // ("address.<mode>")
+ // * ob01: address.physical64
+ uint32_t address;
KernelFeatures()
: computeFloat(0), computeInt(0), storage(0), subgroup(0), dotProduct(0),
- coopMatrix(0) {}
+ coopMatrix(0), address(0) {}
bool empty() const {
return computeFloat == 0 && computeInt == 0 && storage == 0 &&
- subgroup == 0 && dotProduct == 0 && coopMatrix == 0;
+ subgroup == 0 && dotProduct == 0 && coopMatrix == 0 && address == 0;
}
};
@@ -169,6 +176,13 @@
return success();
}
+ //===-------------------------------------------------------------------===//
+ // Address capabilities
+ case spirv::Capability::PhysicalStorageBufferAddresses:
+ // Vulkan only supports 64-bit device buffer addresses.
+ features.address |= 0b01;
+ return success();
+
default:
break;
}
@@ -219,6 +233,9 @@
if (features.coopMatrix) {
result = buildQueryOp("coopmatrix.ops", features.coopMatrix, result);
}
+ if (features.address) {
+ result = buildQueryOp("address.mode", features.address, result);
+ }
builder.create<IREE::HAL::ReturnOp>(loc, result);
}
@@ -245,6 +262,9 @@
if (features.coopMatrix) {
queries.push_back("coopmatrix.ops=" + std::to_string(features.coopMatrix));
}
+ if (features.address) {
+ queries.push_back("address.mode=" + std::to_string(features.address));
+ }
return queries;
}
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/materialize_executable_conditions.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/materialize_executable_conditions.mlir
index 30ed5d6..38f2bba 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/materialize_executable_conditions.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/materialize_executable_conditions.mlir
@@ -8,6 +8,14 @@
]>
]>
+#indirect_pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ <0, bindings = [
+ <0, storage_buffer, ReadOnly>,
+ <1, storage_buffer, ReadOnly>,
+ <2, storage_buffer>
+ ], flags = Indirect>
+]>
+
hal.executable private @dispatch_executable {
// CHECK-LABEL: hal.executable.variant public @test_assumed_capabilities
// CHECK-SAME: target(<"vulkan", "vulkan-spirv-fb", {iree.spirv.features = ["vulkan"]}>)
@@ -233,4 +241,38 @@
}
}
}
+
+ // CHECK-LABEL: hal.executable.variant public @test_address_capabilities
+ // CHECK-SAME: target(<"vulkan", "vulkan-spirv-fb-ptr",
+ // CHECK-SAME: {hal.bindings.indirect, iree.spirv.features = ["vulkan", "compute.bitwidths.int=4", "address.mode=1"]}>)
+ // CHECK: %{{.+}}, %[[V0:.+]] = hal.device.query<%{{.+}} : !hal.device>
+ // CHECK-SAME: key("hal.dispatch" :: "compute.bitwidths.int") : i1, i32 = 0 : i32
+ // CHECK: %[[TARGET0:.+]] = arith.constant 4 : i32
+ // CHECK: %{{.+}} = arith.andi %[[V0]], %[[TARGET0]] : i32
+ // CHECK: %{{.+}}, %[[V1:.+]] = hal.device.query<%{{.+}} : !hal.device>
+ // CHECK-SAME: key("hal.dispatch" :: "address.mode") : i1, i32 = 0 : i32
+ // CHECK: %[[TARGET1:.+]] = arith.constant 1 : i32
+ // CHECK: %{{.+}} = arith.andi %[[V1]], %[[TARGET1]] : i32
+ hal.executable.variant public @test_address_capabilities target(
+ #hal.executable.target<"vulkan", "vulkan-spirv-fb-ptr", {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.5,
+ [Int64, PhysicalStorageBufferAddresses],
+ [SPV_KHR_physical_storage_buffer]>,
+ #spirv.resource_limits<>>,
+ hal.bindings.indirect}>
+ ) {
+ hal.executable.export public @test_address_capabilities ordinal(0) layout(#indirect_pipeline_layout) {
+ ^bb0(%arg0: !hal.device):
+ %c1 = arith.constant 1 : index
+ hal.return %c1, %c1, %c1 : index, index, index
+ }
+ builtin.module {
+ spirv.module Physical64 GLSL450 requires
+ #spirv.vce<v1.5, [Int64, PhysicalStorageBufferAddresses], [SPV_KHR_physical_storage_buffer]> {
+ spirv.func @test_address_capabilities() "None" { spirv.Return }
+ spirv.EntryPoint "GLCompute" @test_address_capabilities
+ spirv.ExecutionMode @test_address_capabilities "LocalSize", 64, 1, 1
+ }
+ }
+ }
}
diff --git a/runtime/src/iree/hal/drivers/vulkan/extensibility_util.h b/runtime/src/iree/hal/drivers/vulkan/extensibility_util.h
index 4c12179..db828e6 100644
--- a/runtime/src/iree/hal/drivers/vulkan/extensibility_util.h
+++ b/runtime/src/iree/hal/drivers/vulkan/extensibility_util.h
@@ -140,6 +140,10 @@
// ("coopmatrix.<input-element-type>.<output-element-type>.<m>x<n>x<k>")
// * 0b01: coopmatrix.f16.f16.16x16x16
uint32_t cooperative_matrix : 8;
+ // Addressing more requirement bitfield:
+ // ("address.<mode>")
+ // * 0b01: address.physical64
+ uint32_t address : 8;
} iree_hal_vulkan_iree_hal_vulkan_device_properties_t;
#endif // IREE_HAL_DRIVERS_VULKAN_EXTENSIBILITY_UTIL_H_
diff --git a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
index ff59b82..aeebe69 100644
--- a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
+++ b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
@@ -902,9 +902,18 @@
coop_matrix_features.pNext = physical_device_features.pNext;
physical_device_features.pNext = &coop_matrix_features;
+ // + Physical storage buffer features.
+ VkPhysicalDeviceBufferDeviceAddressFeatures address_features;
+ memset(&address_features, 0, sizeof(address_features));
+ address_features.sType =
+ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_BUFFER_DEVICE_ADDRESS_FEATURES;
+ address_features.pNext = physical_device_features.pNext;
+ physical_device_features.pNext = &address_features;
+
instance_syms->vkGetPhysicalDeviceFeatures2(physical_device,
&physical_device_features);
+ // + Coop matrix properties.
VkPhysicalDeviceProperties2 physical_device_properties;
memset(&physical_device_properties, 0, sizeof(physical_device_properties));
physical_device_properties.sType =
@@ -992,6 +1001,10 @@
}
}
+ if (address_features.bufferDeviceAddress) {
+ device_properties->address |= 0x1u;
+ }
+
return iree_ok_status();
}
@@ -1453,6 +1466,10 @@
device->logical_device->supported_properties().cooperative_matrix;
return iree_ok_status();
}
+ if (iree_string_view_equal(key, IREE_SV("address.mode"))) {
+ *out_value = device->logical_device->supported_properties().address;
+ return iree_ok_status();
+ }
}
return iree_make_status(