# Copyright 2019 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import json
import numpy as np
import unittest

from iree import runtime as rt
from iree.runtime.function import (
    FunctionInvoker,
    IMPLICIT_BUFFER_ARG_MEMORY_TYPE,
    IMPLICIT_BUFFER_ARG_USAGE,
)
from iree.runtime._binding import VmVariantList


class MockVmContext:

  def __init__(self, invoke_callback):
    self._invoke_callback = invoke_callback
    self.invocations = []

  def invoke(self, vm_function, arg_list, ret_list):
    self._invoke_callback(arg_list, ret_list)
    self.invocations.append((vm_function, arg_list, ret_list))
    print(f"INVOKE: {arg_list} -> {ret_list}")

  @property
  def mock_arg_reprs(self):
    return repr([arg_list for _, arg_list, _ in self.invocations])


class MockVmFunction:

  def __init__(self, reflection):
    self.reflection = reflection


class FunctionTest(unittest.TestCase):

  @classmethod
  def setUpClass(cls):
    # Doesn't matter what device. We just need one.
    config = rt.Config("local-task")
    cls.device = config.device

  def testNoReflectionScalars(self):

    def invoke(arg_list, ret_list):
      ret_list.push_int(3)
      ret_list.push_int(4)

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(reflection={})
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    result = invoker(1, 2)
    self.assertEqual("[<VmVariantList(2): [1, 2]>]", vm_context.mock_arg_reprs)
    self.assertEqual((3, 4), result)

  def testKeywordArgs(self):

    def invoke(arg_list, ret_list):
      ret_list.push_int(3)

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(
        reflection={
            "iree.abi":
                json.dumps({
                    "a": [
                        "i32",
                        ["named", "a", "i32"],
                        ["named", "b", "i32"],
                    ],
                    "r": ["i32",],
                })
        })
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    result = invoker(-1, a=1, b=2)
    self.assertEqual("[<VmVariantList(3): [-1, 1, 2]>]",
                     vm_context.mock_arg_reprs)
    self.assertEqual(3, result)

  def testListArg(self):

    def invoke(arg_list, ret_list):
      ret_list.push_int(3)

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(reflection={
        "iree.abi":
            json.dumps({
                "a": [["slist", "i32", "i32"],],
                "r": ["i32",],
            })
    })
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    _ = invoker([2, 3])
    self.assertEqual("[<VmVariantList(1): [List[2, 3]]>]",
                     vm_context.mock_arg_reprs)

  def testListArgNoReflection(self):

    def invoke(arg_list, ret_list):
      ret_list.push_int(3)

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(reflection={})
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    _ = invoker([2, 3])
    self.assertEqual("[<VmVariantList(1): [List[2, 3]]>]",
                     vm_context.mock_arg_reprs)

  def testListArgArityMismatch(self):

    def invoke(arg_list, ret_list):
      ret_list.push_int(3)

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(reflection={
        "iree.abi":
            json.dumps({
                "a": [["slist", "i32", "i32"],],
                "r": ["i32",],
            })
    })
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    with self.assertRaisesRegex(ValueError,
                                "expected a sequence with 2 values. got:"):
      _ = invoker([2, 3, 4])

  def testTupleArg(self):

    def invoke(arg_list, ret_list):
      ret_list.push_int(3)

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(reflection={
        "iree.abi":
            json.dumps({
                "a": [["stuple", "i32", "i32"],],
                "r": ["i32",],
            })
    })
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    _ = invoker((2, 3))
    self.assertEqual("[<VmVariantList(1): [List[2, 3]]>]",
                     vm_context.mock_arg_reprs)

  def testDictArg(self):

    def invoke(arg_list, ret_list):
      ret_list.push_int(3)

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(
        reflection={
            "iree.abi":
                json.dumps({
                    "a": [["sdict", ["a", "i32"], ["b", "i32"]],],
                    "r": ["i32",],
                })
        })
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    _ = invoker({"b": 3, "a": 2})
    self.assertEqual("[<VmVariantList(1): [List[2, 3]]>]",
                     vm_context.mock_arg_reprs)

  def testDictArgArityMismatch(self):

    def invoke(arg_list, ret_list):
      ret_list.push_int(3)

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(
        reflection={
            "iree.abi":
                json.dumps({
                    "a": [["sdict", ["a", "i32"], ["b", "i32"]],],
                    "r": ["i32",],
                })
        })
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    with self.assertRaisesRegex(ValueError,
                                "expected a dict with 2 values. got:"):
      _ = invoker({"a": 2, "b": 3, "c": 4})

  def testDictArgKeyError(self):

    def invoke(arg_list, ret_list):
      ret_list.push_int(3)

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(
        reflection={
            "iree.abi":
                json.dumps({
                    "a": [["sdict", ["a", "i32"], ["b", "i32"]],],
                    "r": ["i32",],
                })
        })
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    with self.assertRaisesRegex(ValueError, "could not get item 'b' from: "):
      _ = invoker({"a": 2, "c": 3})

  def testDictArgNoReflection(self):

    def invoke(arg_list, ret_list):
      ret_list.push_int(3)

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(reflection={})
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    _ = invoker({"b": 3, "a": 2})
    self.assertEqual("[<VmVariantList(1): [List[2, 3]]>]",
                     vm_context.mock_arg_reprs)

  def testInlinedResults(self):

    def invoke(arg_list, ret_list):
      ret_list.push_int(3)
      ret_list.push_int(4)

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(reflection={
        "iree.abi": json.dumps({
            "a": [],
            "r": [["slist", "i32", "i32"]],
        })
    })
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    result = invoker()
    self.assertEqual([3, 4], result)

  def testNestedResults(self):

    def invoke(arg_list, ret_list):
      ret_list.push_int(3)
      sub_list = VmVariantList(2)
      sub_dict = VmVariantList(2)
      sub_dict.push_int(100)
      sub_dict.push_int(200)
      sub_list.push_list(sub_dict)
      sub_list.push_int(6)
      ret_list.push_list(sub_list)

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(
        reflection={
            "iree.abi":
                json.dumps({
                    "a": [],
                    "r": [
                        "i32",
                        [
                            "slist",
                            ["sdict", ["bar", "i32"], ["foo", "i32"]],
                            "i64",
                        ]
                    ],
                })
        })
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    result = invoker()
    self.assertEqual((3, [{"bar": 100, "foo": 200}, 6]), result)

  def testMissingPositional(self):

    def invoke(arg_list, ret_list):
      ret_list.push_int(3)

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(
        reflection={
            "iree.abi":
                json.dumps({
                    "a": [
                        "i32",
                        ["named", "a", "i32"],
                        ["named", "b", "i32"],
                    ],
                    "r": ["i32",],
                })
        })
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    with self.assertRaisesRegex(ValueError, "mismatched call arity:"):
      result = invoker(a=1, b=1)

  def testMissingPositionalNdarray(self):

    def invoke(arg_list, ret_list):
      ret_list.push_int(3)

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(
        reflection={
            "iree.abi":
                json.dumps({
                    "a": [
                        ["ndarray", "i32", 1, 1],
                        ["named", "a", ["ndarray", "i32", 1, 1]],
                        ["named", "b", ["ndarray", "i32", 1, 1]],
                    ],
                    "r": ["i32",],
                })
        })
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    with self.assertRaisesRegex(ValueError, "mismatched call arity:"):
      result = invoker(a=1, b=1)

  def testMissingKeyword(self):

    def invoke(arg_list, ret_list):
      ret_list.push_int(3)

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(
        reflection={
            "iree.abi":
                json.dumps({
                    "a": [
                        "i32",
                        ["named", "a", "i32"],
                        ["named", "b", "i32"],
                    ],
                    "r": ["i32",],
                })
        })
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    with self.assertRaisesRegex(ValueError, "mismatched call arity:"):
      result = invoker(-1, a=1)

  def testMissingKeywordNdArray(self):

    def invoke(arg_list, ret_list):
      ret_list.push_int(3)

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(
        reflection={
            "iree.abi":
                json.dumps({
                    "a": [
                        ["ndarray", "i32", 1, 1],
                        ["named", "a", ["ndarray", "i32", 1, 1]],
                        ["named", "b", ["ndarray", "i32", 1, 1]],
                    ],
                    "r": ["i32",],
                })
        })
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    with self.assertRaisesRegex(ValueError, "mismatched call arity:"):
      result = invoker(-1, a=1)

  def testExtraKeyword(self):

    def invoke(arg_list, ret_list):
      ret_list.push_int(3)

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(
        reflection={
            "iree.abi":
                json.dumps({
                    "a": [
                        "i32",
                        ["named", "a", "i32"],
                        ["named", "b", "i32"],
                    ],
                    "r": ["i32",],
                })
        })
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    with self.assertRaisesRegex(ValueError, "specified kwarg 'c' is unknown"):
      result = invoker(-1, a=1, b=2, c=3)

  def testNdarrayArg(self):
    arg_array = np.asarray([1, 0], dtype=np.int32)

    invoked_arg_list = None

    def invoke(arg_list, ret_list):
      nonlocal invoked_arg_list
      invoked_arg_list = arg_list

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(reflection={
        "iree.abi": json.dumps({
            "a": [["ndarray", "i32", 1, 2]],
            "r": [],
        })
    })
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    result = invoker(arg_array)
    self.assertEqual("<VmVariantList(1): [HalBufferView(2:0x20000011)]>",
                     repr(invoked_arg_list))

  def testDeviceArrayArg(self):
    # Note that since the device array is set up to disallow implicit host
    # transfers, this also verifies that no accidental/automatic transfers
    # are done as part of marshalling the array to the function.
    arg_array = rt.asdevicearray(self.device,
                                 np.asarray([1, 0], dtype=np.int32),
                                 implicit_host_transfer=False)

    invoked_arg_list = None

    def invoke(arg_list, ret_list):
      nonlocal invoked_arg_list
      invoked_arg_list = arg_list

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(reflection={
        "iree.abi": json.dumps({
            "a": [["ndarray", "i32", 1, 2]],
            "r": [],
        })
    })
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    result = invoker(arg_array)
    self.assertEqual("<VmVariantList(1): [HalBufferView(2:0x20000011)]>",
                     repr(invoked_arg_list))

  def testBufferViewArg(self):
    arg_buffer_view = self.device.allocator.allocate_buffer_copy(
        memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE,
        allowed_usage=IMPLICIT_BUFFER_ARG_USAGE,
        buffer=np.asarray([1, 0], dtype=np.int32),
        element_type=rt.HalElementType.SINT_32)

    invoked_arg_list = None

    def invoke(arg_list, ret_list):
      nonlocal invoked_arg_list
      invoked_arg_list = arg_list

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(reflection={
        "iree.abi": json.dumps({
            "a": [["ndarray", "i32", 1, 2]],
            "r": [],
        })
    })
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    _ = invoker(arg_buffer_view)
    self.assertEqual("<VmVariantList(1): [HalBufferView(2:0x20000011)]>",
                     repr(invoked_arg_list))

  def testNdarrayArgNoReflection(self):
    arg_array = np.asarray([1, 0], dtype=np.int32)

    invoked_arg_list = None

    def invoke(arg_list, ret_list):
      nonlocal invoked_arg_list
      invoked_arg_list = arg_list

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(reflection={})
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    result = invoker(arg_array)
    self.assertEqual("<VmVariantList(1): [HalBufferView(2:0x20000011)]>",
                     repr(invoked_arg_list))

  def testDeviceArrayArgNoReflection(self):
    # Note that since the device array is set up to disallow implicit host
    # transfers, this also verifies that no accidental/automatic transfers
    # are done as part of marshalling the array to the function.
    arg_array = rt.asdevicearray(self.device,
                                 np.asarray([1, 0], dtype=np.int32),
                                 implicit_host_transfer=False)

    invoked_arg_list = None

    def invoke(arg_list, ret_list):
      nonlocal invoked_arg_list
      invoked_arg_list = arg_list

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(reflection={})
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    result = invoker(arg_array)
    self.assertEqual("<VmVariantList(1): [HalBufferView(2:0x20000011)]>",
                     repr(invoked_arg_list))

  def testBufferViewArgNoReflection(self):
    arg_buffer_view = self.device.allocator.allocate_buffer_copy(
        memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE,
        allowed_usage=IMPLICIT_BUFFER_ARG_USAGE,
        buffer=np.asarray([1, 0], dtype=np.int32),
        element_type=rt.HalElementType.SINT_32)

    invoked_arg_list = None

    def invoke(arg_list, ret_list):
      nonlocal invoked_arg_list
      invoked_arg_list = arg_list

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(reflection={})
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    _ = invoker(arg_buffer_view)
    self.assertEqual("<VmVariantList(1): [HalBufferView(2:0x20000011)]>",
                     repr(invoked_arg_list))

  def testReturnBufferView(self):
    result_array = np.asarray([1, 0], dtype=np.int32)

    def invoke(arg_list, ret_list):
      buffer_view = self.device.allocator.allocate_buffer_copy(
          memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE,
          allowed_usage=IMPLICIT_BUFFER_ARG_USAGE,
          buffer=result_array,
          element_type=rt.HalElementType.SINT_32)
      ret_list.push_ref(buffer_view)

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(reflection={
        "iree.abi": json.dumps({
            "a": [],
            "r": [["ndarray", "i32", 1, 2]],
        })
    })
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    result = invoker()
    np.testing.assert_array_equal([1, 0], result)

  def testReturnBufferViewNoReflection(self):
    result_array = np.asarray([1, 0], dtype=np.int32)

    def invoke(arg_list, ret_list):
      buffer_view = self.device.allocator.allocate_buffer_copy(
          memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE,
          allowed_usage=IMPLICIT_BUFFER_ARG_USAGE,
          buffer=result_array,
          element_type=rt.HalElementType.SINT_32)
      ret_list.push_ref(buffer_view)

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(reflection={})
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    result = invoker()
    np.testing.assert_array_equal([1, 0], result)

  # TODO: Fill out all return types.
  def testReturnTypeNdArrayBool(self):
    result_array = np.asarray([1, 0], dtype=np.int8)

    def invoke(arg_list, ret_list):
      buffer_view = self.device.allocator.allocate_buffer_copy(
          memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE,
          allowed_usage=IMPLICIT_BUFFER_ARG_USAGE,
          buffer=result_array,
          element_type=rt.HalElementType.UINT_8)
      ret_list.push_ref(buffer_view)

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(reflection={
        "iree.abi": json.dumps({
            "a": [],
            "r": [["ndarray", "i1", 1, 2]],
        })
    })
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    result = invoker()
    # assertEqual on bool arrays is fraught for... reasons.
    np.testing.assert_array_equal([True, False], result)

  def testReturnTypeList(self):
    vm_list = VmVariantList(2)
    vm_list.push_int(1)
    vm_list.push_int(2)

    def invoke(arg_list, ret_list):
      ret_list.push_list(vm_list)

    vm_context = MockVmContext(invoke)
    vm_function = MockVmFunction(reflection={
        "iree.abi":
            json.dumps({
                "a": [],
                "r": [["py_homogeneous_list", "i64"]],
            })
    })
    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
    result = invoker()
    self.assertEqual("[1, 2]", repr(result))


if __name__ == "__main__":
  unittest.main()
