blob: 3db2671c3401bc8df15b526142dca908be2f84c7 [file] [log] [blame]
# Lint as: python3
# Copyright 2019 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import json
from absl.testing import absltest
from iree import runtime as rt
from iree.runtime.function import FunctionInvoker
from iree.runtime.binding import VmVariantList
class MockVmContext:
def __init__(self, invoke_callback):
self._invoke_callback = invoke_callback
self.invocations = []
def invoke(self, vm_function, arg_list, ret_list):
self._invoke_callback(arg_list, ret_list)
self.invocations.append((vm_function, arg_list, ret_list))
print(f"INVOKE: {arg_list} -> {ret_list}")
@property
def mock_arg_reprs(self):
return repr([arg_list for _, arg_list, _ in self.invocations])
class MockVmFunction:
def __init__(self, reflection):
self.reflection = reflection
class FunctionTest(absltest.TestCase):
def setUp(self):
# Doesn't matter what device. We just need one.
config = rt.Config("vmvx")
self.device = config.device
def testNoReflectionScalars(self):
def invoke(arg_list, ret_list):
ret_list.push_int(3)
ret_list.push_int(4)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(reflection={})
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
result = invoker(1, 2)
self.assertEqual("[<VmVariantList(2): [1, 2]>]", vm_context.mock_arg_reprs)
self.assertEqual((3, 4), result)
def testKeywordArgs(self):
def invoke(arg_list, ret_list):
ret_list.push_int(3)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(
reflection={
"iree.abi":
json.dumps({
"a": [
"i32",
["named", "a", "i32"],
["named", "b", "i32"],
],
"r": ["i32",],
})
})
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
result = invoker(-1, a=1, b=2)
self.assertEqual("[<VmVariantList(3): [-1, 1, 2]>]",
vm_context.mock_arg_reprs)
self.assertEqual(3, result)
def testInlinedResults(self):
def invoke(arg_list, ret_list):
ret_list.push_int(3)
ret_list.push_int(4)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(reflection={
"iree.abi": json.dumps({
"a": [],
"r": [["slist", "i32", "i32"]],
})
})
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
result = invoker()
self.assertEqual([3, 4], result)
def testNestedResults(self):
def invoke(arg_list, ret_list):
ret_list.push_int(3)
sub_list = VmVariantList(2)
sub_dict = VmVariantList(2)
sub_dict.push_int(100)
sub_dict.push_int(200)
sub_list.push_list(sub_dict)
sub_list.push_int(6)
ret_list.push_list(sub_list)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(
reflection={
"iree.abi":
json.dumps({
"a": [],
"r": [
"i32",
[
"slist",
["sdict", ["bar", "i32"], ["foo", "i32"]],
"i64",
]
],
})
})
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
result = invoker()
self.assertEqual((3, [{'bar': 100, 'foo': 200}, 6]), result)
def testMissingPositional(self):
def invoke(arg_list, ret_list):
ret_list.push_int(3)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(
reflection={
"iree.abi":
json.dumps({
"a": [
"i32",
["named", "a", "i32"],
["named", "b", "i32"],
],
"r": ["i32",],
})
})
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
with self.assertRaisesRegexp(ValueError,
"a required argument was not specified"):
result = invoker(a=1, b=2)
def testMissingKeyword(self):
def invoke(arg_list, ret_list):
ret_list.push_int(3)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(
reflection={
"iree.abi":
json.dumps({
"a": [
"i32",
["named", "a", "i32"],
["named", "b", "i32"],
],
"r": ["i32",],
})
})
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
with self.assertRaisesRegexp(ValueError,
"a required argument was not specified"):
result = invoker(-1, a=1)
def testExtraKeyword(self):
def invoke(arg_list, ret_list):
ret_list.push_int(3)
vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(
reflection={
"iree.abi":
json.dumps({
"a": [
"i32",
["named", "a", "i32"],
["named", "b", "i32"],
],
"r": ["i32",],
})
})
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
with self.assertRaisesRegexp(ValueError, "specified kwarg 'c' is unknown"):
result = invoker(-1, a=1, b=2, c=3)
if __name__ == "__main__":
absltest.main()