blob: 1fb386307cda807d170c21d6bc5db252b258bd44 [file] [log] [blame]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#ifndef IREE_BINDINGS_PYTHON_IREE_RT_HAL_H_
#define IREE_BINDINGS_PYTHON_IREE_RT_HAL_H_
#include <vector>
#include "./binding.h"
#include "./status_utils.h"
#include "./vm.h"
#include "iree/hal/api.h"
namespace iree {
namespace python {
//------------------------------------------------------------------------------
// Retain/release bindings
// Note that all HAL types have keep alive relationships in addition to this
// (using the py:keep_alive<>() facility). These relationships form a chain
// such that any live Python leaf (like a buffer or buffer_view) must keep
// alive the allocator, device and driver that created it.
//
// The hierarchy is:
// HalDriver
// HalDevice
// HalAllocator
// HalBuffer
// HalBufferView
//
// Any Python API which produces one of the above must be annotated with
// py::keep_alive<0, 1>() in order to establish the relationship with the
// parent.
//
// Any Python API which consumes one of these objects such that its lifetime
// may extend outside of the current invocation must arrange to retain/release
// all backing devices that may need to survive.
//------------------------------------------------------------------------------
template <>
struct ApiPtrAdapter<iree_hal_driver_t> {
static void Retain(iree_hal_driver_t* d) { iree_hal_driver_retain(d); }
static void Release(iree_hal_driver_t* d) { iree_hal_driver_release(d); }
};
template <>
struct ApiPtrAdapter<iree_hal_device_t> {
static void Retain(iree_hal_device_t* d) { iree_hal_device_retain(d); }
static void Release(iree_hal_device_t* d) { iree_hal_device_release(d); }
};
template <>
struct ApiPtrAdapter<iree_hal_allocator_t> {
static void Retain(iree_hal_allocator_t* d) { iree_hal_allocator_retain(d); }
static void Release(iree_hal_allocator_t* d) {
iree_hal_allocator_release(d);
}
};
template <>
struct ApiPtrAdapter<iree_hal_buffer_t> {
static void Retain(iree_hal_buffer_t* b) { iree_hal_buffer_retain(b); }
static void Release(iree_hal_buffer_t* b) { iree_hal_buffer_release(b); }
};
template <>
struct ApiPtrAdapter<iree_hal_buffer_view_t> {
static void Retain(iree_hal_buffer_view_t* bv) {
iree_hal_buffer_view_retain(bv);
}
static void Release(iree_hal_buffer_view_t* bv) {
iree_hal_buffer_view_release(bv);
}
};
template <>
struct ApiPtrAdapter<iree_hal_semaphore_t> {
static void Retain(iree_hal_semaphore_t* sem) {
iree_hal_semaphore_retain(sem);
}
static void Release(iree_hal_semaphore_t* sem) {
iree_hal_semaphore_release(sem);
}
};
template <>
struct ApiPtrAdapter<iree_hal_fence_t> {
static void Retain(iree_hal_fence_t* fence) { iree_hal_fence_retain(fence); }
static void Release(iree_hal_fence_t* fence) {
iree_hal_fence_release(fence);
}
};
template <>
struct ApiPtrAdapter<iree_hal_command_buffer_t> {
static void Retain(iree_hal_command_buffer_t* cb) {
iree_hal_command_buffer_retain(cb);
}
static void Release(iree_hal_command_buffer_t* cb) {
iree_hal_command_buffer_release(cb);
}
};
//------------------------------------------------------------------------------
// ApiRefCounted types
//------------------------------------------------------------------------------
class HalBuffer;
class HalSemaphore;
class HalDevice : public ApiRefCounted<HalDevice, iree_hal_device_t> {
public:
iree_hal_allocator_t* allocator() {
return iree_hal_device_allocator(raw_ptr());
}
void BeginProfiling(std::optional<std::string> mode,
std::optional<std::string> file_path);
void EndProfiling();
HalSemaphore CreateSemaphore(uint64_t initial_value);
HalBuffer QueueAlloca(uint64_t allocation_size, py::handle wait_semaphores,
py::handle signal_semaphores);
void QueueDealloca(HalBuffer& buffer, py::handle wait_semaphores,
py::handle signal_semaphores);
void QueueExecute(py::handle command_buffers, py::handle wait_semaphores,
py::handle signal_semaphores);
};
class HalDriver : public ApiRefCounted<HalDriver, iree_hal_driver_t> {
public:
static std::vector<std::string> Query();
static py::object Create(const std::string& device_uri,
py::dict& driver_cache);
py::list QueryAvailableDevices();
HalDevice CreateDefaultDevice(std::optional<py::list> allocators);
HalDevice CreateDevice(iree_hal_device_id_t device_id,
std::optional<py::list> allocators);
HalDevice CreateDeviceByURI(std::string& device_uri,
std::optional<py::list> allocators);
};
class HalAllocator : public ApiRefCounted<HalAllocator, iree_hal_allocator_t> {
public:
py::dict QueryStatistics();
py::str FormattedStatistics();
py::object AllocateBufferCopy(
int memory_type, int allowed_usage, HalDevice& device, py::object buffer,
std::optional<iree_hal_element_types_t> element_type);
HalBuffer AllocateHostStagingBufferCopy(HalDevice& device, py::handle buffer);
};
struct HalShape {
public:
HalShape(std::vector<iree_hal_dim_t>& indices) {
s = {indices.begin(), indices.end()};
}
std::vector<iree_hal_dim_t> s;
};
class HalBufferView
: public ApiRefCounted<HalBufferView, iree_hal_buffer_view_t> {
public:
py::str Repr();
};
class HalBuffer : public ApiRefCounted<HalBuffer, iree_hal_buffer_t> {
public:
iree_device_size_t byte_length() const {
return iree_hal_buffer_byte_length(raw_ptr());
}
void FillZero(iree_device_size_t byte_offset,
iree_device_size_t byte_length) {
CheckApiStatus(
iree_hal_buffer_map_zero(raw_ptr(), byte_offset, byte_length),
"Error zero filling buffer");
}
// TODO(laurenzo): make this take element_type instead.
HalBufferView CreateView(HalShape& shape, size_t element_size) {
iree_hal_buffer_view_t* bv;
iree_hal_element_type_t element_type = iree_hal_make_element_type(
IREE_HAL_ELEMENT_TYPE_NONE, element_size * 8);
iree_hal_encoding_type_t encoding_type =
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR;
CheckApiStatus(iree_hal_buffer_view_create(
raw_ptr(), shape.s.size(), shape.s.data(), element_type,
encoding_type, iree_allocator_system(), &bv),
"Error creating buffer view");
return HalBufferView::StealFromRawPtr(bv);
}
py::str Repr();
};
class HalSemaphore : public ApiRefCounted<HalSemaphore, iree_hal_semaphore_t> {
public:
};
class HalFence : public ApiRefCounted<HalFence, iree_hal_fence_t> {
public:
};
// Wrapper around an iree_hal_buffer_mapping_t and iree_hal_buffer_t
// which retains the latter and unmaps/releases on deallocation.
class HalMappedMemory {
public:
HalMappedMemory(iree_hal_buffer_mapping_t mapped_memory,
iree_hal_buffer_t* buffer)
: mapped_memory_(mapped_memory), buffer_(buffer) {
iree_hal_buffer_retain(buffer_);
}
~HalMappedMemory() {
if (buffer_) {
iree_hal_buffer_unmap_range(&mapped_memory_);
iree_hal_buffer_release(buffer_);
}
}
HalMappedMemory(HalMappedMemory&& other)
: mapped_memory_(other.mapped_memory_), buffer_(other.buffer_) {
other.buffer_ = nullptr;
}
static HalMappedMemory Create(iree_hal_buffer_t* buffer) {
iree_device_size_t byte_length = iree_hal_buffer_byte_length(buffer);
iree_hal_buffer_mapping_t mapped_memory = {{0}};
CheckApiStatus(
iree_hal_buffer_map_range(buffer, IREE_HAL_MAPPING_MODE_SCOPED,
IREE_HAL_MEMORY_ACCESS_READ, 0, byte_length,
&mapped_memory),
"Could not map memory");
return HalMappedMemory(mapped_memory, buffer);
}
static HalMappedMemory CreateFromBuffer(HalBuffer& b) {
return Create(b.raw_ptr());
}
static HalMappedMemory CreateFromBufferView(HalBufferView& bv) {
return Create(iree_hal_buffer_view_buffer(bv.raw_ptr()));
}
iree_hal_buffer_mapping_t& mapped_memory() { return mapped_memory_; }
private:
iree_hal_buffer_mapping_t mapped_memory_ = {{0}};
iree_hal_buffer_t* buffer_ = nullptr;
};
class HalCommandBuffer
: public ApiRefCounted<HalCommandBuffer, iree_hal_command_buffer_t> {};
void SetupHalBindings(nanobind::module_ m);
} // namespace python
} // namespace iree
#endif // IREE_BINDINGS_PYTHON_IREE_RT_HAL_H_