| # Copyright 2019 The IREE Authors |
| # |
| # Licensed under the Apache License v2.0 with LLVM Exceptions. |
| # See https://llvm.org/LICENSE.txt for license information. |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| # pylint: disable=unused-variable |
| |
| import gc |
| import logging |
| import os |
| import re |
| import tempfile |
| import unittest |
| |
| import iree.compiler |
| import iree.runtime |
| import numpy as np |
| |
| |
| _SIMPLE_MUL_BINARY = None |
| |
| |
| def compile_simple_mul_binary(): |
| global _SIMPLE_MUL_BINARY |
| if not _SIMPLE_MUL_BINARY: |
| _SIMPLE_MUL_BINARY = iree.compiler.compile_str( |
| """ |
| module @arithmetic { |
| func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { |
| %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> |
| return %0 : tensor<4xf32> |
| } |
| } |
| """, |
| target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS, |
| ) |
| return _SIMPLE_MUL_BINARY |
| |
| |
| def create_simple_mul_module(instance): |
| m = iree.runtime.VmModule.from_flatbuffer(instance, compile_simple_mul_binary()) |
| return m |
| |
| |
| class SystemApiTest(unittest.TestCase): |
| def test_non_existing_driver(self): |
| with self.assertRaisesRegex(ValueError, "No device found from list"): |
| config = iree.runtime.Config("nothere1,nothere2") |
| |
| def test_subsequent_driver(self): |
| config = iree.runtime.Config("nothere1,local-task") |
| |
| def test_multi_config_caches(self): |
| config1 = iree.runtime.Config("nothere1,local-sync") |
| config2 = iree.runtime.Config("nothere1,local-sync") |
| self.assertIs(config1.device, config2.device) |
| |
| def test_empty_dynamic(self): |
| ctx = iree.runtime.SystemContext() |
| self.assertTrue(ctx.is_dynamic) |
| self.assertIn("hal", ctx.modules) |
| self.assertEqual(ctx.modules.hal.name, "hal") |
| |
| def test_empty_static(self): |
| ctx = iree.runtime.SystemContext(vm_modules=()) |
| self.assertFalse(ctx.is_dynamic) |
| self.assertIn("hal", ctx.modules) |
| self.assertEqual(ctx.modules.hal.name, "hal") |
| |
| def test_custom_dynamic(self): |
| ctx = iree.runtime.SystemContext() |
| self.assertTrue(ctx.is_dynamic) |
| ctx.add_vm_module(create_simple_mul_module(ctx.instance)) |
| self.assertEqual(ctx.modules.arithmetic.name, "arithmetic") |
| f = ctx.modules.arithmetic["simple_mul"] |
| f_repr = repr(f) |
| logging.info("f_repr: %s", f_repr) |
| self.assertIn("simple_mul(0rr_r)", f_repr) |
| |
| def test_duplicate_module(self): |
| ctx = iree.runtime.SystemContext() |
| self.assertTrue(ctx.is_dynamic) |
| ctx.add_vm_module(create_simple_mul_module(ctx.instance)) |
| with self.assertRaisesRegex(ValueError, "arithmetic"): |
| ctx.add_vm_module(create_simple_mul_module(ctx.instance)) |
| |
| def test_static_invoke(self): |
| ctx = iree.runtime.SystemContext() |
| self.assertTrue(ctx.is_dynamic) |
| ctx.add_vm_module(create_simple_mul_module(ctx.instance)) |
| self.assertEqual(ctx.modules.arithmetic.name, "arithmetic") |
| f = ctx.modules.arithmetic["simple_mul"] |
| arg0 = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) |
| arg1 = np.array([4.0, 5.0, 6.0, 7.0], dtype=np.float32) |
| results = f(arg0, arg1) |
| np.testing.assert_allclose(results, [4.0, 10.0, 18.0, 28.0]) |
| |
| def test_chained_invoke(self): |
| # This ensures that everything works if DeviceArrays are returned |
| # and input to functions. |
| ctx = iree.runtime.SystemContext() |
| self.assertTrue(ctx.is_dynamic) |
| ctx.add_vm_module(create_simple_mul_module(ctx.instance)) |
| self.assertEqual(ctx.modules.arithmetic.name, "arithmetic") |
| f = ctx.modules.arithmetic["simple_mul"] |
| arg0 = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) |
| arg1 = np.array([4.0, 5.0, 6.0, 7.0], dtype=np.float32) |
| results = f(arg0, arg1) |
| results2 = f(results, results) |
| np.testing.assert_allclose(results2, [16.0, 100.0, 324.0, 784.0]) |
| |
| def test_load_vm_module(self): |
| ctx = iree.runtime.SystemContext() |
| arithmetic = iree.runtime.load_vm_module(create_simple_mul_module(ctx.instance)) |
| arg0 = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) |
| arg1 = np.array([4.0, 5.0, 6.0, 7.0], dtype=np.float32) |
| results = arithmetic.simple_mul(arg0, arg1) |
| print("SIMPLE_MUL RESULTS:", results) |
| np.testing.assert_allclose(results, [4.0, 10.0, 18.0, 28.0]) |
| |
| def test_load_multiple_modules(self): |
| # Doing default device configuration multiple times should be valid |
| # (if this were instantiating drivers multiple times, it can trigger |
| # a crash, depending on whether the driver supports multi-instantiation). |
| ctx = iree.runtime.SystemContext() |
| m = create_simple_mul_module(ctx.instance) |
| m1 = iree.runtime.load_vm_module(m) |
| m2 = iree.runtime.load_vm_module(m) |
| |
| def test_load_vm_flatbuffer(self): |
| # This API is old and not highly recommended but testing as-is. |
| m = iree.runtime.load_vm_flatbuffer( |
| compile_simple_mul_binary(), driver="local-sync" |
| ) |
| m = iree.runtime.load_vm_flatbuffer( |
| compile_simple_mul_binary(), backend="llvm-cpu" |
| ) |
| |
| def test_load_vm_flatbuffer_file(self): |
| with tempfile.NamedTemporaryFile(delete=False) as f: |
| f.write(compile_simple_mul_binary()) |
| |
| def _cleanup(): |
| os.unlink(f.name) |
| |
| m = iree.runtime.load_vm_flatbuffer_file( |
| f.name, driver="local-sync", destroy_callback=_cleanup |
| ) |
| del m |
| gc.collect() |
| self.assertFalse(os.path.exists(f.name)) |
| |
| |
| if __name__ == "__main__": |
| logging.basicConfig(level=logging.DEBUG) |
| unittest.main() |