blob: 06e3658d8f059d1f74590c50cee7ee06d01ef142 [file] [log] [blame]
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" vec_test_helpers
This module is for reusable helper functions used by the templates.
"""
import collections
import enum
import re
import numpy as np
class VecTemplateHelper:
"""Given op and sew provide template with necessary parameters."""
class OperandType(enum.Enum):
"""RISC-V V operand type options."""
VECTOR = enum.auto()
SCALAR = enum.auto()
IMMEDIATE = enum.auto()
FLOATSCALAR = enum.auto()
class OperandWidth(enum.Enum):
"""RISC-V V operand width type option."""
STANDARD = enum.auto()
WIDENING = enum.auto()
NARROWING = enum.auto()
mnemonic_suffix = {
OperandType.VECTOR : {
OperandWidth.STANDARD: "vv",
OperandWidth.WIDENING: "vv",
OperandWidth.NARROWING: "wv",
},
OperandType.SCALAR : {
OperandWidth.STANDARD: "vx",
OperandWidth.WIDENING: "vx",
OperandWidth.NARROWING: "wx",
},
OperandType.IMMEDIATE : {
OperandWidth.STANDARD: "vi",
OperandWidth.WIDENING: "vi",
OperandWidth.NARROWING: "wi",
},
OperandType.FLOATSCALAR : {
OperandWidth.STANDARD: "vf",
OperandWidth.WIDENING: "vf",
OperandWidth.NARROWING: "wf",
},
}
"""Helper class for providing params for use in templates"""
def __init__(self, op_code, sew=None):
self._op_code = op_code
self._sew = sew
self.force_unsigned = False
self.signed_np_types = {
8:np.int8,
16:np.int16,
32:np.int32}
self.unsigned_np_types = {
8:np.uint8,
16:np.uint16,
32:np.uint32}
self.float_np_types = {
32:np.float32}
@property
def op_code(self):
"""Return the op_code."""
if self._op_code is None:
raise ValueError("OP CODE was not set.")
return self._op_code
@op_code.setter
def op_code(self, value):
"""Set the op_code"""
self._op_code = value
def is_floating(self):
"""check if a particular op_code is a floating type."""
if self.op_code[1] == 'm':
return self.op_code[2] == 'f'
return self.op_code[1] == 'f'
@property
def sew(self):
"""Return the selected element width."""
if self._sew is None:
raise ValueError("SEW was not set.")
return self._sew
@sew.setter
def sew(self, value):
"""Set the selected element width."""
if self.is_floating():
if not (value == 32):
raise ValueError("Invalid SEW")
if not value in (8, 16, 32):
raise ValueError("Invalid SEW")
self._sew = value
def is_widening(self):
"""Check if a particular op_code is a widening type."""
if self.is_floating():
raise ValueError("Widening is not supported with this template for floating point values")
return self.op_code[1] == 'w'
def is_narrowing(self):
"""Check if a particular op_code is a narrowing type."""
if self.is_floating():
raise ValueError("Narrowing is not supported with this template for floating point values")
return self.op_code[1] == 'n'
def is_destination_mask_register(self):
"""Check if a particular op_code has a mask output."""
if self.is_floating():
comparison_ops = ('vmfeq', 'vmfle', 'vmflt', 'vmfne', 'vmfgt', 'vmfge')
else:
comparison_ops = ('vmseq', 'vmsne', 'vmsltu', 'vmsleu', 'vmsle', 'vmsgtu', 'vmsgt')
return self.op_code in comparison_ops
def is_unsigned(self):
"""Check if a particular op_code is a unsigned type."""
if self.is_floating():
raise ValueError("Invalid unsigned type op code")
return self.op_code[-1] == 'u'
def get_sews(self):
"""Given an op_code return a list of valid element widths."""
if self.is_floating():
return [32]
if self.is_widening() or self.is_narrowing():
return [8, 16]
return [8, 16, 32]
def get_sew_sizes(self):
"""Return size of types.
imm is not used for floating point op codes.
"""
dest_sew = self.sew
src2_sew = self.sew
src1_sew = self.sew
imm_sew = self.sew
if not self.is_floating() and self.is_narrowing():
src2_sew = self.sew * 2
elif not self.is_floating() and self.is_widening():
dest_sew = self.sew * 2
return dest_sew, src2_sew, src1_sew, imm_sew
def get_var_types(self):
"""Return types for an op_code and element width.
imm_type won't be used for floating point values.
"""
VarTypes = collections.namedtuple(
"VarTypes",
('dest_type', 'src2_type', 'src1_type', 'imm_type'))
if self.is_floating():
"""dest_type is defined as int32_t instead of float
to use "set_bit_in_dest_mask" (in softrvv/include/softrvv_internal.h)
, which has bitwise operations in it.
"""
if self.is_destination_mask_register():
var_types = VarTypes("int32_t", "float", "float", "float")
else:
var_types = VarTypes("float", "float", "float", "float")
else:
type_fmt = "%sint%d_t"
sign_type = "u" if self.is_unsigned() or self.force_unsigned else ""
dest_sew, src2_sew, src1_sew, imm_sew = self.get_sew_sizes()
dest_type = type_fmt % (sign_type, dest_sew)
src1_type = type_fmt % (sign_type, src1_sew)
src2_type = type_fmt % (sign_type, src2_sew)
imm_type = type_fmt % (sign_type, imm_sew)
var_types = VarTypes(dest_type, src2_type, src1_type, imm_type)
return var_types
def get_mnemonic(self, operand_type):
"""Generate the correct mnemonic given a opcode and operand type."""
operand_width = self.OperandWidth.STANDARD
if self.is_floating():
if operand_type == self.OperandType.SCALAR:
operand_type = self.OperandType.FLOATSCALAR
else:
if self.is_narrowing():
operand_width = self.OperandWidth.NARROWING
elif self.is_widening():
operand_width = self.OperandWidth.WIDENING
op_suffix = self.mnemonic_suffix[operand_type][operand_width]
return "%s.%s" % (self.op_code, op_suffix)
def get_lmuls(self):
"""Given an op_code return an iterable of valid lmuls."""
if not self.is_floating():
if self.is_widening() or self.is_narrowing():
return [1, 2, 4]
return [1, 2, 4, 8]
@staticmethod
def get_sew_from_dtype(dtype):
"""Extract the selected element width from a data type."""
match = re.match(r"[a-z]+(?P<sew>[\d]+)", dtype)
return int(match['sew'])
def get_softrvv_template_data_type(self):
"""Return types """
var_types = self.get_var_types()
if not self.is_floating():
if self.is_narrowing() or self.is_widening():
return "%s, %s" % (var_types.dest_type, var_types.src2_type)
return var_types.src1_type
def get_ref_opcode(self):
"""Return the name of the reference code in the softrvv library."""
if self.is_floating():
return self.op_code
return self.op_code[:-1] if self.is_unsigned() else self.op_code
def get_imms(self):
"""Return a list of valid immediate values for a op code."""
if self.is_floating():
raise ValueError("imm is not available.")
if self.op_code in ['vsll', 'vsrl', 'vsra', 'vnsrl', 'vnsra']:
# Left and right shift immediates must be [0,31]
return np.linspace(0, 31, 8, dtype=np.int32)
# Immediate values must be [-16, 15]
return np.linspace(-16, 15, 7, dtype=np.int32)
def get_np_dest_type(self):
"""Return numpy type for destination."""
if self.is_floating():
return self.float_np_types[self.sew]
if self.force_unsigned:
types = self.unsigned_np_types
else:
types = self.signed_np_types
if self.is_widening():
return types[self.sew * 2]
return types[self.sew]
def get_np_src1_type(self):
"""Return numpy type for src1."""
if self.is_floating():
return self.float_np_types[self.sew]
if self.force_unsigned:
types = self.unsigned_np_types
else:
types = self.signed_np_types
return types[self.sew]
def get_np_src2_type(self):
"""Return numpy type for src2."""
if self.is_floating():
return self.float_np_types[self.sew]
if self.force_unsigned:
types = self.unsigned_np_types
else:
types = self.signed_np_types
if self.is_narrowing():
return types[self.sew * 2]
return types[self.sew]
def get_test_inputs(self, n=5, allow_zero=True): # pylint: disable=invalid-name
"""Return test inputs."""
src1_np_type = self.get_np_src1_type()
src2_np_type = self.get_np_src2_type()
src2_np_type = self.get_np_src2_type()
# Return test inputs for floating points.
if self.is_floating():
type_info = np.finfo(src1_np_type)
src1_data = np.random.uniform(type_info.min, type_info.max,
n).astype(src1_np_type)
rs1 = self.get_np_src1_type()(np.random.uniform(
type_info.min, type_info.max))
type_info = np.finfo(src2_np_type)
src2_data = np.random.uniform(type_info.min, type_info.max,
n).astype(src2_np_type)
if not allow_zero:
src2_data[src2_data == 0] = 1
rs1 = 1.0 if rs1 == 0.0 else rs1
return src2_data, src1_data, rs1
# Return test inputs for integers.
type_info = np.iinfo(src1_np_type)
src1_data = np.random.randint(
type_info.min, type_info.max, n).astype(src1_np_type)
rs1 = self.get_np_src1_type()(np.random.randint(
type_info.min, type_info.max))
type_info = np.iinfo(src2_np_type)
src2_data = np.random.randint(
type_info.min, type_info.max, n).astype(src2_np_type)
if not allow_zero:
src2_data[src2_data==0] = 1
rs1 = 1 if rs1 == 0 else rs1
return src2_data, src1_data, rs1
def pack_dest_mask(self, values):
"""Pack values into a single destination register."""
dest_type = self.get_np_dest_type()
return np.packbits(dest_type(values), bitorder='little')
def cast_to_unsigned(arr):
"""Cast a signed array to an unsigned array.
This should not be called with floating point values.
"""
udtypes = {np.int8:np.uint8,
np.int16:np.uint16,
np.int32:np.uint32,
np.int64:np.uint64}
if not arr.dtype.type in udtypes.keys():
raise TypeError
return arr.astype(udtypes[arr.dtype.type])
def to_carr_str(arr):
"""Simple function for turn array into comma separated list."""
return ", ".join(("%s" % x for x in arr))