Ben Vanik | fce839f | 2023-11-28 17:12:58 -0800 | [diff] [blame] | 1 | // Copyright 2023 The IREE Authors |
| 2 | // |
| 3 | // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | |
| 7 | // Converts one or more parameter files into a single IREE Parameter Archive. |
| 8 | // Allows for stripping and renaming parameters as basic editing features. |
| 9 | |
| 10 | #include <ctype.h> |
| 11 | #include <stdio.h> |
| 12 | |
| 13 | #include "iree/base/api.h" |
| 14 | #include "iree/base/internal/file_io.h" |
| 15 | #include "iree/base/internal/flags.h" |
| 16 | #include "iree/hal/api.h" |
| 17 | #include "iree/io/formats/irpa/irpa_builder.h" |
| 18 | #include "iree/io/parameter_index.h" |
| 19 | #include "iree/io/scope_map.h" |
| 20 | #include "iree/tooling/parameter_util.h" |
| 21 | |
| 22 | //===----------------------------------------------------------------------===// |
| 23 | // Parameter index logic |
| 24 | //===----------------------------------------------------------------------===// |
| 25 | |
| 26 | IREE_FLAG_LIST(string, exclude, |
| 27 | "Excludes a named parameter from the resulting file."); |
| 28 | IREE_FLAG_LIST(string, rename, |
| 29 | "Renames a parameter when adding to the resulting file in the " |
| 30 | "form of `--rename=old=new`."); |
| 31 | IREE_FLAG(bool, strip, false, |
| 32 | "Strips all parameters by replacing them with zeros."); |
| 33 | IREE_FLAG_LIST( |
| 34 | string, splat, |
| 35 | "Turns a parameter into a splat of 0 (`--splat=name`) or a specific\n" |
| 36 | "sequence of typed values (`--splat=name=i8=123`, `--splat=name=f32=4.5`,\n" |
| 37 | "`--splat=name=x32=CAFEF00D`)."); |
| 38 | |
| 39 | static bool iree_tooling_is_parameter_excluded(iree_string_view_t name) { |
| 40 | for (iree_host_size_t i = 0; i < FLAG_exclude_list().count; ++i) { |
| 41 | if (iree_string_view_equal(FLAG_exclude_list().values[i], name)) { |
| 42 | return true; |
| 43 | } |
| 44 | } |
| 45 | return false; |
| 46 | } |
| 47 | |
| 48 | static iree_string_view_t iree_tooling_get_renamed_parameter_name( |
| 49 | iree_string_view_t name) { |
| 50 | for (iree_host_size_t i = 0; i < FLAG_rename_list().count; ++i) { |
| 51 | iree_string_view_t old_name, new_name; |
| 52 | iree_string_view_split(FLAG_rename_list().values[i], '=', &old_name, |
| 53 | &new_name); |
| 54 | if (iree_string_view_equal(old_name, name)) { |
| 55 | return new_name; |
| 56 | } |
| 57 | } |
| 58 | return name; |
| 59 | } |
| 60 | |
| 61 | // Expects `type=value` consistent with the HAL. |
| 62 | static iree_status_t iree_tooling_parse_splat(iree_string_view_t splat_value, |
| 63 | uint8_t* out_pattern_length, |
| 64 | uint8_t* out_pattern) { |
| 65 | if (iree_string_view_is_empty(splat_value)) { |
| 66 | *out_pattern_length = 1; |
| 67 | out_pattern[0] = 0; |
| 68 | return iree_ok_status(); |
| 69 | } |
| 70 | |
| 71 | iree_string_view_t type_str, value_str; |
| 72 | iree_string_view_split(splat_value, '=', &type_str, &value_str); |
| 73 | |
| 74 | iree_hal_element_type_t type = IREE_HAL_ELEMENT_TYPE_NONE; |
| 75 | IREE_RETURN_IF_ERROR(iree_hal_parse_element_type(type_str, &type)); |
| 76 | |
| 77 | iree_device_size_t byte_count = iree_hal_element_dense_byte_count(type); |
| 78 | if (byte_count > 16) { |
| 79 | return iree_make_status( |
| 80 | IREE_STATUS_OUT_OF_RANGE, |
| 81 | "element type size for %.*s out of range of splat patterns", |
| 82 | (int)type_str.size, type_str.data); |
| 83 | } |
| 84 | *out_pattern_length = (uint8_t)byte_count; |
| 85 | |
| 86 | return iree_hal_parse_element(value_str, type, |
| 87 | iree_make_byte_span(out_pattern, 16)); |
| 88 | } |
| 89 | |
| 90 | static iree_status_t iree_tooling_replace_splatted_parameter( |
| 91 | iree_io_parameter_index_entry_t* entry) { |
| 92 | // Always favor specific splat values. |
| 93 | for (iree_host_size_t i = 0; i < FLAG_splat_list().count; ++i) { |
| 94 | iree_string_view_t name, splat_value; |
| 95 | iree_string_view_split(FLAG_splat_list().values[i], '=', &name, |
| 96 | &splat_value); |
| 97 | if (iree_string_view_equal(name, entry->key)) { |
| 98 | entry->type = IREE_IO_PARAMETER_INDEX_ENTRY_STORAGE_TYPE_SPLAT; |
| 99 | memset(&entry->storage, 0, sizeof(entry->storage)); |
| 100 | return iree_tooling_parse_splat(splat_value, |
| 101 | &entry->storage.splat.pattern_length, |
| 102 | entry->storage.splat.pattern); |
| 103 | } |
| 104 | } |
| 105 | |
| 106 | // If not specifically splatted then see if we are stripping and use that. |
| 107 | if (FLAG_strip) { |
| 108 | entry->type = IREE_IO_PARAMETER_INDEX_ENTRY_STORAGE_TYPE_SPLAT; |
| 109 | memset(&entry->storage, 0, sizeof(entry->storage)); |
| 110 | entry->storage.splat.pattern_length = 1; |
| 111 | entry->storage.splat.pattern[0] = 0; |
| 112 | return iree_ok_status(); |
| 113 | } |
| 114 | |
| 115 | return iree_ok_status(); |
| 116 | } |
| 117 | |
| 118 | static iree_status_t iree_tooling_convert_parameter_index( |
| 119 | iree_io_parameter_index_t* source_index, |
| 120 | iree_io_parameter_index_t* target_index) { |
| 121 | for (iree_host_size_t i = 0; i < iree_io_parameter_index_count(source_index); |
| 122 | ++i) { |
| 123 | // Get the existing entry we'll use as a template. |
| 124 | const iree_io_parameter_index_entry_t* source_entry = NULL; |
| 125 | IREE_RETURN_IF_ERROR( |
| 126 | iree_io_parameter_index_get(source_index, i, &source_entry)); |
| 127 | iree_io_parameter_index_entry_t target_entry = *source_entry; |
| 128 | |
| 129 | // If the parameter is in the exclude list then we just skip it. |
| 130 | if (iree_tooling_is_parameter_excluded(source_entry->key)) continue; |
| 131 | |
| 132 | // If the parameter is in the rename list we'll add it with the new name. |
| 133 | target_entry.key = |
| 134 | iree_tooling_get_renamed_parameter_name(source_entry->key); |
| 135 | |
| 136 | // If the parameter is turned into a splat we change its type. Note that it |
| 137 | // may have already been a splat but the user may want to change the value. |
| 138 | IREE_RETURN_IF_ERROR( |
| 139 | iree_tooling_replace_splatted_parameter(&target_entry)); |
| 140 | |
| 141 | // Add the entry (potentially modified) to the new index. |
| 142 | IREE_RETURN_IF_ERROR( |
| 143 | iree_io_parameter_index_add(target_index, &target_entry)); |
| 144 | } |
| 145 | return iree_ok_status(); |
| 146 | } |
| 147 | |
| 148 | static iree_status_t iree_tooling_convert_parameters( |
| 149 | iree_io_scope_map_t* scope_map, iree_io_parameter_index_t* target_index, |
| 150 | iree_allocator_t host_allocator) { |
| 151 | for (iree_host_size_t i = 0; i < scope_map->count; ++i) { |
| 152 | IREE_RETURN_IF_ERROR(iree_tooling_convert_parameter_index( |
| 153 | scope_map->entries[i]->index, target_index)); |
| 154 | } |
| 155 | return iree_ok_status(); |
| 156 | } |
| 157 | |
| 158 | //===----------------------------------------------------------------------===// |
| 159 | // main |
| 160 | //===----------------------------------------------------------------------===// |
| 161 | |
| 162 | IREE_FLAG(bool, quiet, false, |
| 163 | "Silences additional stdout output when not needed."); |
| 164 | |
| 165 | IREE_FLAG(string, output, "", "Output .irpa file path."); |
| 166 | |
| 167 | static void iree_io_file_handle_release_mapping( |
| 168 | void* user_data, iree_io_file_handle_primitive_t handle_primitive) { |
| 169 | iree_file_contents_free((iree_file_contents_t*)user_data); |
| 170 | } |
| 171 | |
| 172 | typedef struct { |
| 173 | iree_allocator_t host_allocator; |
| 174 | const char* path; |
| 175 | } iree_tooling_open_params_t; |
| 176 | static iree_status_t iree_tooling_open_output_parameter_file( |
| 177 | void* user_data, iree_io_physical_offset_t archive_offset, |
| 178 | iree_io_physical_size_t archive_length, |
| 179 | iree_io_file_handle_t** out_file_handle) { |
| 180 | iree_tooling_open_params_t* params = (iree_tooling_open_params_t*)user_data; |
| 181 | iree_file_contents_t* file_contents = NULL; |
| 182 | IREE_RETURN_IF_ERROR( |
| 183 | iree_file_create_mapped(params->path, archive_offset + archive_length, |
| 184 | archive_offset, (iree_host_size_t)archive_length, |
| 185 | params->host_allocator, &file_contents)); |
| 186 | iree_io_file_handle_release_callback_t release_callback = { |
| 187 | .fn = iree_io_file_handle_release_mapping, |
| 188 | .user_data = file_contents, |
| 189 | }; |
| 190 | iree_status_t status = iree_io_file_handle_wrap_host_allocation( |
| 191 | IREE_IO_FILE_ACCESS_WRITE, file_contents->buffer, release_callback, |
| 192 | params->host_allocator, out_file_handle); |
| 193 | if (!iree_status_is_ok(status)) { |
| 194 | iree_file_contents_free(file_contents); |
| 195 | } |
| 196 | return status; |
| 197 | } |
| 198 | |
| 199 | int main(int argc, char** argv) { |
Ben Vanik | 23f2828 | 2024-02-23 11:14:25 -0800 | [diff] [blame] | 200 | IREE_TRACE_APP_ENTER(); |
Ben Vanik | fce839f | 2023-11-28 17:12:58 -0800 | [diff] [blame] | 201 | IREE_TRACE_ZONE_BEGIN(z0); |
| 202 | |
| 203 | iree_allocator_t host_allocator = iree_allocator_system(); |
| 204 | int exit_code = EXIT_SUCCESS; |
| 205 | |
| 206 | // Parse command line flags. |
| 207 | iree_flags_set_usage( |
| 208 | "iree-convert-parameters", |
| 209 | "Converts supported parameter file formats into IREE Parameter Archives\n" |
| 210 | "(.irpa) files. Provide one or more input parameter files in the same\n" |
| 211 | "form as expected by the iree-run-module tool (`--parameters=foo.gguf`)\n" |
| 212 | "and an output file with `--output=file.irpa`.\n" |
| 213 | "\n" |
| 214 | "Example converting from safetensors to IRPA:\n" |
| 215 | " iree-convert-parameters \\\n" |
| 216 | " --parameters=input.safetensors \\\n" |
| 217 | " --output=output.irpa\n" |
| 218 | "\n" |
| 219 | "Example mutating parameters:\n" |
| 220 | " iree-convert-parameters \\\n" |
| 221 | " --parameters=a.gguf \\\n" |
| 222 | " --parameters=b.safetensors \\\n" |
| 223 | " --exclude=unneeded_param \\\n" |
| 224 | " --rename=old_name=new_name \\\n" |
| 225 | " --splat=some_name=f32=4.2 \\\n" |
| 226 | " --output=ab.irpa\n" |
| 227 | "\n" |
| 228 | "Example stripping all parameters and replacing them with zeros except\n" |
| 229 | "for one that needs special handling:\n" |
| 230 | " iree-convert-parameters \\\n" |
| 231 | " --parameters=input.irpa \\\n" |
| 232 | " --strip \\\n" |
| 233 | " --splat=special_param=f32=1.0 \\\n" |
| 234 | " --output=output.irpa\n"); |
| 235 | iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv); |
| 236 | |
| 237 | // Load parameter indices as specified by command line flags. |
| 238 | iree_io_scope_map_t scope_map = {0}; |
| 239 | iree_io_scope_map_initialize(host_allocator, &scope_map); |
| 240 | iree_status_t status = |
| 241 | iree_tooling_build_parameter_indices_from_flags(&scope_map); |
| 242 | |
| 243 | // Build the new combined/modified index in memory based on the inputs. |
| 244 | iree_io_parameter_index_t* new_index = NULL; |
| 245 | if (iree_status_is_ok(status)) { |
| 246 | status = iree_io_parameter_index_create(host_allocator, &new_index); |
| 247 | } |
| 248 | if (iree_status_is_ok(status)) { |
| 249 | status = |
| 250 | iree_tooling_convert_parameters(&scope_map, new_index, host_allocator); |
| 251 | } |
| 252 | iree_io_scope_map_deinitialize(&scope_map); |
| 253 | |
| 254 | iree_io_parameter_index_t* built_index = NULL; |
| 255 | if (iree_status_is_ok(status)) { |
| 256 | status = iree_io_parameter_index_create(host_allocator, &built_index); |
| 257 | } |
| 258 | |
| 259 | // Write out the new archive. |
| 260 | if (iree_status_is_ok(status)) { |
| 261 | iree_tooling_open_params_t open_params = { |
| 262 | .host_allocator = host_allocator, |
| 263 | .path = FLAG_output, |
| 264 | }; |
| 265 | iree_io_parameter_archive_file_open_callback_t open_callback = { |
| 266 | .fn = iree_tooling_open_output_parameter_file, |
| 267 | .user_data = &open_params, |
| 268 | }; |
| 269 | status = iree_io_build_parameter_archive( |
| 270 | new_index, built_index, open_callback, |
| 271 | /*target_file_offset=*/0, host_allocator); |
| 272 | } |
| 273 | |
| 274 | // Dump the new index ala iree-dump-parameters to show the final file. |
| 275 | if (iree_status_is_ok(status) && !FLAG_quiet) { |
| 276 | status = iree_io_parameter_index_fprint(stdout, iree_string_view_empty(), |
| 277 | built_index); |
| 278 | } |
| 279 | |
| 280 | iree_io_parameter_index_release(built_index); |
| 281 | iree_io_parameter_index_release(new_index); |
| 282 | |
| 283 | fflush(stdout); |
| 284 | if (!iree_status_is_ok(status)) { |
| 285 | iree_status_fprint(stderr, status); |
| 286 | iree_status_free(status); |
| 287 | exit_code = EXIT_FAILURE; |
| 288 | } |
| 289 | fflush(stderr); |
| 290 | |
| 291 | IREE_TRACE_ZONE_END(z0); |
Ben Vanik | 23f2828 | 2024-02-23 11:14:25 -0800 | [diff] [blame] | 292 | IREE_TRACE_APP_EXIT(exit_code); |
Ben Vanik | fce839f | 2023-11-28 17:12:58 -0800 | [diff] [blame] | 293 | return exit_code; |
| 294 | } |