# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" vec_test_helpers

This module is for reusable helper functions used by the templates.
"""
import collections
import enum
import re
import numpy as np


class VecTemplateHelper:
    """Given op and sew provide template with necessary parameters."""

    class OperandType(enum.Enum):
        """RISC-V V operand type options."""
        VECTOR = enum.auto()
        SCALAR = enum.auto()
        IMMEDIATE = enum.auto()
        FLOATSCALAR = enum.auto()

    class OperandWidth(enum.Enum):
        """RISC-V V operand width type option."""
        STANDARD = enum.auto()
        WIDENING = enum.auto()
        NARROWING = enum.auto()

    mnemonic_suffix = {
        OperandType.VECTOR: {
            OperandWidth.STANDARD: "vv",
            OperandWidth.WIDENING: "vv",
            OperandWidth.NARROWING: "wv",
        },
        OperandType.SCALAR: {
            OperandWidth.STANDARD: "vx",
            OperandWidth.WIDENING: "vx",
            OperandWidth.NARROWING: "wx",
        },
        OperandType.IMMEDIATE: {
            OperandWidth.STANDARD: "vi",
            OperandWidth.WIDENING: "vi",
            OperandWidth.NARROWING: "wi",
        },
        OperandType.FLOATSCALAR: {
            OperandWidth.STANDARD: "vf",
            OperandWidth.WIDENING: "vf",
            OperandWidth.NARROWING: "wf",
        },
    }
    """Helper class for providing params for use in templates"""

    def __init__(self, op_code, sew=None):
        self._op_code = op_code
        self._sew = sew
        self.force_unsigned = False
        self.signed_np_types = {8: np.int8, 16: np.int16, 32: np.int32}
        self.unsigned_np_types = {8: np.uint8, 16: np.uint16, 32: np.uint32}
        self.float_np_types = {32: np.float32}

    @property
    def op_code(self):
        """Return the op_code."""
        if self._op_code is None:
            raise ValueError("OP CODE was not set.")
        return self._op_code

    @op_code.setter
    def op_code(self, value):
        """Set the op_code"""
        self._op_code = value

    def is_floating(self):
        """check if a particular op_code is a floating type."""
        if self.op_code[1] == 'm':
            return self.op_code[2] == 'f'
        return self.op_code[1] == 'f'

    @property
    def sew(self):
        """Return the selected element width."""
        if self._sew is None:
            raise ValueError("SEW was not set.")
        return self._sew

    @sew.setter
    def sew(self, value):
        """Set the selected element width."""
        if self.is_floating():
            if not value == 32:
                raise ValueError("Invalid SEW")
        if not value in (8, 16, 32):
            raise ValueError("Invalid SEW")
        self._sew = value

    def is_widening(self):
        """Check if a particular op_code is a widening type."""
        if self.is_floating():
            raise ValueError(
                "Widening is not supported with this template for floating "
                "point values")
        return self.op_code[1] == 'w'

    def is_narrowing(self):
        """Check if a particular op_code is a narrowing type."""
        if self.is_floating():
            raise ValueError(
                "Narrowing is not supported with this template for floating "
                "point values")
        return self.op_code[1] == 'n'

    def is_destination_mask_register(self):
        """Check if a particular op_code has a mask output."""
        if self.is_floating():
            comparison_ops = ('vmfeq', 'vmfle', 'vmflt', 'vmfne', 'vmfgt',
                              'vmfge')
        else:
            comparison_ops = ('vmseq', 'vmsne', 'vmsltu', 'vmsleu', 'vmsle',
                              'vmsgtu', 'vmsgt')
        return self.op_code in comparison_ops

    def is_unsigned(self):
        """Check if a particular op_code is a unsigned type."""
        if self.is_floating():
            raise ValueError("Invalid unsigned type op code")
        return self.op_code[-1] == 'u'

    def get_sews(self):
        """Given an op_code return a list of valid element widths."""
        if self.is_floating():
            return [32]

        if self.is_widening() or self.is_narrowing():
            return [8, 16]
        return [8, 16, 32]

    def get_sew_sizes(self):
        """Return size of types.
        imm is not used for floating point op codes.
        """
        dest_sew = self.sew
        src2_sew = self.sew
        src1_sew = self.sew
        imm_sew = self.sew
        if not self.is_floating() and self.is_narrowing():
            src2_sew = self.sew * 2
        elif not self.is_floating() and self.is_widening():
            dest_sew = self.sew * 2
        return dest_sew, src2_sew, src1_sew, imm_sew

    def get_var_types(self):
        """Return types for an op_code and element width.
        imm_type won't be used for floating point values.
        """
        VarTypes = collections.namedtuple(
            "VarTypes", ('dest_type', 'src2_type', 'src1_type', 'imm_type'))
        if self.is_floating():
            # dest_type is defined as int32_t instead of float
            # to use "set_bit_in_dest_mask"
            # (in softrvv/include/softrvv_internal.h), which has bitwise
            #  operations in it.
            if self.is_destination_mask_register():
                var_types = VarTypes("int32_t", "float", "float", "float")
            else:
                var_types = VarTypes("float", "float", "float", "float")
        else:
            type_fmt = "%sint%d_t"
            sign_type = "u" if self.is_unsigned(
            ) or self.force_unsigned else ""
            dest_sew, src2_sew, src1_sew, imm_sew = self.get_sew_sizes()
            dest_type = type_fmt % (sign_type, dest_sew)
            src1_type = type_fmt % (sign_type, src1_sew)
            src2_type = type_fmt % (sign_type, src2_sew)
            imm_type = type_fmt % (sign_type, imm_sew)
            var_types = VarTypes(dest_type, src2_type, src1_type, imm_type)
        return var_types

    def get_mnemonic(self, operand_type):
        """Generate the correct mnemonic given a opcode and operand type."""
        if self.is_floating() and operand_type == self.OperandType.SCALAR:
            operand_width = self.OperandWidth.STANDARD
            operand_type = self.OperandType.FLOATSCALAR
        elif not self.is_floating():
            if self.is_narrowing():
                operand_width = self.OperandWidth.NARROWING
            elif self.is_widening():
                operand_width = self.OperandWidth.WIDENING
            else:
                operand_width = self.OperandWidth.STANDARD
        else:  # default operand_width
            operand_width = self.OperandWidth.STANDARD

        op_suffix = self.mnemonic_suffix[operand_type][operand_width]
        return f"{self.op_code}.{op_suffix}"

    def get_lmuls(self):
        """Given an op_code return an iterable of valid lmuls."""
        if not self.is_floating():
            if self.is_widening() or self.is_narrowing():
                return [1, 2, 4]
        return [1, 2, 4, 8]

    @staticmethod
    def get_sew_from_dtype(dtype):
        """Extract the selected element width from a data type."""
        match = re.match(r"[a-z]+(?P<sew>[\d]+)", dtype)
        return int(match['sew'])

    def get_softrvv_template_data_type(self):
        """Return types """
        var_types = self.get_var_types()
        if not self.is_floating():
            if self.is_narrowing() or self.is_widening():
                return f"{var_types.dest_type}, {var_types.src2_type}"
        return var_types.src1_type

    def get_ref_opcode(self):
        """Return the name of the reference code in the softrvv library."""
        if self.is_floating():
            return self.op_code
        return self.op_code[:-1] if self.is_unsigned() else self.op_code

    def get_imms(self):
        """Return a list of valid immediate values for a op code."""
        if self.is_floating():
            raise ValueError("imm is not available.")
        if self.op_code in ['vsll', 'vsrl', 'vsra', 'vnsrl', 'vnsra']:
            # Left and right shift immediates must be [0,31]
            return np.linspace(0, 31, 8, dtype=np.int32)
        # Immediate values must be [-16, 15]
        return np.linspace(-16, 15, 7, dtype=np.int32)

    def get_np_dest_type(self):
        """Return numpy type for destination."""
        if self.is_floating():
            return self.float_np_types[self.sew]
        if self.force_unsigned:
            types = self.unsigned_np_types
        else:
            types = self.signed_np_types
        if self.is_widening():
            return types[self.sew * 2]
        return types[self.sew]

    def get_np_src1_type(self):
        """Return numpy type for src1."""
        if self.is_floating():
            return self.float_np_types[self.sew]
        if self.force_unsigned:
            types = self.unsigned_np_types
        else:
            types = self.signed_np_types
        return types[self.sew]

    def get_np_src2_type(self):
        """Return numpy type for src2."""
        if self.is_floating():
            return self.float_np_types[self.sew]
        if self.force_unsigned:
            types = self.unsigned_np_types
        else:
            types = self.signed_np_types
        if self.is_narrowing():
            return types[self.sew * 2]
        return types[self.sew]

    def get_test_inputs(self, n=5, allow_zero=True):  # pylint: disable=invalid-name
        """Return test inputs."""
        src1_np_type = self.get_np_src1_type()
        src2_np_type = self.get_np_src2_type()
        src2_np_type = self.get_np_src2_type()
        # Return test inputs for floating points.
        if self.is_floating():
            type_info = np.finfo(src1_np_type)
            src1_data = np.random.uniform(type_info.min, type_info.max,
                                          n).astype(src1_np_type)
            rs1 = self.get_np_src1_type()(np.random.uniform(
                type_info.min, type_info.max))
            type_info = np.finfo(src2_np_type)
            src2_data = np.random.uniform(type_info.min, type_info.max,
                                          n).astype(src2_np_type)
            if not allow_zero:
                src2_data[src2_data == 0] = 1
                rs1 = 1.0 if rs1 == 0.0 else rs1
            return src2_data, src1_data, rs1

        # Return test inputs for integers.
        # pylint: disable=redefined-variable-type
        type_info = np.iinfo(src1_np_type)
        src1_data = np.random.randint(type_info.min, type_info.max,
                                      n).astype(src1_np_type)
        rs1 = self.get_np_src1_type()(np.random.randint(
            type_info.min, type_info.max))
        type_info = np.iinfo(src2_np_type)
        src2_data = np.random.randint(type_info.min, type_info.max,
                                      n).astype(src2_np_type)
        if not allow_zero:
            src2_data[src2_data == 0] = 1
            rs1 = 1 if rs1 == 0 else rs1
        return src2_data, src1_data, rs1

    def pack_dest_mask(self, values):
        """Pack values into a single destination register."""
        dest_type = self.get_np_dest_type()
        return np.packbits(dest_type(values), bitorder='little')


def cast_to_unsigned(arr):
    """Cast a signed array to an unsigned array.
    This should not be called with floating point values.
    """
    udtypes = {
        np.int8: np.uint8,
        np.int16: np.uint16,
        np.int32: np.uint32,
        np.int64: np.uint64
    }
    if not arr.dtype.type in udtypes:
        raise TypeError
    return arr.astype(udtypes[arr.dtype.type])


def to_carr_str(arr):
    """Simple function for turn array into comma separated list."""
    return ", ".join((f"{x}" for x in arr))
