# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import array
import logging
import numpy as np
from pathlib import Path
import tempfile
import unittest

import iree.compiler
import iree.runtime as rt


MM_TEST_COMPILED = None
MM_TEST_ASM = r"""
  #map = affine_map<(d0, d1) -> (d0, d1)>
  #map1 = affine_map<(d0, d1) -> (d1, d0)>
  #map2 = affine_map<(d0, d1) -> (d1)>
  module @main {
    util.global private @_params.classifier.weight {inlining_policy = #util.inline.never} = #stream.parameter.named<"params"::"weight"> : tensor<30x20xf32>
    util.global private @_params.classifier.bias {inlining_policy = #util.inline.never} = #stream.parameter.named<"params"::"bias"> : tensor<30xf32>
  func.func @run(%arg0: tensor<128x20xf32>) -> tensor<128x30xf32> {
    %0 = call @forward(%arg0) : (tensor<128x20xf32>) -> tensor<128x30xf32>
    return %0 : tensor<128x30xf32>
  }
  func.func private @forward(%arg0: tensor<128x20xf32>) -> tensor<128x30xf32> attributes {torch.assume_strict_symbolic_shapes} {
    %cst = arith.constant 0.000000e+00 : f32
    %_params.classifier.weight = util.global.load @_params.classifier.weight : tensor<30x20xf32>
    %0 = tensor.empty() : tensor<20x30xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%_params.classifier.weight : tensor<30x20xf32>) outs(%0 : tensor<20x30xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<20x30xf32>
    %2 = tensor.empty() : tensor<128x30xf32>
    %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<128x30xf32>) -> tensor<128x30xf32>
    %4 = linalg.matmul ins(%arg0, %1 : tensor<128x20xf32>, tensor<20x30xf32>) outs(%3 : tensor<128x30xf32>) -> tensor<128x30xf32>
    %_params.classifier.bias = util.global.load @_params.classifier.bias : tensor<30xf32>
    %5 = linalg.generic {indexing_maps = [#map, #map2, #map], iterator_types = ["parallel", "parallel"]} ins(%4, %_params.classifier.bias : tensor<128x30xf32>, tensor<30xf32>) outs(%2 : tensor<128x30xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %6 = arith.addf %in, %in_0 : f32
      linalg.yield %6 : f32
    } -> tensor<128x30xf32>
    return %5 : tensor<128x30xf32>
  }
}
"""


def compile_mm_test():
    global MM_TEST_COMPILED
    if not MM_TEST_COMPILED:
        MM_TEST_COMPILED = iree.compiler.compile_str(
            MM_TEST_ASM,
            target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
            # TODO(#16098): re-enable const eval once parameters are supported.
            extra_args=["--iree-opt-const-eval=false"],
        )
    return MM_TEST_COMPILED


def create_mm_test_module(instance):
    binary = compile_mm_test()
    return rt.VmModule.copy_buffer(instance, binary)


def _float_constant(val: float) -> array.array:
    return array.array("f", [val])


class ParameterArchiveTest(unittest.TestCase):
    def testCreateArchiveFile(self):
        splat_index = rt.ParameterIndex()
        splat_index.add_splat("weight", _float_constant(2.0), 30 * 20 * 4)
        splat_index.add_splat("bias", _float_constant(1.0), 30 * 4)

        with tempfile.TemporaryDirectory() as td:
            file_path = Path(td) / "archive.irpa"
            target_index = splat_index.create_archive_file(str(file_path))
            print(target_index)
            self.assertTrue(file_path.exists())
            self.assertGreater(file_path.stat().st_size, 0)

    def testSaveArchiveFile(self):
        index = rt.ParameterIndex()
        with tempfile.TemporaryDirectory() as td:
            file_path = Path(td) / "archive.irpa"
            rt.save_archive_file(
                {
                    "weight": rt.SplatValue(np.float32(2.0), [30, 20]),
                    "bias": rt.SplatValue(array.array("f", [1.0]), 30),
                    "array": np.asarray([1, 2, 3]),
                },
                file_path,
            )
            self.assertTrue(file_path.exists())
            self.assertGreater(file_path.stat().st_size, 0)


