blob: 23847f95448d0c066dd264b5663c90cb57302f84 [file]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/hal/buffer_view.h"
#include <inttypes.h>
#include "iree/base/tracing.h"
#include "iree/hal/allocator.h"
#include "iree/hal/detail.h"
struct iree_hal_buffer_view_s {
iree_atomic_ref_count_t ref_count;
iree_hal_buffer_t* buffer;
iree_hal_element_type_t element_type;
iree_device_size_t byte_length;
iree_host_size_t shape_rank;
iree_hal_dim_t shape[];
};
IREE_API_EXPORT iree_status_t iree_hal_buffer_view_create(
iree_hal_buffer_t* buffer, iree_hal_element_type_t element_type,
const iree_hal_dim_t* shape, iree_host_size_t shape_rank,
iree_hal_buffer_view_t** out_buffer_view) {
IREE_ASSERT_ARGUMENT(buffer);
IREE_ASSERT_ARGUMENT(out_buffer_view);
*out_buffer_view = NULL;
if (IREE_UNLIKELY(shape_rank > 0 && !shape)) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"no shape dimensions specified");
}
IREE_TRACE_ZONE_BEGIN(z0);
iree_allocator_t host_allocator =
iree_hal_allocator_host_allocator(iree_hal_buffer_allocator(buffer));
// Allocate and initialize the iree_hal_buffer_view_t struct.
// Note that we have the dynamically-sized shape dimensions on the end.
iree_hal_buffer_view_t* buffer_view = NULL;
iree_status_t status = iree_allocator_malloc(
host_allocator,
sizeof(*buffer_view) + sizeof(iree_hal_dim_t) * shape_rank,
(void**)&buffer_view);
if (iree_status_is_ok(status)) {
iree_atomic_ref_count_init(&buffer_view->ref_count);
buffer_view->buffer = buffer;
iree_hal_buffer_retain(buffer_view->buffer);
buffer_view->element_type = element_type;
buffer_view->byte_length =
iree_hal_element_byte_count(buffer_view->element_type);
buffer_view->shape_rank = shape_rank;
for (iree_host_size_t i = 0; i < shape_rank; ++i) {
buffer_view->shape[i] = shape[i];
buffer_view->byte_length *= shape[i];
}
*out_buffer_view = buffer_view;
}
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT void IREE_API_CALL
iree_hal_buffer_view_retain(iree_hal_buffer_view_t* buffer_view) {
if (IREE_LIKELY(buffer_view)) {
iree_atomic_ref_count_inc(&buffer_view->ref_count);
}
}
IREE_API_EXPORT void IREE_API_CALL
iree_hal_buffer_view_release(iree_hal_buffer_view_t* buffer_view) {
if (IREE_LIKELY(buffer_view) &&
iree_atomic_ref_count_dec(&buffer_view->ref_count) == 1) {
iree_hal_buffer_view_destroy(buffer_view);
}
}
IREE_API_EXPORT void IREE_API_CALL
iree_hal_buffer_view_destroy(iree_hal_buffer_view_t* buffer_view) {
iree_allocator_t host_allocator = iree_hal_allocator_host_allocator(
iree_hal_buffer_allocator(buffer_view->buffer));
iree_hal_buffer_release(buffer_view->buffer);
iree_allocator_free(host_allocator, buffer_view);
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_subview(
const iree_hal_buffer_view_t* buffer_view,
const iree_hal_dim_t* start_indices, iree_host_size_t indices_count,
const iree_hal_dim_t* lengths, iree_host_size_t lengths_count,
iree_hal_buffer_view_t** out_buffer_view) {
IREE_ASSERT_ARGUMENT(out_buffer_view);
// NOTE: we rely on the compute range call to do parameter validation.
iree_device_size_t start_offset = 0;
iree_device_size_t subview_length = 0;
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_compute_range(
buffer_view, start_indices, indices_count, lengths, lengths_count,
&start_offset, &subview_length));
iree_hal_buffer_t* subview_buffer = NULL;
IREE_RETURN_IF_ERROR(iree_hal_buffer_subspan(
buffer_view->buffer, start_offset, subview_length, &subview_buffer));
iree_status_t status =
iree_hal_buffer_view_create(subview_buffer, buffer_view->element_type,
lengths, lengths_count, out_buffer_view);
iree_hal_buffer_release(subview_buffer);
return status;
}
IREE_API_EXPORT iree_hal_buffer_t* iree_hal_buffer_view_buffer(
const iree_hal_buffer_view_t* buffer_view) {
IREE_ASSERT_ARGUMENT(buffer_view);
return buffer_view->buffer;
}
IREE_API_EXPORT iree_host_size_t IREE_API_CALL
iree_hal_buffer_view_shape_rank(const iree_hal_buffer_view_t* buffer_view) {
IREE_ASSERT_ARGUMENT(buffer_view);
return buffer_view->shape_rank;
}
IREE_API_EXPORT const iree_hal_dim_t* IREE_API_CALL
iree_hal_buffer_view_shape_dims(const iree_hal_buffer_view_t* buffer_view) {
IREE_ASSERT_ARGUMENT(buffer_view);
return buffer_view->shape;
}
IREE_API_EXPORT iree_hal_dim_t IREE_API_CALL iree_hal_buffer_view_shape_dim(
const iree_hal_buffer_view_t* buffer_view, iree_host_size_t index) {
IREE_ASSERT_ARGUMENT(buffer_view);
if (IREE_UNLIKELY(index > buffer_view->shape_rank)) {
return 0;
}
return buffer_view->shape[index];
}
IREE_API_EXPORT iree_host_size_t
iree_hal_buffer_view_element_count(const iree_hal_buffer_view_t* buffer_view) {
IREE_ASSERT_ARGUMENT(buffer_view);
iree_host_size_t element_count = 1;
for (iree_host_size_t i = 0; i < buffer_view->shape_rank; ++i) {
element_count *= buffer_view->shape[i];
}
return element_count;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_shape(
const iree_hal_buffer_view_t* buffer_view, iree_host_size_t rank_capacity,
iree_hal_dim_t* out_shape, iree_host_size_t* out_shape_rank) {
IREE_ASSERT_ARGUMENT(buffer_view);
IREE_ASSERT_ARGUMENT(out_shape);
if (out_shape_rank) {
*out_shape_rank = 0;
}
if (out_shape_rank) {
*out_shape_rank = buffer_view->shape_rank;
}
if (rank_capacity < buffer_view->shape_rank) {
// Not an error; just a size query.
return iree_status_from_code(IREE_STATUS_OUT_OF_RANGE);
}
for (iree_host_size_t i = 0; i < buffer_view->shape_rank; ++i) {
out_shape[i] = buffer_view->shape[i];
}
return iree_ok_status();
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_reshape(
iree_hal_buffer_view_t* buffer_view, const iree_hal_dim_t* shape,
iree_host_size_t shape_rank) {
IREE_ASSERT_ARGUMENT(buffer_view);
IREE_ASSERT_ARGUMENT(shape);
if (shape_rank != buffer_view->shape_rank) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"buffer view reshapes must have the same rank; "
"target=%zu, existing=%zu",
shape_rank, buffer_view->shape_rank);
}
iree_device_size_t new_element_count = 1;
for (iree_host_size_t i = 0; i < shape_rank; ++i) {
new_element_count *= shape[i];
}
iree_device_size_t old_element_count =
iree_hal_buffer_view_element_count(buffer_view);
if (new_element_count != old_element_count) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"buffer view reshapes must have the same element "
"count; target=%" PRIu64 ", existing=%" PRIu64,
new_element_count, old_element_count);
}
for (iree_host_size_t i = 0; i < shape_rank; ++i) {
buffer_view->shape[i] = shape[i];
}
return iree_ok_status();
}
IREE_API_EXPORT iree_hal_element_type_t IREE_API_CALL
iree_hal_buffer_view_element_type(const iree_hal_buffer_view_t* buffer_view) {
IREE_ASSERT_ARGUMENT(buffer_view);
return buffer_view->element_type;
}
IREE_API_EXPORT iree_host_size_t
iree_hal_buffer_view_element_size(const iree_hal_buffer_view_t* buffer_view) {
IREE_ASSERT_ARGUMENT(buffer_view);
return iree_hal_element_byte_count(buffer_view->element_type);
}
IREE_API_EXPORT iree_device_size_t IREE_API_CALL
iree_hal_buffer_view_byte_length(const iree_hal_buffer_view_t* buffer_view) {
IREE_ASSERT_ARGUMENT(buffer_view);
return buffer_view->byte_length;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_compute_offset(
const iree_hal_buffer_view_t* buffer_view, const iree_hal_dim_t* indices,
iree_host_size_t indices_count, iree_device_size_t* out_offset) {
IREE_ASSERT_ARGUMENT(buffer_view);
return iree_hal_buffer_compute_view_offset(
buffer_view->shape, buffer_view->shape_rank, buffer_view->element_type,
indices, indices_count, out_offset);
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_compute_range(
const iree_hal_buffer_view_t* buffer_view,
const iree_hal_dim_t* start_indices, iree_host_size_t indices_count,
const iree_hal_dim_t* lengths, iree_host_size_t lengths_count,
iree_device_size_t* out_start_offset, iree_device_size_t* out_length) {
IREE_ASSERT_ARGUMENT(buffer_view);
return iree_hal_buffer_compute_view_range(
buffer_view->shape, buffer_view->shape_rank, buffer_view->element_type,
start_indices, indices_count, lengths, lengths_count, out_start_offset,
out_length);
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_compute_view_size(
const iree_hal_dim_t* shape, iree_host_size_t shape_rank,
iree_hal_element_type_t element_type,
iree_device_size_t* out_allocation_size) {
IREE_ASSERT_ARGUMENT(shape);
IREE_ASSERT_ARGUMENT(out_allocation_size);
*out_allocation_size = 0;
iree_device_size_t byte_length = iree_hal_element_byte_count(element_type);
for (iree_host_size_t i = 0; i < shape_rank; ++i) {
byte_length *= shape[i];
}
*out_allocation_size = byte_length;
return iree_ok_status();
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_compute_view_offset(
const iree_hal_dim_t* shape, iree_host_size_t shape_rank,
iree_hal_element_type_t element_type, const iree_hal_dim_t* indices,
iree_host_size_t indices_count, iree_device_size_t* out_offset) {
IREE_ASSERT_ARGUMENT(shape);
IREE_ASSERT_ARGUMENT(indices);
IREE_ASSERT_ARGUMENT(out_offset);
*out_offset = 0;
if (IREE_UNLIKELY(shape_rank != indices_count)) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"shape rank/indices mismatch: %zu != %zu",
shape_rank, indices_count);
}
iree_device_size_t offset = 0;
for (iree_host_size_t i = 0; i < indices_count; ++i) {
if (IREE_UNLIKELY(indices[i] >= shape[i])) {
return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
"index[%zu] out of bounds: %d >= %d", i,
indices[i], shape[i]);
}
iree_device_size_t axis_offset = indices[i];
for (iree_host_size_t j = i + 1; j < shape_rank; ++j) {
axis_offset *= shape[j];
}
offset += axis_offset;
}
offset *= iree_hal_element_byte_count(element_type);
*out_offset = offset;
return iree_ok_status();
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_compute_view_range(
const iree_hal_dim_t* shape, iree_host_size_t shape_rank,
iree_hal_element_type_t element_type, const iree_hal_dim_t* start_indices,
iree_host_size_t indices_count, const iree_hal_dim_t* lengths,
iree_host_size_t lengths_count, iree_device_size_t* out_start_offset,
iree_device_size_t* out_length) {
IREE_ASSERT_ARGUMENT(shape);
IREE_ASSERT_ARGUMENT(start_indices);
IREE_ASSERT_ARGUMENT(lengths);
IREE_ASSERT_ARGUMENT(out_start_offset);
IREE_ASSERT_ARGUMENT(out_length);
*out_start_offset = 0;
*out_length = 0;
if (IREE_UNLIKELY(indices_count != lengths_count)) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"indices/lengths mismatch: %zu != %zu",
indices_count, lengths_count);
}
if (IREE_UNLIKELY(shape_rank != indices_count)) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"shape rank/indices mismatch: %zu != %zu",
shape_rank, indices_count);
}
iree_hal_dim_t* end_indices =
iree_alloca(shape_rank * sizeof(iree_hal_dim_t));
iree_device_size_t element_size = iree_hal_element_byte_count(element_type);
iree_device_size_t subspan_length = element_size;
for (iree_host_size_t i = 0; i < lengths_count; ++i) {
subspan_length *= lengths[i];
end_indices[i] = start_indices[i] + lengths[i] - 1;
}
iree_device_size_t start_byte_offset = 0;
IREE_RETURN_IF_ERROR(iree_hal_buffer_compute_view_offset(
shape, shape_rank, element_type, start_indices, indices_count,
&start_byte_offset));
iree_device_size_t end_byte_offset = 0;
IREE_RETURN_IF_ERROR(iree_hal_buffer_compute_view_offset(
shape, shape_rank, element_type, end_indices, shape_rank,
&end_byte_offset));
// Non-contiguous regions not yet implemented. Will be easier to detect when
// we have strides.
iree_device_size_t offset_length =
end_byte_offset - start_byte_offset + element_size;
if (subspan_length != offset_length) {
return iree_make_status(
IREE_STATUS_UNIMPLEMENTED,
"non-contiguous range region computation not implemented");
}
*out_start_offset = start_byte_offset;
*out_length = subspan_length;
return iree_ok_status();
}