""" vec_test_helpers

This module is for reusable helper functions used by the templates.
"""
import collections
import enum
import re
import numpy as np

class VecTemplateHelper:
    """Given op and sew provide template with necessary parameters."""
    class OperandType(enum.Enum):
        """RISC-V V operand type options."""
        VECTOR = enum.auto()
        SCALAR = enum.auto()
        IMMEDIATE = enum.auto()

    class OperandWidth(enum.Enum):
        """RISC-V V operand width type option."""
        STANDARD = enum.auto()
        WIDENING = enum.auto()
        NARROWING = enum.auto()

    mnemonic_suffix = {
        OperandType.VECTOR : {
            OperandWidth.STANDARD: "vv",
            OperandWidth.WIDENING: "vv",
            OperandWidth.NARROWING: "wv",
        },
        OperandType.SCALAR : {
            OperandWidth.STANDARD: "vx",
            OperandWidth.WIDENING: "vx",
            OperandWidth.NARROWING: "wx",
        },
        OperandType.IMMEDIATE : {
            OperandWidth.STANDARD: "vi",
            OperandWidth.WIDENING: "vi",
            OperandWidth.NARROWING: "wi",
        },
    }

    """Helper class for providing params for use in templates"""
    def __init__(self, op_code, sew=None):
        self._op_code = op_code
        self._sew = sew
        self.force_unsigned = False
        self.signed_np_types = {
            8:np.int8,
            16:np.int16,
            32:np.int32}
        self.unsigned_np_types = {
            8:np.uint8,
            16:np.uint16,
            32:np.uint32}

    @property
    def op_code(self):
        """Return the op_code."""
        if self._op_code is None:
            raise ValueError("SEW was not set.")
        return self._op_code

    @op_code.setter
    def op_code(self, value):
        """Set the op_code"""
        self._op_code = value

    @property
    def sew(self):
        """Return the selected element width."""
        if self._sew is None:
            raise ValueError("SEW was not set.")
        return self._sew

    @sew.setter
    def sew(self, value):
        """Set the selected element width."""
        if not value in (8, 16, 32):
            raise ValueError("Invalid SEW")
        self._sew = value

    def is_widening(self):
        """Check if a particular op_code is a widening type."""
        return self.op_code[1] == 'w'

    def is_narrowing(self):
        """Check if a particular op_code is a narrowing type."""
        return self.op_code[1] == 'n'

    def is_destination_mask_register(self):
        """Check if a particular op_code has a mask output."""
        int_comparison_ops = ('vmseq', 'vmsne', 'vmsltu', 'vmsleu', 'vmsle', 'vmsgtu', 'vmsgt')
        return self.op_code in int_comparison_ops

    def is_unsigned(self):
        """Check if a particular op_code is a unsigned type."""
        return self.op_code[-1] == 'u'

    def get_sews(self):
        """Given an op_code return a list of valid element widths."""
        if self.is_widening() or self.is_narrowing():
            return [8, 16]
        return [8, 16, 32]

    def get_sew_sizes(self):
        """Return size of types."""
        dest_sew = self.sew
        src2_sew = self.sew
        src1_sew = self.sew
        imm_sew = self.sew
        if self.is_narrowing():
            src2_sew = self.sew * 2
        elif self.is_widening():
            dest_sew = self.sew * 2
        return dest_sew, src2_sew, src1_sew, imm_sew

    def get_var_types(self):
        """Return types for an op_code and element width."""
        VarTypes = collections.namedtuple(
            "VarTypes",
            ('dest_type', 'src2_type', 'src1_type', 'imm_type'))
        type_fmt = "%sint%d_t"
        sign_type  = "u" if self.is_unsigned() or self.force_unsigned  else ""
        dest_sew, src2_sew, src1_sew, imm_sew = self.get_sew_sizes()
        dest_type = type_fmt % (sign_type, dest_sew)
        src1_type = type_fmt % (sign_type, src1_sew)
        src2_type = type_fmt % (sign_type, src2_sew)
        imm_type = type_fmt % (sign_type, imm_sew)
        var_types = VarTypes(dest_type, src2_type, src1_type, imm_type)
        return var_types

    def get_mnemonic(self, operand_type):
        """Generate the correct mnemonic given a opcode and operand type."""
        operand_width = self.OperandWidth.STANDARD
        if self.is_narrowing():
            operand_width = self.OperandWidth.NARROWING
        elif self.is_widening():
            operand_width = self.OperandWidth.WIDENING
        op_suffix = self.mnemonic_suffix[operand_type][operand_width]
        return "%s.%s" % (self.op_code, op_suffix)

    def get_lmuls(self):
        """Given an op_code return an iterable of valid lmuls."""
        if self.is_widening() or self.is_narrowing():
            return [1, 2, 4]
        return [1, 2, 4, 8]

    @staticmethod
    def get_sew_from_dtype(dtype):
        """Extract the selected element width from a data type."""
        match = re.match(r"[a-z]+(?P<sew>[\d]+)", dtype)
        return int(match['sew'])

    def get_softrvv_template_data_type(self):
        """Return types """
        var_types = self.get_var_types()
        if self.is_narrowing() or self.is_widening():
            return "%s, %s" % (var_types.dest_type, var_types.src2_type)
        return var_types.src1_type

    def get_ref_opcode(self):
        """Return the name of the reference code in the softrvv library."""
        return self.op_code[:-1] if self.is_unsigned() else self.op_code

    def get_imms(self):
        """Return a list of valid immediate values for a op code."""
        if self.op_code in ['vsll', 'vsrl', 'vsra', 'vnsrl', 'vnsra']:
            # Left and right shift immediates must be [0,31]
            return np.linspace(0, 31, 8, dtype=np.int32)
        # Immediate values must be [-16, 15]
        return np.linspace(-16, 15, 7, dtype=np.int32)

    def get_np_dest_type(self):
        """Return numpy type for destination."""
        if self.force_unsigned:
            types = self.unsigned_np_types
        else:
            types = self.signed_np_types
        if self.is_widening():
            return types[self.sew * 2]
        return types[self.sew]

    def get_np_src1_type(self):
        """Return numpy type for src1."""
        if self.force_unsigned:
            types = self.unsigned_np_types
        else:
            types = self.signed_np_types
        return types[self.sew]

    def get_np_src2_type(self):
        """Return numpy type for src2."""
        if self.force_unsigned:
            types = self.unsigned_np_types
        else:
            types = self.signed_np_types
        if self.is_narrowing():
            return types[self.sew * 2]
        return types[self.sew]

    def get_test_inputs(self, n=5, allow_zero=True): # pylint: disable=invalid-name
        """Return test inputs."""
        src1_np_type = self.get_np_src1_type()
        src2_np_type = self.get_np_src2_type()
        type_info = np.iinfo(src1_np_type)
        src1_data =  np.random.randint(
            type_info.min, type_info.max, n).astype(src1_np_type)
        rs1 = self.get_np_src1_type()(np.random.randint(
            type_info.min, type_info.max))
        src2_np_type = self.get_np_src2_type()
        type_info = np.iinfo(src2_np_type)
        src2_data = np.random.randint(
            type_info.min, type_info.max, n).astype(src2_np_type)
        if not allow_zero:
            src2_data[src2_data==0] = 1
            rs1 = 1 if rs1 == 0 else rs1
        return src2_data, src1_data, rs1

    def pack_dest_mask(self, values):
            """Pack values into a single destination register."""
            dest_type = self.get_np_dest_type()
            return np.packbits(dest_type(values), bitorder='little')

def cast_to_unsigned(arr):
    """Cast a signed array to an unsigned array."""
    udtypes = {np.int8:np.uint8,
               np.int16:np.uint16,
               np.int32:np.uint32,
               np.int64:np.uint64}
    if not arr.dtype.type in udtypes.keys():
        raise TypeError
    return arr.astype(udtypes[arr.dtype.type])

def to_carr_str(arr):
    """Simple function for turn array into comma separated list."""
    return ", ".join(("%s" % x for x in arr))
