# Copyright 2019 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# pylint: disable=unused-variable

import logging
import numpy as np
import unittest

import iree.compiler
import iree.runtime


def create_add_scalar_module():
  binary = iree.compiler.compile_str(
      """
      func.func @add_scalar(%arg0: i32, %arg1: i32) -> i32 {
        %0 = arith.addi %arg0, %arg1 : i32
        return %0 : i32
      }
      """,
      input_type="mhlo",
      target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
  )
  m = iree.runtime.VmModule.from_flatbuffer(binary)
  return m


def create_simple_static_mul_module():
  binary = iree.compiler.compile_str(
      """
      func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
          %0 = "mhlo.multiply"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
          return %0 : tensor<4xf32>
      }
      """,
      input_type="mhlo",
      target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
  )
  m = iree.runtime.VmModule.from_flatbuffer(binary)
  return m


def create_simple_dynamic_abs_module():
  # TODO(laurenzo): Compile for more backends as dynamic shapes come online.
  target_backends = iree.compiler.DEFAULT_TESTING_BACKENDS
  binary = iree.compiler.compile_str(
      """
      func.func @simple_mul(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
          %0 = "mhlo.abs"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
          return %0 : tensor<?x?xf32>
      }
      """,
      input_type="mhlo",
      target_backends=target_backends,
  )
  m = iree.runtime.VmModule.from_flatbuffer(binary)
  return m


class VmTest(unittest.TestCase):

  @classmethod
  def setUpClass(cls):
    super().setUpClass()
    driver_names = iree.runtime.HalDriver.query()
    logging.info("driver_names: %s", driver_names)
    cls.driver = iree.runtime.HalDriver.create(
        iree.compiler.core.DEFAULT_TESTING_DRIVER)
    cls.device = cls.driver.create_default_device()
    cls.hal_module = iree.runtime.create_hal_module(cls.device)

  def test_variant_list(self):
    l = iree.runtime.VmVariantList(5)
    logging.info("variant_list: %s", l)
    self.assertEqual(l.size, 0)

  def test_variant_list_i64(self):
    l = iree.runtime.VmVariantList(5)
    # Push a value that exceeds 32-bit range.
    l.push_int(10 * 1000 * 1000 * 1000)
    self.assertEqual(str(l), "<VmVariantList(1): [10000000000]>")

  def test_variant_list_buffers(self):
    ET = iree.runtime.HalElementType
    for dt, et in ((np.int8, ET.SINT_8), (np.int16, ET.SINT_16),
                   (np.int32, ET.SINT_32), (np.int64, ET.SINT_64),
                   (np.uint8, ET.UINT_8), (np.uint16, ET.UINT_16),
                   (np.uint32, ET.UINT_32), (np.uint64, ET.UINT_64),
                   (np.float32, ET.FLOAT_32), (np.float64, ET.FLOAT_64)):
      # TODO: Unimplemented: (np.float16, ET.FLOAT_16)
      lst = iree.runtime.VmVariantList(5)
      ary1 = np.asarray([1, 2, 3, 4], dtype=dt)
      bv1 = self.device.allocator.allocate_buffer_copy(
          memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
          allowed_usage=(iree.runtime.BufferUsage.DEFAULT |
                         iree.runtime.BufferUsage.MAPPING),
          buffer=ary1,
          element_type=et)
      lst.push_buffer_view(bv1)
      ary2 = iree.runtime.DeviceArray(self.device,
                                      lst.get_as_buffer_view(0),
                                      override_dtype=dt,
                                      implicit_host_transfer=True)
      np.testing.assert_array_equal(ary1, ary2)
      with self.assertRaises(IndexError):
        lst.get_as_buffer_view(1)

  def test_variant_list_list(self):
    lst1 = iree.runtime.VmVariantList(5)
    lst2 = iree.runtime.VmVariantList(5)
    lst1.push_list(lst2)
    self.assertEqual("<VmVariantList(1): [List[]]>", str(lst1))
    lstout = lst1.get_as_list(0)
    self.assertEqual("<VmVariantList(0): []>", str(lstout))
    with self.assertRaises(IndexError):
      lst1.get_as_list(1)

  def test_context_id(self):
    instance = iree.runtime.VmInstance()
    context1 = iree.runtime.VmContext(instance)
    context2 = iree.runtime.VmContext(instance)
    self.assertNotEqual(context2.context_id, context1.context_id)

  def test_module_basics(self):
    m = create_simple_static_mul_module()
    f = m.lookup_function("simple_mul")
    self.assertGreaterEqual(f.ordinal, 0)
    notfound = m.lookup_function("notfound")
    self.assertIs(notfound, None)

  def test_dynamic_module_context(self):
    instance = iree.runtime.VmInstance()
    context = iree.runtime.VmContext(instance)
    m = create_simple_static_mul_module()
    context.register_modules([self.hal_module, m])

  def test_static_module_context(self):
    m = create_simple_static_mul_module()
    logging.info("module: %s", m)
    instance = iree.runtime.VmInstance()
    logging.info("instance: %s", instance)
    context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
    logging.info("context: %s", context)

  def test_dynamic_shape_compile(self):
    m = create_simple_dynamic_abs_module()
    logging.info("module: %s", m)
    instance = iree.runtime.VmInstance()
    logging.info("instance: %s", instance)
    context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
    logging.info("context: %s", context)

  def test_add_scalar_new_abi(self):
    m = create_add_scalar_module()
    instance = iree.runtime.VmInstance()
    context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
    f = m.lookup_function("add_scalar")
    finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
    result = finv(5, 6)
    logging.info("result: %s", result)
    self.assertEqual(result, 11)

  def test_synchronous_dynamic_shape_invoke_function_new_abi(self):
    m = create_simple_dynamic_abs_module()
    instance = iree.runtime.VmInstance()
    context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
    f = m.lookup_function("simple_mul")
    finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
    arg0 = np.array([[-1., 2.], [3., -4.]], dtype=np.float32)
    result = finv(arg0)
    logging.info("result: %s", result)
    np.testing.assert_allclose(result, [[1., 2.], [3., 4.]])

  def test_synchronous_invoke_function_new_abi(self):
    m = create_simple_static_mul_module()
    instance = iree.runtime.VmInstance()
    context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
    f = m.lookup_function("simple_mul")
    finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
    arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
    arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
    result = finv(arg0, arg1)
    logging.info("result: %s", result)
    np.testing.assert_allclose(result, [4., 10., 18., 28.])


if __name__ == "__main__":
  logging.basicConfig(level=logging.DEBUG)
  unittest.main()
