blob: f412c7b77c3912943395d95714522cc05cc4adb5 [file] [log] [blame]
# Copyright 2019 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# pylint: disable=unused-variable
import gc
import logging
import numpy as np
import os
import sys
import tempfile
import traceback
import unittest
import iree.compiler
import iree.runtime
COMPILED_ADD_SCALAR = None
def compile_add_scalar():
global COMPILED_ADD_SCALAR
if not COMPILED_ADD_SCALAR:
COMPILED_ADD_SCALAR = iree.compiler.compile_str(
"""
func.func @add_scalar(%arg0: i32, %arg1: i32) -> i32 {
%0 = arith.addi %arg0, %arg1 : i32
return %0 : i32
}
""",
target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
)
return COMPILED_ADD_SCALAR
def create_add_scalar_module(instance):
binary = compile_add_scalar()
m = iree.runtime.VmModule.from_flatbuffer(instance, binary)
return m
def create_simple_static_mul_module(instance):
binary = iree.compiler.compile_str(
"""
func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%0 = arith.mulf %arg0, %arg1 : tensor<4xf32>
return %0 : tensor<4xf32>
}
""",
target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
)
m = iree.runtime.VmModule.from_flatbuffer(instance, binary)
return m
def create_simple_dynamic_abs_module(instance):
binary = iree.compiler.compile_str(
"""
func.func @dynamic_abs(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = math.absf %arg0 : tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
""",
target_backends=iree.compiler.DEFAULT_TESTING_BACKENDS,
)
m = iree.runtime.VmModule.from_flatbuffer(instance, binary)
return m
class VmTest(unittest.TestCase):
@classmethod
def setUp(self):
self.instance = iree.runtime.VmInstance()
self.device = iree.runtime.get_device(iree.compiler.core.DEFAULT_TESTING_DRIVER)
self.hal_module = iree.runtime.create_hal_module(self.instance, self.device)
def test_context_id(self):
context1 = iree.runtime.VmContext(self.instance)
context2 = iree.runtime.VmContext(self.instance)
self.assertNotEqual(context2.context_id, context1.context_id)
def test_module_basics(self):
m = create_simple_static_mul_module(self.instance)
f = m.lookup_function("simple_mul")
self.assertGreaterEqual(f.ordinal, 0)
notfound = m.lookup_function("notfound")
self.assertIs(notfound, None)
def test_dynamic_module_context(self):
context = iree.runtime.VmContext(self.instance)
m = create_simple_static_mul_module(self.instance)
context.register_modules([self.hal_module, m])
def test_static_module_context(self):
m = create_simple_static_mul_module(self.instance)
logging.info("module: %s", m)
context = iree.runtime.VmContext(self.instance, modules=[self.hal_module, m])
logging.info("context: %s", context)
def test_dynamic_shape_compile(self):
m = create_simple_dynamic_abs_module(self.instance)
logging.info("module: %s", m)
context = iree.runtime.VmContext(self.instance, modules=[self.hal_module, m])
logging.info("context: %s", context)
def test_add_scalar_new_abi(self):
m = create_add_scalar_module(self.instance)
context = iree.runtime.VmContext(self.instance, modules=[self.hal_module, m])
f = m.lookup_function("add_scalar")
finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
result = finv(5, 6)
logging.info("result: %s", result)
self.assertEqual(result, 11)
def test_unaligned_buffer_error(self):
buffer = memoryview(b"foobar")
with self.assertRaisesRegex(ValueError, "unaligned buffer"):
# One byte into a heap buffer will never satisfy alignment
# constraints.
iree.runtime.VmModule.wrap_buffer(self.instance, buffer[1:])
def test_from_buffer_unaligned_warns(self):
binary = compile_add_scalar()
# One byte into a heap buffer will never satisfy alignment
# constraints.
unaligned_binary = memoryview(b"1" + binary)[1:]
with self.assertWarnsRegex(
UserWarning, "Making copy of unaligned VmModule buffer"
):
iree.runtime.VmModule.from_buffer(self.instance, unaligned_binary)
def test_mmap_implicit_unmap(self):
binary = compile_add_scalar()
with tempfile.NamedTemporaryFile(delete=False) as tf:
tf.write(binary)
tf.flush()
vmfb_file_path = tf.name
# Note that on Windows, an open file cannot be mapped.
try:
m = iree.runtime.VmModule.mmap(self.instance, vmfb_file_path)
context = iree.runtime.VmContext(
self.instance, modules=[self.hal_module, m]
)
f = m.lookup_function("add_scalar")
finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
result = finv(5, 6)
logging.info("result: %s", result)
self.assertEqual(result, 11)
del finv
del f
del context
del m
gc.collect()
finally:
# On Windows, a mapped file cannot be deleted and this will fail if
# the mapping was not cleaned up properly.
os.unlink(vmfb_file_path)
def test_mmap_destroy_callback(self):
binary = compile_add_scalar()
with tempfile.NamedTemporaryFile(delete=False) as tf:
tf.write(binary)
tf.flush()
vmfb_file_path = tf.name
destroyed = [False]
def on_destroy():
print("on_destroy callback")
try:
os.unlink(vmfb_file_path)
except:
print("exception while unlinking mapped file")
traceback.print_exc(file=sys.stdout)
raise
destroyed[0] = True
# Note that on Windows, an open file cannot be mapped.
m = iree.runtime.VmModule.mmap(
self.instance, vmfb_file_path, destroy_callback=on_destroy
)
context = iree.runtime.VmContext(self.instance, modules=[self.hal_module, m])
f = m.lookup_function("add_scalar")
finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
result = finv(5, 6)
logging.info("result: %s", result)
self.assertEqual(result, 11)
del finv
del f
del context
del m
gc.collect()
self.assertTrue(destroyed[0])
def test_synchronous_dynamic_shape_invoke_function_new_abi(self):
m = create_simple_dynamic_abs_module(self.instance)
context = iree.runtime.VmContext(self.instance, modules=[self.hal_module, m])
f = m.lookup_function("dynamic_abs")
finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
arg0 = np.array([[-1.0, 2.0], [3.0, -4.0]], dtype=np.float32)
result = finv(arg0)
logging.info("result: %s", result)
np.testing.assert_allclose(result, [[1.0, 2.0], [3.0, 4.0]])
def test_synchronous_invoke_function_new_abi(self):
m = create_simple_static_mul_module(self.instance)
context = iree.runtime.VmContext(self.instance, modules=[self.hal_module, m])
f = m.lookup_function("simple_mul")
finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
arg0 = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
arg1 = np.array([4.0, 5.0, 6.0, 7.0], dtype=np.float32)
result = finv(arg0, arg1)
logging.info("result: %s", result)
np.testing.assert_allclose(result, [4.0, 10.0, 18.0, 28.0])
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()