Adding cuda.device :: compute_capability_major/minor queries. (#14033)
These can be made from either compiled programs (via hal.device.query)
or via the runtime HAL APIs.
Fixes #14031.
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.c b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
index 0740f11..5e721a1 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_device.c
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
@@ -309,21 +309,40 @@
return iree_ok_status();
}
+static iree_status_t iree_hal_cuda_device_query_attribute(
+ iree_hal_cuda_device_t* device, CUdevice_attribute attribute,
+ int64_t* out_value) {
+ int value = 0;
+ CUDA_RETURN_IF_ERROR(device->context_wrapper.syms,
+ cuDeviceGetAttribute(&value, attribute,
+ device->context_wrapper.cu_device),
+ "cuDeviceGetAttribute");
+ *out_value = value;
+ return iree_ok_status();
+}
+
static iree_status_t iree_hal_cuda_device_query_i64(
iree_hal_device_t* base_device, iree_string_view_t category,
iree_string_view_t key, int64_t* out_value) {
- // iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
+ iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
*out_value = 0;
- if (iree_string_view_equal(category,
- iree_make_cstring_view("hal.executable.format"))) {
- *out_value =
- iree_string_view_equal(key, iree_make_cstring_view("cuda-nvptx-fb"))
- ? 1
- : 0;
+ if (iree_string_view_equal(category, IREE_SV("hal.executable.format"))) {
+ *out_value = iree_string_view_equal(key, IREE_SV("cuda-nvptx-fb")) ? 1 : 0;
return iree_ok_status();
}
+ if (iree_string_view_equal(category, IREE_SV("cuda.device"))) {
+ if (iree_string_view_equal(key, IREE_SV("compute_capability_major"))) {
+ return iree_hal_cuda_device_query_attribute(
+ device, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, out_value);
+ } else if (iree_string_view_equal(key,
+ IREE_SV("compute_capability_minor"))) {
+ return iree_hal_cuda_device_query_attribute(
+ device, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, out_value);
+ }
+ }
+
return iree_make_status(
IREE_STATUS_NOT_FOUND,
"unknown device configuration key value '%.*s :: %.*s'",