blob: 0a7840437f899929d59aff08a29003239ae4bdcc [file] [log] [blame]
# Copyright 2019 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# pylint: disable=unused-variable
import gc
import logging
import os
import re
import tempfile
import unittest
import iree.compiler
import iree.runtime
import numpy as np
_SIMPLE_MUL_BINARY = None
def compile_simple_mul_binary():
global _SIMPLE_MUL_BINARY
if not _SIMPLE_MUL_BINARY:
_SIMPLE_MUL_BINARY = iree.compiler.compile_str(
"""
module @arithmetic {
func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%0 = arith.mulf %arg0, %arg1 : tensor<4xf32>
return %0 : tensor<4xf32>
}
}
""",
target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
)
return _SIMPLE_MUL_BINARY
def create_simple_mul_module(instance):
m = iree.runtime.VmModule.from_flatbuffer(instance, compile_simple_mul_binary())
return m
class SystemApiTest(unittest.TestCase):
def test_non_existing_driver(self):
with self.assertRaisesRegex(ValueError, "No device found from list"):
config = iree.runtime.Config("nothere1,nothere2")
def test_subsequent_driver(self):
config = iree.runtime.Config("nothere1,local-task")
def test_multi_config_caches(self):
config1 = iree.runtime.Config("nothere1,local-sync")
config2 = iree.runtime.Config("nothere1,local-sync")
self.assertIs(config1.device, config2.device)
def test_empty_dynamic(self):
ctx = iree.runtime.SystemContext()
self.assertTrue(ctx.is_dynamic)
self.assertIn("hal", ctx.modules)
self.assertEqual(ctx.modules.hal.name, "hal")
def test_empty_static(self):
ctx = iree.runtime.SystemContext(vm_modules=())
self.assertFalse(ctx.is_dynamic)
self.assertIn("hal", ctx.modules)
self.assertEqual(ctx.modules.hal.name, "hal")
def test_custom_dynamic(self):
ctx = iree.runtime.SystemContext()
self.assertTrue(ctx.is_dynamic)
ctx.add_vm_module(create_simple_mul_module(ctx.instance))
self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
f = ctx.modules.arithmetic["simple_mul"]
f_repr = repr(f)
logging.info("f_repr: %s", f_repr)
self.assertIn("simple_mul(0rr_r)", f_repr)
def test_duplicate_module(self):
ctx = iree.runtime.SystemContext()
self.assertTrue(ctx.is_dynamic)
ctx.add_vm_module(create_simple_mul_module(ctx.instance))
with self.assertRaisesRegex(ValueError, "arithmetic"):
ctx.add_vm_module(create_simple_mul_module(ctx.instance))
def test_static_invoke(self):
ctx = iree.runtime.SystemContext()
self.assertTrue(ctx.is_dynamic)
ctx.add_vm_module(create_simple_mul_module(ctx.instance))
self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
f = ctx.modules.arithmetic["simple_mul"]
arg0 = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
arg1 = np.array([4.0, 5.0, 6.0, 7.0], dtype=np.float32)
results = f(arg0, arg1)
np.testing.assert_allclose(results, [4.0, 10.0, 18.0, 28.0])
def test_chained_invoke(self):
# This ensures that everything works if DeviceArrays are returned
# and input to functions.
ctx = iree.runtime.SystemContext()
self.assertTrue(ctx.is_dynamic)
ctx.add_vm_module(create_simple_mul_module(ctx.instance))
self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
f = ctx.modules.arithmetic["simple_mul"]
arg0 = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
arg1 = np.array([4.0, 5.0, 6.0, 7.0], dtype=np.float32)
results = f(arg0, arg1)
results2 = f(results, results)
np.testing.assert_allclose(results2, [16.0, 100.0, 324.0, 784.0])
def test_load_vm_module(self):
ctx = iree.runtime.SystemContext()
arithmetic = iree.runtime.load_vm_module(create_simple_mul_module(ctx.instance))
arg0 = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
arg1 = np.array([4.0, 5.0, 6.0, 7.0], dtype=np.float32)
results = arithmetic.simple_mul(arg0, arg1)
print("SIMPLE_MUL RESULTS:", results)
np.testing.assert_allclose(results, [4.0, 10.0, 18.0, 28.0])
def test_load_multiple_modules(self):
# Doing default device configuration multiple times should be valid
# (if this were instantiating drivers multiple times, it can trigger
# a crash, depending on whether the driver supports multi-instantiation).
ctx = iree.runtime.SystemContext()
m = create_simple_mul_module(ctx.instance)
m1 = iree.runtime.load_vm_module(m)
m2 = iree.runtime.load_vm_module(m)
def test_load_vm_flatbuffer(self):
# This API is old and not highly recommended but testing as-is.
m = iree.runtime.load_vm_flatbuffer(
compile_simple_mul_binary(), driver="local-sync"
)
m = iree.runtime.load_vm_flatbuffer(
compile_simple_mul_binary(), backend="llvm-cpu"
)
def test_load_vm_flatbuffer_file(self):
with tempfile.NamedTemporaryFile(delete=False) as f:
f.write(compile_simple_mul_binary())
def _cleanup():
os.unlink(f.name)
m = iree.runtime.load_vm_flatbuffer_file(
f.name, driver="local-sync", destroy_callback=_cleanup
)
del m
gc.collect()
self.assertFalse(os.path.exists(f.name))
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()