blob: 7f53408f6e950af991df695242fa82718a3ee8df [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "bindings/python/pyiree/rt.h"
#include "bindings/python/pyiree/status_utils.h"
#include "iree/base/api.h"
#include "iree/hal/api.h"
namespace iree {
namespace python {
HalBufferView RtContext::WrapPyBufferForInput(py::buffer py_buffer) {
auto py_buffer_info = py_buffer.request(false /* writable */);
if (py_buffer_info.ndim > IREE_SHAPE_MAX_RANK || py_buffer_info.ndim < 0) {
RaiseValueError("Unsupported buffer rank");
}
if (py_buffer_info.size < 0) {
RaiseValueError("Illegal buffer size");
}
// For the moment, allocate a device visible buffer of equivalent size and
// copy into it.
// TODO(laurenzo): Once sequencer is in place, switch to HeapBuffer, wrap
// and retain the original buffer.
iree_host_size_t byte_size = py_buffer_info.size * py_buffer_info.itemsize;
HalBuffer buffer =
AllocateDeviceVisible(byte_size, IREE_HAL_BUFFER_USAGE_CONSTANT |
IREE_HAL_BUFFER_USAGE_TRANSFER |
IREE_HAL_BUFFER_USAGE_DISPATCH);
CheckApiStatus(iree_hal_buffer_write_data(buffer.raw_ptr(), 0,
py_buffer_info.ptr, byte_size),
"Error writing to input buffer");
// Create the buffer view.
// TODO(laurenzo): This does no validation on dtype and only cares if the
// elementsize matches. Figure out where to enforce actual dtype.
iree_shape_t shape;
shape.rank = py_buffer_info.ndim;
// Verify strides are row-major.
// TODO(laurenzo): Test this with rank>1.
for (int i = 1; i < shape.rank; ++i) {
if ((py_buffer_info.strides[i - 1] * py_buffer_info.itemsize) !=
py_buffer_info.shape[i]) {
RaiseValueError("Expected row-major layout");
}
}
if (!py_buffer_info.strides.empty()) {
if (py_buffer_info.strides.back() != 1) {
RaiseValueError("Expected row-major layout");
}
}
// Populate shape.
for (int i = 0; i < shape.rank; ++i) {
ssize_t dim = py_buffer_info.shape[i];
if (dim < 0) {
RaiseValueError("Unsupported negative dim");
}
shape.dims[i] = dim;
}
iree_hal_buffer_view_t* bv;
CheckApiStatus(iree_hal_buffer_view_create(buffer.raw_ptr(), shape,
py_buffer_info.itemsize,
IREE_ALLOCATOR_DEFAULT, &bv),
"Error allocating buffer view");
return HalBufferView::CreateRetained(bv);
}
void SetupRtBindings(pybind11::module m) {
// BufferPlacement.
py::enum_<BufferPlacement>(m, "BufferPlacement")
.value("HEAP", BufferPlacement::kHeap)
.value("DEVICE_VISIBLE", BufferPlacement::kDeviceVisible)
.value("DEVICE_LOCAL", BufferPlacement::kDeviceLocal)
.export_values();
// RtModule.
py::class_<RtModule>(m, "Module")
.def_property_readonly("name", &RtModule::name)
.def("lookup_function_by_ordinal", &RtModule::lookup_function_by_ordinal)
.def("lookup_function_by_name", &RtModule::lookup_function_by_name);
// RtFunction.
py::class_<RtFunction>(m, "Function")
.def_property_readonly("name", &RtFunction::name)
.def_property_readonly("signature", &RtFunction::signature);
py::class_<iree_rt_function_signature_t>(m, "FunctionSignature")
.def_readonly("argument_count",
&iree_rt_function_signature_t::argument_count)
.def_readonly("result_count",
&iree_rt_function_signature_t::result_count);
// RtPolicy.
py::class_<RtPolicy>(m, "Policy").def(py::init(&RtPolicy::Create));
// RtInstance.
py::class_<RtInstance>(m, "Instance")
.def(py::init(&RtInstance::Create),
py::arg_v("driver_name", absl::optional<std::string>()));
// RtContext.
py::class_<RtContext>(m, "Context")
.def(py::init(&RtContext::Create), py::arg("instance"), py::arg("policy"))
.def_property_readonly("context_id", &RtContext::context_id)
.def("register_modules", &RtContext::RegisterModules, py::arg("modules"))
.def("register_module", &RtContext::RegisterModule, py::arg("module"))
.def("lookup_module_by_name", &RtContext::LookupModuleByName,
py::arg("name"))
.def("resolve_function", &RtContext::ResolveFunction,
py::arg("full_name"))
.def("allocate", &RtContext::Allocate, py::arg("allocation_size"),
py::arg("placement") = BufferPlacement::kHeap,
py::arg("usage") = IREE_HAL_BUFFER_USAGE_ALL)
.def("allocate_device_local", &RtContext::AllocateDeviceLocal,
py::arg("allocation_size"),
py::arg("usage") = IREE_HAL_BUFFER_USAGE_ALL)
.def("allocate_device_visible", &RtContext::AllocateDeviceVisible,
py::arg("allocation_size"),
py::arg("usage") = IREE_HAL_BUFFER_USAGE_ALL)
.def("wrap_for_input", &RtContext::WrapPyBufferForInput, py::arg("v"))
.def("invoke", &RtContext::Invoke, py::arg("f"), py::arg("policy"),
py::arg("arguments"),
py::arg("results") = absl::optional<std::vector<HalBufferView*>>());
// RtInvocation.
py::class_<RtInvocation>(m, "Invocation")
.def("query_status", &RtInvocation::QueryStatus)
.def("await_ready", &RtInvocation::Await,
py::arg("deadline") = IREE_TIME_INFINITE_FUTURE)
.def("await_ready_optional", &RtInvocation::AwaitOptional,
py::arg("deadline") = IREE_TIME_INFINITE_FUTURE)
.def_property_readonly("results", &RtInvocation::ConsumeResults);
}
} // namespace python
} // namespace iree