class ParameterTest(unittest.TestCase):
    def setUp(self):
        self.instance = rt.VmInstance()
        self.device = rt.get_device(iree.compiler.core.DEFAULT_TESTING_DRIVER)
        self.config = rt.Config(device=self.device)

    def testParameterIndex(self):
        index = rt.ParameterIndex()
        self.assertEqual(len(index), 0)
        index.reserve(25)
        self.assertEqual(len(index), 0)
        provider = index.create_provider()
        rt.create_io_parameters_module(self.instance, provider)

    def testFileHandleWrap(self):
        fh = rt.FileHandle.wrap_memory(b"foobar")
        del fh

    def testParameterIndexAddFromFile(self):
        splat_index = rt.ParameterIndex()
        fh = rt.FileHandle.wrap_memory(b"foobar")
        splat_index.add_from_file_handle("data", fh, length=3, offset=3)

    def testSplats(self):
        splat_index = rt.ParameterIndex()
        splat_index.add_splat("weight", _float_constant(2.0), 30 * 20 * 4)
        splat_index.add_splat("bias", _float_constant(1.0), 30 * 4)
        modules = rt.load_vm_modules(
            rt.create_io_parameters_module(
                self.instance, splat_index.create_provider(scope="params")
            ),
            rt.create_hal_module(self.instance, self.device),
            create_mm_test_module(self.instance),
            config=self.config,
        )
        main = modules[-1]
        input = np.zeros([128, 20], dtype=np.float32) + 2.0
        result = main.run(input)
        print(result.to_host())
        # TODO: Fix splat in the parameter code so it is not all zeros.
        # expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
        # np.testing.assert_array_almost_equal(result, expected_result)

    def testSplatsFromBuiltIrpaFile(self):
        with tempfile.TemporaryDirectory() as td:
            file_path = Path(td) / "archive.irpa"
            rt.save_archive_file(
                {
                    "weight": rt.SplatValue(np.float32(2.0), 30 * 20),
                    "bias": rt.SplatValue(np.float32(1.0), 30),
                },
                file_path,
            )

            index = rt.ParameterIndex()
            index.load(str(file_path))
        modules = rt.load_vm_modules(
            rt.create_io_parameters_module(
                self.instance, index.create_provider(scope="params")
            ),
            rt.create_hal_module(self.instance, self.device),
            create_mm_test_module(self.instance),
            config=self.config,
        )
        main = modules[-1]
        input = np.zeros([128, 20], dtype=np.float32) + 2.0
        result = main.run(input)
        print(result.to_host())
        expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
        np.testing.assert_array_almost_equal(result, expected_result)

    def testBuffers(self):
        index = rt.ParameterIndex()
        weight = np.zeros([30, 20], dtype=np.float32) + 2.0
        bias = np.zeros([30], dtype=np.float32) + 1.0
        index.add_buffer("weight", weight)
        index.add_buffer("bias", bias)
        modules = rt.load_vm_modules(
            rt.create_io_parameters_module(
                self.instance, index.create_provider(scope="params")
            ),
            rt.create_hal_module(self.instance, self.device),
            create_mm_test_module(self.instance),
            config=self.config,
        )
        main = modules[-1]
        input = np.zeros([128, 20], dtype=np.float32) + 2.0
        result = main.run(input)
        print(result.to_host())
        expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
        np.testing.assert_array_almost_equal(result, expected_result)

    def testGguf(self):
        index = rt.ParameterIndex()
        index.load(
            str(
                Path(__file__).resolve().parent
                / "testdata"
                / "parameter_weight_bias_1.gguf"
            )
        )
        modules = rt.load_vm_modules(
            rt.create_io_parameters_module(
                self.instance, index.create_provider(scope="params")
            ),
            rt.create_hal_module(self.instance, self.device),
            create_mm_test_module(self.instance),
            config=self.config,
        )
        main = modules[-1]
        input = np.zeros([128, 20], dtype=np.float32) + 2.0
        result = main.run(input)
        print(result.to_host())
        expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
        np.testing.assert_array_almost_equal(result, expected_result)

    def testSafetensors(self):
        index = rt.ParameterIndex()
        index.load(
            str(
                Path(__file__).resolve().parent
                / "testdata"
                / "parameter_weight_bias_1.safetensors"
            )
        )
        modules = rt.load_vm_modules(
            rt.create_io_parameters_module(
                self.instance, index.create_provider(scope="params")
            ),
            rt.create_hal_module(self.instance, self.device),
            create_mm_test_module(self.instance),
            config=self.config,
        )
        main = modules[-1]
        input = np.zeros([128, 20], dtype=np.float32) + 2.0
        result = main.run(input)
        print(result.to_host())
        expected_result = np.zeros([128, 30], dtype=np.float32) + 81.0
        np.testing.assert_array_almost_equal(result, expected_result)

    def testSplatTooBig(self):
        splat_index = rt.ParameterIndex()
        with self.assertRaises(ValueError):
            splat_index.add_splat(
                "weight", array.array("f", [1.0, 2.0, 3.0, 4.0, 5.0]), 30 * 20 * 4
            )


if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG)
    unittest.main()
