blob: 25681e00b020e6aaf9488582b2e1ef1c0da5aad0 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/hal/api.h"
#include <cctype>
#include <cinttypes>
#include <cstdio>
#include "absl/strings/ascii.h"
#include "absl/strings/match.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/strings/strip.h"
#include "absl/types/span.h"
#include "iree/base/api.h"
#include "iree/base/api_util.h"
#include "iree/base/memory.h"
#include "iree/base/tracing.h"
#include "iree/hal/api_detail.h"
#include "iree/hal/buffer.h"
#include "iree/hal/command_buffer.h"
#include "iree/hal/device.h"
#include "iree/hal/driver.h"
#include "iree/hal/driver_registry.h"
#include "iree/hal/heap_buffer.h"
#include "iree/hal/host/host_local_allocator.h"
#include "iree/hal/semaphore.h"
namespace iree {
namespace hal {
// Defines the iree_hal_<type_name>_retain/_release methods.
#define IREE_HAL_API_RETAIN_RELEASE(type_name, cc_type) \
IREE_API_EXPORT void iree_hal_##type_name##_retain( \
iree_hal_##type_name##_t* type_name) { \
auto* handle = reinterpret_cast<cc_type*>(type_name); \
if (handle) handle->AddReference(); \
} \
IREE_API_EXPORT void iree_hal_##type_name##_release( \
iree_hal_##type_name##_t* type_name) { \
auto* handle = reinterpret_cast<cc_type*>(type_name); \
if (handle) handle->ReleaseReference(); \
}
//===----------------------------------------------------------------------===//
// Utilities
//===----------------------------------------------------------------------===//
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_parse_shape(
iree_string_view_t value, iree_host_size_t shape_capacity,
iree_hal_dim_t* out_shape, iree_host_size_t* out_shape_rank) {
if (!out_shape_rank) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_shape_rank = 0;
auto str_value = absl::string_view(value.data, value.size);
if (str_value.empty()) {
return IREE_STATUS_OK;
}
absl::InlinedVector<iree_hal_dim_t, 6> dims;
for (auto dim_str : absl::StrSplit(str_value, 'x')) {
int dim_value = 0;
if (!absl::SimpleAtoi(dim_str, &dim_value)) {
LOG(ERROR) << "Invalid shape dimension '" << dim_str
<< "' while parsing shape '" << str_value << "'";
return IREE_STATUS_INVALID_ARGUMENT;
}
if (dim_value < 0) {
LOG(ERROR) << "Unsupported shape dimension '" << dim_str << "'";
return IREE_STATUS_INVALID_ARGUMENT;
}
dims.push_back(dim_value);
}
if (out_shape_rank) {
*out_shape_rank = dims.size();
}
if (dims.size() > shape_capacity) {
return IREE_STATUS_OUT_OF_RANGE;
}
if (out_shape) {
std::memcpy(out_shape, dims.data(), dims.size() * sizeof(*out_shape));
}
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_hal_format_shape(const iree_hal_dim_t* shape, iree_host_size_t shape_rank,
iree_host_size_t buffer_capacity, char* buffer,
iree_host_size_t* out_buffer_length) {
if (out_buffer_length) {
*out_buffer_length = 0;
}
iree_host_size_t buffer_length = 0;
for (iree_host_size_t i = 0; i < shape_rank; ++i) {
int n = std::snprintf(buffer ? buffer + buffer_length : nullptr,
buffer ? buffer_capacity - buffer_length : 0,
(i < shape_rank - 1) ? "%dx" : "%d", shape[i]);
if (n < 0) {
return IREE_STATUS_FAILED_PRECONDITION;
} else if (buffer && n >= buffer_capacity - buffer_length) {
buffer = nullptr;
}
buffer_length += n;
}
if (out_buffer_length) {
*out_buffer_length = buffer_length;
}
return buffer ? IREE_STATUS_OK : IREE_STATUS_OUT_OF_RANGE;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_parse_element_type(
iree_string_view_t value, iree_hal_element_type_t* out_element_type) {
if (!out_element_type) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_element_type = IREE_HAL_ELEMENT_TYPE_NONE;
auto str_value = absl::string_view(value.data, value.size);
iree_hal_numerical_type_t numerical_type = IREE_HAL_NUMERICAL_TYPE_UNKNOWN;
if (absl::StartsWith(str_value, "i")) {
numerical_type = IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED;
str_value.remove_prefix(1);
} else if (absl::StartsWith(str_value, "u")) {
numerical_type = IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED;
str_value.remove_prefix(1);
} else if (absl::StartsWith(str_value, "f")) {
numerical_type = IREE_HAL_NUMERICAL_TYPE_FLOAT_IEEE;
str_value.remove_prefix(1);
} else if (absl::StartsWith(str_value, "x") ||
absl::StartsWith(str_value, "*")) {
numerical_type = IREE_HAL_NUMERICAL_TYPE_UNKNOWN;
str_value.remove_prefix(1);
} else {
return IREE_STATUS_INVALID_ARGUMENT;
}
uint32_t bit_count = 0;
if (!absl::SimpleAtoi(str_value, &bit_count) || bit_count > 0xFFu) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_element_type = iree_hal_make_element_type(numerical_type, bit_count);
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_format_element_type(
iree_hal_element_type_t element_type, iree_host_size_t buffer_capacity,
char* buffer, iree_host_size_t* out_buffer_length) {
if (out_buffer_length) {
*out_buffer_length = 0;
}
const char* prefix;
switch (iree_hal_element_numerical_type(element_type)) {
case IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED:
prefix = "i";
break;
case IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED:
prefix = "u";
break;
case IREE_HAL_NUMERICAL_TYPE_FLOAT_IEEE:
prefix = "f";
break;
default:
prefix = "*";
break;
}
int n = std::snprintf(
buffer, buffer_capacity, "%s%d", prefix,
static_cast<int32_t>(iree_hal_element_bit_count(element_type)));
if (n < 0) {
return IREE_STATUS_FAILED_PRECONDITION;
}
if (out_buffer_length) {
*out_buffer_length = n;
}
return n >= buffer_capacity ? IREE_STATUS_OUT_OF_RANGE : IREE_STATUS_OK;
}
// Parses a string of two character pairs representing hex numbers into bytes.
static void iree_hal_hex_string_to_bytes(const char* from, uint8_t* to,
ptrdiff_t num) {
/* clang-format off */
static constexpr char kHexValue[256] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, 0, 0, 0, 0, // '0'..'9'
0, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 'A'..'F'
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 'a'..'f'
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
};
/* clang-format on */
for (int i = 0; i < num; i++) {
to[i] = (kHexValue[from[i * 2] & 0xFF] << 4) +
(kHexValue[from[i * 2 + 1] & 0xFF]);
}
}
// Parses a signal element string, assuming that the caller has validated that
// |out_data| has enough storage space for the parsed element data.
static bool iree_hal_parse_element_unsafe(iree_string_view_t data_str,
iree_hal_element_type_t element_type,
uint8_t* out_data) {
switch (element_type) {
case IREE_HAL_ELEMENT_TYPE_SINT_8: {
int32_t temp = 0;
if (!absl::SimpleAtoi(absl::string_view(data_str.data, data_str.size),
&temp) ||
temp > INT8_MAX) {
return false;
}
*reinterpret_cast<int8_t*>(out_data) = static_cast<int8_t>(temp);
return true;
}
case IREE_HAL_ELEMENT_TYPE_UINT_8: {
uint32_t temp = 0;
if (!absl::SimpleAtoi(absl::string_view(data_str.data, data_str.size),
&temp) ||
temp > UINT8_MAX) {
return false;
}
*reinterpret_cast<uint8_t*>(out_data) = static_cast<uint8_t>(temp);
return true;
}
case IREE_HAL_ELEMENT_TYPE_SINT_16: {
int32_t temp = 0;
if (!absl::SimpleAtoi(absl::string_view(data_str.data, data_str.size),
&temp) ||
temp > INT16_MAX) {
return false;
}
*reinterpret_cast<int16_t*>(out_data) = static_cast<int16_t>(temp);
return true;
}
case IREE_HAL_ELEMENT_TYPE_UINT_16: {
uint32_t temp = 0;
if (!absl::SimpleAtoi(absl::string_view(data_str.data, data_str.size),
&temp) ||
temp > UINT16_MAX) {
return false;
}
*reinterpret_cast<uint16_t*>(out_data) = static_cast<uint16_t>(temp);
return true;
}
case IREE_HAL_ELEMENT_TYPE_SINT_32:
return absl::SimpleAtoi(absl::string_view(data_str.data, data_str.size),
reinterpret_cast<int32_t*>(out_data));
case IREE_HAL_ELEMENT_TYPE_UINT_32:
return absl::SimpleAtoi(absl::string_view(data_str.data, data_str.size),
reinterpret_cast<uint32_t*>(out_data));
case IREE_HAL_ELEMENT_TYPE_SINT_64:
return absl::SimpleAtoi(absl::string_view(data_str.data, data_str.size),
reinterpret_cast<int64_t*>(out_data));
case IREE_HAL_ELEMENT_TYPE_UINT_64:
return absl::SimpleAtoi(absl::string_view(data_str.data, data_str.size),
reinterpret_cast<uint64_t*>(out_data));
case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
LOG(ERROR) << "Unimplemented parser for element format FLOAT_16";
return false;
case IREE_HAL_ELEMENT_TYPE_FLOAT_32:
return absl::SimpleAtof(absl::string_view(data_str.data, data_str.size),
reinterpret_cast<float*>(out_data));
case IREE_HAL_ELEMENT_TYPE_FLOAT_64:
return absl::SimpleAtod(absl::string_view(data_str.data, data_str.size),
reinterpret_cast<double*>(out_data));
default: {
// Treat any unknown format as binary.
iree_host_size_t element_size = iree_hal_element_byte_count(element_type);
if (data_str.size != element_size * 2) {
LOG(ERROR) << "Element hex byte count mismatch (expected "
<< element_size * 2 << " chars, have " << data_str.size
<< ")";
return false;
}
iree_hal_hex_string_to_bytes(data_str.data, out_data, element_size);
return true;
}
}
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_parse_element(
iree_string_view_t data_str, iree_hal_element_type_t element_type,
iree_byte_span_t data_ptr) {
iree_host_size_t element_size = iree_hal_element_byte_count(element_type);
if (data_ptr.data_length < element_size) {
LOG(ERROR) << "Output data buffer overflow (needed " << element_size
<< ", have " << data_ptr.data_length << ")";
return IREE_STATUS_INVALID_ARGUMENT;
}
return iree_hal_parse_element_unsafe(data_str, element_type, data_ptr.data)
? IREE_STATUS_OK
: IREE_STATUS_INVALID_ARGUMENT;
}
// Converts a sequence of bytes into hex number strings.
static void iree_hal_bytes_to_hex_string(const uint8_t* src, char* dest,
ptrdiff_t num) {
static constexpr char kHexTable[513] =
"000102030405060708090A0B0C0D0E0F"
"101112131415161718191A1B1C1D1E1F"
"202122232425262728292A2B2C2D2E2F"
"303132333435363738393A3B3C3D3E3F"
"404142434445464748494A4B4C4D4E4F"
"505152535455565758595A5B5C5D5E5F"
"606162636465666768696A6B6C6D6E6F"
"707172737475767778797A7B7C7D7E7F"
"808182838485868788898A8B8C8D8E8F"
"909192939495969798999A9B9C9D9E9F"
"A0A1A2A3A4A5A6A7A8A9AAABACADAEAF"
"B0B1B2B3B4B5B6B7B8B9BABBBCBDBEBF"
"C0C1C2C3C4C5C6C7C8C9CACBCCCDCECF"
"D0D1D2D3D4D5D6D7D8D9DADBDCDDDEDF"
"E0E1E2E3E4E5E6E7E8E9EAEBECEDEEEF"
"F0F1F2F3F4F5F6F7F8F9FAFBFCFDFEFF";
for (auto src_ptr = src; src_ptr != (src + num); ++src_ptr, dest += 2) {
const char* hex_p = &kHexTable[*src_ptr * 2];
std::copy(hex_p, hex_p + 2, dest);
}
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_format_element(
iree_const_byte_span_t data, iree_hal_element_type_t element_type,
iree_host_size_t buffer_capacity, char* buffer,
iree_host_size_t* out_buffer_length) {
iree_host_size_t element_size = iree_hal_element_byte_count(element_type);
if (data.data_length < element_size) {
LOG(ERROR) << "Data buffer underflow on element format";
return IREE_STATUS_OUT_OF_RANGE;
}
int n = 0;
switch (element_type) {
case IREE_HAL_ELEMENT_TYPE_SINT_8:
n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIi8,
*reinterpret_cast<const int8_t*>(data.data));
break;
case IREE_HAL_ELEMENT_TYPE_UINT_8:
n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIu8,
*reinterpret_cast<const uint8_t*>(data.data));
break;
case IREE_HAL_ELEMENT_TYPE_SINT_16:
n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIi16,
*reinterpret_cast<const int16_t*>(data.data));
break;
case IREE_HAL_ELEMENT_TYPE_UINT_16:
n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIu16,
*reinterpret_cast<const uint16_t*>(data.data));
break;
case IREE_HAL_ELEMENT_TYPE_SINT_32:
n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIi32,
*reinterpret_cast<const int32_t*>(data.data));
break;
case IREE_HAL_ELEMENT_TYPE_UINT_32:
n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIu32,
*reinterpret_cast<const uint32_t*>(data.data));
break;
case IREE_HAL_ELEMENT_TYPE_SINT_64:
n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIi64,
*reinterpret_cast<const int64_t*>(data.data));
break;
case IREE_HAL_ELEMENT_TYPE_UINT_64:
n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIu64,
*reinterpret_cast<const uint64_t*>(data.data));
break;
case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
LOG(ERROR) << "Unimplemented parser for element format FLOAT_16";
return IREE_STATUS_UNIMPLEMENTED;
case IREE_HAL_ELEMENT_TYPE_FLOAT_32:
n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%F",
*reinterpret_cast<const float*>(data.data));
break;
case IREE_HAL_ELEMENT_TYPE_FLOAT_64:
n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%E",
*reinterpret_cast<const double*>(data.data));
break;
default: {
// Treat any unknown format as binary.
n = 2 * element_size;
if (buffer && buffer_capacity > n) {
iree_hal_bytes_to_hex_string(data.data, buffer, element_size);
buffer[n] = 0;
}
}
}
if (n < 0) {
return IREE_STATUS_FAILED_PRECONDITION;
} else if (buffer && n >= buffer_capacity) {
buffer = nullptr;
}
if (out_buffer_length) {
*out_buffer_length = n;
}
return buffer ? IREE_STATUS_OK : IREE_STATUS_OUT_OF_RANGE;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_parse_buffer_elements(
iree_string_view_t data_str, iree_hal_element_type_t element_type,
iree_byte_span_t data_ptr) {
IREE_TRACE_SCOPE0("iree_hal_parse_buffer_elements");
iree_host_size_t element_size = iree_hal_element_byte_count(element_type);
iree_host_size_t element_capacity = data_ptr.data_length / element_size;
if (iree_string_view_is_empty(data_str)) {
memset(data_ptr.data, 0, data_ptr.data_length);
return IREE_STATUS_OK;
}
size_t src_i = 0;
size_t dst_i = 0;
size_t token_start = std::string::npos;
while (src_i < data_str.size) {
char c = data_str.data[src_i++];
bool is_separator =
absl::ascii_isspace(c) || c == ',' || c == '[' || c == ']';
if (token_start == std::string::npos) {
if (!is_separator) {
token_start = src_i - 1;
}
continue;
} else if (token_start != std::string::npos && !is_separator) {
continue;
}
if (dst_i >= element_capacity) {
LOG(ERROR)
<< "Output data buffer overflow (too many elements present, have >= "
<< dst_i << ", expected " << element_capacity << ")";
return IREE_STATUS_OUT_OF_RANGE;
}
if (!iree_hal_parse_element_unsafe(
iree_string_view_t{data_str.data + token_start,
src_i - 2 - token_start + 1},
element_type, data_ptr.data + dst_i * element_size)) {
return IREE_STATUS_INVALID_ARGUMENT;
}
++dst_i;
token_start = std::string::npos;
}
if (token_start != std::string::npos) {
if (dst_i >= element_capacity) {
LOG(ERROR)
<< "Output data buffer overflow (too many elements present, have > "
<< dst_i << ", expected " << element_capacity << ")";
return IREE_STATUS_OUT_OF_RANGE;
}
if (!iree_hal_parse_element_unsafe(
iree_string_view_t{data_str.data + token_start,
data_str.size - token_start},
element_type, data_ptr.data + dst_i * element_size)) {
return IREE_STATUS_INVALID_ARGUMENT;
}
++dst_i;
}
if (dst_i == 1 && element_capacity > 1) {
// Splat the single value we got to the entire buffer.
uint8_t* p = data_ptr.data + element_size;
for (int i = 1; i < element_capacity; ++i, p += element_size) {
memcpy(p, data_ptr.data, element_size);
}
} else if (dst_i < element_capacity) {
LOG(ERROR)
<< "Input data string contains fewer elements than the underlying "
"buffer (expected "
<< element_capacity << ", have " << dst_i << ")";
return IREE_STATUS_OUT_OF_RANGE;
}
return IREE_STATUS_OK;
}
static iree_status_t iree_hal_format_buffer_elements_recursive(
iree_const_byte_span_t data, const iree_hal_dim_t* shape,
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
iree_host_size_t* max_element_count, iree_host_size_t buffer_capacity,
char* buffer, iree_host_size_t* out_buffer_length) {
iree_host_size_t buffer_length = 0;
auto append_char = [&](char c) {
if (buffer) {
if (buffer_length < buffer_capacity - 1) {
buffer[buffer_length] = c;
buffer[buffer_length + 1] = '\0';
} else {
buffer = nullptr;
}
}
++buffer_length;
};
if (shape_rank == 0) {
// Scalar value; recurse to get on to the leaf dimension path.
const iree_hal_dim_t one = 1;
return iree_hal_format_buffer_elements_recursive(
data, &one, 1, element_type, max_element_count, buffer_capacity, buffer,
out_buffer_length);
} else if (shape_rank > 1) {
// Nested dimension; recurse into the next innermost dimension.
iree_hal_dim_t dim_length = 1;
for (iree_host_size_t i = 1; i < shape_rank; ++i) {
dim_length *= shape[i];
}
iree_device_size_t dim_stride =
dim_length * iree_hal_element_byte_count(element_type);
if (data.data_length < dim_stride * shape[0]) {
return IREE_STATUS_OUT_OF_RANGE;
}
iree_const_byte_span_t subdata;
subdata.data = data.data;
subdata.data_length = dim_stride;
for (iree_hal_dim_t i = 0; i < shape[0]; ++i) {
append_char('[');
iree_host_size_t actual_length = 0;
iree_status_t status = iree_hal_format_buffer_elements_recursive(
subdata, shape + 1, shape_rank - 1, element_type, max_element_count,
buffer ? buffer_capacity - buffer_length : 0,
buffer ? buffer + buffer_length : nullptr, &actual_length);
buffer_length += actual_length;
if (iree_status_is_out_of_range(status)) {
buffer = nullptr;
} else if (!iree_status_is_ok(status)) {
return status;
}
subdata.data += dim_stride;
append_char(']');
}
} else {
// Leaf dimension; output data.
iree_host_size_t max_count =
std::min(*max_element_count, static_cast<iree_host_size_t>(shape[0]));
iree_device_size_t element_stride =
iree_hal_element_byte_count(element_type);
if (data.data_length < max_count * element_stride) {
return IREE_STATUS_OUT_OF_RANGE;
}
*max_element_count -= max_count;
iree_const_byte_span_t subdata;
subdata.data = data.data;
subdata.data_length = element_stride;
for (iree_hal_dim_t i = 0; i < max_count; ++i) {
if (i > 0) append_char(' ');
iree_host_size_t actual_length = 0;
iree_status_t status = iree_hal_format_element(
subdata, element_type, buffer ? buffer_capacity - buffer_length : 0,
buffer ? buffer + buffer_length : nullptr, &actual_length);
subdata.data += element_stride;
buffer_length += actual_length;
if (iree_status_is_out_of_range(status)) {
buffer = nullptr;
} else if (!iree_status_is_ok(status)) {
return status;
}
}
if (max_count < shape[0]) {
append_char('.');
append_char('.');
append_char('.');
}
}
if (out_buffer_length) {
*out_buffer_length = buffer_length;
}
return buffer ? IREE_STATUS_OK : IREE_STATUS_OUT_OF_RANGE;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_format_buffer_elements(
iree_const_byte_span_t data, const iree_hal_dim_t* shape,
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
iree_host_size_t max_element_count, iree_host_size_t buffer_capacity,
char* buffer, iree_host_size_t* out_buffer_length) {
IREE_TRACE_SCOPE0("iree_hal_format_buffer_elements");
if (out_buffer_length) {
*out_buffer_length = 0;
}
if (buffer && buffer_capacity) {
buffer[0] = '\0';
}
return iree_hal_format_buffer_elements_recursive(
data, shape, shape_rank, element_type, &max_element_count,
buffer_capacity, buffer, out_buffer_length);
}
//===----------------------------------------------------------------------===//
// iree::hal::Allocator
//===----------------------------------------------------------------------===//
IREE_HAL_API_RETAIN_RELEASE(allocator, Allocator);
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_hal_allocator_create_host_local(iree_allocator_t allocator,
iree_hal_allocator** out_allocator) {
IREE_TRACE_SCOPE0("iree_hal_allocator_create_host_local");
if (!out_allocator) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_allocator =
reinterpret_cast<iree_hal_allocator_t*>(new host::HostLocalAllocator());
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_allocator_compute_size(
const iree_hal_allocator_t* allocator, const iree_hal_dim_t* shape,
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
iree_device_size_t* out_allocation_size) {
if (!out_allocation_size) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_allocation_size = 0;
if (!allocator) {
return IREE_STATUS_INVALID_ARGUMENT;
}
// TODO(benvanik): layout/padding.
iree_device_size_t byte_length = iree_hal_element_byte_count(element_type);
for (int i = 0; i < shape_rank; ++i) {
byte_length *= shape[i];
}
*out_allocation_size = byte_length;
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_allocator_compute_offset(
const iree_hal_allocator_t* allocator, const iree_hal_dim_t* shape,
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
const iree_hal_dim_t* indices, iree_host_size_t indices_count,
iree_device_size_t* out_offset) {
if (!out_offset) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_offset = 0;
if (!allocator) {
return IREE_STATUS_INVALID_ARGUMENT;
}
if (shape_rank != indices_count) {
return IREE_STATUS_INVALID_ARGUMENT;
}
// TODO(benvanik): layout/padding.
iree_device_size_t offset = 0;
for (int i = 0; i < indices_count; ++i) {
if (indices[i] >= shape[i]) {
return IREE_STATUS_OUT_OF_RANGE;
}
iree_device_size_t axis_offset = indices[i];
for (int j = i + 1; j < shape_rank; ++j) {
axis_offset *= shape[j];
}
offset += axis_offset;
}
offset *= iree_hal_element_byte_count(element_type);
*out_offset = offset;
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_allocator_compute_range(
const iree_hal_allocator_t* allocator, const iree_hal_dim_t* shape,
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
const iree_hal_dim_t* start_indices, iree_host_size_t indices_count,
const iree_hal_dim_t* lengths, iree_host_size_t lengths_count,
iree_device_size_t* out_start_offset, iree_device_size_t* out_length) {
if (!out_start_offset || !out_length) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_start_offset = 0;
*out_length = 0;
if (!allocator) {
return IREE_STATUS_INVALID_ARGUMENT;
}
if (indices_count != lengths_count || indices_count != shape_rank) {
return IREE_STATUS_INVALID_ARGUMENT;
}
// TODO(benvanik): layout/padding.
absl::InlinedVector<iree_hal_dim_t, 6> end_indices(shape_rank);
iree_device_size_t element_size = iree_hal_element_byte_count(element_type);
iree_device_size_t subspan_length = element_size;
for (int i = 0; i < lengths_count; ++i) {
subspan_length *= lengths[i];
end_indices[i] = start_indices[i] + lengths[i] - 1;
}
iree_device_size_t start_byte_offset = 0;
IREE_RETURN_IF_ERROR(iree_hal_allocator_compute_offset(
allocator, shape, shape_rank, element_type, start_indices, indices_count,
&start_byte_offset));
iree_device_size_t end_byte_offset = 0;
IREE_RETURN_IF_ERROR(iree_hal_allocator_compute_offset(
allocator, shape, shape_rank, element_type, end_indices.data(),
end_indices.size(), &end_byte_offset));
// Non-contiguous regions not yet implemented. Will be easier to detect when
// we have strides.
auto offset_length = end_byte_offset - start_byte_offset + element_size;
if (subspan_length != offset_length) {
return IREE_STATUS_UNIMPLEMENTED;
}
*out_start_offset = start_byte_offset;
*out_length = subspan_length;
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_allocator_allocate_buffer(
iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type,
iree_hal_buffer_usage_t buffer_usage, iree_host_size_t allocation_size,
iree_hal_buffer_t** out_buffer) {
IREE_TRACE_SCOPE0("iree_hal_allocator_allocate_buffer");
if (!out_buffer) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_buffer = nullptr;
auto* handle = reinterpret_cast<Allocator*>(allocator);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
IREE_API_ASSIGN_OR_RETURN(
auto buffer,
handle->Allocate(static_cast<MemoryTypeBitfield>(memory_type),
static_cast<BufferUsageBitfield>(buffer_usage),
allocation_size));
*out_buffer = reinterpret_cast<iree_hal_buffer_t*>(buffer.release());
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_allocator_wrap_buffer(
iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type,
iree_hal_memory_access_t allowed_access,
iree_hal_buffer_usage_t buffer_usage, iree_byte_span_t data,
iree_hal_buffer_t** out_buffer) {
IREE_TRACE_SCOPE0("iree_hal_allocator_wrap_buffer");
if (!out_buffer) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_buffer = nullptr;
auto* handle = reinterpret_cast<Allocator*>(allocator);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
IREE_API_ASSIGN_OR_RETURN(
auto buffer,
handle->WrapMutable(static_cast<MemoryTypeBitfield>(memory_type),
static_cast<MemoryAccessBitfield>(allowed_access),
static_cast<BufferUsageBitfield>(buffer_usage),
data.data, data.data_length));
*out_buffer = reinterpret_cast<iree_hal_buffer_t*>(buffer.release());
return IREE_STATUS_OK;
}
//===----------------------------------------------------------------------===//
// iree::hal::Buffer
//===----------------------------------------------------------------------===//
IREE_HAL_API_RETAIN_RELEASE(buffer, Buffer);
IREE_API_EXPORT iree_status_t iree_hal_buffer_subspan(
iree_hal_buffer_t* buffer, iree_device_size_t byte_offset,
iree_device_size_t byte_length, iree_allocator_t allocator,
iree_hal_buffer_t** out_buffer) {
IREE_TRACE_SCOPE0("iree_hal_buffer_subspan");
if (!out_buffer) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_buffer = nullptr;
if (!buffer) {
return IREE_STATUS_INVALID_ARGUMENT;
}
auto handle = add_ref(reinterpret_cast<Buffer*>(buffer));
IREE_API_ASSIGN_OR_RETURN(auto new_handle,
Buffer::Subspan(handle, byte_offset, byte_length));
*out_buffer = reinterpret_cast<iree_hal_buffer_t*>(new_handle.release());
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_hal_allocator_t* IREE_API_CALL
iree_hal_buffer_allocator(const iree_hal_buffer_t* buffer) {
const auto* handle = reinterpret_cast<const Buffer*>(buffer);
CHECK(handle) << "NULL buffer handle";
return reinterpret_cast<iree_hal_allocator_t*>(handle->allocator());
}
IREE_API_EXPORT iree_device_size_t
iree_hal_buffer_byte_length(const iree_hal_buffer_t* buffer) {
const auto* handle = reinterpret_cast<const Buffer*>(buffer);
CHECK(handle) << "NULL buffer handle";
return handle->byte_length();
}
IREE_API_EXPORT iree_status_t
iree_hal_buffer_zero(iree_hal_buffer_t* buffer, iree_device_size_t byte_offset,
iree_device_size_t byte_length) {
IREE_TRACE_SCOPE0("iree_hal_buffer_zero");
auto* handle = reinterpret_cast<Buffer*>(buffer);
if (!buffer) {
return IREE_STATUS_INVALID_ARGUMENT;
}
IREE_API_RETURN_IF_ERROR(handle->Fill8(byte_offset, byte_length, 0));
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t
iree_hal_buffer_fill(iree_hal_buffer_t* buffer, iree_device_size_t byte_offset,
iree_device_size_t byte_length, const void* pattern,
iree_host_size_t pattern_length) {
IREE_TRACE_SCOPE0("iree_hal_buffer_fill");
auto* handle = reinterpret_cast<Buffer*>(buffer);
if (!buffer) {
return IREE_STATUS_INVALID_ARGUMENT;
}
IREE_API_RETURN_IF_ERROR(
handle->Fill(byte_offset, byte_length, pattern, pattern_length));
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t iree_hal_buffer_read_data(
iree_hal_buffer_t* buffer, iree_device_size_t source_offset,
void* target_buffer, iree_device_size_t data_length) {
IREE_TRACE_SCOPE0("iree_hal_buffer_read_data");
auto* handle = reinterpret_cast<Buffer*>(buffer);
if (!buffer) {
return IREE_STATUS_INVALID_ARGUMENT;
}
IREE_API_RETURN_IF_ERROR(
handle->ReadData(source_offset, target_buffer, data_length));
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t iree_hal_buffer_write_data(
iree_hal_buffer_t* buffer, iree_device_size_t target_offset,
const void* source_buffer, iree_device_size_t data_length) {
IREE_TRACE_SCOPE0("iree_hal_buffer_write_data");
auto* handle = reinterpret_cast<Buffer*>(buffer);
if (!buffer) {
return IREE_STATUS_INVALID_ARGUMENT;
}
IREE_API_RETURN_IF_ERROR(
handle->WriteData(target_offset, source_buffer, data_length));
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t iree_hal_buffer_map(
iree_hal_buffer_t* buffer, iree_hal_memory_access_t memory_access,
iree_device_size_t byte_offset, iree_device_size_t byte_length,
iree_hal_mapped_memory_t* out_mapped_memory) {
IREE_TRACE_SCOPE0("iree_hal_buffer_map");
if (!out_mapped_memory) {
LOG(ERROR) << "output mapped memory not set";
return IREE_STATUS_INVALID_ARGUMENT;
}
std::memset(out_mapped_memory, 0, sizeof(*out_mapped_memory));
if (!buffer) {
LOG(ERROR) << "buffer not set";
return IREE_STATUS_INVALID_ARGUMENT;
}
auto* buffer_handle = reinterpret_cast<Buffer*>(buffer);
IREE_API_ASSIGN_OR_RETURN(
auto mapping, buffer_handle->MapMemory<uint8_t>(
static_cast<MemoryAccessBitfield>(memory_access),
byte_offset, byte_length));
static_assert(sizeof(iree_hal_mapped_memory_t::reserved) >=
sizeof(MappedMemory<uint8_t>),
"C mapped memory struct must have large enough storage for the "
"matching C++ struct");
auto* mapping_storage =
reinterpret_cast<MappedMemory<uint8_t>*>(out_mapped_memory->reserved);
*mapping_storage = std::move(mapping);
out_mapped_memory->contents = {mapping_storage->unsafe_data(),
mapping_storage->size()};
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t iree_hal_buffer_unmap(
iree_hal_buffer_t* buffer, iree_hal_mapped_memory_t* mapped_memory) {
IREE_TRACE_SCOPE0("iree_hal_buffer_map");
auto* buffer_handle = reinterpret_cast<Buffer*>(buffer);
if (!buffer_handle || !mapped_memory) {
return IREE_STATUS_INVALID_ARGUMENT;
}
auto* mapping =
reinterpret_cast<MappedMemory<uint8_t>*>(mapped_memory->reserved);
mapping->reset();
std::memset(mapped_memory, 0, sizeof(*mapped_memory));
return IREE_STATUS_OK;
}
//===----------------------------------------------------------------------===//
// iree::hal::HeapBuffer
//===----------------------------------------------------------------------===//
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_allocate(
iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t usage,
iree_host_size_t allocation_size, iree_allocator_t contents_allocator,
iree_allocator_t allocator, iree_hal_buffer_t** out_buffer) {
IREE_TRACE_SCOPE0("iree_hal_heap_buffer_allocate");
if (!out_buffer) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_buffer = nullptr;
if (!allocation_size) {
return IREE_STATUS_INVALID_ARGUMENT;
}
auto handle = HeapBuffer::Allocate(
static_cast<MemoryTypeBitfield>(memory_type),
static_cast<BufferUsageBitfield>(usage), allocation_size);
*out_buffer = reinterpret_cast<iree_hal_buffer_t*>(
static_cast<Buffer*>(handle.release()));
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_allocate_copy(
iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t usage,
iree_hal_memory_access_t allowed_access, iree_byte_span_t contents,
iree_allocator_t contents_allocator, iree_allocator_t allocator,
iree_hal_buffer_t** out_buffer) {
IREE_TRACE_SCOPE0("iree_hal_heap_buffer_allocate_copy");
if (!out_buffer) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_buffer = nullptr;
if (!contents.data || !contents.data_length) {
return IREE_STATUS_INVALID_ARGUMENT;
}
auto handle = HeapBuffer::AllocateCopy(
static_cast<BufferUsageBitfield>(usage),
static_cast<MemoryAccessBitfield>(allowed_access), contents.data,
contents.data_length);
*out_buffer = reinterpret_cast<iree_hal_buffer_t*>(handle.release());
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_wrap(
iree_hal_memory_type_t memory_type, iree_hal_memory_access_t allowed_access,
iree_hal_buffer_usage_t usage, iree_byte_span_t contents,
iree_allocator_t allocator, iree_hal_buffer_t** out_buffer) {
IREE_TRACE_SCOPE0("iree_hal_heap_buffer_wrap");
if (!out_buffer) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_buffer = nullptr;
if (!contents.data || !contents.data_length) {
return IREE_STATUS_INVALID_ARGUMENT;
}
auto handle =
HeapBuffer::WrapMutable(static_cast<MemoryTypeBitfield>(memory_type),
static_cast<MemoryAccessBitfield>(allowed_access),
static_cast<BufferUsageBitfield>(usage),
contents.data, contents.data_length);
*out_buffer = reinterpret_cast<iree_hal_buffer_t*>(handle.release());
return IREE_STATUS_OK;
}
//===----------------------------------------------------------------------===//
// iree::hal::BufferView
//===----------------------------------------------------------------------===//
IREE_HAL_API_RETAIN_RELEASE(buffer_view, iree_hal_buffer_view);
IREE_API_EXPORT iree_status_t iree_hal_buffer_view_create(
iree_hal_buffer_t* buffer, const iree_hal_dim_t* shape,
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
iree_allocator_t allocator, iree_hal_buffer_view_t** out_buffer_view) {
IREE_TRACE_SCOPE0("iree_hal_buffer_view_create");
if (!out_buffer_view) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_buffer_view = nullptr;
if (!buffer) {
return IREE_STATUS_INVALID_ARGUMENT;
}
// Allocate and initialize the iree_hal_buffer_view struct.
// Note that we have the dynamically-sized shape dimensions on the end.
iree_hal_buffer_view* buffer_view = nullptr;
IREE_RETURN_IF_ERROR(iree_allocator_malloc(
allocator, sizeof(*buffer_view) + sizeof(iree_hal_dim_t) * shape_rank,
reinterpret_cast<void**>(&buffer_view)));
new (buffer_view) iree_hal_buffer_view();
buffer_view->allocator = allocator;
buffer_view->buffer = buffer;
iree_hal_buffer_retain(buffer_view->buffer);
buffer_view->element_type = element_type;
buffer_view->byte_length =
iree_hal_element_byte_count(buffer_view->element_type);
buffer_view->shape_rank = shape_rank;
for (int i = 0; i < shape_rank; ++i) {
buffer_view->shape[i] = shape[i];
buffer_view->byte_length *= shape[i];
}
*out_buffer_view = buffer_view;
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_subview(
const iree_hal_buffer_view_t* buffer_view,
const iree_hal_dim_t* start_indices, iree_host_size_t indices_count,
const iree_hal_dim_t* lengths, iree_host_size_t lengths_count,
iree_allocator_t allocator, iree_hal_buffer_view_t** out_buffer_view) {
if (!out_buffer_view) {
return IREE_STATUS_INVALID_ARGUMENT;
}
// NOTE: we rely on the compute range call to do parameter validation.
iree_device_size_t start_offset = 0;
iree_device_size_t subview_length = 0;
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_compute_range(
buffer_view, start_indices, indices_count, lengths, lengths_count,
&start_offset, &subview_length));
iree_hal_buffer_t* subview_buffer = nullptr;
IREE_RETURN_IF_ERROR(iree_hal_buffer_subspan(buffer_view->buffer,
start_offset, subview_length,
allocator, &subview_buffer));
iree_status_t result = iree_hal_buffer_view_create(
subview_buffer, lengths, lengths_count, buffer_view->element_type,
allocator, out_buffer_view);
iree_hal_buffer_release(subview_buffer);
return result;
}
IREE_API_EXPORT iree_hal_buffer_t* iree_hal_buffer_view_buffer(
const iree_hal_buffer_view_t* buffer_view) {
CHECK(buffer_view) << "NULL buffer_view handle";
return buffer_view->buffer;
}
IREE_API_EXPORT iree_host_size_t IREE_API_CALL
iree_hal_buffer_view_shape_rank(const iree_hal_buffer_view_t* buffer_view) {
CHECK(buffer_view) << "NULL buffer_view handle";
return buffer_view->shape_rank;
}
IREE_API_EXPORT iree_host_size_t IREE_API_CALL iree_hal_buffer_view_shape_dim(
const iree_hal_buffer_view_t* buffer_view, iree_host_size_t index) {
CHECK(buffer_view) << "NULL buffer_view handle";
if (index > buffer_view->shape_rank) {
return 0;
}
return buffer_view->shape[index];
}
IREE_API_EXPORT iree_host_size_t
iree_hal_buffer_view_element_count(const iree_hal_buffer_view_t* buffer_view) {
CHECK(buffer_view) << "NULL buffer_view handle";
iree_host_size_t element_count = 1;
for (iree_host_size_t i = 0; i < buffer_view->shape_rank; ++i) {
element_count *= buffer_view->shape[i];
}
return element_count;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_shape(
const iree_hal_buffer_view_t* buffer_view, iree_host_size_t rank_capacity,
iree_hal_dim_t* out_shape, iree_host_size_t* out_shape_rank) {
if (out_shape_rank) {
*out_shape_rank = 0;
}
if (!buffer_view || !out_shape) {
return IREE_STATUS_INVALID_ARGUMENT;
}
if (out_shape_rank) {
*out_shape_rank = buffer_view->shape_rank;
}
if (rank_capacity < buffer_view->shape_rank) {
return IREE_STATUS_OUT_OF_RANGE;
}
for (int i = 0; i < buffer_view->shape_rank; ++i) {
out_shape[i] = buffer_view->shape[i];
}
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_hal_element_type_t IREE_API_CALL
iree_hal_buffer_view_element_type(const iree_hal_buffer_view_t* buffer_view) {
CHECK(buffer_view) << "NULL buffer_view handle";
return buffer_view->element_type;
}
IREE_API_EXPORT iree_host_size_t
iree_hal_buffer_view_element_size(const iree_hal_buffer_view_t* buffer_view) {
CHECK(buffer_view) << "NULL buffer_view handle";
return iree_hal_element_byte_count(buffer_view->element_type);
}
IREE_API_EXPORT iree_device_size_t IREE_API_CALL
iree_hal_buffer_view_byte_length(const iree_hal_buffer_view_t* buffer_view) {
CHECK(buffer_view) << "NULL buffer_view handle";
return buffer_view->byte_length;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_compute_offset(
const iree_hal_buffer_view_t* buffer_view, const iree_hal_dim_t* indices,
iree_host_size_t indices_count, iree_device_size_t* out_offset) {
if (!buffer_view) {
return IREE_STATUS_INVALID_ARGUMENT;
}
return iree_hal_allocator_compute_offset(
iree_hal_buffer_allocator(buffer_view->buffer), buffer_view->shape,
buffer_view->shape_rank, buffer_view->element_type, indices,
indices_count, out_offset);
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_compute_range(
const iree_hal_buffer_view_t* buffer_view,
const iree_hal_dim_t* start_indices, iree_host_size_t indices_count,
const iree_hal_dim_t* lengths, iree_host_size_t lengths_count,
iree_device_size_t* out_start_offset, iree_device_size_t* out_length) {
if (!buffer_view) {
return IREE_STATUS_INVALID_ARGUMENT;
}
return iree_hal_allocator_compute_range(
iree_hal_buffer_allocator(buffer_view->buffer), buffer_view->shape,
buffer_view->shape_rank, buffer_view->element_type, start_indices,
indices_count, lengths, lengths_count, out_start_offset, out_length);
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_parse(
iree_string_view_t value, iree_hal_allocator_t* buffer_allocator,
iree_allocator_t allocator, iree_hal_buffer_view_t** out_buffer_view) {
IREE_TRACE_SCOPE0("iree_hal_buffer_view_parse");
// Strip whitespace that may come along (linefeeds/etc).
auto string_view =
absl::StripAsciiWhitespace(absl::string_view(value.data, value.size));
string_view = absl::StripPrefix(string_view, "\"");
string_view = absl::StripSuffix(string_view, "\"");
if (string_view.empty()) {
// Empty lines are invalid; need at least the shape/type information.
*out_buffer_view = nullptr;
return IREE_STATUS_INVALID_ARGUMENT;
}
// The part of the string corresponding to the shape, e.g. 1x2x3.
absl::string_view shape_str;
// The part of the string corresponding to the type, e.g. f32
absl::string_view type_str;
// The part of the string corresponding to the buffer data, e.g. 1 2 3 4 5 6
absl::string_view data_str;
absl::string_view shape_and_type_str;
auto equal_index = string_view.find('=');
if (equal_index == std::string::npos) {
// Treat a lack of = as defaulting the data to zeros.
shape_and_type_str = string_view;
} else {
shape_and_type_str = string_view.substr(0, equal_index);
data_str = string_view.substr(equal_index + 1);
}
auto last_x_index = shape_and_type_str.rfind('x');
if (last_x_index == std::string::npos) {
// Scalar.
type_str = shape_and_type_str;
} else {
// Has a shape.
shape_str = shape_and_type_str.substr(0, last_x_index);
type_str = shape_and_type_str.substr(last_x_index + 1);
}
// AxBxC...
absl::InlinedVector<iree_hal_dim_t, 6> shape(6);
iree_host_size_t shape_rank = 0;
iree_status_t shape_result =
iree_hal_parse_shape({shape_str.data(), shape_str.length()}, shape.size(),
shape.data(), &shape_rank);
if (iree_status_is_ok(shape_result)) {
shape.resize(shape_rank);
} else if (iree_status_is_out_of_range(shape_result)) {
shape.resize(shape_rank);
IREE_RETURN_IF_ERROR(
iree_hal_parse_shape({shape_str.data(), shape_str.length()},
shape.size(), shape.data(), &shape_rank));
} else {
return shape_result;
}
// f32, i32, etc
iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE;
IREE_RETURN_IF_ERROR(iree_hal_parse_element_type(
{type_str.data(), type_str.length()}, &element_type));
// Allocate the buffer we will parse into from the provided allocator.
iree_device_size_t buffer_length = 0;
IREE_RETURN_IF_ERROR(iree_hal_allocator_compute_size(
buffer_allocator, shape.data(), shape.size(), element_type,
&buffer_length));
iree_hal_buffer_t* buffer = nullptr;
IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer(
buffer_allocator,
IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING,
buffer_length, &buffer));
iree_status_t status;
// Parse the elements directly into the buffer.
iree_hal_mapped_memory_t mapped_buffer;
status = iree_hal_buffer_map(buffer, IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE, 0,
buffer_length, &mapped_buffer);
if (!iree_status_is_ok(status)) {
iree_hal_buffer_release(buffer);
return status;
}
status = iree_hal_parse_buffer_elements({data_str.data(), data_str.length()},
element_type, mapped_buffer.contents);
iree_hal_buffer_unmap(buffer, &mapped_buffer);
if (!iree_status_is_ok(status)) {
iree_hal_buffer_release(buffer);
return status;
}
// Wrap and pass ownership of the buffer to the buffer view.
status =
iree_hal_buffer_view_create(buffer, shape.data(), shape.size(),
element_type, allocator, out_buffer_view);
iree_hal_buffer_release(buffer);
return status;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_format(
const iree_hal_buffer_view_t* buffer_view,
iree_host_size_t max_element_count, iree_host_size_t buffer_capacity,
char* buffer, iree_host_size_t* out_buffer_length) {
IREE_TRACE_SCOPE0("iree_hal_buffer_view_format");
if (out_buffer_length) {
*out_buffer_length = 0;
}
if (buffer && buffer_capacity) {
buffer[0] = 0;
}
if (!buffer_view) {
return IREE_STATUS_INVALID_ARGUMENT;
}
iree_status_t status;
iree_host_size_t buffer_length = 0;
auto append_char = [&](char c) {
if (buffer) {
if (buffer_length < buffer_capacity - 1) {
buffer[buffer_length] = c;
buffer[buffer_length + 1] = '\0';
} else {
buffer = nullptr;
}
}
++buffer_length;
};
if (buffer_view->shape_rank > 0) {
// Shape: 1x2x3
iree_host_size_t shape_length = 0;
status = iree_hal_format_shape(buffer_view->shape, buffer_view->shape_rank,
buffer ? buffer_capacity - buffer_length : 0,
buffer ? buffer + buffer_length : nullptr,
&shape_length);
buffer_length += shape_length;
if (status == IREE_STATUS_OUT_OF_RANGE) {
buffer = nullptr;
} else if (!iree_status_is_ok(status)) {
return status;
}
// Separator: <shape>x<format>
append_char('x');
}
// Element type: f32
iree_host_size_t element_type_length = 0;
status = iree_hal_format_element_type(
buffer_view->element_type, buffer ? buffer_capacity - buffer_length : 0,
buffer ? buffer + buffer_length : nullptr, &element_type_length);
buffer_length += element_type_length;
if (status == IREE_STATUS_OUT_OF_RANGE) {
buffer = nullptr;
} else if (!iree_status_is_ok(status)) {
return status;
}
// Separator: <meta>=<value>
append_char('=');
// Buffer contents: 0 1 2 3 ...
iree_hal_mapped_memory_t mapped_buffer;
IREE_RETURN_IF_ERROR(iree_hal_buffer_map(buffer_view->buffer,
IREE_HAL_MEMORY_ACCESS_READ, 0,
IREE_WHOLE_BUFFER, &mapped_buffer));
iree_host_size_t elements_length = 0;
status = iree_hal_format_buffer_elements(
iree_const_byte_span_t{mapped_buffer.contents.data,
mapped_buffer.contents.data_length},
buffer_view->shape, buffer_view->shape_rank, buffer_view->element_type,
max_element_count, buffer ? buffer_capacity - buffer_length : 0,
buffer ? buffer + buffer_length : nullptr, &elements_length);
buffer_length += elements_length;
iree_hal_buffer_unmap(buffer_view->buffer, &mapped_buffer);
if (status == IREE_STATUS_OUT_OF_RANGE) {
buffer = nullptr;
} else if (!iree_status_is_ok(status)) {
return status;
}
if (out_buffer_length) {
*out_buffer_length = buffer_length;
}
return buffer ? IREE_STATUS_OK : IREE_STATUS_OUT_OF_RANGE;
}
//===----------------------------------------------------------------------===//
// iree::hal::CommandBuffer
//===----------------------------------------------------------------------===//
IREE_HAL_API_RETAIN_RELEASE(command_buffer, CommandBuffer);
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_command_buffer_create(
iree_hal_device_t* device, iree_hal_command_buffer_mode_t mode,
iree_hal_command_category_t command_categories, iree_allocator_t allocator,
iree_hal_command_buffer_t** out_command_buffer) {
IREE_TRACE_SCOPE0("iree_hal_command_buffer_create");
if (!out_command_buffer) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_command_buffer = nullptr;
auto* handle = reinterpret_cast<Device*>(device);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
IREE_API_ASSIGN_OR_RETURN(
auto command_buffer,
handle->CreateCommandBuffer(
static_cast<CommandBufferModeBitfield>(mode),
static_cast<CommandCategoryBitfield>(command_categories)));
*out_command_buffer =
reinterpret_cast<iree_hal_command_buffer_t*>(command_buffer.release());
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t
iree_hal_command_buffer_begin(iree_hal_command_buffer_t* command_buffer) {
IREE_TRACE_SCOPE0("iree_hal_command_buffer_begin");
auto* handle = reinterpret_cast<CommandBuffer*>(command_buffer);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
return ToApiStatus(handle->Begin());
}
IREE_API_EXPORT iree_status_t
iree_hal_command_buffer_end(iree_hal_command_buffer_t* command_buffer) {
IREE_TRACE_SCOPE0("iree_hal_command_buffer_end");
auto* handle = reinterpret_cast<CommandBuffer*>(command_buffer);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
return ToApiStatus(handle->End());
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_hal_command_buffer_execution_barrier(
iree_hal_command_buffer_t* command_buffer,
iree_hal_execution_stage_t source_stage_mask,
iree_hal_execution_stage_t target_stage_mask,
iree_host_size_t memory_barrier_count,
const iree_hal_memory_barrier_t* memory_barriers,
iree_host_size_t buffer_barrier_count,
const iree_hal_buffer_barrier_t* buffer_barriers) {
IREE_TRACE_SCOPE0("iree_hal_command_buffer_execution_barrier");
auto* handle = reinterpret_cast<CommandBuffer*>(command_buffer);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
// TODO(benvanik): refactor the C++ types to use the C types for storage so
// that we can safely map between the two. For now assume size equality
// is layout equality (as compilers aren't allowed to reorder structs).
static_assert(sizeof(MemoryBarrier) == sizeof(iree_hal_memory_barrier_t),
"Expecting identical layout");
static_assert(sizeof(BufferBarrier) == sizeof(iree_hal_buffer_barrier_t),
"Expecting identical layout");
return ToApiStatus(handle->ExecutionBarrier(
static_cast<ExecutionStageBitfield>(source_stage_mask),
static_cast<ExecutionStageBitfield>(target_stage_mask),
absl::MakeConstSpan(
reinterpret_cast<const MemoryBarrier*>(memory_barriers),
memory_barrier_count),
absl::MakeConstSpan(
reinterpret_cast<const BufferBarrier*>(buffer_barriers),
buffer_barrier_count)));
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_command_buffer_fill_buffer(
iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_t* target_buffer,
iree_device_size_t target_offset, iree_device_size_t length,
const void* pattern, iree_host_size_t pattern_length) {
IREE_TRACE_SCOPE0("iree_hal_command_buffer_fill_buffer");
auto* handle = reinterpret_cast<CommandBuffer*>(command_buffer);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
return ToApiStatus(
handle->FillBuffer(reinterpret_cast<Buffer*>(target_buffer),
target_offset, length, pattern, pattern_length));
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_hal_command_buffer_update_buffer(iree_hal_command_buffer_t* command_buffer,
const void* source_buffer,
iree_host_size_t source_offset,
iree_hal_buffer_t* target_buffer,
iree_device_size_t target_offset,
iree_device_size_t length) {
IREE_TRACE_SCOPE0("iree_hal_command_buffer_update_buffer");
auto* handle = reinterpret_cast<CommandBuffer*>(command_buffer);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
return ToApiStatus(handle->UpdateBuffer(
source_buffer, source_offset, reinterpret_cast<Buffer*>(target_buffer),
target_offset, length));
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_command_buffer_copy_buffer(
iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_t* source_buffer,
iree_device_size_t source_offset, iree_hal_buffer_t* target_buffer,
iree_device_size_t target_offset, iree_device_size_t length) {
IREE_TRACE_SCOPE0("iree_hal_command_buffer_copy_buffer");
auto* handle = reinterpret_cast<CommandBuffer*>(command_buffer);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
return ToApiStatus(handle->CopyBuffer(
reinterpret_cast<Buffer*>(source_buffer), source_offset,
reinterpret_cast<Buffer*>(target_buffer), target_offset, length));
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_hal_command_buffer_push_constants(
iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_layout_t* executable_layout, iree_host_size_t offset,
const void* values, iree_host_size_t values_length) {
IREE_TRACE_SCOPE0("iree_hal_command_buffer_push_constants");
auto* handle = reinterpret_cast<CommandBuffer*>(command_buffer);
if (!handle || !executable_layout || !values) {
return IREE_STATUS_INVALID_ARGUMENT;
} else if (values_length == 0) {
return IREE_STATUS_OK;
}
if ((values_length % 4) != 0) {
return IREE_STATUS_INVALID_ARGUMENT;
}
return ToApiStatus(handle->PushConstants(
reinterpret_cast<ExecutableLayout*>(executable_layout), offset,
absl::MakeConstSpan(reinterpret_cast<const uint32_t*>(values),
values_length / sizeof(uint32_t))));
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_hal_command_buffer_push_descriptor_set(
iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_layout_t* executable_layout, int32_t set,
iree_host_size_t binding_count,
const iree_hal_descriptor_set_binding_t* bindings) {
IREE_TRACE_SCOPE0("iree_hal_command_buffer_push_descriptor_set");
auto* handle = reinterpret_cast<CommandBuffer*>(command_buffer);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
if (!executable_layout) {
return IREE_STATUS_INVALID_ARGUMENT;
}
if (binding_count && !bindings) {
return IREE_STATUS_INVALID_ARGUMENT;
}
// TODO(benvanik): refactor the C++ types to use the C types for storage so
// that we can safely map between the two. For now assume size equality
// is layout equality (as compilers aren't allowed to reorder structs).
static_assert(sizeof(DescriptorSet::Binding) ==
sizeof(iree_hal_descriptor_set_binding_t),
"Expecting identical layout");
return ToApiStatus(handle->PushDescriptorSet(
reinterpret_cast<ExecutableLayout*>(executable_layout), set,
absl::MakeConstSpan(
reinterpret_cast<const DescriptorSet::Binding*>(bindings),
binding_count)));
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_hal_command_buffer_bind_descriptor_set(
iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_layout_t* executable_layout, int32_t set,
iree_hal_descriptor_set_t* descriptor_set,
iree_host_size_t dynamic_offset_count,
const iree_device_size_t* dynamic_offsets) {
IREE_TRACE_SCOPE0("iree_hal_command_buffer_bind_descriptor_set");
auto* handle = reinterpret_cast<CommandBuffer*>(command_buffer);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
if (!executable_layout) {
return IREE_STATUS_INVALID_ARGUMENT;
}
if (dynamic_offset_count && !dynamic_offsets) {
return IREE_STATUS_INVALID_ARGUMENT;
}
static_assert(sizeof(iree_device_size_t) == sizeof(device_size_t),
"Device sizes must match");
return ToApiStatus(handle->BindDescriptorSet(
reinterpret_cast<ExecutableLayout*>(executable_layout), set,
reinterpret_cast<DescriptorSet*>(descriptor_set),
absl::MakeConstSpan(dynamic_offsets, dynamic_offset_count)));
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_command_buffer_dispatch(
iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
IREE_TRACE_SCOPE0("iree_hal_command_buffer_dispatch");
auto* handle = reinterpret_cast<CommandBuffer*>(command_buffer);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
if (!executable) {
return IREE_STATUS_INVALID_ARGUMENT;
}
return ToApiStatus(handle->Dispatch(reinterpret_cast<Executable*>(executable),
entry_point,
{workgroup_x, workgroup_y, workgroup_z}));
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_hal_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
iree_hal_buffer_t* workgroups_buffer,
iree_device_size_t workgroups_offset) {
IREE_TRACE_SCOPE0("iree_hal_command_buffer_dispatch_indirect");
auto* handle = reinterpret_cast<CommandBuffer*>(command_buffer);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
if (!executable || workgroups_buffer) {
return IREE_STATUS_INVALID_ARGUMENT;
}
return ToApiStatus(handle->DispatchIndirect(
reinterpret_cast<Executable*>(executable), entry_point,
reinterpret_cast<Buffer*>(workgroups_buffer), workgroups_offset));
}
//===----------------------------------------------------------------------===//
// iree::hal::DescriptorSet
//===----------------------------------------------------------------------===//
IREE_HAL_API_RETAIN_RELEASE(descriptor_set, DescriptorSet);
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_descriptor_set_create(
iree_hal_device_t* device, iree_hal_descriptor_set_layout_t* set_layout,
iree_host_size_t binding_count,
const iree_hal_descriptor_set_binding_t* bindings,
iree_allocator_t allocator,
iree_hal_descriptor_set_t** out_descriptor_set) {
IREE_TRACE_SCOPE0("iree_hal_descriptor_set_create");
if (!out_descriptor_set) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_descriptor_set = nullptr;
if (!set_layout) {
return IREE_STATUS_INVALID_ARGUMENT;
}
if (binding_count && !bindings) {
return IREE_STATUS_INVALID_ARGUMENT;
}
auto* handle = reinterpret_cast<Device*>(device);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
// TODO(benvanik): refactor the C++ types to use the C types for storage so
// that we can safely map between the two. For now assume size equality
// is layout equality (as compilers aren't allowed to reorder structs).
static_assert(sizeof(DescriptorSet::Binding) ==
sizeof(iree_hal_descriptor_set_binding_t),
"Expecting identical layout");
IREE_API_ASSIGN_OR_RETURN(
auto descriptor_set,
handle->CreateDescriptorSet(
reinterpret_cast<DescriptorSetLayout*>(set_layout),
absl::MakeConstSpan(
reinterpret_cast<const DescriptorSet::Binding*>(bindings),
binding_count)));
*out_descriptor_set =
reinterpret_cast<iree_hal_descriptor_set_t*>(descriptor_set.release());
return IREE_STATUS_OK;
}
//===----------------------------------------------------------------------===//
// iree::hal::DescriptorSetLayout
//===----------------------------------------------------------------------===//
IREE_HAL_API_RETAIN_RELEASE(descriptor_set_layout, DescriptorSetLayout);
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_hal_descriptor_set_layout_create(
iree_hal_device_t* device,
iree_hal_descriptor_set_layout_usage_type_t usage_type,
iree_host_size_t binding_count,
const iree_hal_descriptor_set_layout_binding_t* bindings,
iree_allocator_t allocator,
iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) {
IREE_TRACE_SCOPE0("iree_hal_descriptor_set_layout_create");
if (!out_descriptor_set_layout) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_descriptor_set_layout = nullptr;
if (binding_count && !bindings) {
return IREE_STATUS_INVALID_ARGUMENT;
}
auto* handle = reinterpret_cast<Device*>(device);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
// TODO(benvanik): refactor the C++ types to use the C types for storage so
// that we can safely map between the two. For now assume size equality
// is layout equality (as compilers aren't allowed to reorder structs).
static_assert(sizeof(DescriptorSetLayout::Binding) ==
sizeof(iree_hal_descriptor_set_layout_binding_t),
"Expecting identical layout");
IREE_API_ASSIGN_OR_RETURN(
auto descriptor_set_layout,
handle->CreateDescriptorSetLayout(
static_cast<DescriptorSetLayout::UsageType>(usage_type),
absl::MakeConstSpan(
reinterpret_cast<const DescriptorSetLayout::Binding*>(bindings),
binding_count)));
*out_descriptor_set_layout =
reinterpret_cast<iree_hal_descriptor_set_layout_t*>(
descriptor_set_layout.release());
return IREE_STATUS_OK;
}
//===----------------------------------------------------------------------===//
// iree::hal::Device
//===----------------------------------------------------------------------===//
IREE_HAL_API_RETAIN_RELEASE(device, Device);
IREE_API_EXPORT iree_hal_allocator_t* IREE_API_CALL
iree_hal_device_allocator(iree_hal_device_t* device) {
auto* handle = reinterpret_cast<Device*>(device);
if (!handle) return nullptr;
return reinterpret_cast<iree_hal_allocator_t*>(handle->allocator());
}
IREE_API_EXPORT iree_string_view_t IREE_API_CALL
iree_hal_device_id(iree_hal_device_t* device) {
auto* handle = reinterpret_cast<Device*>(device);
if (!handle) return IREE_STRING_VIEW_EMPTY;
const auto& id = handle->info().id();
return iree_string_view_t{id.data(), id.size()};
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_device_queue_submit(
iree_hal_device_t* device, iree_hal_command_category_t command_categories,
uint64_t queue_affinity, iree_host_size_t batch_count,
const iree_hal_submission_batch_t* batches) {
IREE_TRACE_SCOPE0("iree_hal_device_queue_submit");
auto* handle = reinterpret_cast<Device*>(device);
if (!handle) return IREE_STATUS_INVALID_ARGUMENT;
if (batch_count > 0 && !batches) return IREE_STATUS_INVALID_ARGUMENT;
// We need to allocate storage to marshal in the semaphores. Ideally we'd
// change the C++ API to make this 1:1 with a reinterpret_cast, however that
// makes the C API more difficult. Bleh.
int total_semaphore_count = 0;
for (int i = 0; i < batch_count; ++i) {
total_semaphore_count += batches[i].wait_semaphores.count;
total_semaphore_count += batches[i].signal_semaphores.count;
}
absl::InlinedVector<SemaphoreValue, 4> semaphore_values(
total_semaphore_count);
absl::InlinedVector<SubmissionBatch, 2> dst_batches(batch_count);
int base_semaphore_index = 0;
for (int i = 0; i < batch_count; ++i) {
const auto& src_batch = batches[i];
auto& dst_batch = dst_batches[i];
for (int j = 0; j < src_batch.wait_semaphores.count; ++j) {
semaphore_values[base_semaphore_index + j] = {
reinterpret_cast<Semaphore*>(src_batch.wait_semaphores.semaphores[j]),
src_batch.wait_semaphores.payload_values[j]};
}
dst_batch.wait_semaphores =
absl::MakeConstSpan(&semaphore_values[base_semaphore_index],
src_batch.wait_semaphores.count);
base_semaphore_index += src_batch.wait_semaphores.count;
dst_batch.command_buffers =
iree::ReinterpretSpan<CommandBuffer*>(absl::MakeConstSpan(
src_batch.command_buffers, src_batch.command_buffer_count));
for (int j = 0; j < src_batch.signal_semaphores.count; ++j) {
semaphore_values[base_semaphore_index + j] = {
reinterpret_cast<Semaphore*>(
src_batch.signal_semaphores.semaphores[j]),
src_batch.signal_semaphores.payload_values[j]};
}
dst_batch.signal_semaphores =
absl::MakeConstSpan(&semaphore_values[base_semaphore_index],
src_batch.signal_semaphores.count);
base_semaphore_index += src_batch.signal_semaphores.count;
}
// For now we always go to the first compute queue. TBD cleanup pending the
// device modeling in the IR as to how we really want to handle this. We'll
// want to use queue_affinity in a way that ensures we have some control over
// things on the compiler side and may require that devices are declared by
// the number and types of queues they support.
uint64_t queue_index = queue_affinity % handle->dispatch_queues().size();
auto* command_queue = handle->dispatch_queues()[queue_index];
return ToApiStatus(command_queue->Submit(dst_batches));
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_hal_device_wait_semaphores_with_deadline(
iree_hal_device_t* device, iree_hal_wait_mode_t wait_mode,
const iree_hal_semaphore_list_t* semaphore_list, iree_time_t deadline_ns) {
IREE_TRACE_SCOPE0("iree_hal_device_wait_semaphores_with_deadline");
auto* handle = reinterpret_cast<Device*>(device);
if (!handle) return IREE_STATUS_INVALID_ARGUMENT;
if (!semaphore_list || semaphore_list->count == 0) return IREE_STATUS_OK;
absl::InlinedVector<SemaphoreValue, 4> semaphore_values(
semaphore_list->count);
for (int i = 0; i < semaphore_list->count; ++i) {
semaphore_values[i] = {
reinterpret_cast<Semaphore*>(semaphore_list->semaphores[i]),
semaphore_list->payload_values[i]};
}
Status wait_status;
switch (wait_mode) {
case IREE_HAL_WAIT_MODE_ALL:
wait_status =
handle->WaitAllSemaphores(semaphore_values, ToAbslTime(deadline_ns));
break;
case IREE_HAL_WAIT_MODE_ANY:
wait_status = std::move(handle->WaitAnySemaphore(semaphore_values,
ToAbslTime(deadline_ns)))
.status();
break;
default:
return IREE_STATUS_INVALID_ARGUMENT;
}
// NOTE: we avoid capturing stack traces on deadline exceeded as it's not a
// real error.
if (IsDeadlineExceeded(wait_status)) {
return iree_make_status(IREE_STATUS_DEADLINE_EXCEEDED);
}
return ToApiStatus(wait_status);
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_hal_device_wait_semaphores_with_timeout(
iree_hal_device_t* device, iree_hal_wait_mode_t wait_mode,
const iree_hal_semaphore_list_t* semaphore_list,
iree_duration_t timeout_ns) {
iree_time_t deadline_ns =
FromAbslTime(iree::RelativeTimeoutToDeadline(ToAbslDuration(timeout_ns)));
return iree_hal_device_wait_semaphores_with_deadline(
device, wait_mode, semaphore_list, deadline_ns);
}
//===----------------------------------------------------------------------===//
// iree::hal::Driver
//===----------------------------------------------------------------------===//
IREE_HAL_API_RETAIN_RELEASE(driver, Driver);
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_hal_driver_query_available_devices(
iree_hal_driver_t* driver, iree_allocator_t allocator,
iree_hal_device_info_t** out_device_infos,
iree_host_size_t* out_device_info_count) {
IREE_TRACE_SCOPE0("iree_hal_driver_query_available_devices");
if (!out_device_info_count) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_device_info_count = 0;
if (!out_device_infos) {
return IREE_STATUS_INVALID_ARGUMENT;
}
auto* handle = reinterpret_cast<Driver*>(driver);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
IREE_API_ASSIGN_OR_RETURN(auto device_infos,
handle->EnumerateAvailableDevices());
size_t total_string_size = 0;
for (const auto& device_info : device_infos) {
total_string_size += device_info.name().size();
}
*out_device_info_count = device_infos.size();
iree_hal_device_info_t* device_info_storage = nullptr;
IREE_RETURN_IF_ERROR(iree_allocator_malloc(
allocator,
device_infos.size() * sizeof(*device_info_storage) + total_string_size,
(void**)&device_info_storage));
char* p = reinterpret_cast<char*>(device_info_storage) +
device_infos.size() * sizeof(*device_info_storage);
for (int i = 0; i < device_infos.size(); ++i) {
const auto& device_info = device_infos[i];
device_info_storage[i].device_id = device_info.device_id();
size_t name_size = device_info.name().size();
std::memcpy(p, device_info.name().c_str(), name_size);
device_info_storage[i].name = iree_string_view_t{p, name_size};
p += name_size;
}
*out_device_infos = device_info_storage;
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_driver_create_device(
iree_hal_driver_t* driver, iree_hal_device_id_t device_id,
iree_allocator_t allocator, iree_hal_device_t** out_device) {
IREE_TRACE_SCOPE0("iree_hal_driver_create_device");
if (!out_device) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_device = nullptr;
auto* handle = reinterpret_cast<Driver*>(driver);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
IREE_API_ASSIGN_OR_RETURN(auto device, handle->CreateDevice(device_id));
*out_device = reinterpret_cast<iree_hal_device_t*>(device.release());
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_hal_driver_create_default_device(iree_hal_driver_t* driver,
iree_allocator_t allocator,
iree_hal_device_t** out_device) {
IREE_TRACE_SCOPE0("iree_hal_driver_create_default_device");
if (!out_device) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_device = nullptr;
auto* handle = reinterpret_cast<Driver*>(driver);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
IREE_API_ASSIGN_OR_RETURN(auto device, handle->CreateDefaultDevice());
*out_device = reinterpret_cast<iree_hal_device_t*>(device.release());
return IREE_STATUS_OK;
}
//===----------------------------------------------------------------------===//
// iree::hal::DriverRegistry
//===----------------------------------------------------------------------===//
IREE_API_EXPORT bool IREE_API_CALL
iree_hal_driver_registry_has_driver(iree_string_view_t driver_name) {
return DriverRegistry::shared_registry()->HasDriver(
absl::string_view{driver_name.data, driver_name.size});
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_hal_driver_registry_query_available_drivers(
iree_allocator_t allocator, iree_string_view_t** out_driver_names,
iree_host_size_t* out_driver_count) {
IREE_TRACE_SCOPE0("iree_hal_driver_registry_query_available_drivers");
if (!out_driver_count) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_driver_count = 0;
if (!out_driver_names) {
return IREE_STATUS_INVALID_ARGUMENT;
}
auto* registry = DriverRegistry::shared_registry();
auto available_drivers = registry->EnumerateAvailableDrivers();
size_t total_string_size = 0;
for (const auto& driver_name : available_drivers) {
total_string_size += driver_name.size();
}
*out_driver_count = available_drivers.size();
iree_string_view_t* driver_name_storage = nullptr;
IREE_RETURN_IF_ERROR(iree_allocator_malloc(
allocator,
available_drivers.size() * sizeof(*driver_name_storage) +
total_string_size,
(void**)&driver_name_storage));
char* p = reinterpret_cast<char*>(driver_name_storage) +
available_drivers.size() * sizeof(*driver_name_storage);
for (int i = 0; i < available_drivers.size(); ++i) {
const auto& driver_name = available_drivers[i];
size_t name_size = driver_name.size();
std::memcpy(p, driver_name.c_str(), name_size);
driver_name_storage[i] = iree_string_view_t{p, name_size};
p += name_size;
}
*out_driver_names = driver_name_storage;
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_hal_driver_registry_create_driver(iree_string_view_t driver_name,
iree_allocator_t allocator,
iree_hal_driver_t** out_driver) {
IREE_TRACE_SCOPE0("iree_hal_driver_registry_create_driver");
if (!out_driver) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_driver = nullptr;
auto* registry = DriverRegistry::shared_registry();
IREE_API_ASSIGN_OR_RETURN(
auto driver,
registry->Create(absl::string_view(driver_name.data, driver_name.size)));
*out_driver = reinterpret_cast<iree_hal_driver_t*>(driver.release());
return IREE_STATUS_OK;
}
//===----------------------------------------------------------------------===//
// iree::hal::Executable
//===----------------------------------------------------------------------===//
IREE_HAL_API_RETAIN_RELEASE(executable, Executable);
//===----------------------------------------------------------------------===//
// iree::hal::ExecutableCache
//===----------------------------------------------------------------------===//
IREE_HAL_API_RETAIN_RELEASE(executable_cache, ExecutableCache);
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_executable_cache_create(
iree_hal_device_t* device, iree_string_view_t identifier,
iree_allocator_t allocator,
iree_hal_executable_cache_t** out_executable_cache) {
IREE_TRACE_SCOPE0("iree_hal_executable_cache_create");
if (!out_executable_cache) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_executable_cache = nullptr;
auto* handle = reinterpret_cast<Device*>(device);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
auto executable_cache = handle->CreateExecutableCache();
*out_executable_cache = reinterpret_cast<iree_hal_executable_cache_t*>(
executable_cache.release());
return IREE_STATUS_OK;
}
IREE_API_EXPORT bool IREE_API_CALL iree_hal_executable_cache_can_prepare_format(
iree_hal_executable_cache_t* executable_cache,
iree_hal_executable_format_t format) {
auto* handle = reinterpret_cast<ExecutableCache*>(executable_cache);
if (!handle) {
return false;
}
return handle->CanPrepareFormat(format);
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_hal_executable_cache_prepare_executable(
iree_hal_executable_cache_t* executable_cache,
iree_hal_executable_layout_t* executable_layout,
iree_hal_executable_caching_mode_t caching_mode,
iree_const_byte_span_t executable_data, iree_allocator_t allocator,
iree_hal_executable_t** out_executable) {
IREE_TRACE_SCOPE0("iree_hal_executable_cache_prepare_executable");
if (!out_executable) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_executable = nullptr;
auto* handle = reinterpret_cast<ExecutableCache*>(executable_cache);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
if (!executable_layout) {
return IREE_STATUS_INVALID_ARGUMENT;
}
ExecutableSpec spec;
spec.executable_data = {executable_data.data, executable_data.data_length};
IREE_API_ASSIGN_OR_RETURN(
auto executable,
handle->PrepareExecutable(
reinterpret_cast<ExecutableLayout*>(executable_layout),
static_cast<ExecutableCachingMode>(caching_mode), spec));
*out_executable =
reinterpret_cast<iree_hal_executable_t*>(executable.release());
return IREE_STATUS_OK;
}
//===----------------------------------------------------------------------===//
// iree::hal::ExecutableLayout
//===----------------------------------------------------------------------===//
IREE_HAL_API_RETAIN_RELEASE(executable_layout, ExecutableLayout);
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_executable_layout_create(
iree_hal_device_t* device, iree_host_size_t set_layout_count,
iree_hal_descriptor_set_layout_t** set_layouts,
iree_host_size_t push_constants, iree_allocator_t allocator,
iree_hal_executable_layout_t** out_executable_layout) {
IREE_TRACE_SCOPE0("iree_hal_executable_layout_create");
if (!out_executable_layout) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_executable_layout = nullptr;
if (set_layout_count && !set_layouts) {
return IREE_STATUS_INVALID_ARGUMENT;
}
auto* handle = reinterpret_cast<Device*>(device);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
IREE_API_ASSIGN_OR_RETURN(
auto executable_layout,
handle->CreateExecutableLayout(
absl::MakeConstSpan(
reinterpret_cast<DescriptorSetLayout* const*>(set_layouts),
set_layout_count),
push_constants));
*out_executable_layout = reinterpret_cast<iree_hal_executable_layout_t*>(
executable_layout.release());
return IREE_STATUS_OK;
}
//===----------------------------------------------------------------------===//
// iree::hal::Semaphore
//===----------------------------------------------------------------------===//
IREE_HAL_API_RETAIN_RELEASE(semaphore, Semaphore);
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_semaphore_create(
iree_hal_device_t* device, uint64_t initial_value,
iree_allocator_t allocator, iree_hal_semaphore_t** out_semaphore) {
IREE_TRACE_SCOPE0("iree_hal_semaphore_create");
if (!out_semaphore) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_semaphore = nullptr;
auto* handle = reinterpret_cast<Device*>(device);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
IREE_API_ASSIGN_OR_RETURN(auto semaphore,
handle->CreateSemaphore(initial_value));
*out_semaphore = reinterpret_cast<iree_hal_semaphore_t*>(semaphore.release());
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_hal_semaphore_query(iree_hal_semaphore_t* semaphore, uint64_t* out_value) {
if (!out_value) return IREE_STATUS_INVALID_ARGUMENT;
*out_value = 0;
auto* handle = reinterpret_cast<Semaphore*>(semaphore);
if (!handle) return IREE_STATUS_INVALID_ARGUMENT;
auto result = handle->Query();
if (!result.ok()) return ToApiStatus(std::move(result).status());
*out_value = result.value();
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_hal_semaphore_signal(iree_hal_semaphore_t* semaphore, uint64_t new_value) {
IREE_TRACE_SCOPE0("iree_hal_semaphore_signal");
auto* handle = reinterpret_cast<Semaphore*>(semaphore);
if (!handle) return IREE_STATUS_INVALID_ARGUMENT;
return ToApiStatus(handle->Signal(new_value));
}
IREE_API_EXPORT void IREE_API_CALL
iree_hal_semaphore_fail(iree_hal_semaphore_t* semaphore, iree_status_t status) {
IREE_TRACE_SCOPE0("iree_hal_semaphore_fail");
auto* handle = reinterpret_cast<Semaphore*>(semaphore);
if (!handle) return;
handle->Fail(FromApiStatus(status, IREE_LOC));
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_hal_semaphore_wait_with_deadline(iree_hal_semaphore_t* semaphore,
uint64_t value, iree_time_t deadline_ns) {
IREE_TRACE_SCOPE0("iree_hal_semaphore_wait_with_deadline");
auto* handle = reinterpret_cast<Semaphore*>(semaphore);
if (!handle) return IREE_STATUS_INVALID_ARGUMENT;
return ToApiStatus(handle->Wait(value, ToAbslTime(deadline_ns)));
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_hal_semaphore_wait_with_timeout(iree_hal_semaphore_t* semaphore,
uint64_t value,
iree_duration_t timeout_ns) {
IREE_TRACE_SCOPE0("iree_hal_semaphore_wait_with_timeout");
auto* handle = reinterpret_cast<Semaphore*>(semaphore);
if (!handle) return IREE_STATUS_INVALID_ARGUMENT;
return ToApiStatus(handle->Wait(value, ToAbslDuration(timeout_ns)));
}
} // namespace hal
} // namespace iree