| """ vec_test_helpers |
| |
| This module is for reusable helper functions used by the templates. |
| """ |
| import collections |
| import enum |
| import re |
| import numpy as np |
| |
| class VecTemplateHelper: |
| """Given op and sew provide template with necessary parameters.""" |
| class OperandType(enum.Enum): |
| """RISC-V V operand type options.""" |
| VECTOR = enum.auto() |
| SCALAR = enum.auto() |
| IMMEDIATE = enum.auto() |
| FLOATSCALAR = enum.auto() |
| |
| class OperandWidth(enum.Enum): |
| """RISC-V V operand width type option.""" |
| STANDARD = enum.auto() |
| WIDENING = enum.auto() |
| NARROWING = enum.auto() |
| |
| mnemonic_suffix = { |
| OperandType.VECTOR : { |
| OperandWidth.STANDARD: "vv", |
| OperandWidth.WIDENING: "vv", |
| OperandWidth.NARROWING: "wv", |
| }, |
| OperandType.SCALAR : { |
| OperandWidth.STANDARD: "vx", |
| OperandWidth.WIDENING: "vx", |
| OperandWidth.NARROWING: "wx", |
| }, |
| OperandType.IMMEDIATE : { |
| OperandWidth.STANDARD: "vi", |
| OperandWidth.WIDENING: "vi", |
| OperandWidth.NARROWING: "wi", |
| }, |
| OperandType.FLOATSCALAR : { |
| OperandWidth.STANDARD: "vf", |
| OperandWidth.WIDENING: "vf", |
| OperandWidth.NARROWING: "wf", |
| }, |
| } |
| |
| """Helper class for providing params for use in templates""" |
| def __init__(self, op_code, sew=None): |
| self._op_code = op_code |
| self._sew = sew |
| self.force_unsigned = False |
| self.signed_np_types = { |
| 8:np.int8, |
| 16:np.int16, |
| 32:np.int32} |
| self.unsigned_np_types = { |
| 8:np.uint8, |
| 16:np.uint16, |
| 32:np.uint32} |
| self.float_np_types = { |
| 32:np.float32} |
| |
| @property |
| def op_code(self): |
| """Return the op_code.""" |
| if self._op_code is None: |
| raise ValueError("OP CODE was not set.") |
| return self._op_code |
| |
| @op_code.setter |
| def op_code(self, value): |
| """Set the op_code""" |
| self._op_code = value |
| |
| def is_floating(self): |
| """check if a particular op_code is a floating type.""" |
| if self.op_code[1] == 'm': |
| return self.op_code[2] == 'f' |
| return self.op_code[1] == 'f' |
| |
| @property |
| def sew(self): |
| """Return the selected element width.""" |
| if self._sew is None: |
| raise ValueError("SEW was not set.") |
| return self._sew |
| |
| @sew.setter |
| def sew(self, value): |
| """Set the selected element width.""" |
| if self.is_floating(): |
| if not (value == 32): |
| raise ValueError("Invalid SEW") |
| if not value in (8, 16, 32): |
| raise ValueError("Invalid SEW") |
| self._sew = value |
| |
| def is_widening(self): |
| """Check if a particular op_code is a widening type.""" |
| if self.is_floating(): |
| raise ValueError("Widening is not supported with this template for floating point values") |
| return self.op_code[1] == 'w' |
| |
| def is_narrowing(self): |
| """Check if a particular op_code is a narrowing type.""" |
| if self.is_floating(): |
| raise ValueError("Narrowing is not supported with this template for floating point values") |
| return self.op_code[1] == 'n' |
| |
| def is_destination_mask_register(self): |
| """Check if a particular op_code has a mask output.""" |
| if self.is_floating(): |
| comparison_ops = ('vmfeq', 'vmfle', 'vmflt', 'vmfne', 'vmfgt', 'vmfge') |
| else: |
| comparison_ops = ('vmseq', 'vmsne', 'vmsltu', 'vmsleu', 'vmsle', 'vmsgtu', 'vmsgt') |
| return self.op_code in comparison_ops |
| |
| def is_unsigned(self): |
| """Check if a particular op_code is a unsigned type.""" |
| if self.is_floating(): |
| raise ValueError("Invalid unsigned type op code") |
| return self.op_code[-1] == 'u' |
| |
| def get_sews(self): |
| """Given an op_code return a list of valid element widths.""" |
| if self.is_floating(): |
| return [32] |
| |
| if self.is_widening() or self.is_narrowing(): |
| return [8, 16] |
| return [8, 16, 32] |
| |
| def get_sew_sizes(self): |
| """Return size of types. |
| imm is not used for floating point op codes. |
| """ |
| dest_sew = self.sew |
| src2_sew = self.sew |
| src1_sew = self.sew |
| imm_sew = self.sew |
| if not self.is_floating() and self.is_narrowing(): |
| src2_sew = self.sew * 2 |
| elif not self.is_floating() and self.is_widening(): |
| dest_sew = self.sew * 2 |
| return dest_sew, src2_sew, src1_sew, imm_sew |
| |
| def get_var_types(self): |
| """Return types for an op_code and element width. |
| imm_type won't be used for floating point values. |
| """ |
| VarTypes = collections.namedtuple( |
| "VarTypes", |
| ('dest_type', 'src2_type', 'src1_type', 'imm_type')) |
| if self.is_floating(): |
| """dest_type is defined as int32_t instead of float |
| to use "set_bit_in_dest_mask" (in softrvv/include/softrvv_internal.h) |
| , which has bitwise operations in it. |
| """ |
| if self.is_destination_mask_register(): |
| var_types = VarTypes("int32_t", "float", "float", "float") |
| else: |
| var_types = VarTypes("float", "float", "float", "float") |
| else: |
| type_fmt = "%sint%d_t" |
| sign_type = "u" if self.is_unsigned() or self.force_unsigned else "" |
| dest_sew, src2_sew, src1_sew, imm_sew = self.get_sew_sizes() |
| dest_type = type_fmt % (sign_type, dest_sew) |
| src1_type = type_fmt % (sign_type, src1_sew) |
| src2_type = type_fmt % (sign_type, src2_sew) |
| imm_type = type_fmt % (sign_type, imm_sew) |
| var_types = VarTypes(dest_type, src2_type, src1_type, imm_type) |
| return var_types |
| |
| def get_mnemonic(self, operand_type): |
| """Generate the correct mnemonic given a opcode and operand type.""" |
| operand_width = self.OperandWidth.STANDARD |
| if self.is_floating(): |
| if operand_type == self.OperandType.SCALAR: |
| operand_type = self.OperandType.FLOATSCALAR |
| else: |
| if self.is_narrowing(): |
| operand_width = self.OperandWidth.NARROWING |
| elif self.is_widening(): |
| operand_width = self.OperandWidth.WIDENING |
| op_suffix = self.mnemonic_suffix[operand_type][operand_width] |
| return "%s.%s" % (self.op_code, op_suffix) |
| |
| def get_lmuls(self): |
| """Given an op_code return an iterable of valid lmuls.""" |
| if not self.is_floating(): |
| if self.is_widening() or self.is_narrowing(): |
| return [1, 2, 4] |
| return [1, 2, 4, 8] |
| |
| @staticmethod |
| def get_sew_from_dtype(dtype): |
| """Extract the selected element width from a data type.""" |
| match = re.match(r"[a-z]+(?P<sew>[\d]+)", dtype) |
| return int(match['sew']) |
| |
| def get_softrvv_template_data_type(self): |
| """Return types """ |
| var_types = self.get_var_types() |
| if not self.is_floating(): |
| if self.is_narrowing() or self.is_widening(): |
| return "%s, %s" % (var_types.dest_type, var_types.src2_type) |
| return var_types.src1_type |
| |
| def get_ref_opcode(self): |
| """Return the name of the reference code in the softrvv library.""" |
| if self.is_floating(): |
| return self.op_code |
| return self.op_code[:-1] if self.is_unsigned() else self.op_code |
| |
| def get_imms(self): |
| """Return a list of valid immediate values for a op code.""" |
| if self.is_floating(): |
| raise ValueError("imm is not available.") |
| if self.op_code in ['vsll', 'vsrl', 'vsra', 'vnsrl', 'vnsra']: |
| # Left and right shift immediates must be [0,31] |
| return np.linspace(0, 31, 8, dtype=np.int32) |
| # Immediate values must be [-16, 15] |
| return np.linspace(-16, 15, 7, dtype=np.int32) |
| |
| def get_np_dest_type(self): |
| """Return numpy type for destination.""" |
| if self.is_floating(): |
| return self.float_np_types[self.sew] |
| if self.force_unsigned: |
| types = self.unsigned_np_types |
| else: |
| types = self.signed_np_types |
| if self.is_widening(): |
| return types[self.sew * 2] |
| return types[self.sew] |
| |
| def get_np_src1_type(self): |
| """Return numpy type for src1.""" |
| if self.is_floating(): |
| return self.float_np_types[self.sew] |
| if self.force_unsigned: |
| types = self.unsigned_np_types |
| else: |
| types = self.signed_np_types |
| return types[self.sew] |
| |
| def get_np_src2_type(self): |
| """Return numpy type for src2.""" |
| if self.is_floating(): |
| return self.float_np_types[self.sew] |
| if self.force_unsigned: |
| types = self.unsigned_np_types |
| else: |
| types = self.signed_np_types |
| if self.is_narrowing(): |
| return types[self.sew * 2] |
| return types[self.sew] |
| |
| def get_test_inputs(self, n=5, allow_zero=True): # pylint: disable=invalid-name |
| """Return test inputs.""" |
| src1_np_type = self.get_np_src1_type() |
| src2_np_type = self.get_np_src2_type() |
| src2_np_type = self.get_np_src2_type() |
| # Return test inputs for floating points. |
| if self.is_floating(): |
| type_info = np.finfo(src1_np_type) |
| src1_data = np.random.uniform( |
| type_info.min, type_info.max, n).astype(src1_np_type) |
| rs1 = self.get_np_src1_type()(np.random.uniform( |
| type_info.min, type_info.max)) |
| type_info = np.finfo(src2_np_type) |
| src2_data = np.random.uniform( |
| type_info.min, type_info.max, n).astype(src2_np_type) |
| if not allow_zero: |
| src2_data[src2_data==0] = 1 |
| rs1 = 1.0 if rs1 == 0.0 else rs1 |
| return src2_data, src1_data, rs1 |
| |
| # Return test inputs for integers. |
| type_info = np.iinfo(src1_np_type) |
| src1_data = np.random.randint( |
| type_info.min, type_info.max, n).astype(src1_np_type) |
| rs1 = self.get_np_src1_type()(np.random.randint( |
| type_info.min, type_info.max)) |
| type_info = np.iinfo(src2_np_type) |
| src2_data = np.random.randint( |
| type_info.min, type_info.max, n).astype(src2_np_type) |
| if not allow_zero: |
| src2_data[src2_data==0] = 1 |
| rs1 = 1 if rs1 == 0 else rs1 |
| return src2_data, src1_data, rs1 |
| |
| def pack_dest_mask(self, values): |
| """Pack values into a single destination register.""" |
| dest_type = self.get_np_dest_type() |
| return np.packbits(dest_type(values), bitorder='little') |
| |
| def cast_to_unsigned(arr): |
| """Cast a signed array to an unsigned array. |
| This should not be called with floating point values. |
| """ |
| udtypes = {np.int8:np.uint8, |
| np.int16:np.uint16, |
| np.int32:np.uint32, |
| np.int64:np.uint64} |
| if not arr.dtype.type in udtypes.keys(): |
| raise TypeError |
| return arr.astype(udtypes[arr.dtype.type]) |
| |
| def to_carr_str(arr): |
| """Simple function for turn array into comma separated list.""" |
| return ", ".join(("%s" % x for x in arr)) |