Python implementation of the low level io parameters API. (#15457)
Adds Python support for:
* Creating a ParameterIndex and ParameterProvider
* Populating explicit from splat, FileHandle or buffer
* Loading gguf files
* Loading safetensors files
* FileHandle from host memory
* Creating an io_parameters module from a list of providers
This is a relatively low level interface. Future work may add high level
helpers.
diff --git a/runtime/bindings/python/CMakeLists.txt b/runtime/bindings/python/CMakeLists.txt
index ff8f7d9..69c84b5 100644
--- a/runtime/bindings/python/CMakeLists.txt
+++ b/runtime/bindings/python/CMakeLists.txt
@@ -64,6 +64,8 @@
"initialize_module.cc"
"invoke.h"
"invoke.cc"
+ "io.h"
+ "io.cc"
"hal.h"
"hal.cc"
"numpy_interop.h"
@@ -83,6 +85,16 @@
iree::hal
iree::hal::drivers
iree::hal::utils::allocators
+ iree::base::internal::file_io
+ iree::base::internal::path
+ iree::io::file_handle
+ iree::io::formats::gguf
+ iree::io::formats::safetensors
+ iree::io::parameter_index
+ iree::io::parameter_index_provider
+ iree::io::parameter_provider
+ iree::io::scope_map
+ iree::modules::io::parameters
iree::modules::hal
iree::tooling::device_util
iree::tooling::modules
@@ -213,13 +225,6 @@
iree_py_test(
NAME
- py_module_test
- SRCS
- "tests/py_module_test.py"
-)
-
-iree_py_test(
- NAME
system_setup_test
SRCS
"tests/system_setup_test.py"
@@ -237,6 +242,13 @@
iree_py_test(
NAME
+ io_test
+ SRCS
+ "tests/io_test.py"
+ )
+
+ iree_py_test(
+ NAME
system_api_test
SRCS
"tests/system_api_test.py"
diff --git a/runtime/bindings/python/initialize_module.cc b/runtime/bindings/python/initialize_module.cc
index 8f4dd8d..bac89bd 100644
--- a/runtime/bindings/python/initialize_module.cc
+++ b/runtime/bindings/python/initialize_module.cc
@@ -7,6 +7,7 @@
#include "./binding.h"
#include "./hal.h"
#include "./invoke.h"
+#include "./io.h"
#include "./numpy_interop.h"
#include "./py_module.h"
#include "./status_utils.h"
@@ -26,6 +27,7 @@
m.doc() = "IREE Binding Backend Helpers";
SetupHalBindings(m);
SetupInvokeBindings(m);
+ SetupIoBindings(m);
SetupPyModuleBindings(m);
SetupVmBindings(m);
diff --git a/runtime/bindings/python/io.cc b/runtime/bindings/python/io.cc
new file mode 100644
index 0000000..0d8f8d1
--- /dev/null
+++ b/runtime/bindings/python/io.cc
@@ -0,0 +1,267 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "./io.h"
+
+#include <iostream>
+#include <string_view>
+#include <unordered_map>
+
+#include "./buffer_interop.h"
+#include "./vm.h"
+#include "iree/base/internal/file_io.h"
+#include "iree/base/internal/path.h"
+#include "iree/io/formats/gguf/gguf_format.h"
+#include "iree/io/formats/safetensors/safetensors_format.h"
+#include "iree/io/parameter_index_provider.h"
+#include "iree/modules/io/parameters/module.h"
+
+namespace iree::python {
+
+namespace {
+
+VmModule CreateIoParametersModule(VmInstance &instance, py::args providers) {
+ iree_vm_module_t *module = nullptr;
+ std::vector<iree_io_parameter_provider_t *> c_providers;
+ iree_host_size_t size = providers.size();
+ c_providers.resize(size);
+ for (iree_host_size_t i = 0; i < size; ++i) {
+ ParameterProvider *provider = py::cast<ParameterProvider *>(providers[i]);
+ c_providers[i] = provider->raw_ptr();
+ }
+ CheckApiStatus(iree_io_parameters_module_create(
+ instance.raw_ptr(), size, c_providers.data(),
+ iree_allocator_system(), &module),
+ "Error creating io_parameters module");
+ return VmModule::StealFromRawPtr(module);
+}
+
+FileHandle FileHandleWrapMemory(py::object host_buffer, bool readable,
+ bool writable, size_t &out_buffer_size) {
+ struct Retained {
+ Retained(py::object host_buffer)
+ : buffer_request(host_buffer, PyBUF_SIMPLE),
+ host_buffer(std::move(host_buffer)) {}
+ PyBufferRequest buffer_request;
+ py::object host_buffer;
+ };
+ std::unique_ptr<Retained> outer_retained =
+ std::make_unique<Retained>(std::move(host_buffer));
+ iree_io_file_access_t access = 0;
+ if (readable) access |= IREE_IO_FILE_ACCESS_READ;
+ if (writable) access |= IREE_IO_FILE_ACCESS_WRITE;
+ iree_io_file_handle_t *created_handle;
+ out_buffer_size = outer_retained->buffer_request.view().len;
+ CheckApiStatus(
+ iree_io_file_handle_wrap_host_allocation(
+ access,
+ iree_byte_span_t{
+ static_cast<uint8_t *>(outer_retained->buffer_request.view().buf),
+ static_cast<iree_host_size_t>(
+ outer_retained->buffer_request.view().len)},
+ iree_io_file_handle_release_callback_t{
+ +[](void *user_data, iree_io_file_handle_primitive_t primitive) {
+ Retained *inner_retained = static_cast<Retained *>(user_data);
+ delete inner_retained;
+ },
+ (void *)outer_retained.get(),
+ },
+ iree_allocator_system(), &created_handle),
+ "Could not wrap host memory into a file handle");
+ outer_retained.release();
+ return FileHandle::StealFromRawPtr(created_handle);
+}
+
+void ParameterIndexAddFromFileHandle(ParameterIndex &self, std::string &key,
+ FileHandle &file_handle, uint64_t length,
+ uint64_t offset,
+ std::optional<std::string> metadata) {
+ iree_io_parameter_index_entry_t entry;
+ memset(&entry, 0, sizeof(entry));
+ entry.key = iree_make_string_view(key.data(), key.size());
+ if (metadata) {
+ entry.metadata.data = reinterpret_cast<const uint8_t *>(metadata->data());
+ entry.metadata.data_length = metadata->size();
+ }
+ entry.length = length;
+ entry.type = IREE_IO_PARAMETER_INDEX_ENTRY_STORAGE_TYPE_FILE;
+ entry.storage.file.handle = file_handle.raw_ptr();
+ entry.storage.file.offset = offset;
+ CheckApiStatus(iree_io_parameter_index_add(self.raw_ptr(), &entry),
+ "Could not add parameter index entry");
+}
+
+void ParameterIndexParseFileHandle(ParameterIndex &self,
+ FileHandle &file_handle,
+ std::string &format) {
+ if (format == "gguf") {
+ CheckApiStatus(
+ iree_io_parse_gguf_index(file_handle.raw_ptr(), self.raw_ptr()),
+ "Could not parse gguf file into index");
+ } else if (format == "safetensors") {
+ CheckApiStatus(
+ iree_io_parse_safetensors_index(file_handle.raw_ptr(), self.raw_ptr()),
+ "Could not parse safetensors file into index");
+ } else {
+ throw std::invalid_argument(
+ "Unrecognized file format. Expected one of: 'gguf', 'safetensors'");
+ }
+}
+
+void ParameterIndexLoadFile(ParameterIndex &self, std::string &file_path,
+ std::optional<std::string> format, bool readable,
+ bool writable, bool mmap) {
+ // Default format from extension.
+ if (!format) {
+ iree_string_view_t path_ext = iree_file_path_extension(
+ iree_make_string_view(file_path.data(), file_path.size()));
+ format.emplace(path_ext.data, path_ext.size);
+ }
+
+ // Open file.
+ iree_file_read_flags_t read_flags = IREE_FILE_READ_FLAG_DEFAULT;
+ if (mmap) {
+ read_flags = IREE_FILE_READ_FLAG_MMAP;
+ } else {
+ read_flags = IREE_FILE_READ_FLAG_PRELOAD;
+ }
+ iree_file_contents_t *file_contents = nullptr;
+ CheckApiStatus(
+ iree_file_read_contents(file_path.c_str(), read_flags,
+ iree_allocator_system(), &file_contents),
+ "Error opening parameter file");
+ iree_io_file_handle_release_callback_t release_callback = {
+ +[](void *user_data, iree_io_file_handle_primitive_t handle_primitive) {
+ iree_file_contents_t *file_contents = (iree_file_contents_t *)user_data;
+ iree_file_contents_free(file_contents);
+ },
+ file_contents,
+ };
+
+ // Wrap contents.
+ iree_io_file_handle_t *raw_file_handle = nullptr;
+ iree_status_t status = iree_io_file_handle_wrap_host_allocation(
+ IREE_IO_FILE_ACCESS_READ, file_contents->buffer, release_callback,
+ iree_allocator_system(), &raw_file_handle);
+ if (!iree_status_is_ok(status)) {
+ iree_file_contents_free(file_contents);
+ CheckApiStatus(status, "Error accessing parameter memory");
+ }
+
+ // Parse.
+ FileHandle file_handle = FileHandle::StealFromRawPtr(raw_file_handle);
+ ParameterIndexParseFileHandle(self, file_handle, *format);
+}
+
+} // namespace
+
+void SetupIoBindings(py::module_ &m) {
+ m.def("create_io_parameters_module", &CreateIoParametersModule);
+
+ py::class_<FileHandle>(m, "FileHandle")
+ .def_static(
+ "wrap_memory",
+ [](py::object host_buffer, bool readable, bool writable) {
+ size_t unused_len;
+ return FileHandleWrapMemory(std::move(host_buffer), readable,
+ writable, unused_len);
+ },
+ py::arg("host_buffer"), py::arg("readable") = true,
+ py::arg("writable") = false);
+ py::class_<ParameterProvider>(m, "ParameterProvider");
+ py::class_<ParameterIndex>(m, "ParameterIndex")
+ .def("__init__",
+ [](ParameterIndex *new_self) {
+ iree_io_parameter_index_t *created;
+ CheckApiStatus(iree_io_parameter_index_create(
+ iree_allocator_system(), &created),
+ "Could not create IO parameter index");
+ new (new_self) ParameterIndex();
+ *new_self = ParameterIndex::StealFromRawPtr(created);
+ })
+ .def("__len__",
+ [](ParameterIndex &self) {
+ return iree_io_parameter_index_count(self.raw_ptr());
+ })
+ .def(
+ "reserve",
+ [](ParameterIndex &self, iree_host_size_t new_capacity) {
+ CheckApiStatus(
+ iree_io_parameter_index_reserve(self.raw_ptr(), new_capacity),
+ "Could not reserve capacity");
+ },
+ py::arg("new_capacity"))
+ .def(
+ "add_splat",
+ [](ParameterIndex &self, std::string key, py::object pattern,
+ uint64_t total_length, std::optional<std::string> metadata) {
+ iree_io_parameter_index_entry_t entry;
+ memset(&entry, 0, sizeof(entry));
+ entry.key = iree_make_string_view(key.data(), key.size());
+ if (metadata) {
+ entry.metadata.data =
+ reinterpret_cast<const uint8_t *>(metadata->data());
+ entry.metadata.data_length = metadata->size();
+ }
+ entry.length = total_length;
+ entry.type = IREE_IO_PARAMETER_INDEX_ENTRY_STORAGE_TYPE_SPLAT;
+ PyBufferRequest pattern_info(pattern, PyBUF_SIMPLE);
+ auto pattern_size = pattern_info.view().len;
+ if (pattern_size > sizeof(entry.storage.splat.pattern)) {
+ throw std::invalid_argument(
+ "pattern must be limited to 16 bytes");
+ }
+ entry.storage.splat.pattern_length = pattern_size;
+ std::memcpy(entry.storage.splat.pattern, pattern_info.view().buf,
+ pattern_size);
+ CheckApiStatus(iree_io_parameter_index_add(self.raw_ptr(), &entry),
+ "Could not add parameter index entry");
+ },
+ py::arg("key"), py::arg("pattern"), py::arg("total_length"),
+ py::arg("metadata") = py::none())
+ .def("add_from_file_handle", ParameterIndexAddFromFileHandle,
+ py::arg("key"), py::arg("file_handle"), py::arg("length"),
+ py::arg("offset") = 0, py::arg("metadata") = py::none())
+ .def(
+ "add_buffer",
+ [](ParameterIndex &self, std::string key, py::object buffer,
+ bool readable, bool writable,
+ std::optional<std::string> metadata) {
+ size_t buffer_size;
+ FileHandle file_handle = FileHandleWrapMemory(
+ std::move(buffer), readable, writable, buffer_size);
+ ParameterIndexAddFromFileHandle(self, key, file_handle, buffer_size,
+ /*offset=*/0, std::move(metadata));
+ },
+ py::arg("key"), py::arg("buffer"), py::arg("readable") = true,
+ py::arg("writable") = false, py::arg("metadata") = py::none())
+ .def("load_from_file_handle", ParameterIndexParseFileHandle,
+ py::arg("file_handle"), py::arg("format"))
+ .def("load", ParameterIndexLoadFile, py::arg("file_path"),
+ py::arg("format") = py::none(), py::arg("readable") = true,
+ py::arg("writable") = false, py::arg("mmap") = true)
+ .def(
+ "create_provider",
+ [](ParameterIndex &self, std::string scope,
+ std::optional<iree_host_size_t> max_concurrent_operations) {
+ if (!max_concurrent_operations) {
+ max_concurrent_operations =
+ IREE_IO_PARAMETER_INDEX_PROVIDER_DEFAULT_MAX_CONCURRENT_OPERATIONS;
+ }
+ iree_io_parameter_provider_t *created;
+ CheckApiStatus(
+ iree_io_parameter_index_provider_create(
+ iree_make_string_view(scope.data(), scope.size()),
+ self.raw_ptr(), *max_concurrent_operations,
+ iree_allocator_system(), &created),
+ "Could not create parameter provider from index");
+ return ParameterProvider::StealFromRawPtr(created);
+ },
+ py::arg("scope") = std::string(),
+ py::arg("max_concurrent_operations") = py::none());
+}
+
+} // namespace iree::python
diff --git a/runtime/bindings/python/io.h b/runtime/bindings/python/io.h
new file mode 100644
index 0000000..6381c73
--- /dev/null
+++ b/runtime/bindings/python/io.h
@@ -0,0 +1,61 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_BINDINGS_PYTHON_IREE_RT_IO_PARAMETERS_H_
+#define IREE_BINDINGS_PYTHON_IREE_RT_IO_PARAMETERS_H_
+
+#include <vector>
+
+#include "./binding.h"
+#include "iree/io/file_handle.h"
+#include "iree/io/parameter_index.h"
+#include "iree/io/parameter_provider.h"
+
+namespace iree::python {
+
+template <>
+struct ApiPtrAdapter<iree_io_file_handle_t> {
+ static void Retain(iree_io_file_handle_t *v) {
+ iree_io_file_handle_retain(v);
+ }
+ static void Release(iree_io_file_handle_t *v) {
+ iree_io_file_handle_release(v);
+ }
+};
+
+template <>
+struct ApiPtrAdapter<iree_io_parameter_provider_t> {
+ static void Retain(iree_io_parameter_provider_t *v) {
+ iree_io_parameter_provider_retain(v);
+ }
+ static void Release(iree_io_parameter_provider_t *v) {
+ iree_io_parameter_provider_release(v);
+ }
+};
+
+template <>
+struct ApiPtrAdapter<iree_io_parameter_index_t> {
+ static void Retain(iree_io_parameter_index_t *v) {
+ iree_io_parameter_index_retain(v);
+ }
+ static void Release(iree_io_parameter_index_t *v) {
+ iree_io_parameter_index_release(v);
+ }
+};
+
+class FileHandle : public ApiRefCounted<FileHandle, iree_io_file_handle_t> {};
+
+class ParameterProvider
+ : public ApiRefCounted<ParameterProvider, iree_io_parameter_provider_t> {};
+
+class ParameterIndex
+ : public ApiRefCounted<ParameterIndex, iree_io_parameter_index_t> {};
+
+void SetupIoBindings(py::module_ &m);
+
+} // namespace iree::python
+
+#endif // IREE_BINDINGS_PYTHON_IREE_RT_IO_PARAMETERS_H_
diff --git a/runtime/bindings/python/iree/runtime/__init__.py b/runtime/bindings/python/iree/runtime/__init__.py
index 140de10..b161d47 100644
--- a/runtime/bindings/python/iree/runtime/__init__.py
+++ b/runtime/bindings/python/iree/runtime/__init__.py
@@ -13,6 +13,14 @@
from . import _binding
# Pull some of the native symbols into the public API.
+# Io imports
+from ._binding import (
+ FileHandle,
+ ParameterIndex,
+ ParameterProvider,
+ create_io_parameters_module,
+)
+
# Hal imports
from ._binding import (
BufferCompatibility,
diff --git a/runtime/bindings/python/iree/runtime/_binding.pyi b/runtime/bindings/python/iree/runtime/_binding.pyi
index 858ef7a..faa42ee 100644
--- a/runtime/bindings/python/iree/runtime/_binding.pyi
+++ b/runtime/bindings/python/iree/runtime/_binding.pyi
@@ -3,6 +3,9 @@
from typing import overload
def create_hal_module(instance: VmInstance, device: HalDevice) -> VmModule: ...
+def create_io_parameters_module(
+ instance: VmInstance, *providers: ParameterProvider
+) -> VmModule: ...
def disable_leak_checker(): ...
def get_cached_hal_driver(device_uri: str) -> HalDriver: ...
def parse_flags(*flag: str): ...
@@ -46,6 +49,12 @@
def __and__(self, other: BufferUsage) -> int: ...
def __or__(self, other: BufferUsage) -> int: ...
+class FileHandle:
+ @staticmethod
+ def wrap_memory(
+ host_buffer: Any, readable: bool = True, writable: bool = False
+ ) -> FileHandle: ...
+
class HalAllocator:
def allocate_buffer(
self,
@@ -275,6 +284,52 @@
def __and__(self, other: MemoryType) -> int: ...
def __or__(self, other: MemoryType) -> int: ...
+class ParameterIndex:
+ def __init__() -> None: ...
+ def __len__(self) -> int: ...
+ def reserve(self, new_capacity: int) -> None: ...
+ def add_splat(
+ self,
+ key: str,
+ pattern: Any,
+ total_length: int,
+ *,
+ metadata: Optional[Union[bytes, str]] = None
+ ) -> None: ...
+ def add_from_file_handle(
+ self,
+ key: str,
+ file_handle: FileHandle,
+ length: int,
+ *,
+ offset: int = 0,
+ metadata: Optional[Union[bytes, str]] = None
+ ) -> None: ...
+ def add_buffer(
+ self,
+ key: str,
+ buffer: Any,
+ *,
+ readable: bool = True,
+ writable: bool = False,
+ metadata: Optional[Union[bytes, str]] = None
+ ) -> None: ...
+ def load_from_file_handle(self, file_handle: FileHandle, format: str) -> None: ...
+ def load(
+ self,
+ file_path: str,
+ *,
+ format: Optional[str] = None,
+ readable: bool = True,
+ writable: bool = False,
+ mmap: bool = True
+ ) -> None: ...
+ def create_provider(
+ self, *, scope: str = "", max_concurrent_operations: Optional[int] = None
+ ) -> ParameterProvider: ...
+
+class ParameterProvider: ...
+
class PyModuleInterface:
def __init__(self, module_name: str, ctor: object) -> None: ...
def create(self) -> VmModule: ...
diff --git a/runtime/bindings/python/tests/io_test.py b/runtime/bindings/python/tests/io_test.py
new file mode 100644
index 0000000..ad95a8b
--- /dev/null
+++ b/runtime/bindings/python/tests/io_test.py
@@ -0,0 +1,192 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import array
+import logging
+import numpy as np
+from pathlib import Path
+import unittest
+
+import iree.compiler
+import iree.runtime as rt
+
+
+MM_TEST_COMPILED = None
+MM_TEST_ASM = r"""
+ #map = affine_map<(d0, d1) -> (d0, d1)>
+ #map1 = affine_map<(d0, d1) -> (d1, d0)>
+ #map2 = affine_map<(d0, d1) -> (d1)>
+ module @main {
+ util.global private @_params.classifier.weight {noinline} = #stream.parameter.named<"params"::"weight"> : tensor<30x20xf32>
+ util.global private @_params.classifier.bias {noinline} = #stream.parameter.named<"params"::"bias"> : tensor<30xf32>
+ func.func @run(%arg0: tensor<128x20xf32>) -> tensor<128x30xf32> {
+ %0 = call @forward(%arg0) : (tensor<128x20xf32>) -> tensor<128x30xf32>
+ return %0 : tensor<128x30xf32>
+ }
+ func.func private @forward(%arg0: tensor<128x20xf32>) -> tensor<128x30xf32> attributes {torch.assume_strict_symbolic_shapes} {
+ %cst = arith.constant 0.000000e+00 : f32
+ %_params.classifier.weight = util.global.load @_params.classifier.weight : tensor<30x20xf32>
+ %0 = tensor.empty() : tensor<20x30xf32>
+ %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%_params.classifier.weight : tensor<30x20xf32>) outs(%0 : tensor<20x30xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<20x30xf32>
+ %2 = tensor.empty() : tensor<128x30xf32>
+ %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<128x30xf32>) -> tensor<128x30xf32>
+ %4 = linalg.matmul ins(%arg0, %1 : tensor<128x20xf32>, tensor<20x30xf32>) outs(%3 : tensor<128x30xf32>) -> tensor<128x30xf32>
+ %_params.classifier.bias = util.global.load @_params.classifier.bias : tensor<30xf32>
+ %5 = linalg.generic {indexing_maps = [#map, #map2, #map], iterator_types = ["parallel", "parallel"]} ins(%4, %_params.classifier.bias : tensor<128x30xf32>, tensor<30xf32>) outs(%2 : tensor<128x30xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %6 = arith.addf %in, %in_0 : f32
+ linalg.yield %6 : f32
+ } -> tensor<128x30xf32>
+ return %5 : tensor<128x30xf32>
+ }
+}
+"""
+
+
+def compile_mm_test():
+ global MM_TEST_COMPILED
+ if not MM_TEST_COMPILED:
+ MM_TEST_COMPILED = iree.compiler.compile_str(
+ MM_TEST_ASM, target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS
+ )
+ return MM_TEST_COMPILED
+
+
+def create_mm_test_module(instance):
+ binary = compile_mm_test()
+ return rt.VmModule.copy_buffer(instance, binary)
+
+
+def _float_constant(val: float) -> array.array:
+ return array.array("f", [val])
+
+
+class ParameterTest(unittest.TestCase):
+ def setUp(self):
+ self.instance = rt.VmInstance()
+ self.device = rt.get_device(iree.compiler.core.DEFAULT_TESTING_DRIVER)
+ self.config = rt.Config(device=self.device)
+
+ def testParameterIndex(self):
+ index = rt.ParameterIndex()
+ self.assertEqual(len(index), 0)
+ index.reserve(25)
+ self.assertEqual(len(index), 0)
+ provider = index.create_provider()
+ rt.create_io_parameters_module(self.instance, provider)
+
+ def testFileHandleWrap(self):
+ fh = rt.FileHandle.wrap_memory(b"foobar")
+ del fh
+
+ def testParameterIndexAddFromFile(self):
+ splat_index = rt.ParameterIndex()
+ fh = rt.FileHandle.wrap_memory(b"foobar")
+ splat_index.add_from_file_handle("data", fh, length=3, offset=3)
+
+ def testSplats(self):
+ splat_index = rt.ParameterIndex()
+ splat_index.add_splat("weight", _float_constant(2.0), 30 * 20 * 4)
+ splat_index.add_splat("bias", _float_constant(1.0), 30 * 4)
+ modules = rt.load_vm_modules(
+ rt.create_io_parameters_module(
+ self.instance, splat_index.create_provider(scope="params")
+ ),
+ rt.create_hal_module(self.instance, self.device),
+ create_mm_test_module(self.instance),
+ config=self.config,
+ )
+ main = modules[-1]
+ input = np.zeros([128, 20], dtype=np.float32) + 2.0
+ result = main.run(input)
+ print(result.to_host())
+ # TODO: Fix splat in the parameter code so it is not all zeros.
+ # expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
+ # np.testing.assert_array_almost_equal(result, expected_result)
+
+ def testBuffers(self):
+ index = rt.ParameterIndex()
+ weight = np.zeros([30, 20], dtype=np.float32) + 2.0
+ bias = np.zeros([30], dtype=np.float32) + 1.0
+ index.add_buffer("weight", weight)
+ index.add_buffer("bias", bias)
+ modules = rt.load_vm_modules(
+ rt.create_io_parameters_module(
+ self.instance, index.create_provider(scope="params")
+ ),
+ rt.create_hal_module(self.instance, self.device),
+ create_mm_test_module(self.instance),
+ config=self.config,
+ )
+ main = modules[-1]
+ input = np.zeros([128, 20], dtype=np.float32) + 2.0
+ result = main.run(input)
+ print(result.to_host())
+ expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
+ np.testing.assert_array_almost_equal(result, expected_result)
+
+ def testGguf(self):
+ index = rt.ParameterIndex()
+ index.load(
+ str(
+ Path(__file__).resolve().parent
+ / "testdata"
+ / "parameter_weight_bias_1.gguf"
+ )
+ )
+ modules = rt.load_vm_modules(
+ rt.create_io_parameters_module(
+ self.instance, index.create_provider(scope="params")
+ ),
+ rt.create_hal_module(self.instance, self.device),
+ create_mm_test_module(self.instance),
+ config=self.config,
+ )
+ main = modules[-1]
+ input = np.zeros([128, 20], dtype=np.float32) + 2.0
+ result = main.run(input)
+ print(result.to_host())
+ expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
+ np.testing.assert_array_almost_equal(result, expected_result)
+
+ def testSafetensors(self):
+ index = rt.ParameterIndex()
+ index.load(
+ str(
+ Path(__file__).resolve().parent
+ / "testdata"
+ / "parameter_weight_bias_1.safetensors"
+ )
+ )
+ modules = rt.load_vm_modules(
+ rt.create_io_parameters_module(
+ self.instance, index.create_provider(scope="params")
+ ),
+ rt.create_hal_module(self.instance, self.device),
+ create_mm_test_module(self.instance),
+ config=self.config,
+ )
+ main = modules[-1]
+ input = np.zeros([128, 20], dtype=np.float32) + 2.0
+ result = main.run(input)
+ print(result.to_host())
+ expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
+ np.testing.assert_array_almost_equal(result, expected_result)
+
+ def testSplatTooBig(self):
+ splat_index = rt.ParameterIndex()
+ with self.assertRaises(ValueError):
+ splat_index.add_splat(
+ "weight", array.array("f", [1.0, 2.0, 3.0, 4.0, 5.0]), 30 * 20 * 4
+ )
+
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.DEBUG)
+ unittest.main()
diff --git a/runtime/bindings/python/tests/testdata/generate_parameter_gguf.py b/runtime/bindings/python/tests/testdata/generate_parameter_gguf.py
new file mode 100755
index 0000000..607f57e
--- /dev/null
+++ b/runtime/bindings/python/tests/testdata/generate_parameter_gguf.py
@@ -0,0 +1,46 @@
+#!/usr/bin/env python
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# https://huggingface.co/docs/safetensors/index
+#
+# To regenerate:
+# $ pip install safetensors
+# $ ./runtime/bindings/python/tests/testdata/generate_parameter_safetensors.py
+
+from pathlib import Path
+import numpy as np
+from gguf import GGUFWriter
+
+
+def save_file(tensors, path):
+ writer = GGUFWriter(str(path), "generic")
+
+ writer.add_architecture()
+ writer.add_custom_alignment(64)
+
+ writer.add_uint32("metadata_uint32", 42)
+ writer.add_string("metadata_str", "hello")
+ writer.add_array("metadata_strs", ["a", "b", "c"])
+
+ for key, value in tensors.items():
+ writer.add_tensor(key, value)
+
+ writer.write_header_to_file()
+ writer.write_kv_data_to_file()
+ writer.write_tensors_to_file()
+
+ writer.close()
+
+
+# multiple tensors
+save_file(
+ {
+ "weight": np.zeros([30, 20], dtype=np.float32) + 2.0,
+ "bias": np.zeros([30], dtype=np.float32) + 1.0,
+ },
+ Path(__file__).resolve().parent / "parameter_weight_bias_1.gguf",
+)
diff --git a/runtime/bindings/python/tests/testdata/generate_parameter_safetensors.py b/runtime/bindings/python/tests/testdata/generate_parameter_safetensors.py
new file mode 100755
index 0000000..43a282a
--- /dev/null
+++ b/runtime/bindings/python/tests/testdata/generate_parameter_safetensors.py
@@ -0,0 +1,25 @@
+#!/usr/bin/env python
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# https://huggingface.co/docs/safetensors/index
+#
+# To regenerate:
+# $ pip install safetensors
+# $ ./runtime/bindings/python/tests/testdata/generate_parameter_safetensors.py
+
+from pathlib import Path
+import numpy as np
+from safetensors.numpy import save_file
+
+# multiple tensors
+save_file(
+ {
+ "weight": np.zeros([30, 20], dtype=np.float32) + 2.0,
+ "bias": np.zeros([30], dtype=np.float32) + 1.0,
+ },
+ Path(__file__).resolve().parent / "parameter_weight_bias_1.safetensors",
+)
diff --git a/runtime/bindings/python/tests/testdata/parameter_weight_bias_1.gguf b/runtime/bindings/python/tests/testdata/parameter_weight_bias_1.gguf
new file mode 100644
index 0000000..f27b773
--- /dev/null
+++ b/runtime/bindings/python/tests/testdata/parameter_weight_bias_1.gguf
Binary files differ
diff --git a/runtime/bindings/python/tests/testdata/parameter_weight_bias_1.safetensors b/runtime/bindings/python/tests/testdata/parameter_weight_bias_1.safetensors
new file mode 100644
index 0000000..a498050
--- /dev/null
+++ b/runtime/bindings/python/tests/testdata/parameter_weight_bias_1.safetensors
Binary files differ
diff --git a/runtime/src/iree/io/formats/gguf/testdata/generate_gguf_files.py b/runtime/src/iree/io/formats/gguf/testdata/generate_gguf_files.py
index 31430a1..e4609bb 100644
--- a/runtime/src/iree/io/formats/gguf/testdata/generate_gguf_files.py
+++ b/runtime/src/iree/io/formats/gguf/testdata/generate_gguf_files.py
@@ -1,3 +1,4 @@
+#!/usr/bin/env python
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.