| // Copyright 2019 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "./hal.h" |
| |
| #include <nanobind/nanobind.h> |
| #include <nanobind/stl/vector.h> |
| |
| #include <optional> |
| |
| #include "./local_dlpack.h" |
| #include "./numpy_interop.h" |
| #include "./vm.h" |
| #include "iree/base/internal/path.h" |
| #include "iree/hal/api.h" |
| #include "iree/hal/utils/allocators.h" |
| #include "iree/modules/hal/module.h" |
| #include "iree/tooling/device_util.h" |
| |
| namespace iree { |
| namespace python { |
| |
| namespace { |
| |
| static const char kHalDeviceQueueAlloca[] = |
| R"(Reserves and returns a device-local queue-ordered transient buffer. |
| |
| Args: |
| allocation_size: The size in bytes of the allocation. |
| wait_semaphores: `List[Tuple[HalSemaphore, int]]` of semaphore values or |
| a HalFence. The allocation will be made once these semaphores are |
| satisfied. |
| signal_semaphores: Semaphores/Fence to signal. |
| |
| Returns: |
| HalBuffer. |
| )"; |
| |
| static const char kHalDeviceQueueDealloca[] = |
| R"(Deallocates a queue-ordered transient buffer. |
| |
| Args: |
| wait_semaphores: `List[Tuple[HalSemaphore, int]]` of semaphore values or |
| a HalFence. The allocation will be made once these semaphores are |
| satisfied. |
| signal_semaphores: Semaphores/Fence to signal. |
| |
| Returns: |
| HalBuffer. |
| )"; |
| |
| static const char kHalDeviceQueueExecute[] = |
| R"(Executes a sequence of command buffers. |
| |
| Args: |
| command_buffers: Sequence of command buffers to enqueue. |
| wait_semaphores: `List[Tuple[HalSemaphore, int]]` of semaphore values or |
| a HalFence. The allocation will be made once these semaphores are |
| satisfied. |
| signal_semaphores: Semaphores/Fence to signal. |
| )"; |
| |
| static const char kHalDeviceQueueCopy[] = |
| R"(Copy data from a source buffer to destination buffer. |
| |
| Args: |
| source_buffer: `HalBuffer` that holds src data. |
| target_buffer: `HalBuffer` that will receive data. |
| wait_semaphores: `List[Tuple[HalSemaphore, int]]` of semaphore values or |
| a HalFence. The allocation will be made once these semaphores are |
| satisfied. |
| signal_semaphores: Semaphores/Fence to signal. |
| )"; |
| |
| static const char kHalWait[] = |
| R"(Waits until the semaphore or fence is signalled or errored. |
| |
| Three wait cases are supported: |
| * timeout: Relative nanoseconds to wait. |
| * deadine: Absolute nanoseconds to wait. |
| * Neither: Waits for infinite time. |
| |
| Returns whether the wait succeeded (True) or timed out (False). If the fence was |
| asynchronously failed, an exception is raised. |
| )"; |
| |
| // RAII wrapper for a Py_buffer which calls PyBuffer_Release when it goes |
| // out of scope. |
| class PyBufferReleaser { |
| public: |
| PyBufferReleaser(Py_buffer& b) : b_(b) {} |
| ~PyBufferReleaser() { PyBuffer_Release(&b_); } |
| |
| private: |
| Py_buffer& b_; |
| }; |
| |
| static std::string ToHexString(const uint8_t* data, size_t length) { |
| static constexpr char kHexChars[] = {'0', '1', '2', '3', '4', '5', '6', '7', |
| '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'}; |
| std::string s(length * 2, ' '); |
| for (size_t i = 0; i < length; ++i) { |
| s[2 * i + 0] = kHexChars[(data[i] & 0xF0) >> 4]; |
| s[2 * i + 1] = kHexChars[(data[i] & 0x0F) >> 0]; |
| } |
| return s; |
| } |
| static std::string ToHexString(uint32_t value) { |
| return ToHexString((const uint8_t*)&value, sizeof(value)); |
| } |
| |
| iree_timeout_t NormalizeTimeout(std::optional<iree_duration_t> timeout, |
| std::optional<iree_time_t> deadline) { |
| if (!timeout && !deadline) { |
| return iree_infinite_timeout(); |
| } else if (timeout && deadline) { |
| throw std::invalid_argument("timeout and deadline cannot both be set"); |
| } else if (timeout) { |
| return iree_make_timeout_ns(*timeout); |
| } else { |
| return iree_timeout_t{IREE_TIMEOUT_ABSOLUTE, *deadline}; |
| } |
| } |
| |
| } // namespace |
| |
| //------------------------------------------------------------------------------ |
| // HalAllocator |
| //------------------------------------------------------------------------------ |
| |
| py::dict HalAllocator::QueryStatistics() { |
| py::dict items; |
| iree_hal_allocator_statistics_t stats; |
| iree_hal_allocator_query_statistics(raw_ptr(), &stats); |
| #if IREE_STATISTICS_ENABLE |
| items["host_bytes_peak"] = stats.host_bytes_peak; |
| items["host_bytes_allocated"] = stats.host_bytes_allocated; |
| items["host_bytes_freed"] = stats.host_bytes_freed; |
| items["device_bytes_peak"] = stats.device_bytes_peak; |
| items["device_bytes_allocated"] = stats.device_bytes_allocated; |
| items["device_bytes_freed"] = stats.device_bytes_freed; |
| #endif |
| return items; |
| } |
| |
| py::str HalAllocator::FormattedStatistics() { |
| // Perform all allocating string manipulation without early exit. |
| iree_string_builder_t builder; |
| iree_string_builder_initialize(iree_allocator_system(), &builder); |
| iree_hal_allocator_statistics_t stats; |
| iree_hal_allocator_query_statistics(raw_ptr(), &stats); |
| auto status = iree_hal_allocator_statistics_format(&stats, &builder); |
| iree_string_view_t view = iree_string_builder_view(&builder); |
| py::str result = py::str(view.data, view.size); |
| iree_string_builder_deinitialize(&builder); |
| |
| // Check/raise after all memory alloc/dealloc. |
| CheckApiStatus(status, "unable to format statistics"); |
| return result; |
| } |
| |
| py::object HalAllocator::AllocateBufferCopy( |
| int memory_type, int allowed_usage, HalDevice& device, py::object buffer, |
| std::optional<uint64_t> raw_element_type) { |
| IREE_TRACE_SCOPE_NAMED("HalAllocator::AllocateBufferCopy"); |
| // Request a view of the buffer (use the raw python C API to avoid |
| // some allocation and copying at the pybind level). |
| Py_buffer py_view; |
| // Note that only C-Contiguous ND-arrays are presently supported, so |
| // only request that via PyBUF_ND. Long term, we should consult an |
| // "oracle" in the runtime to determine the precise required format |
| // and set flags accordingly (and fallback/copy on failure). |
| int flags = PyBUF_FORMAT | PyBUF_ND; |
| |
| // Acquire the backing buffer and setup RAII release. |
| if (PyObject_GetBuffer(buffer.ptr(), &py_view, flags) != 0) { |
| // The GetBuffer call is required to set an appropriate error. |
| throw py::python_error(); |
| } |
| PyBufferReleaser py_view_releaser(py_view); |
| |
| iree_hal_buffer_params_t params = {0}; |
| // TODO: Should not require host visible :( |
| params.type = memory_type | IREE_HAL_MEMORY_TYPE_HOST_VISIBLE; |
| params.usage = allowed_usage; |
| |
| iree_hal_buffer_t* hal_buffer = nullptr; |
| iree_status_t status = iree_ok_status(); |
| { |
| py::gil_scoped_release release; |
| status = iree_hal_allocator_allocate_buffer(raw_ptr(), params, py_view.len, |
| &hal_buffer); |
| if (iree_status_is_ok(status)) { |
| status = iree_hal_device_transfer_h2d( |
| device.raw_ptr(), py_view.buf, hal_buffer, 0, py_view.len, |
| IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()); |
| } |
| } |
| CheckApiStatus(status, "Failed to allocate device visible buffer"); |
| |
| if (!raw_element_type) { |
| return py::cast(HalBuffer::StealFromRawPtr(hal_buffer), |
| py::rv_policy::move); |
| } |
| |
| // Create the buffer_view. (note that numpy shape is ssize_t, so we need to |
| // copy). |
| iree_hal_element_types_t element_type = |
| (iree_hal_element_types_t)*raw_element_type; |
| iree_hal_encoding_type_t encoding_type = |
| IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR; |
| std::vector<iree_hal_dim_t> dims(py_view.ndim); |
| std::copy(py_view.shape, py_view.shape + py_view.ndim, dims.begin()); |
| iree_hal_buffer_view_t* hal_buffer_view; |
| CheckApiStatus( |
| iree_hal_buffer_view_create( |
| hal_buffer, dims.size(), dims.data(), element_type, encoding_type, |
| iree_hal_allocator_host_allocator(raw_ptr()), &hal_buffer_view), |
| "Error allocating buffer_view"); |
| iree_hal_buffer_release(hal_buffer); |
| |
| return py::cast(HalBufferView::StealFromRawPtr(hal_buffer_view), |
| py::rv_policy::move); |
| } |
| |
| HalBuffer HalAllocator::AllocateHostStagingBufferCopy(HalDevice& device, |
| py::handle buffer) { |
| IREE_TRACE_SCOPE_NAMED("HalAllocator::AllocateHostStagingBufferCopy"); |
| // Request a view of the buffer (use the raw python C API to avoid |
| // some allocation and copying at the pybind level). |
| Py_buffer py_view; |
| // Note that only C-Contiguous ND-arrays are presently supported, so |
| // only request that via PyBUF_ND. Long term, we should consult an |
| // "oracle" in the runtime to determine the precise required format |
| // and set flags accordingly (and fallback/copy on failure). |
| int flags = PyBUF_FORMAT | PyBUF_ND; |
| |
| // Acquire the backing buffer and setup RAII release. |
| if (PyObject_GetBuffer(buffer.ptr(), &py_view, flags) != 0) { |
| // The GetBuffer call is required to set an appropriate error. |
| throw py::python_error(); |
| } |
| PyBufferReleaser py_view_releaser(py_view); |
| |
| iree_hal_buffer_params_t params = {0}; |
| std::memset(¶ms, 0, sizeof(params)); |
| params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE; |
| params.usage = IREE_HAL_BUFFER_USAGE_TRANSFER; |
| |
| iree_hal_buffer_t* hal_buffer = nullptr; |
| iree_status_t status = iree_ok_status(); |
| { |
| py::gil_scoped_release release; |
| status = iree_hal_allocator_allocate_buffer(raw_ptr(), params, py_view.len, |
| &hal_buffer); |
| if (iree_status_is_ok(status)) { |
| status = iree_hal_device_transfer_h2d( |
| device.raw_ptr(), py_view.buf, hal_buffer, 0, py_view.len, |
| IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()); |
| } |
| } |
| CheckApiStatus(status, "Failed to allocate device visible buffer"); |
| |
| return HalBuffer::StealFromRawPtr(hal_buffer); |
| } |
| |
| //------------------------------------------------------------------------------ |
| // HalBuffer |
| //------------------------------------------------------------------------------ |
| |
| namespace { |
| |
| void AppendHalBufferRepr(iree_hal_buffer_t* buffer, std::string& repr) { |
| repr.append(std::to_string(iree_hal_buffer_byte_length(buffer))); |
| repr.append(" bytes (at offset "); |
| repr.append(std::to_string(iree_hal_buffer_byte_offset(buffer))); |
| repr.append(" into "); |
| repr.append(std::to_string(iree_hal_buffer_allocation_size(buffer))); |
| repr.append("), memory_type="); |
| |
| // Memory type. |
| iree_bitfield_string_temp_t tmp; |
| iree_string_view_t sv; |
| sv = iree_hal_memory_type_format(iree_hal_buffer_memory_type(buffer), &tmp); |
| repr.append(sv.data, sv.size); |
| |
| // Allowed access. |
| repr.append(", allowed_access="); |
| sv = iree_hal_memory_access_format(iree_hal_buffer_allowed_access(buffer), |
| &tmp); |
| repr.append(sv.data, sv.size); |
| |
| // Allowed usage. |
| repr.append(", allowed_usage="); |
| sv = |
| iree_hal_buffer_usage_format(iree_hal_buffer_allowed_usage(buffer), &tmp); |
| repr.append(sv.data, sv.size); |
| } |
| |
| } // namespace |
| |
| py::str HalBuffer::Repr() { |
| std::string repr("<HalBuffer "); |
| AppendHalBufferRepr(raw_ptr(), repr); |
| repr.append(">"); |
| return py::str(py::cast(repr)); |
| } |
| |
| //------------------------------------------------------------------------------ |
| // HalBufferView |
| //------------------------------------------------------------------------------ |
| |
| py::str HalBufferView::Repr() { |
| std::string repr("<HalBufferView ("); |
| |
| // Shape. |
| iree_host_size_t rank = iree_hal_buffer_view_shape_rank(raw_ptr()); |
| for (iree_host_size_t i = 0; i < rank; ++i) { |
| if (i > 0) { |
| repr.append(", "); |
| } |
| repr.append(std::to_string(iree_hal_buffer_view_shape_dim(raw_ptr(), i))); |
| } |
| repr.append(")"); |
| |
| // Element type. |
| repr.append(", element_type=0x"); |
| auto element_type = iree_hal_buffer_view_element_type(raw_ptr()); |
| repr.append(ToHexString(static_cast<uint32_t>(element_type))); |
| |
| repr.append(", "); |
| AppendHalBufferRepr(iree_hal_buffer_view_buffer(raw_ptr()), repr); |
| repr.append(">"); |
| return py::str(py::cast(repr)); |
| } |
| |
| //------------------------------------------------------------------------------ |
| // HalDevice |
| //------------------------------------------------------------------------------ |
| |
| void HalDevice::BeginProfiling(std::optional<std::string> mode, |
| std::optional<std::string> file_path) { |
| iree_hal_device_profiling_options_t options; |
| memset(&options, 0, sizeof(options)); |
| |
| options.mode = IREE_HAL_DEVICE_PROFILING_MODE_QUEUE_OPERATIONS; |
| if (mode) { |
| if (*mode == "queue") { |
| options.mode = IREE_HAL_DEVICE_PROFILING_MODE_QUEUE_OPERATIONS; |
| } else if (*mode == "dispatch") { |
| options.mode = IREE_HAL_DEVICE_PROFILING_MODE_DISPATCH_COUNTERS; |
| } else if (*mode == "executable") { |
| options.mode = IREE_HAL_DEVICE_PROFILING_MODE_EXECUTABLE_COUNTERS; |
| } else { |
| throw RaiseValueError("unrecognized profiling mode"); |
| } |
| } |
| |
| options.file_path = file_path ? file_path->c_str() : nullptr; |
| CheckApiStatus(iree_hal_device_profiling_begin(raw_ptr(), &options), |
| "starting device profiling"); |
| } |
| |
| void HalDevice::FlushProfiling() { |
| CheckApiStatus(iree_hal_device_profiling_flush(raw_ptr()), |
| "flushing device profiling"); |
| } |
| |
| void HalDevice::EndProfiling() { |
| CheckApiStatus(iree_hal_device_profiling_end(raw_ptr()), |
| "ending device profiling"); |
| } |
| |
| HalSemaphore HalDevice::CreateSemaphore(uint64_t initial_value) { |
| iree_hal_semaphore_t* out_sem; |
| CheckApiStatus( |
| iree_hal_semaphore_create(raw_ptr(), initial_value, |
| IREE_HAL_SEMAPHORE_FLAG_NONE, &out_sem), |
| "creating semaphore"); |
| return HalSemaphore::StealFromRawPtr(out_sem); |
| } |
| |
| HalBuffer HalDevice::QueueAlloca(uint64_t allocation_size, |
| py::handle wait_semaphores, |
| py::handle signal_semaphores) { |
| iree_hal_buffer_params_t params; |
| memset(¶ms, 0, sizeof(params)); |
| // TODO: Accept explicit params in API. |
| params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL; |
| |
| iree_hal_semaphore_list_t wait_list; |
| iree_hal_semaphore_list_t signal_list; |
| |
| // Wait list. |
| if (py::isinstance<HalFence>(wait_semaphores)) { |
| wait_list = iree_hal_fence_semaphore_list( |
| py::cast<HalFence*>(wait_semaphores)->raw_ptr()); |
| } else { |
| size_t wait_count = py::len(wait_semaphores); |
| wait_list = { |
| wait_count, |
| /*semaphores=*/ |
| static_cast<iree_hal_semaphore_t**>( |
| alloca(sizeof(iree_hal_semaphore_t*) * wait_count)), |
| /*payload_values=*/ |
| static_cast<uint64_t*>(alloca(sizeof(uint64_t) * wait_count)), |
| }; |
| for (size_t i = 0; i < wait_count; ++i) { |
| py::tuple pair = wait_semaphores[i]; |
| wait_list.semaphores[i] = py::cast<HalSemaphore*>(pair[0])->raw_ptr(); |
| wait_list.payload_values[i] = py::cast<uint64_t>(pair[1]); |
| } |
| } |
| |
| // Signal list. |
| if (py::isinstance<HalFence>(signal_semaphores)) { |
| signal_list = iree_hal_fence_semaphore_list( |
| py::cast<HalFence*>(signal_semaphores)->raw_ptr()); |
| } else { |
| size_t signal_count = py::len(signal_semaphores); |
| signal_list = { |
| signal_count, |
| /*semaphores=*/ |
| static_cast<iree_hal_semaphore_t**>( |
| alloca(sizeof(iree_hal_semaphore_t*) * signal_count)), |
| /*payload_values=*/ |
| static_cast<uint64_t*>(alloca(sizeof(uint64_t) * signal_count)), |
| }; |
| for (size_t i = 0; i < signal_count; ++i) { |
| py::tuple pair = signal_semaphores[i]; |
| signal_list.semaphores[i] = py::cast<HalSemaphore*>(pair[0])->raw_ptr(); |
| signal_list.payload_values[i] = py::cast<uint64_t>(pair[1]); |
| } |
| } |
| |
| iree_hal_buffer_t* out_buffer; |
| // TODO: Accept params for queue affinity and pool. |
| CheckApiStatus(iree_hal_device_queue_alloca( |
| raw_ptr(), IREE_HAL_QUEUE_AFFINITY_ANY, wait_list, |
| signal_list, IREE_HAL_ALLOCATOR_POOL_DEFAULT, params, |
| allocation_size, &out_buffer), |
| "allocating memory on queue"); |
| return HalBuffer::StealFromRawPtr(out_buffer); |
| } |
| |
| void HalDevice::QueueDealloca(HalBuffer& buffer, py::handle wait_semaphores, |
| py::handle signal_semaphores) { |
| iree_hal_semaphore_list_t wait_list; |
| iree_hal_semaphore_list_t signal_list; |
| |
| // Wait list. |
| if (py::isinstance<HalFence>(wait_semaphores)) { |
| wait_list = iree_hal_fence_semaphore_list( |
| py::cast<HalFence*>(wait_semaphores)->raw_ptr()); |
| } else { |
| size_t wait_count = py::len(wait_semaphores); |
| wait_list = { |
| wait_count, |
| /*semaphores=*/ |
| static_cast<iree_hal_semaphore_t**>( |
| alloca(sizeof(iree_hal_semaphore_t*) * wait_count)), |
| /*payload_values=*/ |
| static_cast<uint64_t*>(alloca(sizeof(uint64_t) * wait_count)), |
| }; |
| for (size_t i = 0; i < wait_count; ++i) { |
| py::tuple pair = wait_semaphores[i]; |
| wait_list.semaphores[i] = py::cast<HalSemaphore*>(pair[0])->raw_ptr(); |
| wait_list.payload_values[i] = py::cast<uint64_t>(pair[1]); |
| } |
| } |
| |
| // Signal list. |
| if (py::isinstance<HalFence>(signal_semaphores)) { |
| signal_list = iree_hal_fence_semaphore_list( |
| py::cast<HalFence*>(signal_semaphores)->raw_ptr()); |
| } else { |
| size_t signal_count = py::len(signal_semaphores); |
| signal_list = { |
| signal_count, |
| /*semaphores=*/ |
| static_cast<iree_hal_semaphore_t**>( |
| alloca(sizeof(iree_hal_semaphore_t*) * signal_count)), |
| /*payload_values=*/ |
| static_cast<uint64_t*>(alloca(sizeof(uint64_t) * signal_count)), |
| }; |
| for (size_t i = 0; i < signal_count; ++i) { |
| py::tuple pair = signal_semaphores[i]; |
| signal_list.semaphores[i] = py::cast<HalSemaphore*>(pair[0])->raw_ptr(); |
| signal_list.payload_values[i] = py::cast<uint64_t>(pair[1]); |
| } |
| } |
| |
| CheckApiStatus( |
| iree_hal_device_queue_dealloca(raw_ptr(), IREE_HAL_QUEUE_AFFINITY_ANY, |
| wait_list, signal_list, buffer.raw_ptr()), |
| "deallocating memory on queue"); |
| } |
| |
| void HalDevice::QueueExecute(py::handle command_buffers, |
| py::handle wait_semaphores, |
| py::handle signal_semaphores) { |
| iree_hal_semaphore_list_t wait_list; |
| iree_hal_semaphore_list_t signal_list; |
| |
| // Wait list. |
| if (py::isinstance<HalFence>(wait_semaphores)) { |
| wait_list = iree_hal_fence_semaphore_list( |
| py::cast<HalFence*>(wait_semaphores)->raw_ptr()); |
| } else { |
| size_t wait_count = py::len(wait_semaphores); |
| wait_list = { |
| wait_count, |
| /*semaphores=*/ |
| static_cast<iree_hal_semaphore_t**>( |
| alloca(sizeof(iree_hal_semaphore_t*) * wait_count)), |
| /*payload_values=*/ |
| static_cast<uint64_t*>(alloca(sizeof(uint64_t) * wait_count)), |
| }; |
| for (size_t i = 0; i < wait_count; ++i) { |
| py::tuple pair = wait_semaphores[i]; |
| wait_list.semaphores[i] = py::cast<HalSemaphore*>(pair[0])->raw_ptr(); |
| wait_list.payload_values[i] = py::cast<uint64_t>(pair[1]); |
| } |
| } |
| |
| // Signal list. |
| if (py::isinstance<HalFence>(signal_semaphores)) { |
| signal_list = iree_hal_fence_semaphore_list( |
| py::cast<HalFence*>(signal_semaphores)->raw_ptr()); |
| } else { |
| size_t signal_count = py::len(signal_semaphores); |
| signal_list = { |
| signal_count, |
| /*semaphores=*/ |
| static_cast<iree_hal_semaphore_t**>( |
| alloca(sizeof(iree_hal_semaphore_t*) * signal_count)), |
| /*payload_values=*/ |
| static_cast<uint64_t*>(alloca(sizeof(uint64_t) * signal_count)), |
| }; |
| for (size_t i = 0; i < signal_count; ++i) { |
| py::tuple pair = signal_semaphores[i]; |
| signal_list.semaphores[i] = py::cast<HalSemaphore*>(pair[0])->raw_ptr(); |
| signal_list.payload_values[i] = py::cast<uint64_t>(pair[1]); |
| } |
| } |
| |
| // Unpack command buffers. |
| size_t cb_count = py::len(command_buffers); |
| iree_hal_command_buffer_t** cb_list = |
| static_cast<iree_hal_command_buffer_t**>( |
| alloca(sizeof(iree_hal_command_buffer_t*) * cb_count)); |
| for (size_t i = 0; i < cb_count; ++i) { |
| cb_list[i] = py::cast<HalCommandBuffer*>(command_buffers[i])->raw_ptr(); |
| } |
| |
| CheckApiStatus(iree_hal_device_queue_execute( |
| raw_ptr(), IREE_HAL_QUEUE_AFFINITY_ANY, wait_list, |
| signal_list, cb_count, cb_list, /*binding_tables=*/NULL), |
| "executing command buffers"); |
| } |
| |
| void HalDevice::QueueCopy(HalBuffer& source_buffer, HalBuffer& target_buffer, |
| py::handle wait_semaphores, |
| py::handle signal_semaphores) { |
| iree_hal_semaphore_list_t wait_list; |
| iree_hal_semaphore_list_t signal_list; |
| |
| // Wait list. |
| if (py::isinstance<HalFence>(wait_semaphores)) { |
| wait_list = iree_hal_fence_semaphore_list( |
| py::cast<HalFence*>(wait_semaphores)->raw_ptr()); |
| } else { |
| size_t wait_count = py::len(wait_semaphores); |
| wait_list = { |
| wait_count, |
| /*semaphores=*/ |
| static_cast<iree_hal_semaphore_t**>( |
| alloca(sizeof(iree_hal_semaphore_t*) * wait_count)), |
| /*payload_values=*/ |
| static_cast<uint64_t*>(alloca(sizeof(uint64_t) * wait_count)), |
| }; |
| for (size_t i = 0; i < wait_count; ++i) { |
| py::tuple pair = wait_semaphores[i]; |
| wait_list.semaphores[i] = py::cast<HalSemaphore*>(pair[0])->raw_ptr(); |
| wait_list.payload_values[i] = py::cast<uint64_t>(pair[1]); |
| } |
| } |
| |
| // Signal list. |
| if (py::isinstance<HalFence>(signal_semaphores)) { |
| signal_list = iree_hal_fence_semaphore_list( |
| py::cast<HalFence*>(signal_semaphores)->raw_ptr()); |
| } else { |
| size_t signal_count = py::len(signal_semaphores); |
| signal_list = { |
| signal_count, |
| /*semaphores=*/ |
| static_cast<iree_hal_semaphore_t**>( |
| alloca(sizeof(iree_hal_semaphore_t*) * signal_count)), |
| /*payload_values=*/ |
| static_cast<uint64_t*>(alloca(sizeof(uint64_t) * signal_count)), |
| }; |
| for (size_t i = 0; i < signal_count; ++i) { |
| py::tuple pair = signal_semaphores[i]; |
| signal_list.semaphores[i] = py::cast<HalSemaphore*>(pair[0])->raw_ptr(); |
| signal_list.payload_values[i] = py::cast<uint64_t>(pair[1]); |
| } |
| } |
| |
| // TODO: Accept params for src_offset and target_offset. Just check that |
| // the source will fit in the target buffer for now. |
| iree_device_size_t source_length = |
| iree_hal_buffer_byte_length(source_buffer.raw_ptr()); |
| if (source_length > iree_hal_buffer_byte_length(target_buffer.raw_ptr())) { |
| throw std::invalid_argument( |
| "Source and buffer length must be less than the target buffer length " |
| "and it does not. Please check allocations"); |
| } |
| CheckApiStatus(iree_hal_device_queue_copy( |
| raw_ptr(), IREE_HAL_QUEUE_AFFINITY_ANY, wait_list, |
| signal_list, source_buffer.raw_ptr(), 0, |
| target_buffer.raw_ptr(), 0, source_length), |
| "Copying buffer on queue"); |
| } |
| |
| py::object HalDevice::CreateDLPackCapsule(HalBufferView& buffer_view, |
| int device_type_code, int device_id) { |
| const size_t kStaticDimLimit = 6; |
| struct ExtDLManagedTensor : public DLManagedTensor { |
| ~ExtDLManagedTensor() { |
| if (retained_buffer) { |
| iree_hal_buffer_release(retained_buffer); |
| } |
| if (dl_tensor.ndim > kStaticDimLimit) { |
| delete[] dim_storage.dynamic_shape; |
| } |
| } |
| iree_hal_buffer_t* retained_buffer = nullptr; |
| union { |
| int64_t static_shape[kStaticDimLimit]; |
| int64_t* dynamic_shape; |
| } dim_storage; |
| }; |
| auto tensor = std::make_unique<ExtDLManagedTensor>(); |
| memset(static_cast<DLManagedTensor*>(tensor.get()), 0, |
| sizeof(DLManagedTensor)); |
| auto capsule_destructor = +[](PyObject* capsule) { |
| const char* actual_name = PyCapsule_GetName(capsule); |
| if (!actual_name || strcmp(actual_name, "dltensor") != 0) { |
| // Caller consumed the capsule. Do nothing. |
| return; |
| } |
| |
| // Capsule was dropped on the floor before consumed. Release resources. |
| void* capsule_ptr = PyCapsule_GetPointer(capsule, "dltensor"); |
| if (!capsule_ptr) { |
| return; |
| } |
| DLManagedTensor* tensor_ptr = static_cast<DLManagedTensor*>(capsule_ptr); |
| tensor_ptr->deleter(tensor_ptr); |
| }; |
| auto deleter = +[](struct DLManagedTensor* self) { |
| auto* ext_self = static_cast<ExtDLManagedTensor*>(self); |
| delete ext_self; |
| }; |
| |
| // Populate the DLManagedTensor. |
| tensor->deleter = deleter; |
| auto& dl_tensor = tensor->dl_tensor; |
| dl_tensor.device.device_type = static_cast<DLDeviceType>(device_type_code); |
| dl_tensor.device.device_id = device_id; |
| |
| // Convert metadata. |
| iree_hal_element_type_t et = |
| iree_hal_buffer_view_element_type(buffer_view.raw_ptr()); |
| dl_tensor.dtype.bits = iree_hal_element_bit_count(et); |
| dl_tensor.dtype.lanes = 1; |
| switch (iree_hal_element_numerical_type(et)) { |
| case IREE_HAL_NUMERICAL_TYPE_FLOAT_IEEE: |
| dl_tensor.dtype.code = kDLFloat; |
| break; |
| case IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED: |
| dl_tensor.dtype.code = kDLUInt; |
| break; |
| case IREE_HAL_NUMERICAL_TYPE_INTEGER: |
| case IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED: |
| dl_tensor.dtype.code = kDLInt; |
| break; |
| case IREE_HAL_NUMERICAL_TYPE_FLOAT_BRAIN: |
| dl_tensor.dtype.code = kDLBfloat; |
| break; |
| case IREE_HAL_NUMERICAL_TYPE_FLOAT_COMPLEX: |
| dl_tensor.dtype.code = kDLComplex; |
| break; |
| case IREE_HAL_NUMERICAL_TYPE_BOOLEAN: |
| dl_tensor.dtype.code = kDLBool; |
| break; |
| default: |
| throw std::invalid_argument( |
| "dlpack unsupported buffer view element type"); |
| } |
| |
| // Shape. |
| // Leave strides nullptr to signify dense row-major. |
| auto rank = iree_hal_buffer_view_shape_rank(buffer_view.raw_ptr()); |
| auto* bv_dims = iree_hal_buffer_view_shape_dims(buffer_view.raw_ptr()); |
| if (rank > kStaticDimLimit) { |
| dl_tensor.shape = new int64_t[rank]; |
| tensor->dim_storage.dynamic_shape = dl_tensor.shape; |
| } else { |
| dl_tensor.shape = tensor->dim_storage.static_shape; |
| } |
| for (size_t i = 0; i < rank; ++i) { |
| dl_tensor.shape[i] = bv_dims[i]; |
| } |
| dl_tensor.ndim = rank; |
| |
| // Export buffer view. |
| iree_hal_buffer_t* buffer = |
| iree_hal_buffer_view_buffer(buffer_view.raw_ptr()); |
| auto offset = iree_hal_buffer_byte_offset(buffer); |
| buffer = iree_hal_buffer_allocated_buffer(buffer); |
| iree_hal_allocator_t* alloc = iree_hal_device_allocator(raw_ptr()); |
| iree_hal_external_buffer_t external_buffer; |
| CheckApiStatus( |
| iree_hal_allocator_export_buffer( |
| alloc, buffer, IREE_HAL_EXTERNAL_BUFFER_TYPE_DEVICE_ALLOCATION, |
| IREE_HAL_EXTERNAL_BUFFER_FLAG_NONE, &external_buffer), |
| "Cannot export device buffer"); |
| static_assert(sizeof(dl_tensor.data) >= |
| sizeof(external_buffer.handle.device_allocation.ptr)); |
| dl_tensor.data = |
| reinterpret_cast<void*>(external_buffer.handle.device_allocation.ptr); |
| dl_tensor.byte_offset = offset; |
| |
| // Create and return capsule. |
| PyObject* capsule = PyCapsule_New(static_cast<DLManagedTensor*>(tensor.get()), |
| "dltensor", capsule_destructor); |
| if (!capsule) { |
| throw py::python_error(); |
| } |
| |
| // Retain the backing buffer view bound to the capsule lifetime. |
| tensor->retained_buffer = buffer; |
| iree_hal_buffer_retain(buffer); |
| tensor.release(); |
| return py::steal<py::object>(capsule); |
| } |
| |
| HalBufferView HalDevice::FromDLPackCapsule(py::object input_capsule) { |
| struct State { |
| ~State() { |
| if (managed_tensor && managed_tensor->deleter) { |
| managed_tensor->deleter(managed_tensor); |
| } |
| } |
| py::object capsule; |
| void* raw = nullptr; |
| DLManagedTensor* managed_tensor = nullptr; |
| } state; |
| state.capsule = std::move(input_capsule); |
| state.raw = PyCapsule_GetPointer(state.capsule.ptr(), "dltensor"); |
| if (!state.raw) { |
| throw py::python_error(); |
| } |
| state.managed_tensor = static_cast<DLManagedTensor*>(state.raw); |
| // Takes ownership. |
| if (PyCapsule_SetName(state.capsule.ptr(), "used_dltensor")) { |
| throw py::python_error(); |
| } |
| |
| DLTensor* dlt = &state.managed_tensor->dl_tensor; |
| |
| // Some validation on what we accept. |
| if (dlt->dtype.lanes != 1) { |
| throw std::invalid_argument("Unsupported dtype lanes != 1"); |
| } |
| |
| iree_hal_element_type_t et; |
| switch (dlt->dtype.code) { |
| case kDLInt: |
| et = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED, |
| dlt->dtype.bits); |
| break; |
| case kDLUInt: |
| et = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED, |
| dlt->dtype.bits); |
| break; |
| case kDLFloat: |
| et = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_FLOAT_IEEE, |
| dlt->dtype.bits); |
| break; |
| case kDLBfloat: |
| et = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_FLOAT_BRAIN, |
| dlt->dtype.bits); |
| break; |
| case kDLComplex: |
| et = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_FLOAT_COMPLEX, |
| dlt->dtype.bits); |
| break; |
| case kDLBool: |
| et = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_BOOLEAN, |
| dlt->dtype.bits); |
| break; |
| default: |
| throw std::invalid_argument("Unsupported dlpack dtype code"); |
| } |
| |
| // Verify dense row major strides (for now a requirement). |
| if (dlt->strides && dlt->ndim > 0) { |
| int64_t stride = 1; |
| for (int32_t i = dlt->ndim - 1; i >= 0; --i) { |
| auto dim = dlt->shape[i]; |
| // The stride value for 1 or 0 dims is undefined and dlpack can normalize |
| // it, so we skip validation for these. |
| // See: |
| // https://github.com/pytorch/pytorch/issues/99803#issuecomment-1521214463 |
| if (dim == 1 || dim == 0) continue; |
| if (dlt->strides[i] != stride) { |
| throw std::invalid_argument("Unsupported strided tensor"); |
| } |
| stride *= dim; |
| } |
| } |
| |
| // Verify no byte offset. We could technically allow this, but there are all |
| // kinds of bugs and caveats listed, and would like to see how it is used. |
| if (dlt->byte_offset != 0) { |
| throw std::invalid_argument("NYI: dlpack byte_offset != 0"); |
| } |
| |
| // Compute size. |
| auto* dims = static_cast<iree_hal_dim_t*>( |
| iree_alloca(sizeof(iree_hal_dim_t) * dlt->ndim)); |
| iree_device_size_t byte_size = iree_hal_element_bit_count(et); |
| if (dlt->ndim > 0) { |
| for (int32_t i = 0; i < dlt->ndim; ++i) { |
| byte_size *= dlt->shape[i]; |
| dims[i] = dlt->shape[i]; |
| } |
| } |
| if ((byte_size % 8) != 0) { |
| throw std::invalid_argument( |
| "dlpack tensor does not have a byte aligned size"); |
| } |
| byte_size /= 8; |
| |
| iree_hal_buffer_t* imported_buffer; |
| iree_hal_allocator_t* allocator = iree_hal_device_allocator(raw_ptr()); |
| iree_hal_buffer_params_t params; |
| memset(¶ms, 0, sizeof(params)); |
| params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT; |
| params.access = IREE_HAL_MEMORY_ACCESS_ANY; |
| params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL; |
| iree_hal_external_buffer_t external_buffer; |
| memset(&external_buffer, 0, sizeof(external_buffer)); |
| external_buffer.type = IREE_HAL_EXTERNAL_BUFFER_TYPE_DEVICE_ALLOCATION; |
| external_buffer.size = byte_size; |
| external_buffer.handle.device_allocation.ptr = |
| reinterpret_cast<uint64_t>(dlt->data); |
| iree_hal_buffer_release_callback_t release_callback = { |
| +[](void* user_data, struct iree_hal_buffer_t* buffer) { |
| auto managed_tensor = static_cast<DLManagedTensor*>(user_data); |
| if (managed_tensor->deleter) { |
| managed_tensor->deleter(managed_tensor); |
| } |
| }, |
| state.raw, |
| }; |
| CheckApiStatus( |
| iree_hal_allocator_import_buffer(allocator, params, &external_buffer, |
| release_callback, &imported_buffer), |
| "Could not import external device buffer"); |
| state.managed_tensor = nullptr; // Ownership transferred. |
| |
| // Create Buffer View. |
| iree_hal_buffer_view_t* buffer_view; |
| iree_status_t status = |
| iree_hal_buffer_view_create(imported_buffer, dlt->ndim, dims, et, |
| IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, |
| iree_allocator_system(), &buffer_view); |
| |
| if (!iree_status_is_ok(status)) { |
| iree_hal_buffer_release(imported_buffer); |
| CheckApiStatus(status, "Failed to create buffer view"); |
| } |
| |
| return HalBufferView::StealFromRawPtr(buffer_view); |
| } |
| |
| //------------------------------------------------------------------------------ |
| // HalDriver |
| //------------------------------------------------------------------------------ |
| |
| std::vector<std::string> HalDriver::Query() { |
| iree_host_size_t driver_info_count = 0; |
| iree_hal_driver_info_t* driver_infos = NULL; |
| CheckApiStatus( |
| iree_hal_driver_registry_enumerate(iree_hal_driver_registry_default(), |
| iree_allocator_system(), |
| &driver_info_count, &driver_infos), |
| "Error enumerating HAL drivers"); |
| std::vector<std::string> driver_names(driver_info_count); |
| for (iree_host_size_t i = 0; i < driver_info_count; ++i) { |
| driver_names[i] = std::string(driver_infos[i].driver_name.data, |
| driver_infos[i].driver_name.size); |
| } |
| iree_allocator_free(iree_allocator_system(), driver_infos); |
| return driver_names; |
| } |
| |
| HalDriver::DeviceUri::DeviceUri(const std::string& device_uri) { |
| iree_string_view_t device_uri_sv{ |
| device_uri.data(), static_cast<iree_host_size_t>(device_uri.size())}; |
| iree_uri_split(device_uri_sv, &driver_name, &device_path, ¶ms_str); |
| } |
| |
| py::object HalDriver::Create(const DeviceUri& device_uri) { |
| iree_hal_driver_t* driver; |
| CheckApiStatus(iree_hal_driver_registry_try_create( |
| iree_hal_driver_registry_default(), device_uri.driver_name, |
| iree_allocator_system(), &driver), |
| "Error creating driver"); |
| |
| py::object driver_obj = py::cast(HalDriver::StealFromRawPtr(driver)); |
| return driver_obj; |
| } |
| |
| py::object HalDriver::Create(const std::string& device_uri) { |
| DeviceUri parsed_uri(device_uri); |
| return HalDriver::Create(parsed_uri); |
| } |
| |
| py::object HalDriver::Create(const std::string& device_uri, |
| py::dict& driver_cache) { |
| // Look up the driver by driver name in the cache, and return it if found. |
| DeviceUri parsed_uri(device_uri); |
| py::str cache_key(parsed_uri.driver_name.data, parsed_uri.driver_name.size); |
| py::object cached = driver_cache.attr("get")(cache_key); |
| if (!cached.is_none()) { |
| return cached; |
| } |
| |
| // Create a new driver and put it in the cache. |
| py::object driver_obj = HalDriver::Create(parsed_uri); |
| driver_cache[cache_key] = driver_obj; |
| return driver_obj; |
| } |
| |
| py::list HalDriver::QueryAvailableDevices() { |
| iree_hal_device_info_t* device_infos; |
| iree_host_size_t count; |
| CheckApiStatus(iree_hal_driver_query_available_devices( |
| raw_ptr(), iree_allocator_system(), &count, &device_infos), |
| "Error querying devices"); |
| py::list results; |
| for (iree_host_size_t i = 0; i < count; ++i) { |
| py::dict device_data; |
| device_data["device_id"] = py::cast(device_infos[i].device_id); |
| device_data["path"] = |
| py::str(device_infos[i].path.data, device_infos[i].path.size); |
| device_data["name"] = |
| py::str(device_infos[i].name.data, device_infos[i].name.size); |
| results.append(device_data); |
| } |
| |
| iree_allocator_free(iree_allocator_system(), device_infos); |
| return results; |
| } |
| |
| // Configures |device| based on flags before returning it to the user. |
| static iree_status_t ConfigureDevice(iree_hal_device_t* device, |
| std::optional<py::list> allocators) { |
| // Optionally wrap the base device allocator with caching/pooling. |
| // Doing this here satisfies the requirement that no buffers have been |
| // allocated yet - if we returned the device without doing this the caller |
| // can more easily break the rules. |
| if (allocators) { |
| // NOTE: we need to pass string views that point to the std::string storage. |
| // We do that in two passes because as we grow spec_storage it may |
| // reallocate itself and invalidate the pointers - only after we're done |
| // can we capture them in views. |
| auto& spec_list = *allocators; |
| std::vector<std::string> spec_storage; |
| spec_storage.reserve(spec_list.size()); |
| for (auto item : spec_list) { |
| auto spec = py::cast<std::string>(item); |
| spec_storage.push_back(std::move(spec)); |
| } |
| std::vector<iree_string_view_t> spec_views; |
| spec_views.reserve(spec_list.size()); |
| for (const auto& spec : spec_storage) { |
| spec_views.push_back(iree_make_string_view(spec.data(), spec.size())); |
| } |
| IREE_RETURN_IF_ERROR(iree_hal_configure_allocator_from_specs( |
| spec_views.size(), spec_views.data(), device)); |
| } |
| |
| IREE_RETURN_IF_ERROR(iree_hal_device_set_default_channel_provider(device)); |
| |
| return iree_ok_status(); |
| } |
| |
| HalDevice HalDriver::CreateDefaultDevice(std::optional<py::list> allocators) { |
| iree_hal_device_t* device; |
| CheckApiStatus(iree_hal_driver_create_default_device( |
| raw_ptr(), iree_allocator_system(), &device), |
| "Error creating default device"); |
| CheckApiStatus(ConfigureDevice(device, allocators), |
| "Error configuring the device"); |
| return HalDevice::StealFromRawPtr(device); |
| } |
| |
| HalDevice HalDriver::CreateDevice(iree_hal_device_id_t device_id, |
| std::optional<py::list> allocators) { |
| // Since the device ids are supposed to be opaque, we need to verify |
| // them by querying available devices. |
| py::list available_devices = QueryAvailableDevices(); |
| bool found = false; |
| py::object compare_device_id = py::cast(device_id); |
| for (auto record : available_devices) { |
| // Each record is a dict: |
| // {"device_id": obj, "path": str, "name": str}. |
| auto record_dict = py::cast<py::dict>(record); |
| py::object found_device_id = record_dict["device_id"]; |
| if (found_device_id.equal(compare_device_id)) { |
| found = true; |
| break; |
| } |
| } |
| |
| if (!found) { |
| std::string msg; |
| msg.append("Device id "); |
| msg.append(std::to_string(device_id)); |
| msg.append(" not found. Available devices: "); |
| msg.append(py::cast<std::string>(py::repr(available_devices))); |
| throw std::invalid_argument(std::move(msg)); |
| } |
| |
| std::vector<iree_string_pair_t> params; |
| iree_hal_device_t* device; |
| CheckApiStatus(iree_hal_driver_create_device_by_id( |
| raw_ptr(), device_id, params.size(), |
| (params.empty() ? nullptr : ¶ms.front()), |
| iree_allocator_system(), &device), |
| "Error creating default device"); |
| CheckApiStatus(ConfigureDevice(device, allocators), |
| "Error configuring the device"); |
| return HalDevice::StealFromRawPtr(device); |
| } |
| |
| HalDevice HalDriver::CreateDeviceByURI(std::string& device_uri, |
| std::optional<py::list> allocators) { |
| iree_hal_device_t* device; |
| iree_string_view_t device_uri_sv{ |
| device_uri.data(), static_cast<iree_host_size_t>(device_uri.size())}; |
| CheckApiStatus( |
| iree_hal_driver_create_device_by_uri(raw_ptr(), device_uri_sv, |
| iree_allocator_system(), &device), |
| "Error creating device"); |
| CheckApiStatus(ConfigureDevice(device, allocators), |
| "Error configuring the device"); |
| return HalDevice::StealFromRawPtr(device); |
| } |
| |
| //------------------------------------------------------------------------------ |
| // HAL module |
| //------------------------------------------------------------------------------ |
| |
| VmModule CreateHalModule(VmInstance* instance, std::optional<HalDevice*> device, |
| std::optional<py::list> devices) { |
| if (device && devices) { |
| PyErr_SetString( |
| PyExc_ValueError, |
| "\"device\" and \"devices\" are mutually exclusive arguments."); |
| } |
| std::vector<iree_hal_device_t*> devices_vector; |
| iree_hal_device_t* device_ptr; |
| iree_hal_device_t** devices_ptr; |
| iree_host_size_t device_count; |
| iree_vm_module_t* module = NULL; |
| if (device) { |
| device_ptr = device.value()->raw_ptr(); |
| devices_ptr = &device_ptr; |
| device_count = 1; |
| } else { |
| // Set device related arguments in the case of multiple devices. |
| devices_vector.reserve(devices->size()); |
| for (auto devicesIt = devices->begin(); devicesIt != devices->end(); |
| ++devicesIt) { |
| devices_vector.push_back(py::cast<HalDevice*>(*devicesIt)->raw_ptr()); |
| } |
| devices_ptr = devices_vector.data(); |
| device_count = devices_vector.size(); |
| } |
| CheckApiStatus(iree_hal_module_create(instance->raw_ptr(), device_count, |
| devices_ptr, IREE_HAL_MODULE_FLAG_NONE, |
| iree_allocator_system(), &module), |
| "Error creating hal module"); |
| return VmModule::StealFromRawPtr(module); |
| } |
| |
| //------------------------------------------------------------------------------ |
| // Bindings |
| //------------------------------------------------------------------------------ |
| |
| void SetupHalBindings(nanobind::module_ m) { |
| py::dict driver_cache; |
| |
| // Built-in module creation. |
| m.def("create_hal_module", &CreateHalModule, py::arg("instance"), |
| py::arg("device") = py::none(), py::arg("devices") = py::none()); |
| |
| // Enums. |
| py::enum_<enum iree_hal_memory_type_bits_t>(m, "MemoryType") |
| .value("NONE", IREE_HAL_MEMORY_TYPE_NONE) |
| .value("OPTIMAL", IREE_HAL_MEMORY_TYPE_OPTIMAL) |
| .value("HOST_VISIBLE", IREE_HAL_MEMORY_TYPE_HOST_VISIBLE) |
| .value("HOST_COHERENT", IREE_HAL_MEMORY_TYPE_HOST_COHERENT) |
| .value("HOST_CACHED", IREE_HAL_MEMORY_TYPE_HOST_CACHED) |
| .value("HOST_LOCAL", IREE_HAL_MEMORY_TYPE_HOST_LOCAL) |
| .value("DEVICE_VISIBLE", IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE) |
| .value("DEVICE_LOCAL", IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL) |
| .export_values() |
| .def("__or__", [](uint64_t self, uint64_t other) { return self | other; }) |
| .def("__and__", |
| [](uint64_t self, uint64_t other) { return self & other; }) |
| .def("__int__", [](enum iree_hal_memory_type_bits_t self) { |
| return (uint64_t)self; |
| }); |
| |
| py::enum_<enum iree_hal_buffer_compatibility_bits_t>(m, "BufferCompatibility") |
| .value("NONE", IREE_HAL_BUFFER_COMPATIBILITY_NONE) |
| .value("ALLOCATABLE", IREE_HAL_BUFFER_COMPATIBILITY_ALLOCATABLE) |
| .value("IMPORTABLE", IREE_HAL_BUFFER_COMPATIBILITY_IMPORTABLE) |
| .value("EXPORTABLE", IREE_HAL_BUFFER_COMPATIBILITY_EXPORTABLE) |
| .value("QUEUE_TRANSFER", IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER) |
| .value("QUEUE_DISPATCH", IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH) |
| .export_values() |
| .def("__or__", [](uint64_t self, uint64_t other) { return self | other; }) |
| .def("__and__", |
| [](uint64_t self, uint64_t other) { return self & other; }) |
| .def("__int__", [](enum iree_hal_buffer_compatibility_bits_t self) { |
| return (uint64_t)self; |
| }); |
| |
| py::enum_<enum iree_hal_buffer_usage_bits_t>(m, "BufferUsage") |
| .value("NONE", IREE_HAL_BUFFER_USAGE_NONE) |
| .value("TRANSFER_SOURCE", IREE_HAL_BUFFER_USAGE_TRANSFER_SOURCE) |
| .value("TRANSFER_TARGET", IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET) |
| .value("TRANSFER", IREE_HAL_BUFFER_USAGE_TRANSFER) |
| .value("DISPATCH_INDIRECT_PARAMS", |
| IREE_HAL_BUFFER_USAGE_DISPATCH_INDIRECT_PARAMS) |
| .value("DISPATCH_UNIFORM_READ", |
| IREE_HAL_BUFFER_USAGE_DISPATCH_UNIFORM_READ) |
| .value("DISPATCH_STORAGE_READ", |
| IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE_READ) |
| .value("DISPATCH_STORAGE_WRITE", |
| IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE_WRITE) |
| .value("DISPATCH_STORAGE", IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE) |
| .value("DISPATCH_IMAGE_READ", IREE_HAL_BUFFER_USAGE_DISPATCH_IMAGE_READ) |
| .value("DISPATCH_IMAGE_WRITE", IREE_HAL_BUFFER_USAGE_DISPATCH_IMAGE_WRITE) |
| .value("DISPATCH_IMAGE", IREE_HAL_BUFFER_USAGE_DISPATCH_IMAGE) |
| .value("SHARING_EXPORT", IREE_HAL_BUFFER_USAGE_SHARING_EXPORT) |
| .value("SHARING_REPLICATE", IREE_HAL_BUFFER_USAGE_SHARING_REPLICATE) |
| .value("SHARING_CONCURRENT", IREE_HAL_BUFFER_USAGE_SHARING_CONCURRENT) |
| .value("SHARING_IMMUTABLE", IREE_HAL_BUFFER_USAGE_SHARING_IMMUTABLE) |
| .value("MAPPING_SCOPED", IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED) |
| .value("MAPPING_PERSISTENT", IREE_HAL_BUFFER_USAGE_MAPPING_PERSISTENT) |
| .value("MAPPING_OPTIONAL", IREE_HAL_BUFFER_USAGE_MAPPING_OPTIONAL) |
| .value("MAPPING_ACCESS_RANDOM", |
| IREE_HAL_BUFFER_USAGE_MAPPING_ACCESS_RANDOM) |
| .value("MAPPING_ACCESS_SEQUENTIAL_WRITE", |
| IREE_HAL_BUFFER_USAGE_MAPPING_ACCESS_SEQUENTIAL_WRITE) |
| .value("MAPPING", IREE_HAL_BUFFER_USAGE_MAPPING) |
| .value("DEFAULT", IREE_HAL_BUFFER_USAGE_DEFAULT) |
| .export_values() |
| .def("__or__", [](enum iree_hal_buffer_usage_bits_t self, |
| uint64_t other) { return self | other; }) |
| .def("__and__", [](enum iree_hal_buffer_usage_bits_t self, |
| uint64_t other) { return self & other; }) |
| .def("__int__", [](enum iree_hal_buffer_usage_bits_t self) { |
| return (uint64_t)self; |
| }); |
| |
| py::enum_<enum iree_hal_memory_access_bits_t>(m, "MemoryAccess") |
| .value("NONE", IREE_HAL_MEMORY_ACCESS_NONE) |
| .value("READ", IREE_HAL_MEMORY_ACCESS_READ) |
| .value("WRITE", IREE_HAL_MEMORY_ACCESS_WRITE) |
| .value("DISCARD", IREE_HAL_MEMORY_ACCESS_DISCARD) |
| .value("DISCARD_WRITE", IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE) |
| .value("ALL", IREE_HAL_MEMORY_ACCESS_ALL) |
| .export_values() |
| .def("__or__", [](uint64_t self, uint64_t other) { return self | other; }) |
| .def("__and__", |
| [](uint64_t self, uint64_t other) { return self & other; }) |
| .def("__int__", [](enum iree_hal_memory_access_bits_t self) { |
| return (uint64_t)self; |
| }); |
| |
| // Use compatibility type to enable def_static. |
| // See: https://github.com/wjakob/nanobind/issues/597 |
| auto hal_element_type = nanobind1_compat_enum_<enum iree_hal_element_types_t>( |
| m, "HalElementType"); |
| hal_element_type |
| .def_static("map_to_dtype", |
| [](iree_hal_element_type_t element_type) { |
| int typenum = numpy::ConvertHalElementTypeToNumPyTypeNum( |
| element_type); |
| return numpy::DescrNewFromType(typenum); |
| }) |
| .def_static("is_byte_aligned", |
| [](iree_hal_element_type_t element_type) { |
| return iree_hal_element_is_byte_aligned(element_type); |
| }) |
| .def_static("dense_byte_count", [](iree_hal_element_type_t element_type) { |
| return iree_hal_element_dense_byte_count(element_type); |
| }); |
| hal_element_type.value("NONE", IREE_HAL_ELEMENT_TYPE_NONE) |
| .value("OPAQUE_8", IREE_HAL_ELEMENT_TYPE_OPAQUE_8) |
| .value("OPAQUE_16", IREE_HAL_ELEMENT_TYPE_OPAQUE_16) |
| .value("OPAQUE_32", IREE_HAL_ELEMENT_TYPE_OPAQUE_32) |
| .value("OPAQUE_64", IREE_HAL_ELEMENT_TYPE_OPAQUE_64) |
| .value("BOOL_8", IREE_HAL_ELEMENT_TYPE_BOOL_8) |
| .value("INT_4", IREE_HAL_ELEMENT_TYPE_INT_4) |
| .value("INT_8", IREE_HAL_ELEMENT_TYPE_INT_8) |
| .value("INT_16", IREE_HAL_ELEMENT_TYPE_INT_16) |
| .value("INT_32", IREE_HAL_ELEMENT_TYPE_INT_32) |
| .value("INT_64", IREE_HAL_ELEMENT_TYPE_INT_64) |
| .value("SINT_4", IREE_HAL_ELEMENT_TYPE_SINT_4) |
| .value("SINT_8", IREE_HAL_ELEMENT_TYPE_SINT_8) |
| .value("SINT_16", IREE_HAL_ELEMENT_TYPE_SINT_16) |
| .value("SINT_32", IREE_HAL_ELEMENT_TYPE_SINT_32) |
| .value("SINT_64", IREE_HAL_ELEMENT_TYPE_SINT_64) |
| .value("UINT_4", IREE_HAL_ELEMENT_TYPE_UINT_4) |
| .value("UINT_8", IREE_HAL_ELEMENT_TYPE_UINT_8) |
| .value("UINT_16", IREE_HAL_ELEMENT_TYPE_UINT_16) |
| .value("UINT_32", IREE_HAL_ELEMENT_TYPE_UINT_32) |
| .value("UINT_64", IREE_HAL_ELEMENT_TYPE_UINT_64) |
| .value("FLOAT_16", IREE_HAL_ELEMENT_TYPE_FLOAT_16) |
| .value("FLOAT_32", IREE_HAL_ELEMENT_TYPE_FLOAT_32) |
| .value("FLOAT_64", IREE_HAL_ELEMENT_TYPE_FLOAT_64) |
| .value("BFLOAT_16", IREE_HAL_ELEMENT_TYPE_BFLOAT_16) |
| .value("COMPLEX_64", IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64) |
| .value("COMPLEX_128", IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128) |
| .export_values() |
| .def("__int__", |
| [](enum iree_hal_element_types_t self) { return (uint64_t)self; }); |
| |
| py::class_<HalDevice>(m, "HalDevice") |
| .def_prop_ro( |
| "allocator", |
| [](HalDevice& self) { |
| return HalAllocator::BorrowFromRawPtr(self.allocator()); |
| }, |
| py::keep_alive<0, 1>()) |
| .def("begin_profiling", &HalDevice::BeginProfiling, |
| py::arg("mode") = py::none(), py::arg("file_path") = py::none()) |
| .def("flush_profiling", &HalDevice::FlushProfiling) |
| .def("end_profiling", &HalDevice::EndProfiling) |
| .def("create_semaphore", &HalDevice::CreateSemaphore, |
| py::arg("initial_value")) |
| .def("queue_alloca", &HalDevice::QueueAlloca, py::arg("allocation_size"), |
| py::arg("wait_semaphores"), py::arg("signal_semaphores"), |
| kHalDeviceQueueAlloca) |
| .def("queue_dealloca", &HalDevice::QueueDealloca, py::arg("buffer"), |
| py::arg("wait_semaphores"), py::arg("signal_semaphores"), |
| kHalDeviceQueueDealloca) |
| .def("queue_execute", &HalDevice::QueueExecute, |
| py::arg("command_buffers"), py::arg("wait_semaphores"), |
| py::arg("signal_semaphores"), kHalDeviceQueueExecute) |
| .def("queue_copy", &HalDevice::QueueCopy, py::arg("source_buffer"), |
| py::arg("target_buffer"), py::arg("wait_semaphores"), |
| py::arg("signal_semaphores"), kHalDeviceQueueCopy) |
| .def("create_dlpack_capsule", &HalDevice::CreateDLPackCapsule, |
| py::arg("buffer_view"), py::arg("device_type_code"), |
| py::arg("device_id")) |
| .def("from_dlpack_capsule", &HalDevice::FromDLPackCapsule) |
| .def("__repr__", [](HalDevice& self) { |
| auto id_sv = iree_hal_device_id(self.raw_ptr()); |
| return std::string(id_sv.data, id_sv.size); |
| }); |
| |
| py::class_<HalDriver>(m, "HalDriver") |
| .def_static("query", &HalDriver::Query) |
| |
| // All 'create_device' functions take optional kwargs that should be kept |
| // in sync. |
| .def("create_default_device", &HalDriver::CreateDefaultDevice, |
| py::keep_alive<0, 1>(), py::arg("allocators") = py::none()) |
| .def("create_device", &HalDriver::CreateDevice, py::keep_alive<0, 1>(), |
| py::arg("device_id"), py::arg("allocators") = py::none()) |
| .def("create_device_by_uri", &HalDriver::CreateDeviceByURI, |
| py::keep_alive<0, 1>(), py::arg("device_uri"), |
| py::arg("allocators") = py::none()) |
| .def( |
| "create_device", |
| [](HalDriver& self, py::dict device_info, |
| std::optional<py::list> allocators) -> HalDevice { |
| // Alias of create_device that takes a dict as returned from |
| // query_available_devices for convenience. |
| auto device_id = |
| py::cast<iree_hal_device_id_t>(device_info["device_id"]); |
| return self.CreateDevice(device_id, allocators); |
| }, |
| py::keep_alive<0, 1>(), py::arg("device_info"), |
| py::arg("allocators") = py::none()) |
| .def("query_available_devices", &HalDriver::QueryAvailableDevices) |
| .def("dump_device_info", |
| [](HalDriver& self, iree_hal_device_id_t device_id) { |
| iree_string_builder_t builder; |
| iree_string_builder_initialize(iree_allocator_system(), &builder); |
| CheckApiStatus(iree_hal_driver_dump_device_info( |
| self.raw_ptr(), device_id, &builder), |
| "Querying device info"); |
| iree_string_view_t view = iree_string_builder_view(&builder); |
| py::str result(view.data, view.size); |
| iree_string_builder_deinitialize(&builder); |
| return result; |
| }); |
| |
| m.def( |
| "get_cached_hal_driver", |
| [driver_cache](std::string device_uri) { |
| return HalDriver::Create(device_uri, |
| const_cast<py::dict&>(driver_cache)); |
| }, |
| py::arg("device_uri")); |
| |
| m.def( |
| "create_hal_driver", |
| [](std::string device_uri) { return HalDriver::Create(device_uri); }, |
| py::arg("device_uri")); |
| |
| m.def("clear_hal_driver_cache", |
| [driver_cache]() { const_cast<py::dict&>(driver_cache).clear(); }); |
| |
| py::class_<HalAllocator>(m, "HalAllocator") |
| .def("trim", |
| [](HalAllocator& self) { |
| CheckApiStatus(iree_hal_allocator_trim(self.raw_ptr()), |
| "Error trim()'ing HAL allocator"); |
| }) |
| .def_prop_ro( |
| "has_statistics", |
| [](HalAllocator& self) -> bool { return IREE_STATISTICS_ENABLE; }) |
| .def_prop_ro("statistics", &HalAllocator::QueryStatistics) |
| .def_prop_ro("formatted_statistics", &HalAllocator::FormattedStatistics) |
| .def( |
| "query_buffer_compatibility", |
| [](HalAllocator& self, int memory_type, int allowed_usage, |
| int intended_usage, iree_device_size_t allocation_size) -> int { |
| iree_hal_buffer_params_t params = {0}; |
| params.type = memory_type; |
| params.usage = allowed_usage & intended_usage; |
| return iree_hal_allocator_query_buffer_compatibility( |
| self.raw_ptr(), params, allocation_size, |
| /*out_params=*/nullptr, /*out_allocation_size=*/0); |
| }, |
| py::arg("memory_type"), py::arg("allowed_usage"), |
| py::arg("intended_usage"), py::arg("allocation_size")) |
| .def( |
| "allocate_buffer", |
| [](HalAllocator& self, int memory_type, int allowed_usage, |
| iree_device_size_t allocation_size) { |
| iree_hal_buffer_params_t params = {0}; |
| params.type = memory_type; |
| params.usage = allowed_usage; |
| iree_hal_buffer_t* buffer = nullptr; |
| CheckApiStatus( |
| iree_hal_allocator_allocate_buffer(self.raw_ptr(), params, |
| allocation_size, &buffer), |
| "could not allocate buffer"); |
| return HalBuffer::StealFromRawPtr(buffer); |
| }, |
| py::arg("memory_type"), py::arg("allowed_usage"), |
| py::arg("allocation_size"), py::keep_alive<0, 1>(), |
| "Allocates a new buffer with requested characteristics (does not " |
| "initialize with specific data).") |
| .def("allocate_buffer_copy", &HalAllocator::AllocateBufferCopy, |
| py::arg("memory_type"), py::arg("allowed_usage"), py::arg("device"), |
| py::arg("buffer"), py::arg("element_type") = py::none(), |
| py::keep_alive<0, 1>(), |
| "Allocates a new buffer and initializes it from a Python buffer " |
| "object. If an element type is specified, wraps in a BufferView " |
| "matching the characteristics of the Python buffer. The format is " |
| "requested as ND/C-Contiguous, which may incur copies if not " |
| "already in that format.") |
| .def("allocate_host_staging_buffer_copy", |
| &HalAllocator::AllocateHostStagingBufferCopy, py::arg("device"), |
| py::arg("initial_contents"), py::keep_alive<0, 1>(), |
| "Allocates a new buffer and initializes it from a Python buffer " |
| "object. The buffer is configured as optimal for use on the device " |
| "as a transfer buffer. For buffers of unknown providence, this is a " |
| "last resort method for making them compatible for transfer to " |
| "arbitrary devices."); |
| |
| auto hal_buffer = py::class_<HalBuffer>(m, "HalBuffer"); |
| VmRef::BindRefProtocol(hal_buffer, iree_hal_buffer_type, |
| iree_hal_buffer_retain_ref, iree_hal_buffer_deref, |
| iree_hal_buffer_isa); |
| hal_buffer |
| .def("fill_zero", &HalBuffer::FillZero, py::arg("byte_offset"), |
| py::arg("byte_length")) |
| .def("byte_length", &HalBuffer::byte_length) |
| .def("memory_type", &HalBuffer::memory_type) |
| .def("allowed_usage", &HalBuffer::allowed_usage) |
| .def("create_view", &HalBuffer::CreateView, py::arg("shape"), |
| py::arg("element_size"), py::keep_alive<0, 1>()) |
| .def("map", HalMappedMemory::CreateFromBuffer, py::keep_alive<0, 1>()) |
| .def("__repr__", &HalBuffer::Repr); |
| |
| auto hal_buffer_view = py::class_<HalBufferView>(m, "HalBufferView"); |
| VmRef::BindRefProtocol(hal_buffer_view, iree_hal_buffer_view_type, |
| iree_hal_buffer_view_retain_ref, |
| iree_hal_buffer_view_deref, iree_hal_buffer_view_isa); |
| hal_buffer_view.def( |
| "__init__", |
| [](HalBufferView* new_self, HalBuffer& buffer, py::handle shape, |
| iree_hal_element_type_t element_type) { |
| size_t rank = py::len(shape); |
| iree_hal_dim_t* dims = |
| static_cast<iree_hal_dim_t*>(alloca(sizeof(iree_hal_dim_t) * rank)); |
| for (size_t i = 0; i < rank; ++i) { |
| dims[i] = py::cast<iree_hal_dim_t>(shape[i]); |
| } |
| iree_hal_buffer_view_t* out_bv; |
| CheckApiStatus(iree_hal_buffer_view_create( |
| buffer.raw_ptr(), rank, dims, element_type, |
| IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, |
| iree_allocator_system(), &out_bv), |
| "creating buffer view"); |
| new (new_self) HalBufferView(); |
| *new_self = HalBufferView::StealFromRawPtr(out_bv); |
| }, |
| py::arg("buffer"), py::arg("shape"), py::arg("element_type")); |
| hal_buffer_view |
| .def("map", HalMappedMemory::CreateFromBufferView, py::keep_alive<0, 1>()) |
| .def("get_buffer", HalBuffer::CreateFromBufferView, |
| py::keep_alive<0, 1>()) |
| .def_prop_ro("shape", |
| [](HalBufferView& self) { |
| iree_host_size_t rank = |
| iree_hal_buffer_view_shape_rank(self.raw_ptr()); |
| auto* dims = |
| iree_hal_buffer_view_shape_dims(self.raw_ptr()); |
| py::list result; |
| for (iree_host_size_t i = 0; i < rank; ++i) { |
| result.append(dims[i]); |
| } |
| return result; |
| }) |
| .def_prop_ro("element_type", |
| [](HalBufferView& self) { |
| return iree_hal_buffer_view_element_type(self.raw_ptr()); |
| }) |
| .def_prop_ro("byte_length", |
| [](HalBufferView& self) { |
| return iree_hal_buffer_view_byte_length(self.raw_ptr()); |
| }) |
| .def("__repr__", &HalBufferView::Repr); |
| |
| py::class_<HalSemaphore>(m, "HalSemaphore") |
| .def( |
| "fail", |
| [](HalSemaphore& self, std::string& message) { |
| // TODO: Take some category enum and use that is available. |
| iree_status_t status = |
| iree_make_status(IREE_STATUS_UNKNOWN, "%s", message.c_str()); |
| iree_hal_semaphore_fail(self.raw_ptr(), status); |
| }, |
| py::arg("message")) |
| .def("query", |
| [](HalSemaphore& self) { |
| uint64_t out_value; |
| CheckApiStatus( |
| iree_hal_semaphore_query(self.raw_ptr(), &out_value), |
| "querying semaphore"); |
| return out_value; |
| }) |
| .def("signal", |
| [](HalSemaphore& self, uint64_t new_value) { |
| CheckApiStatus( |
| iree_hal_semaphore_signal(self.raw_ptr(), new_value), |
| "signaling semaphore"); |
| }) |
| .def( |
| "wait", |
| [](HalSemaphore& self, uint64_t payload, |
| std::optional<iree_duration_t> timeout, |
| std::optional<iree_time_t> deadline) -> bool { |
| iree_timeout_t t = NormalizeTimeout(timeout, deadline); |
| iree_status_t status; |
| uint64_t unused_value; |
| { |
| py::gil_scoped_release release; |
| status = iree_hal_semaphore_wait(self.raw_ptr(), payload, t); |
| } |
| if (iree_status_is_deadline_exceeded(status)) { |
| // Time out. |
| return false; |
| } else if (iree_status_is_aborted(status)) { |
| // Synchronous failure. |
| iree_status_ignore(status); |
| status = iree_hal_semaphore_query(self.raw_ptr(), &unused_value); |
| if (iree_status_is_ok(status)) { |
| status = iree_make_status( |
| IREE_STATUS_FAILED_PRECONDITION, |
| "expected synchronous status failure missing"); |
| } |
| CheckApiStatus(status, "synchronous semaphore failure"); |
| } else { |
| // General failure check. |
| CheckApiStatus(status, "waiting for semaphore"); |
| } |
| |
| // Asynchronous failure. |
| status = iree_hal_semaphore_query(self.raw_ptr(), &unused_value); |
| if (iree_status_is_deferred(status)) { |
| return false; |
| } |
| CheckApiStatus(status, "asynchronous semaphore failure"); |
| return true; |
| }, |
| py::arg("payload"), py::arg("timeout") = py::none(), |
| py::arg("deadline") = py::none(), kHalWait); |
| |
| auto hal_fence = py::class_<HalFence>(m, "HalFence"); |
| VmRef::BindRefProtocol(hal_fence, iree_hal_fence_type, |
| iree_hal_fence_retain_ref, iree_hal_fence_deref, |
| iree_hal_fence_isa); |
| hal_fence |
| .def( |
| "__init__", |
| [](HalFence* new_fence, iree_host_size_t capacity) { |
| iree_hal_fence_t* out_fence; |
| CheckApiStatus(iree_hal_fence_create( |
| capacity, iree_allocator_system(), &out_fence), |
| "creating fence"); |
| new (new_fence) HalFence(); |
| (*new_fence) = HalFence::StealFromRawPtr(out_fence); |
| }, |
| py::arg("capacity")) |
| .def_static( |
| "create_at", |
| [](HalSemaphore& sem, uint64_t value) { |
| iree_hal_fence_t* out_fence; |
| CheckApiStatus( |
| iree_hal_fence_create_at(sem.raw_ptr(), value, |
| iree_allocator_system(), &out_fence), |
| "creating fence"); |
| return HalFence::StealFromRawPtr(out_fence); |
| }, |
| py::arg("sem"), py::arg("value")) |
| .def_static( |
| "join", |
| [](py::sequence fences) { |
| size_t count = py::len(fences); |
| iree_hal_fence_t** fence_ptrs = static_cast<iree_hal_fence_t**>( |
| alloca(sizeof(iree_hal_fence_t*) * count)); |
| for (size_t i = 0; i < count; ++i) { |
| fence_ptrs[i] = py::cast<HalFence*>(fences[i])->raw_ptr(); |
| } |
| iree_hal_fence_t* out_fence; |
| CheckApiStatus( |
| iree_hal_fence_join(count, fence_ptrs, iree_allocator_system(), |
| &out_fence), |
| "joining fences"); |
| return HalFence::StealFromRawPtr(out_fence); |
| }, |
| py::arg("fences")) |
| .def_prop_ro("timepoint_count", |
| [](HalFence& self) { |
| return iree_hal_fence_timepoint_count(self.raw_ptr()); |
| }) |
| .def( |
| "insert", |
| [](HalFence& self, HalSemaphore& sem, uint64_t value) { |
| CheckApiStatus( |
| iree_hal_fence_insert(self.raw_ptr(), sem.raw_ptr(), value), |
| "insertint into fence"); |
| }, |
| py::arg("sem"), py::arg("value")) |
| .def( |
| "extend", |
| [](HalFence& self, HalFence& from_fence) { |
| CheckApiStatus( |
| iree_hal_fence_extend(self.raw_ptr(), from_fence.raw_ptr()), |
| "extending fence"); |
| }, |
| py::arg("from_fence")) |
| .def( |
| "fail", |
| [](HalFence& self, std::string& message) { |
| // TODO: Take some category enum and use that is available. |
| iree_status_t status = |
| iree_make_status(IREE_STATUS_UNKNOWN, "%s", message.c_str()); |
| iree_hal_fence_fail(self.raw_ptr(), status); |
| }, |
| py::arg("message")) |
| .def("signal", |
| [](HalFence& self) { |
| CheckApiStatus(iree_hal_fence_signal(self.raw_ptr()), |
| "signalling fence"); |
| }) |
| .def( |
| "wait", |
| [](HalFence& self, std::optional<iree_duration_t> timeout, |
| std::optional<iree_time_t> deadline) -> bool { |
| iree_timeout_t t = NormalizeTimeout(timeout, deadline); |
| iree_status_t status; |
| { |
| py::gil_scoped_release release; |
| status = iree_hal_fence_wait(self.raw_ptr(), t); |
| } |
| if (iree_status_is_deadline_exceeded(status)) { |
| // Time out. |
| return false; |
| } else if (iree_status_is_aborted(status)) { |
| // Synchronous failure. |
| iree_status_ignore(status); |
| status = iree_hal_fence_query(self.raw_ptr()); |
| if (iree_status_is_ok(status)) { |
| status = iree_make_status( |
| IREE_STATUS_FAILED_PRECONDITION, |
| "expected synchronous status failure missing"); |
| } |
| CheckApiStatus(status, "synchronous fence failure"); |
| } else { |
| // General failure check. |
| CheckApiStatus(status, "waiting for fence"); |
| } |
| |
| // Asynchronous failure. |
| status = iree_hal_fence_query(self.raw_ptr()); |
| if (iree_status_is_deferred(status)) { |
| return false; |
| } |
| CheckApiStatus(status, "asynchronous fence failure"); |
| return true; |
| }, |
| py::arg("timeout") = py::none(), py::arg("deadline") = py::none(), |
| kHalWait); |
| |
| py::class_<HalMappedMemory>(m, "MappedMemory") |
| .def( |
| "asarray", |
| [](HalMappedMemory* self, py::handle shape, py::object dtype_descr) { |
| py::object py_mapped_memory = py::cast(self); |
| size_t rank = py::len(shape); |
| intptr_t* dims = |
| static_cast<intptr_t*>(alloca(sizeof(intptr_t) * rank)); |
| for (size_t i = 0; i < rank; ++i) { |
| dims[i] = py::cast<intptr_t>(shape[i]); |
| } |
| int typenum = numpy::TypenumFromDescr(dtype_descr); |
| return numpy::SimpleNewFromData(rank, dims, typenum, |
| self->mapped_memory().contents.data, |
| py_mapped_memory); |
| }, |
| py::arg("shape"), py::arg("numpy_dtype_descr")); |
| |
| py::class_<HalShape>(m, "Shape") |
| .def("__init__", [](HalShape* self, std::vector<iree_hal_dim_t> indices) { |
| new (self) HalShape(indices); |
| }); |
| |
| py::class_<HalCommandBuffer>(m, "HalCommandBuffer") |
| .def( |
| "__init__", |
| [](HalCommandBuffer* new_self, HalDevice& device, |
| iree_host_size_t binding_capacity, bool begin) { |
| iree_hal_command_buffer_t* out_cb; |
| CheckApiStatus(iree_hal_command_buffer_create( |
| device.raw_ptr(), |
| /*mode=*/IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, |
| /*categories=*/IREE_HAL_COMMAND_CATEGORY_ANY, |
| /*queue_affinity=*/IREE_HAL_QUEUE_AFFINITY_ANY, |
| binding_capacity, &out_cb), |
| "creating command buffer"); |
| HalCommandBuffer cb = HalCommandBuffer::StealFromRawPtr(out_cb); |
| if (begin) { |
| CheckApiStatus(iree_hal_command_buffer_begin(cb.raw_ptr()), |
| "command buffer begin"); |
| } |
| new (new_self) HalCommandBuffer(); |
| *new_self = std::move(cb); |
| }, |
| py::arg("device"), py::arg("binding_capacity") = 0, |
| py::arg("begin") = true) |
| .def("begin", |
| [](HalCommandBuffer& self) { |
| CheckApiStatus(iree_hal_command_buffer_begin(self.raw_ptr()), |
| "command buffer begin"); |
| }) |
| .def("end", |
| [](HalCommandBuffer& self) { |
| CheckApiStatus(iree_hal_command_buffer_end(self.raw_ptr()), |
| "command buffer end"); |
| }) |
| .def( |
| "copy", |
| [](HalCommandBuffer& self, HalBuffer& source_buffer, |
| HalBuffer& target_buffer, iree_device_size_t source_offset, |
| iree_device_size_t target_offset, |
| std::optional<iree_device_size_t> length, bool end) { |
| iree_device_size_t resolved_length; |
| if (length) { |
| resolved_length = *length; |
| } else { |
| resolved_length = |
| iree_hal_buffer_byte_length(source_buffer.raw_ptr()); |
| if (resolved_length != |
| iree_hal_buffer_byte_length(target_buffer.raw_ptr())) { |
| throw std::invalid_argument( |
| "If length is not provided, source and target bufer length " |
| "must match and it does not. Provide explicit length="); |
| } |
| } |
| CheckApiStatus( |
| iree_hal_command_buffer_copy_buffer( |
| self.raw_ptr(), |
| iree_hal_make_buffer_ref(source_buffer.raw_ptr(), |
| source_offset, resolved_length), |
| iree_hal_make_buffer_ref(target_buffer.raw_ptr(), |
| target_offset, resolved_length)), |
| "copy command"); |
| if (end) { |
| CheckApiStatus(iree_hal_command_buffer_end(self.raw_ptr()), |
| "command buffer end"); |
| } |
| }, |
| py::arg("source_buffer"), py::arg("target_buffer"), |
| py::arg("source_offset") = 0, py::arg("target_offset") = 0, |
| py::arg("length") = py::none(), py::arg("end") = false, |
| "Copies a range from a source to target buffer. If the length is " |
| "not specified, then it is taken from the source/target buffer, " |
| "which must match.") |
| .def( |
| "fill", |
| [](HalCommandBuffer& self, HalBuffer& target_buffer, |
| py::handle pattern, iree_device_size_t target_offset, |
| std::optional<iree_device_size_t> length, bool end) { |
| Py_buffer pattern_view; |
| int flags = PyBUF_FORMAT | PyBUF_ND; |
| if (PyObject_GetBuffer(pattern.ptr(), &pattern_view, flags) != 0) { |
| // The GetBuffer call is required to set an appropriate error. |
| throw py::python_error(); |
| } |
| PyBufferReleaser py_pattern_releaser(pattern_view); |
| |
| iree_device_size_t resolved_length; |
| if (length) { |
| resolved_length = *length; |
| } else { |
| resolved_length = |
| iree_hal_buffer_byte_length(target_buffer.raw_ptr()); |
| } |
| CheckApiStatus( |
| iree_hal_command_buffer_fill_buffer( |
| self.raw_ptr(), |
| iree_hal_make_buffer_ref(target_buffer.raw_ptr(), |
| target_offset, resolved_length), |
| pattern_view.buf, pattern_view.len), |
| "command buffer fill"); |
| if (end) { |
| CheckApiStatus(iree_hal_command_buffer_end(self.raw_ptr()), |
| "command buffer end"); |
| } |
| }, |
| py::arg("target_buffer"), py::arg("pattern"), |
| py::arg("target_offset") = 0, py::arg("length") = py::none(), |
| py::arg("end") = false); |
| } |
| |
| } // namespace python |
| } // namespace iree |