blob: 95625e906b9a54ec07b91a9be2a847f11e076142 [file] [log] [blame]
# Copyright 2019 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import json
import numpy as np
import unittest
from iree import runtime as rt
from iree.runtime.function import (
FunctionInvoker,
IMPLICIT_BUFFER_ARG_MEMORY_TYPE,
IMPLICIT_BUFFER_ARG_USAGE,
)
from iree.runtime._binding import VmVariantList
class MockVmContext:
def __init__(self, invoke_callback):
self._invoke_callback = invoke_callback
self.invocations = []
def invoke(self, vm_function, arg_list, ret_list):
self._invoke_callback(arg_list, ret_list)
self.invocations.append((vm_function, arg_list, ret_list))
print(f"INVOKE: {arg_list} -> {ret_list}")
@property
def mock_arg_reprs(self):
return repr([arg_list for _, arg_list, _ in self.invocations])
class MockVmFunction:
def __init__(self, reflection):
self.reflection = reflection
class FunctionTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Doesn't matter what device. We just need one.
config = rt.Config("local-task")
cls.device = config.device
def testNoReflectionScalars(self):
def invoke(arg_list, ret_list):
ret_list.push_int(3)
ret_list.push_int(4)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(reflection={})
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
result = invoker(1, 2)
self.assertEqual("[<VmVariantList(2): [1, 2]>]", vm_context.mock_arg_reprs)
self.assertEqual((3, 4), result)
def testKeywordArgs(self):
def invoke(arg_list, ret_list):
ret_list.push_int(3)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(
reflection={
"iree.abi": json.dumps(
{
"a": [
"i32",
["named", "a", "i32"],
["named", "b", "i32"],
],
"r": [
"i32",
],
}
)
}
)
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
result = invoker(-1, a=1, b=2)
self.assertEqual("[<VmVariantList(3): [-1, 1, 2]>]", vm_context.mock_arg_reprs)
self.assertEqual(3, result)
def testListArg(self):
def invoke(arg_list, ret_list):
ret_list.push_int(3)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(
reflection={
"iree.abi": json.dumps(
{
"a": [
["slist", "i32", "i32"],
],
"r": [
"i32",
],
}
)
}
)
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
_ = invoker([2, 3])
self.assertEqual(
"[<VmVariantList(1): [List[2, 3]]>]", vm_context.mock_arg_reprs
)
def testListArgNoReflection(self):
def invoke(arg_list, ret_list):
ret_list.push_int(3)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(reflection={})
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
_ = invoker([2, 3])
self.assertEqual(
"[<VmVariantList(1): [List[2, 3]]>]", vm_context.mock_arg_reprs
)
def testListArgArityMismatch(self):
def invoke(arg_list, ret_list):
ret_list.push_int(3)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(
reflection={
"iree.abi": json.dumps(
{
"a": [
["slist", "i32", "i32"],
],
"r": [
"i32",
],
}
)
}
)
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
with self.assertRaisesRegex(
ValueError, "expected a sequence with 2 values. got:"
):
_ = invoker([2, 3, 4])
def testTupleArg(self):
def invoke(arg_list, ret_list):
ret_list.push_int(3)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(
reflection={
"iree.abi": json.dumps(
{
"a": [
["stuple", "i32", "i32"],
],
"r": [
"i32",
],
}
)
}
)
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
_ = invoker((2, 3))
self.assertEqual(
"[<VmVariantList(1): [List[2, 3]]>]", vm_context.mock_arg_reprs
)
def testDictArg(self):
def invoke(arg_list, ret_list):
ret_list.push_int(3)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(
reflection={
"iree.abi": json.dumps(
{
"a": [
["sdict", ["a", "i32"], ["b", "i32"]],
],
"r": [
"i32",
],
}
)
}
)
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
_ = invoker({"b": 3, "a": 2})
self.assertEqual(
"[<VmVariantList(1): [List[2, 3]]>]", vm_context.mock_arg_reprs
)
def testDictArgArityMismatch(self):
def invoke(arg_list, ret_list):
ret_list.push_int(3)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(
reflection={
"iree.abi": json.dumps(
{
"a": [
["sdict", ["a", "i32"], ["b", "i32"]],
],
"r": [
"i32",
],
}
)
}
)
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
with self.assertRaisesRegex(ValueError, "expected a dict with 2 values. got:"):
_ = invoker({"a": 2, "b": 3, "c": 4})
def testDictArgKeyError(self):
def invoke(arg_list, ret_list):
ret_list.push_int(3)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(
reflection={
"iree.abi": json.dumps(
{
"a": [
["sdict", ["a", "i32"], ["b", "i32"]],
],
"r": [
"i32",
],
}
)
}
)
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
with self.assertRaisesRegex(ValueError, "could not get item 'b' from: "):
_ = invoker({"a": 2, "c": 3})
def testDictArgNoReflection(self):
def invoke(arg_list, ret_list):
ret_list.push_int(3)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(reflection={})
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
_ = invoker({"b": 3, "a": 2})
self.assertEqual(
"[<VmVariantList(1): [List[2, 3]]>]", vm_context.mock_arg_reprs
)
def testInlinedResults(self):
def invoke(arg_list, ret_list):
ret_list.push_int(3)
ret_list.push_int(4)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(
reflection={
"iree.abi": json.dumps(
{
"a": [],
"r": [["slist", "i32", "i32"]],
}
)
}
)
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
result = invoker()
self.assertEqual([3, 4], result)
def testNestedResults(self):
def invoke(arg_list, ret_list):
ret_list.push_int(3)
sub_list = VmVariantList(2)
sub_dict = VmVariantList(2)
sub_dict.push_int(100)
sub_dict.push_int(200)
sub_list.push_list(sub_dict)
sub_list.push_int(6)
ret_list.push_list(sub_list)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(
reflection={
"iree.abi": json.dumps(
{
"a": [],
"r": [
"i32",
[
"slist",
["sdict", ["bar", "i32"], ["foo", "i32"]],
"i64",
],
],
}
)
}
)
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
result = invoker()
self.assertEqual((3, [{"bar": 100, "foo": 200}, 6]), result)
def testMissingPositional(self):
def invoke(arg_list, ret_list):
ret_list.push_int(3)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(
reflection={
"iree.abi": json.dumps(
{
"a": [
"i32",
["named", "a", "i32"],
["named", "b", "i32"],
],
"r": [
"i32",
],
}
)
}
)
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
with self.assertRaisesRegex(ValueError, "mismatched call arity:"):
result = invoker(a=1, b=1)
def testMissingPositionalNdarray(self):
def invoke(arg_list, ret_list):
ret_list.push_int(3)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(
reflection={
"iree.abi": json.dumps(
{
"a": [
["ndarray", "i32", 1, 1],
["named", "a", ["ndarray", "i32", 1, 1]],
["named", "b", ["ndarray", "i32", 1, 1]],
],
"r": [
"i32",
],
}
)
}
)
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
with self.assertRaisesRegex(ValueError, "mismatched call arity:"):
result = invoker(a=1, b=1)
def testMissingKeyword(self):
def invoke(arg_list, ret_list):
ret_list.push_int(3)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(
reflection={
"iree.abi": json.dumps(
{
"a": [
"i32",
["named", "a", "i32"],
["named", "b", "i32"],
],
"r": [
"i32",
],
}
)
}
)
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
with self.assertRaisesRegex(ValueError, "mismatched call arity:"):
result = invoker(-1, a=1)
def testMissingKeywordNdArray(self):
def invoke(arg_list, ret_list):
ret_list.push_int(3)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(
reflection={
"iree.abi": json.dumps(
{
"a": [
["ndarray", "i32", 1, 1],
["named", "a", ["ndarray", "i32", 1, 1]],
["named", "b", ["ndarray", "i32", 1, 1]],
],
"r": [
"i32",
],
}
)
}
)
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
with self.assertRaisesRegex(ValueError, "mismatched call arity:"):
result = invoker(-1, a=1)
def testExtraKeyword(self):
def invoke(arg_list, ret_list):
ret_list.push_int(3)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(
reflection={
"iree.abi": json.dumps(
{
"a": [
"i32",
["named", "a", "i32"],
["named", "b", "i32"],
],
"r": [
"i32",
],
}
)
}
)
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
with self.assertRaisesRegex(ValueError, "specified kwarg 'c' is unknown"):
result = invoker(-1, a=1, b=2, c=3)
def testNdarrayArg(self):
arg_array = np.asarray([1, 0], dtype=np.int32)
invoked_arg_list = None
def invoke(arg_list, ret_list):
nonlocal invoked_arg_list
invoked_arg_list = arg_list
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(
reflection={
"iree.abi": json.dumps(
{
"a": [["ndarray", "i32", 1, 2]],
"r": [],
}
)
}
)
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
result = invoker(arg_array)
self.assertEqual(
"<VmVariantList(1): [HalBufferView(2:0x20000011)]>", repr(invoked_arg_list)
)
def testDeviceArrayArg(self):
# Note that since the device array is set up to disallow implicit host
# transfers, this also verifies that no accidental/automatic transfers
# are done as part of marshalling the array to the function.
arg_array = rt.asdevicearray(
self.device,
np.asarray([1, 0], dtype=np.int32),
implicit_host_transfer=False,
)
invoked_arg_list = None
def invoke(arg_list, ret_list):
nonlocal invoked_arg_list
invoked_arg_list = arg_list
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(
reflection={
"iree.abi": json.dumps(
{
"a": [["ndarray", "i32", 1, 2]],
"r": [],
}
)
}
)
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
result = invoker(arg_array)
self.assertEqual(
"<VmVariantList(1): [HalBufferView(2:0x20000011)]>", repr(invoked_arg_list)
)
def testBufferViewArg(self):
arg_buffer_view = self.device.allocator.allocate_buffer_copy(
memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE,
allowed_usage=IMPLICIT_BUFFER_ARG_USAGE,
buffer=np.asarray([1, 0], dtype=np.int32),
element_type=rt.HalElementType.SINT_32,
)
invoked_arg_list = None
def invoke(arg_list, ret_list):
nonlocal invoked_arg_list
invoked_arg_list = arg_list
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(
reflection={
"iree.abi": json.dumps(
{
"a": [["ndarray", "i32", 1, 2]],
"r": [],
}
)
}
)
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
_ = invoker(arg_buffer_view)
self.assertEqual(
"<VmVariantList(1): [HalBufferView(2:0x20000011)]>", repr(invoked_arg_list)
)
def testNdarrayArgNoReflection(self):
arg_array = np.asarray([1, 0], dtype=np.int32)
invoked_arg_list = None
def invoke(arg_list, ret_list):
nonlocal invoked_arg_list
invoked_arg_list = arg_list
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(reflection={})
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
result = invoker(arg_array)
self.assertEqual(
"<VmVariantList(1): [HalBufferView(2:0x20000011)]>", repr(invoked_arg_list)
)
def testDeviceArrayArgNoReflection(self):
# Note that since the device array is set up to disallow implicit host
# transfers, this also verifies that no accidental/automatic transfers
# are done as part of marshalling the array to the function.
arg_array = rt.asdevicearray(
self.device,
np.asarray([1, 0], dtype=np.int32),
implicit_host_transfer=False,
)
invoked_arg_list = None
def invoke(arg_list, ret_list):
nonlocal invoked_arg_list
invoked_arg_list = arg_list
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(reflection={})
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
result = invoker(arg_array)
self.assertEqual(
"<VmVariantList(1): [HalBufferView(2:0x20000011)]>", repr(invoked_arg_list)
)
def testBufferViewArgNoReflection(self):
arg_buffer_view = self.device.allocator.allocate_buffer_copy(
memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE,
allowed_usage=IMPLICIT_BUFFER_ARG_USAGE,
buffer=np.asarray([1, 0], dtype=np.int32),
element_type=rt.HalElementType.SINT_32,
)
invoked_arg_list = None
def invoke(arg_list, ret_list):
nonlocal invoked_arg_list
invoked_arg_list = arg_list
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(reflection={})
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
_ = invoker(arg_buffer_view)
self.assertEqual(
"<VmVariantList(1): [HalBufferView(2:0x20000011)]>", repr(invoked_arg_list)
)
def testReturnBufferView(self):
result_array = np.asarray([1, 0], dtype=np.int32)
def invoke(arg_list, ret_list):
buffer_view = self.device.allocator.allocate_buffer_copy(
memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE,
allowed_usage=IMPLICIT_BUFFER_ARG_USAGE,
buffer=result_array,
element_type=rt.HalElementType.SINT_32,
)
ret_list.push_ref(buffer_view)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(
reflection={
"iree.abi": json.dumps(
{
"a": [],
"r": [["ndarray", "i32", 1, 2]],
}
)
}
)
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
result = invoker()
np.testing.assert_array_equal([1, 0], result)
def testReturnBufferViewNoReflection(self):
result_array = np.asarray([1, 0], dtype=np.int32)
def invoke(arg_list, ret_list):
buffer_view = self.device.allocator.allocate_buffer_copy(
memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE,
allowed_usage=IMPLICIT_BUFFER_ARG_USAGE,
buffer=result_array,
element_type=rt.HalElementType.SINT_32,
)
ret_list.push_ref(buffer_view)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(reflection={})
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
result = invoker()
np.testing.assert_array_equal([1, 0], result)
# TODO: Fill out all return types.
def testReturnTypeNdArrayBool(self):
result_array = np.asarray([1, 0], dtype=np.int8)
def invoke(arg_list, ret_list):
buffer_view = self.device.allocator.allocate_buffer_copy(
memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE,
allowed_usage=IMPLICIT_BUFFER_ARG_USAGE,
buffer=result_array,
element_type=rt.HalElementType.UINT_8,
)
ret_list.push_ref(buffer_view)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(
reflection={
"iree.abi": json.dumps(
{
"a": [],
"r": [["ndarray", "i1", 1, 2]],
}
)
}
)
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
result = invoker()
# assertEqual on bool arrays is fraught for... reasons.
np.testing.assert_array_equal([True, False], result)
def testReturnTypeList(self):
vm_list = VmVariantList(2)
vm_list.push_int(1)
vm_list.push_int(2)
def invoke(arg_list, ret_list):
ret_list.push_list(vm_list)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(
reflection={
"iree.abi": json.dumps(
{
"a": [],
"r": [["py_homogeneous_list", "i64"]],
}
)
}
)
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
result = invoker()
self.assertEqual("[1, 2]", repr(result))
if __name__ == "__main__":
unittest.main()