blob: 0ad4a0e6b8bd90bc3f8393c2a7653a0c3d89c0c1 [file] [log] [blame]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/tooling/parameter_util.h"
#include "iree/base/internal/file_io.h"
#include "iree/base/internal/flags.h"
#include "iree/base/internal/path.h"
#include "iree/io/formats/gguf/gguf_format.h"
#include "iree/io/formats/safetensors/safetensors_format.h"
#include "iree/io/parameter_index.h"
#include "iree/io/parameter_index_provider.h"
#include "iree/io/scope_map.h"
#include "iree/modules/io/parameters/module.h"
//===----------------------------------------------------------------------===//
// Parameter file I/O
//===----------------------------------------------------------------------===//
IREE_FLAG(
string, parameter_mode, "mmap",
"A parameter I/O mode of ['preload', 'mmap'].\n"
" preload: read entire parameter files into wired memory on startup.\n"
" mmap: maps the parameter files into discardable memory - can increase\n"
" warm-up time and variance as mapped pages are swapped\n"
" by the OS.");
static void iree_file_contents_release_callback(
void* user_data, iree_io_file_handle_primitive_t handle_primitive) {
iree_file_contents_t* file_contents = (iree_file_contents_t*)user_data;
iree_file_contents_free(file_contents);
}
// Opens the parameter file at |path| with the mode specified by the
// --parameter_mode flag and returns its handle.
static iree_status_t iree_io_open_parameter_file(
iree_string_view_t path, iree_allocator_t host_allocator,
iree_io_file_handle_t** out_file_handle) {
IREE_ASSERT_ARGUMENT(out_file_handle);
*out_file_handle = NULL;
IREE_TRACE_ZONE_BEGIN(z0);
IREE_TRACE_ZONE_APPEND_TEXT(z0, path.data, path.size);
char path_str[2048] = {0};
iree_string_view_to_cstring(path, path_str, sizeof(path_str));
iree_file_read_flags_t read_flags = 0;
if (strcmp(FLAG_parameter_mode, "mmap") == 0) {
read_flags |= IREE_FILE_READ_FLAG_MMAP;
} else if (strcmp(FLAG_parameter_mode, "preload") == 0) {
read_flags |= IREE_FILE_READ_FLAG_PRELOAD;
} else {
IREE_TRACE_ZONE_END(z0);
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"unrecognized --parameter_mode= value '%s'",
FLAG_parameter_mode);
}
iree_file_contents_t* file_contents = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_file_read_contents(path_str, read_flags, host_allocator,
&file_contents));
iree_io_file_handle_release_callback_t release_callback = {
.fn = iree_file_contents_release_callback,
.user_data = file_contents,
};
iree_io_file_handle_t* file_handle = NULL;
iree_status_t status = iree_io_file_handle_wrap_host_allocation(
IREE_IO_FILE_ACCESS_READ, file_contents->buffer, release_callback,
host_allocator, &file_handle);
if (iree_status_is_ok(status)) {
*out_file_handle = file_handle;
} else {
iree_file_contents_free(file_contents);
}
IREE_TRACE_ZONE_END(z0);
return status;
}
//===----------------------------------------------------------------------===//
// Parameter file format parsing
//===----------------------------------------------------------------------===//
IREE_FLAG_LIST(
string, parameters,
"Specifies a parameter file to make available to programs with either an\n"
"anonymous global scope (`some_file.gguf`) or a named scope like\n"
"`my_scope=some_file.gguf`.\n"
"\n"
"Supported formats:\n"
"- .irpa (IREE parameter archive)\n"
"- .gguf (https://github.com/ggerganov/ggml/blob/master/docs/gguf.md)"
"- .safetensors (https://github.com/huggingface/safetensors)");
// Appends the parameter file located at |path| to |index|.
static iree_status_t iree_io_append_parameter_file_to_index(
iree_string_view_t path, iree_io_parameter_index_t* index,
iree_allocator_t host_allocator) {
IREE_ASSERT_ARGUMENT(index);
IREE_TRACE_ZONE_BEGIN(z0);
// Open the file.
iree_io_file_handle_t* file_handle = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_io_open_parameter_file(path, host_allocator, &file_handle));
// Index the file based on its (inferred) format.
iree_status_t status = iree_ok_status();
iree_string_view_t path_ext = iree_file_path_extension(path);
if (iree_string_view_equal_case(path_ext, IREE_SV("gguf"))) {
status = iree_io_parse_gguf_index(file_handle, index);
} else if (iree_string_view_equal_case(path_ext, IREE_SV("safetensors"))) {
status = iree_io_parse_safetensors_index(file_handle, index);
} else {
status = iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"unhandled parameter file format: .%.*s",
(int)path_ext.size, path_ext.data);
}
// Release our file reference - it's still retained by the index if it had any
// parameters in it.
iree_io_file_handle_release(file_handle);
IREE_TRACE_ZONE_END(z0);
return status;
}
iree_status_t iree_tooling_build_parameter_indices_from_flags(
iree_io_scope_map_t* scope_map) {
IREE_TRACE_ZONE_BEGIN(z0);
// Create one index per scope and add parameters to each.
for (iree_host_size_t i = 0; i < FLAG_parameters_list().count; ++i) {
// Parse the `scope=path` flag. Note that the scope is optional.
iree_string_view_t flag = FLAG_parameters_list().values[i];
iree_string_view_t scope, path;
if (iree_string_view_split(flag, '=', &scope, &path) == -1) {
// No scope provided (that's ok).
path = scope;
scope = iree_string_view_empty();
}
// Lookup (or create) the index for the given scope.
iree_io_parameter_index_t* index = NULL; // unowned
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_io_scope_map_lookup(scope_map, scope, &index));
// Index the file.
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_io_append_parameter_file_to_index(path, index,
scope_map->host_allocator));
}
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
iree_status_t iree_tooling_create_parameters_module_from_flags(
iree_vm_instance_t* instance, iree_allocator_t host_allocator,
iree_vm_module_t** out_module) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_io_scope_map_t scope_map;
iree_io_scope_map_initialize(host_allocator, &scope_map);
// Parse all parameter files and build out their indices.
iree_status_t status =
iree_tooling_build_parameter_indices_from_flags(&scope_map);
// Create one provider per scope.
iree_host_size_t provider_count = 0;
iree_io_parameter_provider_t** providers =
(iree_io_parameter_provider_t**)iree_alloca(
scope_map.count * sizeof(iree_io_parameter_provider_t*));
if (iree_status_is_ok(status)) {
for (iree_host_size_t i = 0; i < scope_map.count; ++i) {
status = iree_io_parameter_index_provider_create(
scope_map.entries[i]->scope, scope_map.entries[i]->index,
IREE_IO_PARAMETER_INDEX_PROVIDER_DEFAULT_MAX_CONCURRENT_OPERATIONS,
host_allocator, &providers[i]);
if (!iree_status_is_ok(status)) break;
++provider_count;
}
}
// Create the module with the list of providers.
if (iree_status_is_ok(status)) {
status = iree_io_parameters_module_create(
instance, provider_count, providers, host_allocator, out_module);
}
// Cleanup (module owns providers which own indices/etc).
for (iree_host_size_t i = 0; i < provider_count; ++i) {
iree_io_parameter_provider_release(providers[i]);
}
iree_io_scope_map_deinitialize(&scope_map);
IREE_TRACE_ZONE_END(z0);
return status;
}