# Copyright 2019 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import json
import numpy as np
import unittest

from iree import runtime as rt
from iree.runtime.function import (
    FunctionInvoker,
    IMPLICIT_BUFFER_ARG_MEMORY_TYPE,
    IMPLICIT_BUFFER_ARG_USAGE,
)
from iree.runtime._binding import VmVariantList


class MockVmContext:
    def __init__(self, invoke_callback):
        self._invoke_callback = invoke_callback
        self.invocations = []

    def invoke(self, vm_function, arg_list, ret_list):
        self._invoke_callback(arg_list, ret_list)
        self.invocations.append((vm_function, arg_list, ret_list))
        print(f"INVOKE: {arg_list} -> {ret_list}")

    @property
    def mock_arg_reprs(self):
        return repr([arg_list for _, arg_list, _ in self.invocations])


class MockVmFunction:
    def __init__(self, reflection):
        self.reflection = reflection


class FunctionTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        # Doesn't matter what device. We just need one.
        config = rt.Config("local-task")
        cls.device = config.device

    def testNoReflectionScalars(self):
        def invoke(arg_list, ret_list):
            ret_list.push_int(3)
            ret_list.push_int(4)

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(reflection={})
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        result = invoker(1, 2)
        self.assertEqual("[<VmVariantList(2): [1, 2]>]", vm_context.mock_arg_reprs)
        self.assertEqual((3, 4), result)

    def testKeywordArgs(self):
        def invoke(arg_list, ret_list):
            ret_list.push_int(3)

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(
            reflection={
                "iree.abi": json.dumps(
                    {
                        "a": [
                            "i32",
                            ["named", "a", "i32"],
                            ["named", "b", "i32"],
                        ],
                        "r": [
                            "i32",
                        ],
                    }
                )
            }
        )
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        result = invoker(-1, a=1, b=2)
        self.assertEqual("[<VmVariantList(3): [-1, 1, 2]>]", vm_context.mock_arg_reprs)
        self.assertEqual(3, result)

    def testListArg(self):
        def invoke(arg_list, ret_list):
            ret_list.push_int(3)

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(
            reflection={
                "iree.abi": json.dumps(
                    {
                        "a": [
                            ["slist", "i32", "i32"],
                        ],
                        "r": [
                            "i32",
                        ],
                    }
                )
            }
        )
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        _ = invoker([2, 3])
        self.assertEqual(
            "[<VmVariantList(1): [List[2, 3]]>]", vm_context.mock_arg_reprs
        )

    def testListArgNoReflection(self):
        def invoke(arg_list, ret_list):
            ret_list.push_int(3)

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(reflection={})
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        _ = invoker([2, 3])
        self.assertEqual(
            "[<VmVariantList(1): [List[2, 3]]>]", vm_context.mock_arg_reprs
        )

    def testListArgArityMismatch(self):
        def invoke(arg_list, ret_list):
            ret_list.push_int(3)

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(
            reflection={
                "iree.abi": json.dumps(
                    {
                        "a": [
                            ["slist", "i32", "i32"],
                        ],
                        "r": [
                            "i32",
                        ],
                    }
                )
            }
        )
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        with self.assertRaisesRegex(
            ValueError, "expected a sequence with 2 values. got:"
        ):
            _ = invoker([2, 3, 4])

    def testTupleArg(self):
        def invoke(arg_list, ret_list):
            ret_list.push_int(3)

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(
            reflection={
                "iree.abi": json.dumps(
                    {
                        "a": [
                            ["stuple", "i32", "i32"],
                        ],
                        "r": [
                            "i32",
                        ],
                    }
                )
            }
        )
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        _ = invoker((2, 3))
        self.assertEqual(
            "[<VmVariantList(1): [List[2, 3]]>]", vm_context.mock_arg_reprs
        )

    def testDictArg(self):
        def invoke(arg_list, ret_list):
            ret_list.push_int(3)

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(
            reflection={
                "iree.abi": json.dumps(
                    {
                        "a": [
                            ["sdict", ["a", "i32"], ["b", "i32"]],
                        ],
                        "r": [
                            "i32",
                        ],
                    }
                )
            }
        )
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        _ = invoker({"b": 3, "a": 2})
        self.assertEqual(
            "[<VmVariantList(1): [List[2, 3]]>]", vm_context.mock_arg_reprs
        )

    def testDictArgArityMismatch(self):
        def invoke(arg_list, ret_list):
            ret_list.push_int(3)

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(
            reflection={
                "iree.abi": json.dumps(
                    {
                        "a": [
                            ["sdict", ["a", "i32"], ["b", "i32"]],
                        ],
                        "r": [
                            "i32",
                        ],
                    }
                )
            }
        )
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        with self.assertRaisesRegex(ValueError, "expected a dict with 2 values. got:"):
            _ = invoker({"a": 2, "b": 3, "c": 4})

    def testDictArgKeyError(self):
        def invoke(arg_list, ret_list):
            ret_list.push_int(3)

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(
            reflection={
                "iree.abi": json.dumps(
                    {
                        "a": [
                            ["sdict", ["a", "i32"], ["b", "i32"]],
                        ],
                        "r": [
                            "i32",
                        ],
                    }
                )
            }
        )
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        with self.assertRaisesRegex(ValueError, "could not get item 'b' from: "):
            _ = invoker({"a": 2, "c": 3})

    def testDictArgNoReflection(self):
        def invoke(arg_list, ret_list):
            ret_list.push_int(3)

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(reflection={})
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        _ = invoker({"b": 3, "a": 2})
        self.assertEqual(
            "[<VmVariantList(1): [List[2, 3]]>]", vm_context.mock_arg_reprs
        )

    def testInlinedResults(self):
        def invoke(arg_list, ret_list):
            ret_list.push_int(3)
            ret_list.push_int(4)

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(
            reflection={
                "iree.abi": json.dumps(
                    {
                        "a": [],
                        "r": [["slist", "i32", "i32"]],
                    }
                )
            }
        )
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        result = invoker()
        self.assertEqual([3, 4], result)

    def testNestedResults(self):
        def invoke(arg_list, ret_list):
            ret_list.push_int(3)
            sub_list = VmVariantList(2)
            sub_dict = VmVariantList(2)
            sub_dict.push_int(100)
            sub_dict.push_int(200)
            sub_list.push_list(sub_dict)
            sub_list.push_int(6)
            ret_list.push_list(sub_list)

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(
            reflection={
                "iree.abi": json.dumps(
                    {
                        "a": [],
                        "r": [
                            "i32",
                            [
                                "slist",
                                ["sdict", ["bar", "i32"], ["foo", "i32"]],
                                "i64",
                            ],
                        ],
                    }
                )
            }
        )
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        result = invoker()
        self.assertEqual((3, [{"bar": 100, "foo": 200}, 6]), result)

    def testMissingPositional(self):
        def invoke(arg_list, ret_list):
            ret_list.push_int(3)

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(
            reflection={
                "iree.abi": json.dumps(
                    {
                        "a": [
                            "i32",
                            ["named", "a", "i32"],
                            ["named", "b", "i32"],
                        ],
                        "r": [
                            "i32",
                        ],
                    }
                )
            }
        )
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        with self.assertRaisesRegex(ValueError, "mismatched call arity:"):
            result = invoker(a=1, b=1)

    def testMissingPositionalNdarray(self):
        def invoke(arg_list, ret_list):
            ret_list.push_int(3)

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(
            reflection={
                "iree.abi": json.dumps(
                    {
                        "a": [
                            ["ndarray", "i32", 1, 1],
                            ["named", "a", ["ndarray", "i32", 1, 1]],
                            ["named", "b", ["ndarray", "i32", 1, 1]],
                        ],
                        "r": [
                            "i32",
                        ],
                    }
                )
            }
        )
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        with self.assertRaisesRegex(ValueError, "mismatched call arity:"):
            result = invoker(a=1, b=1)

    def testMissingKeyword(self):
        def invoke(arg_list, ret_list):
            ret_list.push_int(3)

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(
            reflection={
                "iree.abi": json.dumps(
                    {
                        "a": [
                            "i32",
                            ["named", "a", "i32"],
                            ["named", "b", "i32"],
                        ],
                        "r": [
                            "i32",
                        ],
                    }
                )
            }
        )
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        with self.assertRaisesRegex(ValueError, "mismatched call arity:"):
            result = invoker(-1, a=1)

    def testMissingKeywordNdArray(self):
        def invoke(arg_list, ret_list):
            ret_list.push_int(3)

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(
            reflection={
                "iree.abi": json.dumps(
                    {
                        "a": [
                            ["ndarray", "i32", 1, 1],
                            ["named", "a", ["ndarray", "i32", 1, 1]],
                            ["named", "b", ["ndarray", "i32", 1, 1]],
                        ],
                        "r": [
                            "i32",
                        ],
                    }
                )
            }
        )
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        with self.assertRaisesRegex(ValueError, "mismatched call arity:"):
            result = invoker(-1, a=1)

    def testExtraKeyword(self):
        def invoke(arg_list, ret_list):
            ret_list.push_int(3)

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(
            reflection={
                "iree.abi": json.dumps(
                    {
                        "a": [
                            "i32",
                            ["named", "a", "i32"],
                            ["named", "b", "i32"],
                        ],
                        "r": [
                            "i32",
                        ],
                    }
                )
            }
        )
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        with self.assertRaisesRegex(ValueError, "specified kwarg 'c' is unknown"):
            result = invoker(-1, a=1, b=2, c=3)

    def testNdarrayArg(self):
        arg_array = np.asarray([1, 0], dtype=np.int32)

        invoked_arg_list = None

        def invoke(arg_list, ret_list):
            nonlocal invoked_arg_list
            invoked_arg_list = arg_list

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(
            reflection={
                "iree.abi": json.dumps(
                    {
                        "a": [["ndarray", "i32", 1, 2]],
                        "r": [],
                    }
                )
            }
        )
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        result = invoker(arg_array)
        self.assertEqual(
            "<VmVariantList(1): [HalBufferView(2:0x20000011)]>", repr(invoked_arg_list)
        )

    def testDeviceArrayArg(self):
        # Note that since the device array is set up to disallow implicit host
        # transfers, this also verifies that no accidental/automatic transfers
        # are done as part of marshalling the array to the function.
        arg_array = rt.asdevicearray(
            self.device,
            np.asarray([1, 0], dtype=np.int32),
            implicit_host_transfer=False,
        )

        invoked_arg_list = None

        def invoke(arg_list, ret_list):
            nonlocal invoked_arg_list
            invoked_arg_list = arg_list

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(
            reflection={
                "iree.abi": json.dumps(
                    {
                        "a": [["ndarray", "i32", 1, 2]],
                        "r": [],
                    }
                )
            }
        )
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        result = invoker(arg_array)
        self.assertEqual(
            "<VmVariantList(1): [HalBufferView(2:0x20000011)]>", repr(invoked_arg_list)
        )

    def testBufferViewArg(self):
        arg_buffer_view = self.device.allocator.allocate_buffer_copy(
            memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE,
            allowed_usage=IMPLICIT_BUFFER_ARG_USAGE,
            buffer=np.asarray([1, 0], dtype=np.int32),
            element_type=rt.HalElementType.SINT_32,
        )

        invoked_arg_list = None

        def invoke(arg_list, ret_list):
            nonlocal invoked_arg_list
            invoked_arg_list = arg_list

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(
            reflection={
                "iree.abi": json.dumps(
                    {
                        "a": [["ndarray", "i32", 1, 2]],
                        "r": [],
                    }
                )
            }
        )
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        _ = invoker(arg_buffer_view)
        self.assertEqual(
            "<VmVariantList(1): [HalBufferView(2:0x20000011)]>", repr(invoked_arg_list)
        )

    def testNdarrayArgNoReflection(self):
        arg_array = np.asarray([1, 0], dtype=np.int32)

        invoked_arg_list = None

        def invoke(arg_list, ret_list):
            nonlocal invoked_arg_list
            invoked_arg_list = arg_list

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(reflection={})
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        result = invoker(arg_array)
        self.assertEqual(
            "<VmVariantList(1): [HalBufferView(2:0x20000011)]>", repr(invoked_arg_list)
        )

    def testDeviceArrayArgNoReflection(self):
        # Note that since the device array is set up to disallow implicit host
        # transfers, this also verifies that no accidental/automatic transfers
        # are done as part of marshalling the array to the function.
        arg_array = rt.asdevicearray(
            self.device,
            np.asarray([1, 0], dtype=np.int32),
            implicit_host_transfer=False,
        )

        invoked_arg_list = None

        def invoke(arg_list, ret_list):
            nonlocal invoked_arg_list
            invoked_arg_list = arg_list

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(reflection={})
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        result = invoker(arg_array)
        self.assertEqual(
            "<VmVariantList(1): [HalBufferView(2:0x20000011)]>", repr(invoked_arg_list)
        )

    def testBufferViewArgNoReflection(self):
        arg_buffer_view = self.device.allocator.allocate_buffer_copy(
            memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE,
            allowed_usage=IMPLICIT_BUFFER_ARG_USAGE,
            buffer=np.asarray([1, 0], dtype=np.int32),
            element_type=rt.HalElementType.SINT_32,
        )

        invoked_arg_list = None

        def invoke(arg_list, ret_list):
            nonlocal invoked_arg_list
            invoked_arg_list = arg_list

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(reflection={})
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        _ = invoker(arg_buffer_view)
        self.assertEqual(
            "<VmVariantList(1): [HalBufferView(2:0x20000011)]>", repr(invoked_arg_list)
        )

    def testReturnBufferView(self):
        result_array = np.asarray([1, 0], dtype=np.int32)

        def invoke(arg_list, ret_list):
            buffer_view = self.device.allocator.allocate_buffer_copy(
                memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE,
                allowed_usage=IMPLICIT_BUFFER_ARG_USAGE,
                buffer=result_array,
                element_type=rt.HalElementType.SINT_32,
            )
            ret_list.push_ref(buffer_view)

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(
            reflection={
                "iree.abi": json.dumps(
                    {
                        "a": [],
                        "r": [["ndarray", "i32", 1, 2]],
                    }
                )
            }
        )
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        result = invoker()
        np.testing.assert_array_equal([1, 0], result)

    def testReturnBufferViewNoReflection(self):
        result_array = np.asarray([1, 0], dtype=np.int32)

        def invoke(arg_list, ret_list):
            buffer_view = self.device.allocator.allocate_buffer_copy(
                memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE,
                allowed_usage=IMPLICIT_BUFFER_ARG_USAGE,
                buffer=result_array,
                element_type=rt.HalElementType.SINT_32,
            )
            ret_list.push_ref(buffer_view)

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(reflection={})
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        result = invoker()
        np.testing.assert_array_equal([1, 0], result)

    # TODO: Fill out all return types.
    def testReturnTypeNdArrayBool(self):
        result_array = np.asarray([1, 0], dtype=np.int8)

        def invoke(arg_list, ret_list):
            buffer_view = self.device.allocator.allocate_buffer_copy(
                memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE,
                allowed_usage=IMPLICIT_BUFFER_ARG_USAGE,
                buffer=result_array,
                element_type=rt.HalElementType.UINT_8,
            )
            ret_list.push_ref(buffer_view)

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(
            reflection={
                "iree.abi": json.dumps(
                    {
                        "a": [],
                        "r": [["ndarray", "i1", 1, 2]],
                    }
                )
            }
        )
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        result = invoker()
        # assertEqual on bool arrays is fraught for... reasons.
        np.testing.assert_array_equal([True, False], result)

    def testReturnTypeList(self):
        vm_list = VmVariantList(2)
        vm_list.push_int(1)
        vm_list.push_int(2)

        def invoke(arg_list, ret_list):
            ret_list.push_list(vm_list)

        vm_context = MockVmContext(invoke)
        vm_function = MockVmFunction(
            reflection={
                "iree.abi": json.dumps(
                    {
                        "a": [],
                        "r": [["py_homogeneous_list", "i64"]],
                    }
                )
            }
        )
        invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
        result = invoker()
        self.assertEqual("[1, 2]", repr(result))


if __name__ == "__main__":
    unittest.main()
