blob: 460a61c32031e8202826a6482697c69031f424d0 [file] [log] [blame]
Ben Vanikfce839f2023-11-28 17:12:58 -08001// Copyright 2023 The IREE Authors
2//
3// Licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
7// Converts one or more parameter files into a single IREE Parameter Archive.
8// Allows for stripping and renaming parameters as basic editing features.
9
10#include <ctype.h>
11#include <stdio.h>
12
13#include "iree/base/api.h"
14#include "iree/base/internal/file_io.h"
15#include "iree/base/internal/flags.h"
16#include "iree/hal/api.h"
17#include "iree/io/formats/irpa/irpa_builder.h"
18#include "iree/io/parameter_index.h"
19#include "iree/io/scope_map.h"
20#include "iree/tooling/parameter_util.h"
21
22//===----------------------------------------------------------------------===//
23// Parameter index logic
24//===----------------------------------------------------------------------===//
25
26IREE_FLAG_LIST(string, exclude,
27 "Excludes a named parameter from the resulting file.");
28IREE_FLAG_LIST(string, rename,
29 "Renames a parameter when adding to the resulting file in the "
30 "form of `--rename=old=new`.");
31IREE_FLAG(bool, strip, false,
32 "Strips all parameters by replacing them with zeros.");
33IREE_FLAG_LIST(
34 string, splat,
35 "Turns a parameter into a splat of 0 (`--splat=name`) or a specific\n"
36 "sequence of typed values (`--splat=name=i8=123`, `--splat=name=f32=4.5`,\n"
37 "`--splat=name=x32=CAFEF00D`).");
38
39static bool iree_tooling_is_parameter_excluded(iree_string_view_t name) {
40 for (iree_host_size_t i = 0; i < FLAG_exclude_list().count; ++i) {
41 if (iree_string_view_equal(FLAG_exclude_list().values[i], name)) {
42 return true;
43 }
44 }
45 return false;
46}
47
48static iree_string_view_t iree_tooling_get_renamed_parameter_name(
49 iree_string_view_t name) {
50 for (iree_host_size_t i = 0; i < FLAG_rename_list().count; ++i) {
51 iree_string_view_t old_name, new_name;
52 iree_string_view_split(FLAG_rename_list().values[i], '=', &old_name,
53 &new_name);
54 if (iree_string_view_equal(old_name, name)) {
55 return new_name;
56 }
57 }
58 return name;
59}
60
61// Expects `type=value` consistent with the HAL.
62static iree_status_t iree_tooling_parse_splat(iree_string_view_t splat_value,
63 uint8_t* out_pattern_length,
64 uint8_t* out_pattern) {
65 if (iree_string_view_is_empty(splat_value)) {
66 *out_pattern_length = 1;
67 out_pattern[0] = 0;
68 return iree_ok_status();
69 }
70
71 iree_string_view_t type_str, value_str;
72 iree_string_view_split(splat_value, '=', &type_str, &value_str);
73
74 iree_hal_element_type_t type = IREE_HAL_ELEMENT_TYPE_NONE;
75 IREE_RETURN_IF_ERROR(iree_hal_parse_element_type(type_str, &type));
76
77 iree_device_size_t byte_count = iree_hal_element_dense_byte_count(type);
78 if (byte_count > 16) {
79 return iree_make_status(
80 IREE_STATUS_OUT_OF_RANGE,
81 "element type size for %.*s out of range of splat patterns",
82 (int)type_str.size, type_str.data);
83 }
84 *out_pattern_length = (uint8_t)byte_count;
85
86 return iree_hal_parse_element(value_str, type,
87 iree_make_byte_span(out_pattern, 16));
88}
89
90static iree_status_t iree_tooling_replace_splatted_parameter(
91 iree_io_parameter_index_entry_t* entry) {
92 // Always favor specific splat values.
93 for (iree_host_size_t i = 0; i < FLAG_splat_list().count; ++i) {
94 iree_string_view_t name, splat_value;
95 iree_string_view_split(FLAG_splat_list().values[i], '=', &name,
96 &splat_value);
97 if (iree_string_view_equal(name, entry->key)) {
98 entry->type = IREE_IO_PARAMETER_INDEX_ENTRY_STORAGE_TYPE_SPLAT;
99 memset(&entry->storage, 0, sizeof(entry->storage));
100 return iree_tooling_parse_splat(splat_value,
101 &entry->storage.splat.pattern_length,
102 entry->storage.splat.pattern);
103 }
104 }
105
106 // If not specifically splatted then see if we are stripping and use that.
107 if (FLAG_strip) {
108 entry->type = IREE_IO_PARAMETER_INDEX_ENTRY_STORAGE_TYPE_SPLAT;
109 memset(&entry->storage, 0, sizeof(entry->storage));
110 entry->storage.splat.pattern_length = 1;
111 entry->storage.splat.pattern[0] = 0;
112 return iree_ok_status();
113 }
114
115 return iree_ok_status();
116}
117
118static iree_status_t iree_tooling_convert_parameter_index(
119 iree_io_parameter_index_t* source_index,
120 iree_io_parameter_index_t* target_index) {
121 for (iree_host_size_t i = 0; i < iree_io_parameter_index_count(source_index);
122 ++i) {
123 // Get the existing entry we'll use as a template.
124 const iree_io_parameter_index_entry_t* source_entry = NULL;
125 IREE_RETURN_IF_ERROR(
126 iree_io_parameter_index_get(source_index, i, &source_entry));
127 iree_io_parameter_index_entry_t target_entry = *source_entry;
128
129 // If the parameter is in the exclude list then we just skip it.
130 if (iree_tooling_is_parameter_excluded(source_entry->key)) continue;
131
132 // If the parameter is in the rename list we'll add it with the new name.
133 target_entry.key =
134 iree_tooling_get_renamed_parameter_name(source_entry->key);
135
136 // If the parameter is turned into a splat we change its type. Note that it
137 // may have already been a splat but the user may want to change the value.
138 IREE_RETURN_IF_ERROR(
139 iree_tooling_replace_splatted_parameter(&target_entry));
140
141 // Add the entry (potentially modified) to the new index.
142 IREE_RETURN_IF_ERROR(
143 iree_io_parameter_index_add(target_index, &target_entry));
144 }
145 return iree_ok_status();
146}
147
148static iree_status_t iree_tooling_convert_parameters(
149 iree_io_scope_map_t* scope_map, iree_io_parameter_index_t* target_index,
150 iree_allocator_t host_allocator) {
151 for (iree_host_size_t i = 0; i < scope_map->count; ++i) {
152 IREE_RETURN_IF_ERROR(iree_tooling_convert_parameter_index(
153 scope_map->entries[i]->index, target_index));
154 }
155 return iree_ok_status();
156}
157
158//===----------------------------------------------------------------------===//
159// main
160//===----------------------------------------------------------------------===//
161
162IREE_FLAG(bool, quiet, false,
163 "Silences additional stdout output when not needed.");
164
165IREE_FLAG(string, output, "", "Output .irpa file path.");
166
167static void iree_io_file_handle_release_mapping(
168 void* user_data, iree_io_file_handle_primitive_t handle_primitive) {
169 iree_file_contents_free((iree_file_contents_t*)user_data);
170}
171
172typedef struct {
173 iree_allocator_t host_allocator;
174 const char* path;
175} iree_tooling_open_params_t;
176static iree_status_t iree_tooling_open_output_parameter_file(
177 void* user_data, iree_io_physical_offset_t archive_offset,
178 iree_io_physical_size_t archive_length,
179 iree_io_file_handle_t** out_file_handle) {
180 iree_tooling_open_params_t* params = (iree_tooling_open_params_t*)user_data;
181 iree_file_contents_t* file_contents = NULL;
182 IREE_RETURN_IF_ERROR(
183 iree_file_create_mapped(params->path, archive_offset + archive_length,
184 archive_offset, (iree_host_size_t)archive_length,
185 params->host_allocator, &file_contents));
186 iree_io_file_handle_release_callback_t release_callback = {
187 .fn = iree_io_file_handle_release_mapping,
188 .user_data = file_contents,
189 };
190 iree_status_t status = iree_io_file_handle_wrap_host_allocation(
191 IREE_IO_FILE_ACCESS_WRITE, file_contents->buffer, release_callback,
192 params->host_allocator, out_file_handle);
193 if (!iree_status_is_ok(status)) {
194 iree_file_contents_free(file_contents);
195 }
196 return status;
197}
198
199int main(int argc, char** argv) {
Ben Vanik23f28282024-02-23 11:14:25 -0800200 IREE_TRACE_APP_ENTER();
Ben Vanikfce839f2023-11-28 17:12:58 -0800201 IREE_TRACE_ZONE_BEGIN(z0);
202
203 iree_allocator_t host_allocator = iree_allocator_system();
204 int exit_code = EXIT_SUCCESS;
205
206 // Parse command line flags.
207 iree_flags_set_usage(
208 "iree-convert-parameters",
209 "Converts supported parameter file formats into IREE Parameter Archives\n"
210 "(.irpa) files. Provide one or more input parameter files in the same\n"
211 "form as expected by the iree-run-module tool (`--parameters=foo.gguf`)\n"
212 "and an output file with `--output=file.irpa`.\n"
213 "\n"
214 "Example converting from safetensors to IRPA:\n"
215 " iree-convert-parameters \\\n"
216 " --parameters=input.safetensors \\\n"
217 " --output=output.irpa\n"
218 "\n"
219 "Example mutating parameters:\n"
220 " iree-convert-parameters \\\n"
221 " --parameters=a.gguf \\\n"
222 " --parameters=b.safetensors \\\n"
223 " --exclude=unneeded_param \\\n"
224 " --rename=old_name=new_name \\\n"
225 " --splat=some_name=f32=4.2 \\\n"
226 " --output=ab.irpa\n"
227 "\n"
228 "Example stripping all parameters and replacing them with zeros except\n"
229 "for one that needs special handling:\n"
230 " iree-convert-parameters \\\n"
231 " --parameters=input.irpa \\\n"
232 " --strip \\\n"
233 " --splat=special_param=f32=1.0 \\\n"
234 " --output=output.irpa\n");
235 iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv);
236
237 // Load parameter indices as specified by command line flags.
238 iree_io_scope_map_t scope_map = {0};
239 iree_io_scope_map_initialize(host_allocator, &scope_map);
240 iree_status_t status =
241 iree_tooling_build_parameter_indices_from_flags(&scope_map);
242
243 // Build the new combined/modified index in memory based on the inputs.
244 iree_io_parameter_index_t* new_index = NULL;
245 if (iree_status_is_ok(status)) {
246 status = iree_io_parameter_index_create(host_allocator, &new_index);
247 }
248 if (iree_status_is_ok(status)) {
249 status =
250 iree_tooling_convert_parameters(&scope_map, new_index, host_allocator);
251 }
252 iree_io_scope_map_deinitialize(&scope_map);
253
254 iree_io_parameter_index_t* built_index = NULL;
255 if (iree_status_is_ok(status)) {
256 status = iree_io_parameter_index_create(host_allocator, &built_index);
257 }
258
259 // Write out the new archive.
260 if (iree_status_is_ok(status)) {
261 iree_tooling_open_params_t open_params = {
262 .host_allocator = host_allocator,
263 .path = FLAG_output,
264 };
265 iree_io_parameter_archive_file_open_callback_t open_callback = {
266 .fn = iree_tooling_open_output_parameter_file,
267 .user_data = &open_params,
268 };
269 status = iree_io_build_parameter_archive(
270 new_index, built_index, open_callback,
271 /*target_file_offset=*/0, host_allocator);
272 }
273
274 // Dump the new index ala iree-dump-parameters to show the final file.
275 if (iree_status_is_ok(status) && !FLAG_quiet) {
276 status = iree_io_parameter_index_fprint(stdout, iree_string_view_empty(),
277 built_index);
278 }
279
280 iree_io_parameter_index_release(built_index);
281 iree_io_parameter_index_release(new_index);
282
283 fflush(stdout);
284 if (!iree_status_is_ok(status)) {
285 iree_status_fprint(stderr, status);
286 iree_status_free(status);
287 exit_code = EXIT_FAILURE;
288 }
289 fflush(stderr);
290
291 IREE_TRACE_ZONE_END(z0);
Ben Vanik23f28282024-02-23 11:14:25 -0800292 IREE_TRACE_APP_EXIT(exit_code);
Ben Vanikfce839f2023-11-28 17:12:58 -0800293 return exit_code;
294}