blob: 460a61c32031e8202826a6482697c69031f424d0 [file] [log] [blame]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Converts one or more parameter files into a single IREE Parameter Archive.
// Allows for stripping and renaming parameters as basic editing features.
#include <ctype.h>
#include <stdio.h>
#include "iree/base/api.h"
#include "iree/base/internal/file_io.h"
#include "iree/base/internal/flags.h"
#include "iree/hal/api.h"
#include "iree/io/formats/irpa/irpa_builder.h"
#include "iree/io/parameter_index.h"
#include "iree/io/scope_map.h"
#include "iree/tooling/parameter_util.h"
//===----------------------------------------------------------------------===//
// Parameter index logic
//===----------------------------------------------------------------------===//
IREE_FLAG_LIST(string, exclude,
"Excludes a named parameter from the resulting file.");
IREE_FLAG_LIST(string, rename,
"Renames a parameter when adding to the resulting file in the "
"form of `--rename=old=new`.");
IREE_FLAG(bool, strip, false,
"Strips all parameters by replacing them with zeros.");
IREE_FLAG_LIST(
string, splat,
"Turns a parameter into a splat of 0 (`--splat=name`) or a specific\n"
"sequence of typed values (`--splat=name=i8=123`, `--splat=name=f32=4.5`,\n"
"`--splat=name=x32=CAFEF00D`).");
static bool iree_tooling_is_parameter_excluded(iree_string_view_t name) {
for (iree_host_size_t i = 0; i < FLAG_exclude_list().count; ++i) {
if (iree_string_view_equal(FLAG_exclude_list().values[i], name)) {
return true;
}
}
return false;
}
static iree_string_view_t iree_tooling_get_renamed_parameter_name(
iree_string_view_t name) {
for (iree_host_size_t i = 0; i < FLAG_rename_list().count; ++i) {
iree_string_view_t old_name, new_name;
iree_string_view_split(FLAG_rename_list().values[i], '=', &old_name,
&new_name);
if (iree_string_view_equal(old_name, name)) {
return new_name;
}
}
return name;
}
// Expects `type=value` consistent with the HAL.
static iree_status_t iree_tooling_parse_splat(iree_string_view_t splat_value,
uint8_t* out_pattern_length,
uint8_t* out_pattern) {
if (iree_string_view_is_empty(splat_value)) {
*out_pattern_length = 1;
out_pattern[0] = 0;
return iree_ok_status();
}
iree_string_view_t type_str, value_str;
iree_string_view_split(splat_value, '=', &type_str, &value_str);
iree_hal_element_type_t type = IREE_HAL_ELEMENT_TYPE_NONE;
IREE_RETURN_IF_ERROR(iree_hal_parse_element_type(type_str, &type));
iree_device_size_t byte_count = iree_hal_element_dense_byte_count(type);
if (byte_count > 16) {
return iree_make_status(
IREE_STATUS_OUT_OF_RANGE,
"element type size for %.*s out of range of splat patterns",
(int)type_str.size, type_str.data);
}
*out_pattern_length = (uint8_t)byte_count;
return iree_hal_parse_element(value_str, type,
iree_make_byte_span(out_pattern, 16));
}
static iree_status_t iree_tooling_replace_splatted_parameter(
iree_io_parameter_index_entry_t* entry) {
// Always favor specific splat values.
for (iree_host_size_t i = 0; i < FLAG_splat_list().count; ++i) {
iree_string_view_t name, splat_value;
iree_string_view_split(FLAG_splat_list().values[i], '=', &name,
&splat_value);
if (iree_string_view_equal(name, entry->key)) {
entry->type = IREE_IO_PARAMETER_INDEX_ENTRY_STORAGE_TYPE_SPLAT;
memset(&entry->storage, 0, sizeof(entry->storage));
return iree_tooling_parse_splat(splat_value,
&entry->storage.splat.pattern_length,
entry->storage.splat.pattern);
}
}
// If not specifically splatted then see if we are stripping and use that.
if (FLAG_strip) {
entry->type = IREE_IO_PARAMETER_INDEX_ENTRY_STORAGE_TYPE_SPLAT;
memset(&entry->storage, 0, sizeof(entry->storage));
entry->storage.splat.pattern_length = 1;
entry->storage.splat.pattern[0] = 0;
return iree_ok_status();
}
return iree_ok_status();
}
static iree_status_t iree_tooling_convert_parameter_index(
iree_io_parameter_index_t* source_index,
iree_io_parameter_index_t* target_index) {
for (iree_host_size_t i = 0; i < iree_io_parameter_index_count(source_index);
++i) {
// Get the existing entry we'll use as a template.
const iree_io_parameter_index_entry_t* source_entry = NULL;
IREE_RETURN_IF_ERROR(
iree_io_parameter_index_get(source_index, i, &source_entry));
iree_io_parameter_index_entry_t target_entry = *source_entry;
// If the parameter is in the exclude list then we just skip it.
if (iree_tooling_is_parameter_excluded(source_entry->key)) continue;
// If the parameter is in the rename list we'll add it with the new name.
target_entry.key =
iree_tooling_get_renamed_parameter_name(source_entry->key);
// If the parameter is turned into a splat we change its type. Note that it
// may have already been a splat but the user may want to change the value.
IREE_RETURN_IF_ERROR(
iree_tooling_replace_splatted_parameter(&target_entry));
// Add the entry (potentially modified) to the new index.
IREE_RETURN_IF_ERROR(
iree_io_parameter_index_add(target_index, &target_entry));
}
return iree_ok_status();
}
static iree_status_t iree_tooling_convert_parameters(
iree_io_scope_map_t* scope_map, iree_io_parameter_index_t* target_index,
iree_allocator_t host_allocator) {
for (iree_host_size_t i = 0; i < scope_map->count; ++i) {
IREE_RETURN_IF_ERROR(iree_tooling_convert_parameter_index(
scope_map->entries[i]->index, target_index));
}
return iree_ok_status();
}
//===----------------------------------------------------------------------===//
// main
//===----------------------------------------------------------------------===//
IREE_FLAG(bool, quiet, false,
"Silences additional stdout output when not needed.");
IREE_FLAG(string, output, "", "Output .irpa file path.");
static void iree_io_file_handle_release_mapping(
void* user_data, iree_io_file_handle_primitive_t handle_primitive) {
iree_file_contents_free((iree_file_contents_t*)user_data);
}
typedef struct {
iree_allocator_t host_allocator;
const char* path;
} iree_tooling_open_params_t;
static iree_status_t iree_tooling_open_output_parameter_file(
void* user_data, iree_io_physical_offset_t archive_offset,
iree_io_physical_size_t archive_length,
iree_io_file_handle_t** out_file_handle) {
iree_tooling_open_params_t* params = (iree_tooling_open_params_t*)user_data;
iree_file_contents_t* file_contents = NULL;
IREE_RETURN_IF_ERROR(
iree_file_create_mapped(params->path, archive_offset + archive_length,
archive_offset, (iree_host_size_t)archive_length,
params->host_allocator, &file_contents));
iree_io_file_handle_release_callback_t release_callback = {
.fn = iree_io_file_handle_release_mapping,
.user_data = file_contents,
};
iree_status_t status = iree_io_file_handle_wrap_host_allocation(
IREE_IO_FILE_ACCESS_WRITE, file_contents->buffer, release_callback,
params->host_allocator, out_file_handle);
if (!iree_status_is_ok(status)) {
iree_file_contents_free(file_contents);
}
return status;
}
int main(int argc, char** argv) {
IREE_TRACE_APP_ENTER();
IREE_TRACE_ZONE_BEGIN(z0);
iree_allocator_t host_allocator = iree_allocator_system();
int exit_code = EXIT_SUCCESS;
// Parse command line flags.
iree_flags_set_usage(
"iree-convert-parameters",
"Converts supported parameter file formats into IREE Parameter Archives\n"
"(.irpa) files. Provide one or more input parameter files in the same\n"
"form as expected by the iree-run-module tool (`--parameters=foo.gguf`)\n"
"and an output file with `--output=file.irpa`.\n"
"\n"
"Example converting from safetensors to IRPA:\n"
" iree-convert-parameters \\\n"
" --parameters=input.safetensors \\\n"
" --output=output.irpa\n"
"\n"
"Example mutating parameters:\n"
" iree-convert-parameters \\\n"
" --parameters=a.gguf \\\n"
" --parameters=b.safetensors \\\n"
" --exclude=unneeded_param \\\n"
" --rename=old_name=new_name \\\n"
" --splat=some_name=f32=4.2 \\\n"
" --output=ab.irpa\n"
"\n"
"Example stripping all parameters and replacing them with zeros except\n"
"for one that needs special handling:\n"
" iree-convert-parameters \\\n"
" --parameters=input.irpa \\\n"
" --strip \\\n"
" --splat=special_param=f32=1.0 \\\n"
" --output=output.irpa\n");
iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv);
// Load parameter indices as specified by command line flags.
iree_io_scope_map_t scope_map = {0};
iree_io_scope_map_initialize(host_allocator, &scope_map);
iree_status_t status =
iree_tooling_build_parameter_indices_from_flags(&scope_map);
// Build the new combined/modified index in memory based on the inputs.
iree_io_parameter_index_t* new_index = NULL;
if (iree_status_is_ok(status)) {
status = iree_io_parameter_index_create(host_allocator, &new_index);
}
if (iree_status_is_ok(status)) {
status =
iree_tooling_convert_parameters(&scope_map, new_index, host_allocator);
}
iree_io_scope_map_deinitialize(&scope_map);
iree_io_parameter_index_t* built_index = NULL;
if (iree_status_is_ok(status)) {
status = iree_io_parameter_index_create(host_allocator, &built_index);
}
// Write out the new archive.
if (iree_status_is_ok(status)) {
iree_tooling_open_params_t open_params = {
.host_allocator = host_allocator,
.path = FLAG_output,
};
iree_io_parameter_archive_file_open_callback_t open_callback = {
.fn = iree_tooling_open_output_parameter_file,
.user_data = &open_params,
};
status = iree_io_build_parameter_archive(
new_index, built_index, open_callback,
/*target_file_offset=*/0, host_allocator);
}
// Dump the new index ala iree-dump-parameters to show the final file.
if (iree_status_is_ok(status) && !FLAG_quiet) {
status = iree_io_parameter_index_fprint(stdout, iree_string_view_empty(),
built_index);
}
iree_io_parameter_index_release(built_index);
iree_io_parameter_index_release(new_index);
fflush(stdout);
if (!iree_status_is_ok(status)) {
iree_status_fprint(stderr, status);
iree_status_free(status);
exit_code = EXIT_FAILURE;
}
fflush(stderr);
IREE_TRACE_ZONE_END(z0);
IREE_TRACE_APP_EXIT(exit_code);
return exit_code;
}