| // Copyright 2020 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/tooling/vm_util.h" |
| |
| #include <errno.h> |
| #include <stdint.h> |
| #include <stdio.h> |
| |
| #include "iree/base/api.h" |
| #include "iree/base/internal/file_io.h" |
| #include "iree/hal/api.h" |
| #include "iree/modules/hal/module.h" |
| #include "iree/tooling/numpy_io.h" |
| |
| static iree_status_t iree_allocate_and_copy_cstring_from_view( |
| iree_allocator_t allocator, iree_string_view_t view, char** cstring) { |
| IREE_RETURN_IF_ERROR( |
| iree_allocator_malloc(allocator, view.size + 1, (void**)cstring)); |
| memcpy(*cstring, view.data, view.size); |
| (*cstring)[view.size] = 0; |
| return iree_ok_status(); |
| } |
| |
| static iree_status_t iree_tooling_load_ndarrays_from_file( |
| iree_string_view_t file_path, iree_hal_device_t* device, |
| iree_hal_allocator_t* device_allocator, iree_vm_list_t* list) { |
| char* file_path_cstring = NULL; |
| IREE_RETURN_IF_ERROR(iree_allocate_and_copy_cstring_from_view( |
| iree_allocator_system(), file_path, &file_path_cstring)); |
| FILE* file = fopen(file_path_cstring, "rb"); |
| iree_allocator_free(iree_allocator_system(), file_path_cstring); |
| if (!file) { |
| return iree_make_status(iree_status_code_from_errno(errno), |
| "failed to open file '%.*s'", (int)file_path.size, |
| file_path.data); |
| } |
| |
| uint64_t file_length = 0; |
| iree_status_t status = iree_file_query_length(file, &file_length); |
| |
| iree_hal_buffer_params_t buffer_params = {0}; |
| buffer_params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT; |
| buffer_params.access = IREE_HAL_MEMORY_ACCESS_READ; |
| buffer_params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL; |
| |
| while (iree_status_is_ok(status) && !iree_file_is_at(file, file_length)) { |
| iree_hal_buffer_view_t* buffer_view = NULL; |
| status = iree_numpy_npy_load_ndarray( |
| file, IREE_NUMPY_NPY_LOAD_OPTION_DEFAULT, buffer_params, device, |
| device_allocator, &buffer_view); |
| if (iree_status_is_ok(status)) { |
| iree_vm_ref_t buffer_view_ref = |
| iree_hal_buffer_view_retain_ref(buffer_view); |
| status = iree_vm_list_push_ref_move(list, &buffer_view_ref); |
| } |
| iree_hal_buffer_view_release(buffer_view); |
| } |
| |
| fclose(file); |
| return status; |
| } |
| |
| struct iree_create_buffer_from_file_generator_user_data_t { |
| FILE* file; |
| }; |
| |
| static iree_status_t iree_create_buffer_from_file_generator_callback( |
| iree_hal_buffer_mapping_t* mapping, void* user_data) { |
| struct iree_create_buffer_from_file_generator_user_data_t* read_params = |
| user_data; |
| size_t bytes_read = fread(mapping->contents.data, 1, |
| mapping->contents.data_length, read_params->file); |
| if (bytes_read != mapping->contents.data_length) { |
| return iree_make_status(IREE_STATUS_OUT_OF_RANGE, |
| "file contents truncated; expected %" PRIhsz |
| " bytes " |
| "based on buffer view size", |
| mapping->contents.data_length); |
| } |
| return iree_ok_status(); |
| } |
| |
| // Creates a HAL buffer view with the given |metadata| and reads the contents |
| // from the file at |file_path|. |
| // |
| // The file contents are directly read in to memory with no processing. |
| static iree_status_t iree_create_buffer_view_from_file( |
| iree_string_view_t metadata, iree_string_view_t file_path, |
| iree_hal_device_t* device, iree_hal_allocator_t* device_allocator, |
| iree_hal_buffer_view_t** out_buffer_view) { |
| *out_buffer_view = NULL; |
| |
| // Parse shape and element type used to allocate the buffer view. |
| iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE; |
| iree_host_size_t shape_rank = 0; |
| iree_status_t shape_result = iree_hal_parse_shape_and_element_type( |
| metadata, 0, &shape_rank, NULL, &element_type); |
| if (!iree_status_is_ok(shape_result) && |
| !iree_status_is_out_of_range(shape_result)) { |
| return shape_result; |
| } else if (shape_rank > 128) { |
| return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, |
| "a shape rank of %" PRIhsz |
| " is just a little bit excessive, eh?", |
| shape_rank); |
| } |
| iree_status_ignore(shape_result); |
| iree_hal_dim_t* shape = |
| (iree_hal_dim_t*)iree_alloca(shape_rank * sizeof(iree_hal_dim_t)); |
| IREE_RETURN_IF_ERROR(iree_hal_parse_shape_and_element_type( |
| metadata, shape_rank, &shape_rank, shape, &element_type)); |
| |
| // TODO(benvanik): allow specifying the encoding. |
| iree_hal_encoding_type_t encoding_type = |
| IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR; |
| |
| // Open the file for reading. |
| char* file_path_cstring = NULL; |
| IREE_RETURN_IF_ERROR(iree_allocate_and_copy_cstring_from_view( |
| iree_allocator_system(), file_path, &file_path_cstring)); |
| FILE* file = fopen(file_path_cstring, "rb"); |
| iree_allocator_free(iree_allocator_system(), file_path_cstring); |
| if (!file) { |
| return iree_make_status(iree_status_code_from_errno(errno), |
| "failed to open file '%.*s'", (int)file_path.size, |
| file_path.data); |
| } |
| |
| iree_hal_buffer_params_t buffer_params = {0}; |
| buffer_params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL; |
| buffer_params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT; |
| struct iree_create_buffer_from_file_generator_user_data_t read_params = { |
| file, |
| }; |
| iree_status_t status = iree_hal_buffer_view_generate_buffer( |
| device, device_allocator, shape_rank, shape, element_type, encoding_type, |
| buffer_params, iree_create_buffer_from_file_generator_callback, |
| &read_params, out_buffer_view); |
| |
| fclose(file); |
| |
| return status; |
| } |
| |
| iree_status_t iree_tooling_parse_to_variant_list( |
| iree_hal_device_t* device, iree_hal_allocator_t* device_allocator, |
| const iree_string_view_t* input_strings, |
| iree_host_size_t input_strings_count, iree_allocator_t host_allocator, |
| iree_vm_list_t** out_list) { |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| *out_list = NULL; |
| iree_vm_list_t* list = NULL; |
| |
| IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| z0, iree_vm_list_create(iree_vm_make_undefined_type_def(), |
| input_strings_count, host_allocator, &list)); |
| |
| iree_status_t status = iree_tooling_parse_into_variant_list( |
| device, device_allocator, input_strings, input_strings_count, |
| host_allocator, list); |
| if (iree_status_is_ok(status)) { |
| *out_list = list; |
| } else { |
| iree_vm_list_release(list); |
| } |
| IREE_TRACE_ZONE_END(z0); |
| return status; |
| } |
| |
| iree_status_t iree_tooling_parse_into_variant_list( |
| iree_hal_device_t* device, iree_hal_allocator_t* device_allocator, |
| const iree_string_view_t* input_strings, |
| iree_host_size_t input_strings_count, iree_allocator_t host_allocator, |
| iree_vm_list_t* list) { |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| // Reset the list and prepare for pushing items. |
| iree_vm_list_clear(list); |
| IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| z0, iree_vm_list_reserve(list, input_strings_count)); |
| |
| iree_status_t status = iree_ok_status(); |
| for (size_t i = 0; i < input_strings_count; ++i) { |
| if (!iree_status_is_ok(status)) break; |
| iree_string_view_t input_view = iree_string_view_trim(input_strings[i]); |
| if (iree_string_view_is_empty(input_view)) { |
| status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, |
| "no value specified for input"); |
| break; |
| } else if (iree_string_view_consume_prefix(&input_view, IREE_SV("@"))) { |
| status = iree_tooling_load_ndarrays_from_file(input_view, device, |
| device_allocator, list); |
| continue; |
| } else if (iree_string_view_equal(input_view, IREE_SV("(null)")) || |
| iree_string_view_equal(input_view, IREE_SV("(ignored)"))) { |
| iree_vm_ref_t null_ref = iree_vm_ref_null(); |
| status = iree_vm_list_push_ref_retain(list, &null_ref); |
| continue; |
| } |
| bool has_equal = |
| iree_string_view_find_char(input_view, '=', 0) != IREE_STRING_VIEW_NPOS; |
| bool has_x = |
| iree_string_view_find_char(input_view, 'x', 0) != IREE_STRING_VIEW_NPOS; |
| if (has_equal || has_x) { |
| // Buffer view (either just a shape or a shape=value) or buffer. |
| bool is_storage_reference = iree_string_view_consume_prefix( |
| &input_view, iree_make_cstring_view("&")); |
| iree_hal_buffer_view_t* buffer_view = NULL; |
| bool has_at = iree_string_view_find_char(input_view, '@', 0) != |
| IREE_STRING_VIEW_NPOS; |
| if (has_at) { |
| // Referencing an external file; split into the portion used to |
| // initialize the buffer view and the file contents. |
| iree_string_view_t metadata, file_path; |
| iree_string_view_split(input_view, '@', &metadata, &file_path); |
| iree_string_view_consume_suffix(&metadata, iree_make_cstring_view("=")); |
| status = iree_create_buffer_view_from_file( |
| metadata, file_path, device, device_allocator, &buffer_view); |
| if (!iree_status_is_ok(status)) break; |
| } else { |
| status = iree_hal_buffer_view_parse(input_view, device, |
| device_allocator, &buffer_view); |
| if (!iree_status_is_ok(status)) { |
| status = |
| iree_status_annotate_f(status, "parsing value '%.*s'", |
| (int)input_view.size, input_view.data); |
| break; |
| } |
| } |
| if (is_storage_reference) { |
| // Storage buffer reference; just take the storage for the buffer view - |
| // it'll still have whatever contents were specified (or 0) but we'll |
| // discard the metadata. |
| iree_vm_ref_t buffer_ref = iree_hal_buffer_retain_ref( |
| iree_hal_buffer_view_buffer(buffer_view)); |
| iree_hal_buffer_view_release(buffer_view); |
| status = iree_vm_list_push_ref_move(list, &buffer_ref); |
| if (!iree_status_is_ok(status)) break; |
| } else { |
| iree_vm_ref_t buffer_view_ref = |
| iree_hal_buffer_view_move_ref(buffer_view); |
| status = iree_vm_list_push_ref_move(list, &buffer_view_ref); |
| if (!iree_status_is_ok(status)) break; |
| } |
| } else { |
| // Scalar. |
| bool has_dot = iree_string_view_find_char(input_view, '.', 0) != |
| IREE_STRING_VIEW_NPOS; |
| iree_vm_value_t val; |
| if (has_dot) { |
| // Float. |
| val = iree_vm_value_make_f32(0.0f); |
| if (!iree_string_view_atof(input_view, &val.f32)) { |
| status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, |
| "parsing value '%.*s' as f32", |
| (int)input_view.size, input_view.data); |
| break; |
| } |
| } else { |
| // Integer. |
| val = iree_vm_value_make_i32(0); |
| if (!iree_string_view_atoi_int32(input_view, &val.i32)) { |
| status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, |
| "parsing value '%.*s' as i32", |
| (int)input_view.size, input_view.data); |
| break; |
| } |
| } |
| status = iree_vm_list_push_value(list, &val); |
| if (!iree_status_is_ok(status)) break; |
| } |
| } |
| |
| IREE_TRACE_ZONE_END(z0); |
| return status; |
| } |
| |
| iree_status_t iree_tooling_append_async_fence_inputs( |
| iree_vm_list_t* list, const iree_vm_function_t* function, |
| iree_hal_device_t* device, iree_hal_fence_t* wait_fence, |
| iree_hal_fence_t** out_signal_fence) { |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| iree_string_view_t model = |
| iree_vm_function_lookup_attr_by_name(function, IREE_SV("iree.abi.model")); |
| if (!iree_string_view_equal(model, IREE_SV("coarse-fences"))) { |
| // Ignore unknown models - the user may have provided their own fences. |
| IREE_TRACE_ZONE_END(z0); |
| return iree_ok_status(); |
| } |
| |
| // Create the signal fence as a 0->1 transition. The caller will wait on that. |
| iree_hal_semaphore_t* semaphore = NULL; |
| IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| z0, iree_hal_semaphore_create(device, 0ull, &semaphore)); |
| iree_hal_fence_t* signal_fence = NULL; |
| iree_status_t status = iree_hal_fence_create_at( |
| semaphore, 1ull, iree_hal_device_host_allocator(device), &signal_fence); |
| iree_hal_semaphore_release(semaphore); |
| |
| // Append (wait, signal) fences. |
| if (iree_status_is_ok(status)) { |
| iree_vm_ref_t wait_fence_ref = iree_hal_fence_retain_ref(wait_fence); |
| status = iree_vm_list_push_ref_move(list, &wait_fence_ref); |
| iree_vm_ref_release(&wait_fence_ref); |
| } |
| if (iree_status_is_ok(status)) { |
| iree_vm_ref_t signal_fence_ref = iree_hal_fence_retain_ref(signal_fence); |
| status = iree_vm_list_push_ref_move(list, &signal_fence_ref); |
| iree_vm_ref_release(&signal_fence_ref); |
| } |
| |
| if (iree_status_is_ok(status)) { |
| *out_signal_fence = signal_fence; |
| } else { |
| iree_hal_fence_release(signal_fence); |
| } |
| IREE_TRACE_ZONE_END(z0); |
| return status; |
| } |
| |
| static bool iree_tooling_requires_buffer_transfer( |
| iree_hal_buffer_t* source_buffer, iree_hal_buffer_params_t target_params) { |
| return !iree_all_bits_set(iree_hal_buffer_memory_type(source_buffer), |
| target_params.type) || |
| !iree_all_bits_set(iree_hal_buffer_allowed_usage(source_buffer), |
| target_params.usage); |
| } |
| |
| static iree_status_t iree_tooling_setup_buffer_transfer( |
| iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_t* source_buffer, |
| iree_hal_allocator_t* target_allocator, |
| iree_hal_buffer_params_t target_params, |
| iree_hal_buffer_t** out_target_buffer) { |
| IREE_ASSERT_ARGUMENT(command_buffer); |
| IREE_ASSERT_ARGUMENT(source_buffer); |
| IREE_ASSERT_ARGUMENT(target_allocator); |
| IREE_ASSERT_ARGUMENT(out_target_buffer); |
| *out_target_buffer = NULL; |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| iree_hal_buffer_t* target_buffer = NULL; |
| IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| z0, iree_hal_allocator_allocate_buffer( |
| target_allocator, target_params, |
| iree_hal_buffer_allocation_size(source_buffer), &target_buffer)); |
| |
| iree_status_t status = iree_hal_command_buffer_copy_buffer( |
| command_buffer, source_buffer, 0, target_buffer, 0, |
| iree_hal_buffer_byte_length(source_buffer)); |
| |
| if (iree_status_is_ok(status)) { |
| *out_target_buffer = target_buffer; |
| } else { |
| iree_hal_buffer_release(target_buffer); |
| } |
| IREE_TRACE_ZONE_END(z0); |
| return status; |
| } |
| |
| static iree_status_t iree_tooling_submit_transfer( |
| iree_hal_device_t* device, iree_hal_fence_t* wait_fence, |
| iree_hal_queue_affinity_t queue_affinity, |
| iree_hal_command_buffer_t* command_buffer, iree_hal_fence_t* signal_fence) { |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| iree_status_t status = iree_ok_status(); |
| |
| bool needs_wait = signal_fence == NULL; |
| if (needs_wait) { |
| iree_hal_semaphore_t* semaphore = NULL; |
| IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| z0, iree_hal_semaphore_create(device, 0ull, &semaphore)); |
| status = iree_hal_fence_create_at( |
| semaphore, 1ull, iree_hal_device_host_allocator(device), &signal_fence); |
| iree_hal_semaphore_release(semaphore); |
| } else { |
| iree_hal_fence_retain(signal_fence); |
| } |
| |
| if (iree_status_is_ok(status)) { |
| status = iree_hal_device_queue_execute( |
| device, queue_affinity, iree_hal_fence_semaphore_list(wait_fence), |
| iree_hal_fence_semaphore_list(signal_fence), 1, &command_buffer); |
| } |
| |
| if (iree_status_is_ok(status) && needs_wait) { |
| status = iree_hal_fence_wait(signal_fence, iree_infinite_timeout()); |
| } |
| |
| iree_hal_fence_release(signal_fence); |
| IREE_TRACE_ZONE_END(z0); |
| return status; |
| } |
| |
| iree_status_t iree_tooling_transfer_variant_list( |
| iree_hal_device_t* device, iree_vm_list_t* list, |
| iree_hal_allocator_t* target_allocator, |
| iree_hal_buffer_params_t target_params, iree_hal_fence_t* wait_fence, |
| iree_hal_fence_t* signal_fence) { |
| IREE_ASSERT_ARGUMENT(device); |
| IREE_ASSERT_ARGUMENT(list); |
| IREE_ASSERT_ARGUMENT(target_allocator); |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| // If all buffers are already host-accessible we can skip the transfer. |
| bool requires_transfer = false; |
| for (iree_host_size_t i = 0; i < iree_vm_list_size(list); ++i) { |
| iree_vm_ref_t value = iree_vm_ref_null(); |
| IREE_IGNORE_ERROR(iree_vm_list_get_ref_assign(list, i, &value)); |
| if (iree_hal_buffer_isa(value)) { |
| iree_hal_buffer_t* source_buffer = iree_hal_buffer_deref(value); |
| if (iree_tooling_requires_buffer_transfer(source_buffer, target_params)) { |
| requires_transfer = true; |
| break; |
| } |
| } else if (iree_hal_buffer_view_isa(value)) { |
| iree_hal_buffer_view_t* source_view = iree_hal_buffer_view_deref(value); |
| iree_hal_buffer_t* source_buffer = |
| iree_hal_buffer_view_buffer(source_view); |
| if (iree_tooling_requires_buffer_transfer(source_buffer, target_params)) { |
| requires_transfer = true; |
| break; |
| } |
| } |
| } |
| if (!requires_transfer) { |
| IREE_TRACE_ZONE_END(z0); |
| return iree_ok_status(); |
| } |
| |
| iree_hal_command_buffer_t* command_buffer = NULL; |
| IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| z0, iree_hal_command_buffer_create( |
| device, |
| IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT | |
| IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION, |
| IREE_HAL_COMMAND_CATEGORY_TRANSFER, target_params.queue_affinity, |
| /*binding_capacity=*/0, &command_buffer)); |
| |
| iree_status_t status = iree_hal_command_buffer_begin(command_buffer); |
| if (iree_status_is_ok(status)) { |
| for (iree_host_size_t i = 0; i < iree_vm_list_size(list); ++i) { |
| iree_vm_ref_t value = iree_vm_ref_null(); |
| IREE_IGNORE_ERROR(iree_vm_list_get_ref_assign(list, i, &value)); |
| if (iree_hal_buffer_isa(value)) { |
| iree_hal_buffer_t* source_buffer = iree_hal_buffer_deref(value); |
| if (!iree_tooling_requires_buffer_transfer(source_buffer, |
| target_params)) { |
| // Already ok. |
| continue; |
| } |
| iree_hal_buffer_t* target_buffer = NULL; |
| status = iree_tooling_setup_buffer_transfer( |
| command_buffer, source_buffer, target_allocator, target_params, |
| &target_buffer); |
| if (!iree_status_is_ok(status)) break; |
| status = iree_vm_list_set_buffer_retain(list, i, target_buffer); |
| iree_hal_buffer_release(target_buffer); |
| if (!iree_status_is_ok(status)) break; |
| } else if (iree_hal_buffer_view_isa(value)) { |
| iree_hal_buffer_view_t* source_view = iree_hal_buffer_view_deref(value); |
| iree_hal_buffer_t* source_buffer = |
| iree_hal_buffer_view_buffer(source_view); |
| if (!iree_tooling_requires_buffer_transfer(source_buffer, |
| target_params)) { |
| // Already ok. |
| continue; |
| } |
| iree_hal_buffer_t* target_buffer = NULL; |
| status = iree_tooling_setup_buffer_transfer( |
| command_buffer, source_buffer, target_allocator, target_params, |
| &target_buffer); |
| if (!iree_status_is_ok(status)) break; |
| iree_hal_buffer_view_t* target_view = NULL; |
| status = iree_hal_buffer_view_create_like( |
| target_buffer, source_view, |
| iree_hal_allocator_host_allocator(target_allocator), &target_view); |
| iree_hal_buffer_release(target_buffer); |
| if (!iree_status_is_ok(status)) break; |
| status = iree_vm_list_set_buffer_view_retain(list, i, target_view); |
| iree_hal_buffer_view_release(target_view); |
| if (!iree_status_is_ok(status)) break; |
| } |
| } |
| } |
| if (iree_status_is_ok(status)) { |
| status = iree_hal_command_buffer_end(command_buffer); |
| } |
| |
| if (iree_status_is_ok(status)) { |
| status = iree_tooling_submit_transfer(device, wait_fence, |
| target_params.queue_affinity, |
| command_buffer, signal_fence); |
| } |
| |
| iree_hal_command_buffer_release(command_buffer); |
| |
| IREE_TRACE_ZONE_END(z0); |
| return status; |
| } |
| |
| #define IREE_PRINTVARIANT_CASE_I(SIZE, B, V) \ |
| case IREE_VM_VALUE_TYPE_I##SIZE: \ |
| return iree_string_builder_append_format( \ |
| B, "i" #SIZE "=%" PRIi##SIZE "\n", (V).i##SIZE); |
| |
| #define IREE_PRINTVARIANT_CASE_F(SIZE, B, V) \ |
| case IREE_VM_VALUE_TYPE_F##SIZE: \ |
| return iree_string_builder_append_format(B, "f" #SIZE "=%g\n", (V).f##SIZE); |
| |
| // Prints variant description including a trailing newline. |
| static iree_status_t iree_variant_format(iree_vm_variant_t variant, |
| iree_host_size_t max_element_count, |
| iree_string_builder_t* builder) { |
| if (iree_vm_variant_is_empty(variant)) { |
| return iree_string_builder_append_string(builder, IREE_SV("(null)\n")); |
| } else if (iree_vm_variant_is_value(variant)) { |
| switch (iree_vm_type_def_as_value(variant.type)) { |
| IREE_PRINTVARIANT_CASE_I(8, builder, variant) |
| IREE_PRINTVARIANT_CASE_I(16, builder, variant) |
| IREE_PRINTVARIANT_CASE_I(32, builder, variant) |
| IREE_PRINTVARIANT_CASE_I(64, builder, variant) |
| IREE_PRINTVARIANT_CASE_F(32, builder, variant) |
| IREE_PRINTVARIANT_CASE_F(64, builder, variant) |
| default: |
| return iree_string_builder_append_string(builder, IREE_SV("?\n")); |
| } |
| } else if (iree_vm_variant_is_ref(variant)) { |
| iree_string_view_t type_name = |
| iree_vm_ref_type_name(iree_vm_type_def_as_ref(variant.type)); |
| IREE_RETURN_IF_ERROR(iree_string_builder_append_string(builder, type_name)); |
| IREE_RETURN_IF_ERROR( |
| iree_string_builder_append_string(builder, IREE_SV("\n"))); |
| if (iree_vm_list_isa(variant.ref)) { |
| iree_vm_list_t* child_list = iree_vm_list_deref(variant.ref); |
| IREE_RETURN_IF_ERROR(iree_tooling_append_variant_list_lines( |
| IREE_SV("child_list"), child_list, max_element_count, builder)); |
| return iree_string_builder_append_string(builder, IREE_SV("\n")); |
| } else if (iree_hal_buffer_view_isa(variant.ref)) { |
| iree_hal_buffer_view_t* buffer_view = |
| iree_hal_buffer_view_deref(variant.ref); |
| IREE_RETURN_IF_ERROR(iree_hal_buffer_view_append_to_builder( |
| buffer_view, max_element_count, builder)); |
| return iree_string_builder_append_string(builder, IREE_SV("\n")); |
| } else { |
| // TODO(benvanik): a way for ref types to describe themselves. |
| return iree_string_builder_append_string(builder, |
| IREE_SV("(no printer)\n")); |
| } |
| } else { |
| return iree_string_builder_append_string(builder, IREE_SV("(null)\n")); |
| } |
| return iree_ok_status(); |
| } |
| |
| static iree_status_t iree_variant_fprint(iree_vm_variant_t variant, |
| iree_host_size_t max_element_count, |
| FILE* file) { |
| iree_string_builder_t builder; |
| iree_string_builder_initialize(iree_allocator_system(), &builder); |
| iree_status_t status = |
| iree_variant_format(variant, max_element_count, &builder); |
| if (iree_status_is_ok(status)) { |
| size_t written = fwrite(iree_string_builder_buffer(&builder), 1, |
| iree_string_builder_size(&builder), file); |
| if (written != iree_string_builder_size(&builder)) { |
| status = iree_status_from_code(IREE_STATUS_PERMISSION_DENIED); |
| } |
| fflush(file); |
| } |
| iree_string_builder_deinitialize(&builder); |
| return status; |
| } |
| |
| iree_status_t iree_tooling_append_variant_list_lines( |
| iree_string_view_t list_name, iree_vm_list_t* list, |
| iree_host_size_t max_element_count, iree_string_builder_t* builder) { |
| IREE_TRACE_ZONE_BEGIN(z0); |
| for (iree_host_size_t i = 0; i < iree_vm_list_size(list); ++i) { |
| iree_vm_variant_t variant = iree_vm_variant_empty(); |
| IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| z0, iree_vm_list_get_variant_assign(list, i, &variant), |
| "variant %" PRIhsz " not present", i); |
| iree_string_builder_append_format( |
| builder, "%.*s[%" PRIhsz "]: ", (int)list_name.size, list_name.data, i); |
| IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| z0, iree_variant_format(variant, max_element_count, builder)); |
| } |
| IREE_TRACE_ZONE_END(z0); |
| return iree_ok_status(); |
| } |
| |
| iree_status_t iree_tooling_variant_list_fprint( |
| iree_string_view_t list_name, iree_vm_list_t* list, |
| iree_host_size_t max_element_count, FILE* file) { |
| iree_string_builder_t builder; |
| iree_string_builder_initialize(iree_allocator_system(), &builder); |
| iree_status_t status = iree_tooling_append_variant_list_lines( |
| list_name, list, max_element_count, &builder); |
| if (iree_status_is_ok(status)) { |
| size_t written = fwrite(iree_string_builder_buffer(&builder), 1, |
| iree_string_builder_size(&builder), file); |
| if (written != iree_string_builder_size(&builder)) { |
| status = iree_status_from_code(IREE_STATUS_PERMISSION_DENIED); |
| } |
| fflush(file); |
| } |
| iree_string_builder_deinitialize(&builder); |
| return status; |
| } |
| |
| static iree_status_t iree_tooling_output_variant( |
| iree_vm_variant_t variant, iree_string_view_t output_str, |
| iree_host_size_t max_element_count, FILE* default_file) { |
| if (iree_string_view_is_empty(output_str)) { |
| // Send into the void. |
| return iree_ok_status(); |
| } else if (iree_string_view_equal(output_str, IREE_SV("-"))) { |
| // Route to the provided file. |
| return iree_variant_fprint(variant, max_element_count, default_file); |
| } |
| |
| bool has_at = iree_string_view_consume_prefix(&output_str, IREE_SV("@")); |
| bool has_plus = iree_string_view_consume_prefix(&output_str, IREE_SV("+")); |
| if (!has_at && !has_plus) { |
| // Other types of outputs are not yet supported. We could allow for shapes |
| // and either verify metadata or output binary files ala |
| // `--input=4xf32=@foo.bin`. |
| return iree_make_status(IREE_STATUS_UNIMPLEMENTED, |
| "unsupported output mode specification '%.*s'", |
| (int)output_str.size, output_str.data); |
| } |
| |
| // For now we just send buffer views to npy files as primitive values (like |
| // just a normal int) can't be round-tripped. We could wrap the primitives in |
| // a single-element buffer view if needed. |
| if (!iree_vm_variant_is_ref(variant) || |
| !iree_hal_buffer_view_isa(variant.ref)) { |
| return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, |
| "only buffer views can be written to npy files"); |
| } |
| iree_hal_buffer_view_t* buffer_view = iree_hal_buffer_view_deref(variant.ref); |
| |
| // Open file for either overwriting or appending (npy files can contain |
| // multiple arrays). |
| iree_string_view_t file_path = output_str; |
| char* file_path_cstring = NULL; |
| IREE_RETURN_IF_ERROR(iree_allocate_and_copy_cstring_from_view( |
| iree_allocator_system(), file_path, &file_path_cstring)); |
| const char* mode = has_plus ? "ab" : "wb"; |
| FILE* file = fopen(file_path_cstring, mode); |
| iree_allocator_free(iree_allocator_system(), file_path_cstring); |
| if (!file) { |
| return iree_make_status(iree_status_code_from_errno(errno), |
| "failed to open file '%.*s'", (int)file_path.size, |
| file_path.data); |
| } |
| |
| // Append buffer view contents to the file stream. |
| iree_numpy_npy_save_options_t options = IREE_NUMPY_NPY_SAVE_OPTION_DEFAULT; |
| iree_status_t status = iree_numpy_npy_save_ndarray(file, options, buffer_view, |
| iree_allocator_system()); |
| |
| fclose(file); |
| return status; |
| } |
| |
| iree_status_t iree_tooling_output_variant_list( |
| iree_vm_list_t* list, const iree_string_view_t* output_strings, |
| iree_host_size_t output_strings_count, iree_host_size_t max_element_count, |
| FILE* file) { |
| IREE_ASSERT_ARGUMENT(list); |
| IREE_ASSERT_ARGUMENT(!output_strings_count || output_strings); |
| |
| // We only care if there are not enough outputs to satisfy the user |
| // request. We could force users to specify all outputs to make this a bit |
| // harder to misuse but saving off outputs is a power-user feature. |
| if (iree_vm_list_size(list) != output_strings_count) { |
| return iree_make_status( |
| IREE_STATUS_OUT_OF_RANGE, |
| "%" PRIhsz " outputs specified but the provided list only has %" PRIhsz |
| " elements", |
| output_strings_count, iree_vm_list_size(list)); |
| } |
| |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| for (iree_host_size_t i = 0; i < output_strings_count; ++i) { |
| iree_vm_variant_t variant = iree_vm_variant_empty(); |
| IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| z0, iree_vm_list_get_variant_assign(list, i, &variant)); |
| IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| z0, iree_tooling_output_variant(variant, output_strings[i], |
| max_element_count, file)); |
| } |
| |
| IREE_TRACE_ZONE_END(z0); |
| return iree_ok_status(); |
| } |