blob: 026c9b665364e6d09e45784046d540955cc228e3 [file] [log] [blame]
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import array
import gc
import logging
import numpy as np
from pathlib import Path
import tempfile
import unittest
import iree.runtime as rt
def _float_constant(val: float) -> array.array:
return array.array("f", [val])
def index_entry_as_array(entry, like_array):
return np.asarray(entry.file_view).view(like_array.dtype).reshape(like_array.shape)
class ParameterApiTest(unittest.TestCase):
def testCreateArchiveFile(self):
splat_index = rt.ParameterIndex()
splat_index.add_splat("weight", _float_constant(2.0), 30 * 20 * 4)
splat_index.add_splat("bias", _float_constant(1.0), 30 * 4)
with tempfile.TemporaryDirectory() as td:
file_path = Path(td) / "archive.irpa"
target_index = splat_index.create_archive_file(str(file_path))
print(target_index)
self.assertTrue(file_path.exists())
self.assertGreater(file_path.stat().st_size, 0)
def testArchiveFileRoundtrip(self):
orig_array = np.asarray([[1], [2], [3]], dtype=np.int64)
def verify_archive(file_path: Path):
# Load and verify.
index = rt.ParameterIndex()
# For this test, disable mmap to make temp file management on
# windows a bit better.
index.load(str(file_path), mmap=False)
self.assertEqual(len(index), 3)
# Note that the happy path of most properties are verified via
# the repr (as they are called internal to that).
entries = dict(index.items())
self.assertEqual(
repr(entries["weight"]),
"<ParameterIndexEntry 'weight' splat b'\\x02':600>",
)
self.assertEqual(
repr(entries["bias"]),
"<ParameterIndexEntry 'bias' splat b' ':30>",
)
self.assertRegex(
repr(entries["array"]),
r"<ParameterIndexEntry 'array' FileHandle<host_allocation\(.*\)>:384:24",
)
# Verify some non-happy paths.
with self.assertRaisesRegex(ValueError, "Entry is not file storage based"):
entries["weight"].file_storage
with self.assertRaisesRegex(ValueError, "Entry is not splat"):
entries["array"].splat_pattern
# Verify that the repr of the index itself is sensical.
index_repr = repr(index)
self.assertIn("Parameter scope <global> (3 entries", index_repr)
# Get the array contents and verify against original.
array_view = entries["array"].file_view
self.assertEqual(len(array_view), 24)
array_back = np.asarray(array_view).view(np.int64).reshape(orig_array.shape)
np.testing.assert_array_equal(array_back, orig_array)
with tempfile.TemporaryDirectory() as td:
file_path = Path(td) / "archive.irpa"
rt.save_archive_file(
{
"weight": rt.SplatValue(np.int8(2), [30, 20]),
"bias": rt.SplatValue(array.array("b", [32]), 30),
"array": orig_array,
},
file_path,
)
self.assertTrue(file_path.exists())
self.assertGreater(file_path.stat().st_size, 0)
# Open / verify in its own scope and collect prior to tearing
# down the temp dir.
verify_archive(file_path)
gc.collect()
def testFileHandleWrap(self):
fh = rt.FileHandle.wrap_memory(b"foobar")
view = fh.host_allocation
del fh
self.assertEqual(bytes(view), b"foobar")
def testParameterIndexAddFromFile(self):
splat_index = rt.ParameterIndex()
fh = rt.FileHandle.wrap_memory(b"foobar")
splat_index.add_from_file_handle("data", fh, length=3, offset=3)
def testSplatTooBig(self):
splat_index = rt.ParameterIndex()
with self.assertRaises(ValueError):
splat_index.add_splat(
"weight", array.array("f", [1.0, 2.0, 3.0, 4.0, 5.0]), 30 * 20 * 4
)
def testGguf(self):
index = rt.ParameterIndex()
index.load(
str(
Path(__file__).resolve().parent
/ "testdata"
/ "parameter_weight_bias_1.gguf"
)
)
expected_weight = np.zeros([30, 20], dtype=np.float32) + 2.0
expected_bias = np.zeros([30], dtype=np.float32) + 1.0
entries = dict(index.items())
weight = index_entry_as_array(entries["weight"], expected_weight)
bias = index_entry_as_array(entries["bias"], expected_bias)
np.testing.assert_array_equal(weight, expected_weight)
np.testing.assert_array_equal(bias, expected_bias)
def testSafetensors(self):
index = rt.ParameterIndex()
index.load(
str(
Path(__file__).resolve().parent
/ "testdata"
/ "parameter_weight_bias_1.safetensors"
)
)
expected_weight = np.zeros([30, 20], dtype=np.float32) + 2.0
expected_bias = np.zeros([30], dtype=np.float32) + 1.0
entries = dict(index.items())
weight = index_entry_as_array(entries["weight"], expected_weight)
bias = index_entry_as_array(entries["bias"], expected_bias)
np.testing.assert_array_equal(weight, expected_weight)
np.testing.assert_array_equal(bias, expected_bias)
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()