[python] Flesh out more of the python parameters API. (#16957)
The existing Python parameters API was sufficient to construct parameter
indexes/archives but did not fully support introspecting them from
Python.
New APIs:
* `FileHandle`:
* Implements buffer protocol for accessing host allocations.
* `is_host_allocation -> bool` property
* `host_allocation -> memoryview`
* `__repr__`
* `ParameterIndexEntry`:
* Class added.
* Properties: `key`, `length`, `metadata`, `is_file`, `is_splat`,
`file_storage`, `file_view`, `splat_pattern`, `__repr__`
* `ParameterIndex`:
* `__getitem__` for index based iteration
* `items()` for `dict`-like access to a list of key/value tuples
* `__repr__`
Includes a version upgrade of nanobind to 1.9.2 as some new features
were needed (just bumped to current vs finding the exact version). This
requires a patch to the Linux CI to make it set up a venv and install
versions of build requirements like all of the others do (vs leaving
this to whatever was packaged in the base docker).
diff --git a/runtime/bindings/python/CMakeLists.txt b/runtime/bindings/python/CMakeLists.txt
index ddf8200..c787624 100644
--- a/runtime/bindings/python/CMakeLists.txt
+++ b/runtime/bindings/python/CMakeLists.txt
@@ -244,6 +244,13 @@
iree_py_test(
NAME
+ io_test
+ SRCS
+ "tests/io_test.py"
+)
+
+iree_py_test(
+ NAME
system_setup_test
SRCS
"tests/system_setup_test.py"
@@ -261,9 +268,9 @@
iree_py_test(
NAME
- io_test
+ io_runtime_test
SRCS
- "tests/io_test.py"
+ "tests/io_runtime_test.py"
)
iree_py_test(
diff --git a/runtime/bindings/python/binding.h b/runtime/bindings/python/binding.h
index c9a3a30..5a2bb53 100644
--- a/runtime/bindings/python/binding.h
+++ b/runtime/bindings/python/binding.h
@@ -128,6 +128,39 @@
return py::steal(py::handle(PyTuple_New(0)));
}
+// For a bound class, binds the buffer protocol. This will result in a call
+// to on the CppType:
+// HandleBufferProtocol(Py_buffer *view, int flags)
+// This is a low level callback and must not raise any exceptions. If
+// error conditions are warranted the usual PyErr_SetString approach must be
+// used (and -1 returned). Return 0 on success.
+template <typename CppType>
+void BindBufferProtocol(py::handle clazz) {
+ PyBufferProcs buffer_procs;
+ memset(&buffer_procs, 0, sizeof(buffer_procs));
+ buffer_procs.bf_getbuffer =
+ // It is not legal to raise exceptions from these callbacks.
+ +[](PyObject* raw_self, Py_buffer* view, int flags) -> int {
+ if (view == NULL) {
+ PyErr_SetString(PyExc_ValueError, "NULL view in getbuffer");
+ return -1;
+ }
+
+ // Cast must succeed due to invariants.
+ auto self = py::cast<CppType*>(py::handle(raw_self));
+
+ Py_INCREF(raw_self);
+ view->obj = raw_self;
+ return self->HandleBufferProtocol(view, flags);
+ };
+ buffer_procs.bf_releasebuffer =
+ +[](PyObject* raw_self, Py_buffer* view) -> void {};
+ auto heap_type = reinterpret_cast<PyHeapTypeObject*>(clazz.ptr());
+ assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
+ "must be heap type");
+ heap_type->as_buffer = buffer_procs;
+}
+
} // namespace python
} // namespace iree
diff --git a/runtime/bindings/python/io.cc b/runtime/bindings/python/io.cc
index 44a77ec..3305698 100644
--- a/runtime/bindings/python/io.cc
+++ b/runtime/bindings/python/io.cc
@@ -164,12 +164,50 @@
ParameterIndexParseFileHandle(self, file_handle, *format);
}
+// Wraps an index and an entry, extending lifetime of the index.
+struct ParameterIndexEntryWrapper {
+ ParameterIndexEntryWrapper(ParameterIndex index) : index(std::move(index)) {}
+
+ ParameterIndex index;
+ const iree_io_parameter_index_entry_t *entry = nullptr;
+};
+
} // namespace
+int FileHandle::HandleBufferProtocol(Py_buffer *view, int flags) {
+ auto primitive = iree_io_file_handle_primitive(raw_ptr());
+ if (primitive.type != IREE_IO_FILE_HANDLE_TYPE_HOST_ALLOCATION) {
+ PyErr_SetString(PyExc_ValueError,
+ "FileHandle is not based on a host allocation and "
+ "cannot be mapped");
+ return -1;
+ }
+ if (view == NULL) {
+ PyErr_SetString(PyExc_ValueError, "NULL view in getbuffer");
+ return -1;
+ }
+
+ view->buf = primitive.value.host_allocation.data;
+ view->len = primitive.value.host_allocation.data_length;
+ bool is_writable =
+ iree_io_file_handle_access(raw_ptr()) & IREE_IO_FILE_ACCESS_WRITE;
+ view->readonly = !is_writable;
+ view->itemsize = 1;
+ view->format = (char *)"B"; // Byte
+ view->ndim = 1;
+ view->shape = nullptr;
+ view->strides = nullptr;
+ view->suboffsets = nullptr;
+ view->internal = nullptr;
+ return 0;
+}
+
void SetupIoBindings(py::module_ &m) {
m.def("create_io_parameters_module", &CreateIoParametersModule);
- py::class_<FileHandle>(m, "FileHandle")
+ auto file_handle = py::class_<FileHandle>(m, "FileHandle");
+ BindBufferProtocol<FileHandle>(file_handle);
+ file_handle
.def_static(
"wrap_memory",
[](py::object host_buffer, bool readable, bool writable) {
@@ -178,8 +216,98 @@
writable, unused_len);
},
py::arg("host_buffer"), py::arg("readable") = true,
- py::arg("writable") = false);
+ py::arg("writable") = false)
+ .def_prop_ro(
+ "is_host_allocation",
+ [](FileHandle &self) {
+ auto primitive = iree_io_file_handle_primitive(self.raw_ptr());
+ return primitive.type == IREE_IO_FILE_HANDLE_TYPE_HOST_ALLOCATION;
+ })
+ .def_prop_ro(
+ "host_allocation",
+ [](py::handle self) {
+ return py::steal<py::object>(PyMemoryView_FromObject(self.ptr()));
+ })
+ .def("__repr__", [](py::handle self_object) {
+ if (py::cast<py::bool_>(self_object.attr("is_host_allocation"))) {
+ return py::str("FileHandle<host_allocation({})>")
+ .format(self_object.attr("host_allocation"));
+ } else {
+ return py::str("<FileHandle unknown>");
+ }
+ });
+
py::class_<ParameterProvider>(m, "ParameterProvider");
+ py::class_<ParameterIndexEntryWrapper>(m, "ParameterIndexEntry")
+ .def_prop_ro("key",
+ [](ParameterIndexEntryWrapper &self) {
+ return py::str(self.entry->key.data, self.entry->key.size);
+ })
+ .def_prop_ro(
+ "length",
+ [](ParameterIndexEntryWrapper &self) { return self.entry->length; })
+ .def_prop_ro("metadata",
+ [](ParameterIndexEntryWrapper &self) {
+ return py::bytes((const char *)self.entry->metadata.data,
+ self.entry->metadata.data_length);
+ })
+ .def_prop_ro("is_file",
+ [](ParameterIndexEntryWrapper &self) {
+ return self.entry->type ==
+ IREE_IO_PARAMETER_INDEX_ENTRY_STORAGE_TYPE_FILE;
+ })
+ .def_prop_ro("is_splat",
+ [](ParameterIndexEntryWrapper &self) {
+ return self.entry->type ==
+ IREE_IO_PARAMETER_INDEX_ENTRY_STORAGE_TYPE_SPLAT;
+ })
+ .def_prop_ro(
+ "file_storage",
+ [](ParameterIndexEntryWrapper &self) {
+ if (self.entry->type !=
+ IREE_IO_PARAMETER_INDEX_ENTRY_STORAGE_TYPE_FILE) {
+ throw std::invalid_argument("Entry is not file storage based");
+ }
+ return py::make_tuple(
+ FileHandle::BorrowFromRawPtr(self.entry->storage.file.handle),
+ self.entry->storage.file.offset);
+ })
+ .def_prop_ro("file_view",
+ [](py::handle self_object) {
+ auto file_storage = self_object.attr("file_storage");
+ py::handle file_handle = file_storage[0];
+ auto offset = py::cast<iree_host_size_t>(file_storage[1]);
+ auto length =
+ py::cast<iree_host_size_t>(self_object.attr("length"));
+ py::object memview = file_handle.attr("host_allocation");
+ py::slice slice(offset, offset + length);
+ return memview.attr("__getitem__")(slice);
+ })
+ .def_prop_ro("splat_pattern",
+ [](ParameterIndexEntryWrapper &self) {
+ if (self.entry->type !=
+ IREE_IO_PARAMETER_INDEX_ENTRY_STORAGE_TYPE_SPLAT) {
+ throw std::invalid_argument("Entry is not splat");
+ }
+ return py::bytes(
+ (const char *)self.entry->storage.splat.pattern,
+ self.entry->storage.splat.pattern_length);
+ })
+ .def("__repr__", [](py::handle &self_object) {
+ if (py::cast<py::bool_>(self_object.attr("is_splat"))) {
+ return py::str("<ParameterIndexEntry '{}' splat {}:{}>")
+ .format(self_object.attr("key"),
+ self_object.attr("splat_pattern"),
+ self_object.attr("length"));
+ } else if (py::cast<py::bool_>(self_object.attr("is_file"))) {
+ py::object file_storage = self_object.attr("file_storage");
+ return py::str("<ParameterIndexEntry '{}' {}:{}:{}")
+ .format(self_object.attr("key"), file_storage[0], file_storage[1],
+ self_object.attr("length"));
+ } else {
+ return py::str("<ParameterIndexEntry unknown>");
+ }
+ });
py::class_<ParameterIndex>(m, "ParameterIndex")
.def("__init__",
[](ParameterIndex *new_self) {
@@ -195,6 +323,44 @@
return iree_io_parameter_index_count(self.raw_ptr());
})
.def(
+ "__getitem__",
+ [](ParameterIndex &self, iree_host_size_t i) {
+ ParameterIndexEntryWrapper entry_wrapper(self);
+ CheckApiStatus(iree_io_parameter_index_get(self.raw_ptr(), i,
+ &entry_wrapper.entry),
+ "Could not enumerate parameter index");
+ return entry_wrapper;
+ },
+ py::arg("i"))
+ .def("items",
+ [](ParameterIndex &self) {
+ py::list items;
+ for (iree_host_size_t i = 0;
+ i < iree_io_parameter_index_count(self.raw_ptr()); ++i) {
+ ParameterIndexEntryWrapper entry_wrapper(self);
+ CheckApiStatus(iree_io_parameter_index_get(self.raw_ptr(), i,
+ &entry_wrapper.entry),
+ "Could not enumerate parameter index");
+ py::str key(entry_wrapper.entry->key.data,
+ entry_wrapper.entry->key.size);
+ py::object value = py::cast(std::move(entry_wrapper));
+ items.append(py::make_tuple(key, value));
+ }
+ return items;
+ })
+ .def("__repr__",
+ [](ParameterIndex &self) {
+ iree_string_builder_t b;
+ iree_string_builder_initialize(iree_allocator_system(), &b);
+ iree_status_t status = iree_io_parameter_index_dump(
+ iree_string_view_empty(), self.raw_ptr(), &b);
+ iree_string_view_t sv = iree_string_builder_view(&b);
+ py::str result = py::str(sv.data, sv.size);
+ iree_string_builder_deinitialize(&b);
+ CheckApiStatus(status, "Failed to dump parameter index");
+ return result;
+ })
+ .def(
"reserve",
[](ParameterIndex &self, iree_host_size_t new_capacity) {
CheckApiStatus(
diff --git a/runtime/bindings/python/io.h b/runtime/bindings/python/io.h
index 6381c73..1595b04 100644
--- a/runtime/bindings/python/io.h
+++ b/runtime/bindings/python/io.h
@@ -46,7 +46,10 @@
}
};
-class FileHandle : public ApiRefCounted<FileHandle, iree_io_file_handle_t> {};
+class FileHandle : public ApiRefCounted<FileHandle, iree_io_file_handle_t> {
+ public:
+ int HandleBufferProtocol(Py_buffer *view, int flags);
+};
class ParameterProvider
: public ApiRefCounted<ParameterProvider, iree_io_parameter_provider_t> {};
diff --git a/runtime/bindings/python/iree/runtime/_binding.pyi b/runtime/bindings/python/iree/runtime/_binding.pyi
index 3c44e88..90653f0 100644
--- a/runtime/bindings/python/iree/runtime/_binding.pyi
+++ b/runtime/bindings/python/iree/runtime/_binding.pyi
@@ -56,6 +56,14 @@
def wrap_memory(
host_buffer: Any, readable: bool = True, writable: bool = False
) -> FileHandle: ...
+ def host_allocation(self) -> memoryview:
+ """Access the raw view of the allocated host memory.
+
+ Requires is_host_allocation.
+ """
+ ...
+ @property
+ def is_host_allocation(self) -> bool: ...
class HalAllocator:
def allocate_buffer(
@@ -308,9 +316,40 @@
def __and__(self, other: MemoryType) -> int: ...
def __or__(self, other: MemoryType) -> int: ...
+class ParameterIndexEntry:
+ @property
+ def key(self) -> str: ...
+ @property
+ def length(self) -> int: ...
+ @property
+ def metadata(self) -> bytes: ...
+ @property
+ def is_file(self) -> bool: ...
+ @property
+ def is_splat(self) -> bool: ...
+ @property
+ def file_storage(self) -> Tuple[FileHandle, int]:
+ """Accesses the underlying storage (if is_file).
+
+ Only valid if is_file. Returns the backing FileHandle and offset.
+ """
+ ...
+ @property
+ def file_view(self) -> memoryview:
+ """Accesses a memoryview of the file contents.
+
+ Only valid if is_file and the file has host accessible storage.
+ """
+ ...
+ @property
+ def splat_pattern(self) -> bytes:
+ """Accesses the splat pattern (if is_splat)."""
+ ...
+
class ParameterIndex:
def __init__() -> None: ...
def __len__(self) -> int: ...
+ def __getitem__(self, i) -> ParameterIndexEntry: ...
def reserve(self, new_capacity: int) -> None: ...
def add_splat(
self,
@@ -357,6 +396,13 @@
def create_provider(
self, *, scope: str = "", max_concurrent_operations: Optional[int] = None
) -> ParameterProvider: ...
+ def items(self) -> List[Tuple[str, ParameterIndexEntry]]:
+ """Accesses the items as a tuple(str, entry).
+
+ Note that the index may contain duplicates, so loading into a dict
+ is up to the user, as only they can know if this is legal.
+ """
+ ...
class ParameterProvider: ...
diff --git a/runtime/bindings/python/tests/io_runtime_test.py b/runtime/bindings/python/tests/io_runtime_test.py
new file mode 100644
index 0000000..7fe1379
--- /dev/null
+++ b/runtime/bindings/python/tests/io_runtime_test.py
@@ -0,0 +1,208 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import array
+import logging
+import numpy as np
+from pathlib import Path
+import tempfile
+import unittest
+
+import iree.compiler
+import iree.runtime as rt
+
+
+MM_TEST_COMPILED = None
+MM_TEST_ASM = r"""
+ #map = affine_map<(d0, d1) -> (d0, d1)>
+ #map1 = affine_map<(d0, d1) -> (d1, d0)>
+ #map2 = affine_map<(d0, d1) -> (d1)>
+ module @main {
+ util.global private @_params.classifier.weight {inlining_policy = #util.inline.never} = #stream.parameter.named<"params"::"weight"> : tensor<30x20xf32>
+ util.global private @_params.classifier.bias {inlining_policy = #util.inline.never} = #stream.parameter.named<"params"::"bias"> : tensor<30xf32>
+ func.func @run(%arg0: tensor<128x20xf32>) -> tensor<128x30xf32> {
+ %0 = call @forward(%arg0) : (tensor<128x20xf32>) -> tensor<128x30xf32>
+ return %0 : tensor<128x30xf32>
+ }
+ func.func private @forward(%arg0: tensor<128x20xf32>) -> tensor<128x30xf32> attributes {torch.assume_strict_symbolic_shapes} {
+ %cst = arith.constant 0.000000e+00 : f32
+ %_params.classifier.weight = util.global.load @_params.classifier.weight : tensor<30x20xf32>
+ %0 = tensor.empty() : tensor<20x30xf32>
+ %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%_params.classifier.weight : tensor<30x20xf32>) outs(%0 : tensor<20x30xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<20x30xf32>
+ %2 = tensor.empty() : tensor<128x30xf32>
+ %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<128x30xf32>) -> tensor<128x30xf32>
+ %4 = linalg.matmul ins(%arg0, %1 : tensor<128x20xf32>, tensor<20x30xf32>) outs(%3 : tensor<128x30xf32>) -> tensor<128x30xf32>
+ %_params.classifier.bias = util.global.load @_params.classifier.bias : tensor<30xf32>
+ %5 = linalg.generic {indexing_maps = [#map, #map2, #map], iterator_types = ["parallel", "parallel"]} ins(%4, %_params.classifier.bias : tensor<128x30xf32>, tensor<30xf32>) outs(%2 : tensor<128x30xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %6 = arith.addf %in, %in_0 : f32
+ linalg.yield %6 : f32
+ } -> tensor<128x30xf32>
+ return %5 : tensor<128x30xf32>
+ }
+}
+"""
+
+
+def compile_mm_test():
+ global MM_TEST_COMPILED
+ if not MM_TEST_COMPILED:
+ MM_TEST_COMPILED = iree.compiler.compile_str(
+ MM_TEST_ASM,
+ target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
+ # TODO(#16098): re-enable const eval once parameters are supported.
+ extra_args=["--iree-opt-const-eval=false"],
+ )
+ return MM_TEST_COMPILED
+
+
+def create_mm_test_module(instance):
+ binary = compile_mm_test()
+ return rt.VmModule.copy_buffer(instance, binary)
+
+
+def _float_constant(val: float) -> array.array:
+ return array.array("f", [val])
+
+
+class ParameterTest(unittest.TestCase):
+ def setUp(self):
+ self.instance = rt.VmInstance()
+ self.device = rt.get_device(iree.compiler.core.DEFAULT_TESTING_DRIVER)
+ self.config = rt.Config(device=self.device)
+
+ def testParameterIndex(self):
+ index = rt.ParameterIndex()
+ self.assertEqual(len(index), 0)
+ index.reserve(25)
+ self.assertEqual(len(index), 0)
+ provider = index.create_provider()
+ rt.create_io_parameters_module(self.instance, provider)
+
+ def testSplats(self):
+ splat_index = rt.ParameterIndex()
+ splat_index.add_splat("weight", _float_constant(2.0), 30 * 20 * 4)
+ splat_index.add_splat("bias", _float_constant(1.0), 30 * 4)
+ modules = rt.load_vm_modules(
+ rt.create_io_parameters_module(
+ self.instance, splat_index.create_provider(scope="params")
+ ),
+ rt.create_hal_module(self.instance, self.device),
+ create_mm_test_module(self.instance),
+ config=self.config,
+ )
+ main = modules[-1]
+ input = np.zeros([128, 20], dtype=np.float32) + 2.0
+ result = main.run(input)
+ print(result.to_host())
+ # TODO: Fix splat in the parameter code so it is not all zeros.
+ # expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
+ # np.testing.assert_array_almost_equal(result, expected_result)
+
+ def testSplatsFromBuiltIrpaFile(self):
+ with tempfile.TemporaryDirectory() as td:
+ file_path = Path(td) / "archive.irpa"
+ rt.save_archive_file(
+ {
+ "weight": rt.SplatValue(np.float32(2.0), 30 * 20),
+ "bias": rt.SplatValue(np.float32(1.0), 30),
+ },
+ file_path,
+ )
+
+ index = rt.ParameterIndex()
+ index.load(str(file_path))
+ modules = rt.load_vm_modules(
+ rt.create_io_parameters_module(
+ self.instance, index.create_provider(scope="params")
+ ),
+ rt.create_hal_module(self.instance, self.device),
+ create_mm_test_module(self.instance),
+ config=self.config,
+ )
+ main = modules[-1]
+ input = np.zeros([128, 20], dtype=np.float32) + 2.0
+ result = main.run(input)
+ print(result.to_host())
+ expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
+ np.testing.assert_array_almost_equal(result, expected_result)
+
+ def testBuffers(self):
+ index = rt.ParameterIndex()
+ weight = np.zeros([30, 20], dtype=np.float32) + 2.0
+ bias = np.zeros([30], dtype=np.float32) + 1.0
+ index.add_buffer("weight", weight)
+ index.add_buffer("bias", bias)
+ modules = rt.load_vm_modules(
+ rt.create_io_parameters_module(
+ self.instance, index.create_provider(scope="params")
+ ),
+ rt.create_hal_module(self.instance, self.device),
+ create_mm_test_module(self.instance),
+ config=self.config,
+ )
+ main = modules[-1]
+ input = np.zeros([128, 20], dtype=np.float32) + 2.0
+ result = main.run(input)
+ print(result.to_host())
+ expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
+ np.testing.assert_array_almost_equal(result, expected_result)
+
+ def testGguf(self):
+ index = rt.ParameterIndex()
+ index.load(
+ str(
+ Path(__file__).resolve().parent
+ / "testdata"
+ / "parameter_weight_bias_1.gguf"
+ )
+ )
+ modules = rt.load_vm_modules(
+ rt.create_io_parameters_module(
+ self.instance, index.create_provider(scope="params")
+ ),
+ rt.create_hal_module(self.instance, self.device),
+ create_mm_test_module(self.instance),
+ config=self.config,
+ )
+ main = modules[-1]
+ input = np.zeros([128, 20], dtype=np.float32) + 2.0
+ result = main.run(input)
+ print(result.to_host())
+ expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
+ np.testing.assert_array_almost_equal(result, expected_result)
+
+ def testSafetensors(self):
+ index = rt.ParameterIndex()
+ index.load(
+ str(
+ Path(__file__).resolve().parent
+ / "testdata"
+ / "parameter_weight_bias_1.safetensors"
+ )
+ )
+ modules = rt.load_vm_modules(
+ rt.create_io_parameters_module(
+ self.instance, index.create_provider(scope="params")
+ ),
+ rt.create_hal_module(self.instance, self.device),
+ create_mm_test_module(self.instance),
+ config=self.config,
+ )
+ main = modules[-1]
+ input = np.zeros([128, 20], dtype=np.float32) + 2.0
+ result = main.run(input)
+ print(result.to_host())
+ expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
+ np.testing.assert_array_almost_equal(result, expected_result)
+
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.DEBUG)
+ unittest.main()
diff --git a/runtime/bindings/python/tests/io_test.py b/runtime/bindings/python/tests/io_test.py
index 88a1523..6557c4a 100644
--- a/runtime/bindings/python/tests/io_test.py
+++ b/runtime/bindings/python/tests/io_test.py
@@ -11,67 +11,14 @@
import tempfile
import unittest
-import iree.compiler
import iree.runtime as rt
-MM_TEST_COMPILED = None
-MM_TEST_ASM = r"""
- #map = affine_map<(d0, d1) -> (d0, d1)>
- #map1 = affine_map<(d0, d1) -> (d1, d0)>
- #map2 = affine_map<(d0, d1) -> (d1)>
- module @main {
- util.global private @_params.classifier.weight {inlining_policy = #util.inline.never} = #stream.parameter.named<"params"::"weight"> : tensor<30x20xf32>
- util.global private @_params.classifier.bias {inlining_policy = #util.inline.never} = #stream.parameter.named<"params"::"bias"> : tensor<30xf32>
- func.func @run(%arg0: tensor<128x20xf32>) -> tensor<128x30xf32> {
- %0 = call @forward(%arg0) : (tensor<128x20xf32>) -> tensor<128x30xf32>
- return %0 : tensor<128x30xf32>
- }
- func.func private @forward(%arg0: tensor<128x20xf32>) -> tensor<128x30xf32> attributes {torch.assume_strict_symbolic_shapes} {
- %cst = arith.constant 0.000000e+00 : f32
- %_params.classifier.weight = util.global.load @_params.classifier.weight : tensor<30x20xf32>
- %0 = tensor.empty() : tensor<20x30xf32>
- %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%_params.classifier.weight : tensor<30x20xf32>) outs(%0 : tensor<20x30xf32>) {
- ^bb0(%in: f32, %out: f32):
- linalg.yield %in : f32
- } -> tensor<20x30xf32>
- %2 = tensor.empty() : tensor<128x30xf32>
- %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<128x30xf32>) -> tensor<128x30xf32>
- %4 = linalg.matmul ins(%arg0, %1 : tensor<128x20xf32>, tensor<20x30xf32>) outs(%3 : tensor<128x30xf32>) -> tensor<128x30xf32>
- %_params.classifier.bias = util.global.load @_params.classifier.bias : tensor<30xf32>
- %5 = linalg.generic {indexing_maps = [#map, #map2, #map], iterator_types = ["parallel", "parallel"]} ins(%4, %_params.classifier.bias : tensor<128x30xf32>, tensor<30xf32>) outs(%2 : tensor<128x30xf32>) {
- ^bb0(%in: f32, %in_0: f32, %out: f32):
- %6 = arith.addf %in, %in_0 : f32
- linalg.yield %6 : f32
- } -> tensor<128x30xf32>
- return %5 : tensor<128x30xf32>
- }
-}
-"""
-
-
-def compile_mm_test():
- global MM_TEST_COMPILED
- if not MM_TEST_COMPILED:
- MM_TEST_COMPILED = iree.compiler.compile_str(
- MM_TEST_ASM,
- target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
- # TODO(#16098): re-enable const eval once parameters are supported.
- extra_args=["--iree-opt-const-eval=false"],
- )
- return MM_TEST_COMPILED
-
-
-def create_mm_test_module(instance):
- binary = compile_mm_test()
- return rt.VmModule.copy_buffer(instance, binary)
-
-
def _float_constant(val: float) -> array.array:
return array.array("f", [val])
-class ParameterArchiveTest(unittest.TestCase):
+class ParameterApiTest(unittest.TestCase):
def testCreateArchiveFile(self):
splat_index = rt.ParameterIndex()
splat_index.add_splat("weight", _float_constant(2.0), 30 * 20 * 4)
@@ -84,162 +31,69 @@
self.assertTrue(file_path.exists())
self.assertGreater(file_path.stat().st_size, 0)
- def testSaveArchiveFile(self):
- index = rt.ParameterIndex()
+ def testArchiveFileRoundtrip(self):
with tempfile.TemporaryDirectory() as td:
file_path = Path(td) / "archive.irpa"
+ orig_array = np.asarray([[1], [2], [3]], dtype=np.int64)
rt.save_archive_file(
{
- "weight": rt.SplatValue(np.float32(2.0), [30, 20]),
- "bias": rt.SplatValue(array.array("f", [1.0]), 30),
- "array": np.asarray([1, 2, 3]),
+ "weight": rt.SplatValue(np.int8(2), [30, 20]),
+ "bias": rt.SplatValue(array.array("b", [32]), 30),
+ "array": orig_array,
},
file_path,
)
self.assertTrue(file_path.exists())
self.assertGreater(file_path.stat().st_size, 0)
+ # Load and verify.
+ index = rt.ParameterIndex()
+ index.load(str(file_path))
+ self.assertEqual(len(index), 3)
-class ParameterTest(unittest.TestCase):
- def setUp(self):
- self.instance = rt.VmInstance()
- self.device = rt.get_device(iree.compiler.core.DEFAULT_TESTING_DRIVER)
- self.config = rt.Config(device=self.device)
+ # Note that the happy path of most properties are verified via
+ # the repr (as they are called internal to that).
+ entries = dict(index.items())
+ self.assertEqual(
+ repr(entries["weight"]),
+ "<ParameterIndexEntry 'weight' splat b'\\x02':600>",
+ )
+ self.assertEqual(
+ repr(entries["bias"]),
+ "<ParameterIndexEntry 'bias' splat b' ':30>",
+ )
+ self.assertRegex(
+ repr(entries["array"]),
+ r"<ParameterIndexEntry 'array' FileHandle<host_allocation\(.*\)>:384:24",
+ )
- def testParameterIndex(self):
- index = rt.ParameterIndex()
- self.assertEqual(len(index), 0)
- index.reserve(25)
- self.assertEqual(len(index), 0)
- provider = index.create_provider()
- rt.create_io_parameters_module(self.instance, provider)
+ # Verify some non-happy paths.
+ with self.assertRaisesRegex(ValueError, "Entry is not file storage based"):
+ entries["weight"].file_storage
+ with self.assertRaisesRegex(ValueError, "Entry is not splat"):
+ entries["array"].splat_pattern
+
+ # Verify that the repr of the index itself is sensical.
+ index_repr = repr(index)
+ self.assertIn("Parameter scope <global> (3 entries", index_repr)
+
+ # Get the array contents and verify against original.
+ array_view = entries["array"].file_view
+ self.assertEqual(len(array_view), 24)
+ array_back = np.asarray(array_view).view(np.int64).reshape(orig_array.shape)
+ np.testing.assert_array_equal(array_back, orig_array)
def testFileHandleWrap(self):
fh = rt.FileHandle.wrap_memory(b"foobar")
+ view = fh.host_allocation
del fh
+ self.assertEqual(bytes(view), b"foobar")
def testParameterIndexAddFromFile(self):
splat_index = rt.ParameterIndex()
fh = rt.FileHandle.wrap_memory(b"foobar")
splat_index.add_from_file_handle("data", fh, length=3, offset=3)
- def testSplats(self):
- splat_index = rt.ParameterIndex()
- splat_index.add_splat("weight", _float_constant(2.0), 30 * 20 * 4)
- splat_index.add_splat("bias", _float_constant(1.0), 30 * 4)
- modules = rt.load_vm_modules(
- rt.create_io_parameters_module(
- self.instance, splat_index.create_provider(scope="params")
- ),
- rt.create_hal_module(self.instance, self.device),
- create_mm_test_module(self.instance),
- config=self.config,
- )
- main = modules[-1]
- input = np.zeros([128, 20], dtype=np.float32) + 2.0
- result = main.run(input)
- print(result.to_host())
- # TODO: Fix splat in the parameter code so it is not all zeros.
- # expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
- # np.testing.assert_array_almost_equal(result, expected_result)
-
- def testSplatsFromBuiltIrpaFile(self):
- with tempfile.TemporaryDirectory() as td:
- file_path = Path(td) / "archive.irpa"
- rt.save_archive_file(
- {
- "weight": rt.SplatValue(np.float32(2.0), 30 * 20),
- "bias": rt.SplatValue(np.float32(1.0), 30),
- },
- file_path,
- )
-
- index = rt.ParameterIndex()
- index.load(str(file_path))
- modules = rt.load_vm_modules(
- rt.create_io_parameters_module(
- self.instance, index.create_provider(scope="params")
- ),
- rt.create_hal_module(self.instance, self.device),
- create_mm_test_module(self.instance),
- config=self.config,
- )
- main = modules[-1]
- input = np.zeros([128, 20], dtype=np.float32) + 2.0
- result = main.run(input)
- print(result.to_host())
- expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
- np.testing.assert_array_almost_equal(result, expected_result)
-
- def testBuffers(self):
- index = rt.ParameterIndex()
- weight = np.zeros([30, 20], dtype=np.float32) + 2.0
- bias = np.zeros([30], dtype=np.float32) + 1.0
- index.add_buffer("weight", weight)
- index.add_buffer("bias", bias)
- modules = rt.load_vm_modules(
- rt.create_io_parameters_module(
- self.instance, index.create_provider(scope="params")
- ),
- rt.create_hal_module(self.instance, self.device),
- create_mm_test_module(self.instance),
- config=self.config,
- )
- main = modules[-1]
- input = np.zeros([128, 20], dtype=np.float32) + 2.0
- result = main.run(input)
- print(result.to_host())
- expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
- np.testing.assert_array_almost_equal(result, expected_result)
-
- def testGguf(self):
- index = rt.ParameterIndex()
- index.load(
- str(
- Path(__file__).resolve().parent
- / "testdata"
- / "parameter_weight_bias_1.gguf"
- )
- )
- modules = rt.load_vm_modules(
- rt.create_io_parameters_module(
- self.instance, index.create_provider(scope="params")
- ),
- rt.create_hal_module(self.instance, self.device),
- create_mm_test_module(self.instance),
- config=self.config,
- )
- main = modules[-1]
- input = np.zeros([128, 20], dtype=np.float32) + 2.0
- result = main.run(input)
- print(result.to_host())
- expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
- np.testing.assert_array_almost_equal(result, expected_result)
-
- def testSafetensors(self):
- index = rt.ParameterIndex()
- index.load(
- str(
- Path(__file__).resolve().parent
- / "testdata"
- / "parameter_weight_bias_1.safetensors"
- )
- )
- modules = rt.load_vm_modules(
- rt.create_io_parameters_module(
- self.instance, index.create_provider(scope="params")
- ),
- rt.create_hal_module(self.instance, self.device),
- create_mm_test_module(self.instance),
- config=self.config,
- )
- main = modules[-1]
- input = np.zeros([128, 20], dtype=np.float32) + 2.0
- result = main.run(input)
- print(result.to_host())
- expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
- np.testing.assert_array_almost_equal(result, expected_result)
-
def testSplatTooBig(self):
splat_index = rt.ParameterIndex()
with self.assertRaises(ValueError):
diff --git a/runtime/bindings/python/vm.cc b/runtime/bindings/python/vm.cc
index 80beb5c..f8b758c 100644
--- a/runtime/bindings/python/vm.cc
+++ b/runtime/bindings/python/vm.cc
@@ -131,6 +131,24 @@
} // namespace
//------------------------------------------------------------------------------
+// VmBuffer
+//------------------------------------------------------------------------------
+
+int VmBuffer::HandleBufferProtocol(Py_buffer* view, int flags) {
+ view->buf = raw_ptr()->data.data;
+ view->len = raw_ptr()->data.data_length;
+ view->readonly = !(raw_ptr()->access & IREE_VM_BUFFER_ACCESS_MUTABLE);
+ view->itemsize = 1;
+ view->format = (char*)"B"; // Byte
+ view->ndim = 1;
+ view->shape = nullptr;
+ view->strides = nullptr;
+ view->suboffsets = nullptr;
+ view->internal = nullptr;
+ return 0;
+}
+
+//------------------------------------------------------------------------------
// VmInstance
//------------------------------------------------------------------------------
@@ -807,43 +825,7 @@
VmRef::BindRefProtocol(vm_buffer, iree_vm_buffer_type,
iree_vm_buffer_retain_ref, iree_vm_buffer_deref,
iree_vm_buffer_isa);
- // Implement the buffer protocol with low-level API.
- {
- static PyBufferProcs buffer_procs = {
- // It is not legal to raise exceptions from these callbacks.
- +[](PyObject* raw_self, Py_buffer* view, int flags) -> int {
- // Cast must succeed due to invariants.
- auto self = py::cast<VmBuffer*>(py::handle(raw_self));
- if (view == NULL) {
- PyErr_SetString(PyExc_ValueError, "NULL view in getbuffer");
- return -1;
- }
-
- Py_INCREF(raw_self);
- view->obj = raw_self;
- view->buf = self->raw_ptr()->data.data;
- view->len = self->raw_ptr()->data.data_length;
- view->readonly =
- !(self->raw_ptr()->access & IREE_VM_BUFFER_ACCESS_MUTABLE);
- view->itemsize = 1;
- view->format = (char*)"B"; // Byte
- view->ndim = 1;
- view->shape = nullptr;
- view->strides = nullptr;
- view->suboffsets = nullptr;
- view->internal = nullptr;
- return 0;
- },
- +[](PyObject* self_obj, Py_buffer* view) -> void {
-
- },
- };
- auto heap_type = reinterpret_cast<PyHeapTypeObject*>(vm_buffer.ptr());
- assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
- "must be heap type");
- heap_type->as_buffer = buffer_procs;
- }
-
+ BindBufferProtocol<VmBuffer>(vm_buffer);
vm_buffer
.def(
"__init__",
diff --git a/runtime/bindings/python/vm.h b/runtime/bindings/python/vm.h
index 80b8607..da83c69 100644
--- a/runtime/bindings/python/vm.h
+++ b/runtime/bindings/python/vm.h
@@ -76,7 +76,10 @@
// VmBuffer
//------------------------------------------------------------------------------
-class VmBuffer : public ApiRefCounted<VmBuffer, iree_vm_buffer_t> {};
+class VmBuffer : public ApiRefCounted<VmBuffer, iree_vm_buffer_t> {
+ public:
+ int HandleBufferProtocol(Py_buffer* view, int flags);
+};
//------------------------------------------------------------------------------
// VmVariantList