[python] Convert python io tests to unit tests. (#16984)
These were overly relying on obtuse integration tests. Now that the read
API is present, we can just test most things directly and have one
integration test to make sure runtime integration functions. That
integration test is copied from the simplified test in tools/test and
eliminates complicated IR.
diff --git a/runtime/bindings/python/tests/io_runtime_test.py b/runtime/bindings/python/tests/io_runtime_test.py
index 7fe1379..417b99d 100644
--- a/runtime/bindings/python/tests/io_runtime_test.py
+++ b/runtime/bindings/python/tests/io_runtime_test.py
@@ -15,51 +15,30 @@
import iree.runtime as rt
-MM_TEST_COMPILED = None
-MM_TEST_ASM = r"""
- #map = affine_map<(d0, d1) -> (d0, d1)>
- #map1 = affine_map<(d0, d1) -> (d1, d0)>
- #map2 = affine_map<(d0, d1) -> (d1)>
- module @main {
- util.global private @_params.classifier.weight {inlining_policy = #util.inline.never} = #stream.parameter.named<"params"::"weight"> : tensor<30x20xf32>
- util.global private @_params.classifier.bias {inlining_policy = #util.inline.never} = #stream.parameter.named<"params"::"bias"> : tensor<30xf32>
- func.func @run(%arg0: tensor<128x20xf32>) -> tensor<128x30xf32> {
- %0 = call @forward(%arg0) : (tensor<128x20xf32>) -> tensor<128x30xf32>
- return %0 : tensor<128x30xf32>
- }
- func.func private @forward(%arg0: tensor<128x20xf32>) -> tensor<128x30xf32> attributes {torch.assume_strict_symbolic_shapes} {
- %cst = arith.constant 0.000000e+00 : f32
- %_params.classifier.weight = util.global.load @_params.classifier.weight : tensor<30x20xf32>
- %0 = tensor.empty() : tensor<20x30xf32>
- %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%_params.classifier.weight : tensor<30x20xf32>) outs(%0 : tensor<20x30xf32>) {
- ^bb0(%in: f32, %out: f32):
- linalg.yield %in : f32
- } -> tensor<20x30xf32>
- %2 = tensor.empty() : tensor<128x30xf32>
- %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<128x30xf32>) -> tensor<128x30xf32>
- %4 = linalg.matmul ins(%arg0, %1 : tensor<128x20xf32>, tensor<20x30xf32>) outs(%3 : tensor<128x30xf32>) -> tensor<128x30xf32>
- %_params.classifier.bias = util.global.load @_params.classifier.bias : tensor<30xf32>
- %5 = linalg.generic {indexing_maps = [#map, #map2, #map], iterator_types = ["parallel", "parallel"]} ins(%4, %_params.classifier.bias : tensor<128x30xf32>, tensor<30xf32>) outs(%2 : tensor<128x30xf32>) {
- ^bb0(%in: f32, %in_0: f32, %out: f32):
- %6 = arith.addf %in, %in_0 : f32
- linalg.yield %6 : f32
- } -> tensor<128x30xf32>
- return %5 : tensor<128x30xf32>
- }
+TEST_COMPILED = None
+TEST_ASM = r"""
+util.global private @a0 = #stream.parameter.named<"a"::"a0"> : tensor<4xi64>
+util.global private @a1 = #stream.parameter.named<"a"::"a1"> : tensor<4xi64>
+util.global private @b0 = #stream.parameter.named<"b"::"b0"> : tensor<8xi64>
+util.global private @b1 = #stream.parameter.named<"b"::"b1"> : tensor<8xi64>
+func.func @echo() -> (tensor<4xi64>, tensor<4xi64>, tensor<8xi64>, tensor<8xi64>) {
+ %a0 = util.global.load @a0 : tensor<4xi64>
+ %a1 = util.global.load @a1 : tensor<4xi64>
+ %b0 = util.global.load @b0 : tensor<8xi64>
+ %b1 = util.global.load @b1 : tensor<8xi64>
+ return %a0, %a1, %b0, %b1 : tensor<4xi64>, tensor<4xi64>, tensor<8xi64>, tensor<8xi64>
}
"""
def compile_mm_test():
- global MM_TEST_COMPILED
- if not MM_TEST_COMPILED:
- MM_TEST_COMPILED = iree.compiler.compile_str(
- MM_TEST_ASM,
+ global TEST_COMPILED
+ if not TEST_COMPILED:
+ TEST_COMPILED = iree.compiler.compile_str(
+ TEST_ASM,
target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
- # TODO(#16098): re-enable const eval once parameters are supported.
- extra_args=["--iree-opt-const-eval=false"],
)
- return MM_TEST_COMPILED
+ return TEST_COMPILED
def create_mm_test_module(instance):
@@ -67,8 +46,11 @@
return rt.VmModule.copy_buffer(instance, binary)
-def _float_constant(val: float) -> array.array:
- return array.array("f", [val])
+def create_index_from_arrays(**kwargs) -> rt.ParameterIndex:
+ idx = rt.ParameterIndex()
+ for key, value in kwargs.items():
+ idx.add_buffer(key, value)
+ return idx
class ParameterTest(unittest.TestCase):
@@ -77,130 +59,29 @@
self.device = rt.get_device(iree.compiler.core.DEFAULT_TESTING_DRIVER)
self.config = rt.Config(device=self.device)
- def testParameterIndex(self):
- index = rt.ParameterIndex()
- self.assertEqual(len(index), 0)
- index.reserve(25)
- self.assertEqual(len(index), 0)
- provider = index.create_provider()
- rt.create_io_parameters_module(self.instance, provider)
-
- def testSplats(self):
- splat_index = rt.ParameterIndex()
- splat_index.add_splat("weight", _float_constant(2.0), 30 * 20 * 4)
- splat_index.add_splat("bias", _float_constant(1.0), 30 * 4)
+ def test_index_provider_module(self):
+ a0 = np.asarray([1] * 4, dtype=np.int64)
+ a1 = np.asarray([2] * 4, dtype=np.int64)
+ b0 = np.asarray([3] * 8, dtype=np.int64)
+ b1 = np.asarray([4] * 8, dtype=np.int64)
+ idx_a = create_index_from_arrays(a0=a0, a1=a1)
+ idx_b = create_index_from_arrays(b0=b0, b1=b1)
modules = rt.load_vm_modules(
rt.create_io_parameters_module(
- self.instance, splat_index.create_provider(scope="params")
+ self.instance,
+ idx_a.create_provider(scope="a"),
+ idx_b.create_provider(scope="b"),
),
rt.create_hal_module(self.instance, self.device),
create_mm_test_module(self.instance),
config=self.config,
)
- main = modules[-1]
- input = np.zeros([128, 20], dtype=np.float32) + 2.0
- result = main.run(input)
- print(result.to_host())
- # TODO: Fix splat in the parameter code so it is not all zeros.
- # expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
- # np.testing.assert_array_almost_equal(result, expected_result)
-
- def testSplatsFromBuiltIrpaFile(self):
- with tempfile.TemporaryDirectory() as td:
- file_path = Path(td) / "archive.irpa"
- rt.save_archive_file(
- {
- "weight": rt.SplatValue(np.float32(2.0), 30 * 20),
- "bias": rt.SplatValue(np.float32(1.0), 30),
- },
- file_path,
- )
-
- index = rt.ParameterIndex()
- index.load(str(file_path))
- modules = rt.load_vm_modules(
- rt.create_io_parameters_module(
- self.instance, index.create_provider(scope="params")
- ),
- rt.create_hal_module(self.instance, self.device),
- create_mm_test_module(self.instance),
- config=self.config,
- )
- main = modules[-1]
- input = np.zeros([128, 20], dtype=np.float32) + 2.0
- result = main.run(input)
- print(result.to_host())
- expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
- np.testing.assert_array_almost_equal(result, expected_result)
-
- def testBuffers(self):
- index = rt.ParameterIndex()
- weight = np.zeros([30, 20], dtype=np.float32) + 2.0
- bias = np.zeros([30], dtype=np.float32) + 1.0
- index.add_buffer("weight", weight)
- index.add_buffer("bias", bias)
- modules = rt.load_vm_modules(
- rt.create_io_parameters_module(
- self.instance, index.create_provider(scope="params")
- ),
- rt.create_hal_module(self.instance, self.device),
- create_mm_test_module(self.instance),
- config=self.config,
- )
- main = modules[-1]
- input = np.zeros([128, 20], dtype=np.float32) + 2.0
- result = main.run(input)
- print(result.to_host())
- expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
- np.testing.assert_array_almost_equal(result, expected_result)
-
- def testGguf(self):
- index = rt.ParameterIndex()
- index.load(
- str(
- Path(__file__).resolve().parent
- / "testdata"
- / "parameter_weight_bias_1.gguf"
- )
- )
- modules = rt.load_vm_modules(
- rt.create_io_parameters_module(
- self.instance, index.create_provider(scope="params")
- ),
- rt.create_hal_module(self.instance, self.device),
- create_mm_test_module(self.instance),
- config=self.config,
- )
- main = modules[-1]
- input = np.zeros([128, 20], dtype=np.float32) + 2.0
- result = main.run(input)
- print(result.to_host())
- expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
- np.testing.assert_array_almost_equal(result, expected_result)
-
- def testSafetensors(self):
- index = rt.ParameterIndex()
- index.load(
- str(
- Path(__file__).resolve().parent
- / "testdata"
- / "parameter_weight_bias_1.safetensors"
- )
- )
- modules = rt.load_vm_modules(
- rt.create_io_parameters_module(
- self.instance, index.create_provider(scope="params")
- ),
- rt.create_hal_module(self.instance, self.device),
- create_mm_test_module(self.instance),
- config=self.config,
- )
- main = modules[-1]
- input = np.zeros([128, 20], dtype=np.float32) + 2.0
- result = main.run(input)
- print(result.to_host())
- expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
- np.testing.assert_array_almost_equal(result, expected_result)
+ m = modules[-1]
+ a0_actual, a1_actual, b0_actual, b1_actual = m.echo()
+ np.testing.assert_array_equal(a0, a0_actual)
+ np.testing.assert_array_equal(a1, a1_actual)
+ np.testing.assert_array_equal(b0, b0_actual)
+ np.testing.assert_array_equal(b1, b1_actual)
if __name__ == "__main__":
diff --git a/runtime/bindings/python/tests/io_test.py b/runtime/bindings/python/tests/io_test.py
index e8ead9f..026c9b6 100644
--- a/runtime/bindings/python/tests/io_test.py
+++ b/runtime/bindings/python/tests/io_test.py
@@ -19,6 +19,10 @@
return array.array("f", [val])
+def index_entry_as_array(entry, like_array):
+ return np.asarray(entry.file_view).view(like_array.dtype).reshape(like_array.shape)
+
+
class ParameterApiTest(unittest.TestCase):
def testCreateArchiveFile(self):
splat_index = rt.ParameterIndex()
@@ -110,6 +114,40 @@
"weight", array.array("f", [1.0, 2.0, 3.0, 4.0, 5.0]), 30 * 20 * 4
)
+ def testGguf(self):
+ index = rt.ParameterIndex()
+ index.load(
+ str(
+ Path(__file__).resolve().parent
+ / "testdata"
+ / "parameter_weight_bias_1.gguf"
+ )
+ )
+ expected_weight = np.zeros([30, 20], dtype=np.float32) + 2.0
+ expected_bias = np.zeros([30], dtype=np.float32) + 1.0
+ entries = dict(index.items())
+ weight = index_entry_as_array(entries["weight"], expected_weight)
+ bias = index_entry_as_array(entries["bias"], expected_bias)
+ np.testing.assert_array_equal(weight, expected_weight)
+ np.testing.assert_array_equal(bias, expected_bias)
+
+ def testSafetensors(self):
+ index = rt.ParameterIndex()
+ index.load(
+ str(
+ Path(__file__).resolve().parent
+ / "testdata"
+ / "parameter_weight_bias_1.safetensors"
+ )
+ )
+ expected_weight = np.zeros([30, 20], dtype=np.float32) + 2.0
+ expected_bias = np.zeros([30], dtype=np.float32) + 1.0
+ entries = dict(index.items())
+ weight = index_entry_as_array(entries["weight"], expected_weight)
+ bias = index_entry_as_array(entries["bias"], expected_bias)
+ np.testing.assert_array_equal(weight, expected_weight)
+ np.testing.assert_array_equal(bias, expected_bias)
+
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)