#!/usr/bin/env python3
# Copyright lowRISC contributors.
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0

import argparse
import math
import sys
from typing import List, Optional

# Size of OTBN wide data registers.
WDR_SIZE = 256
# Size of bn.mulqacc multiplier.
LIMB_SIZE = 64
# Limbs per WDR.
LIMBS_PER_WDR = WDR_SIZE // LIMB_SIZE
# Assert that limb size divides data register size.
assert(WDR_SIZE % LIMB_SIZE == 0)


class Line:
    '''Represents a single line of a bn.mulqacc block.'''
    def __init__(self, modifiers: List[str], dst: Optional[str], lhs: str,
                 rhs: str, shift: int):
        self.modifiers = modifiers
        self.dst = dst
        self.lhs = lhs
        self.rhs = rhs
        self.shift = shift

    def insn_string(self) -> str:
        '''Return the string representing the instruction.

        Includes 2 spaces after the instruction, which is the minimum required
        padding.
        '''
        # The empty string produces a leading '.' if modifiers exist.
        modifiers_str = '.'.join([''] + self.modifiers)
        return f'bn.mulqacc{modifiers_str}  '

    def dst_string(self) -> str:
        '''Return the string representing the writeback destination.

        Includes a comma and one space of required padding, if self.dst exists.
        '''
        return (self.dst + ', ') if self.dst is not None else ''

    def min_lhs_align(self) -> int:
        '''Returns the minimum column at which the left operand can appear.'''
        return len(self.insn_string()) + len(self.dst_string())

    def pretty(self, lhs_align: Optional[int] = None) -> str:
        '''Stringify the line.

        If `lhs_align` is provided, adds padding to align the left-hand operand
        at that index.
        '''
        insn = self.insn_string()
        dst_str = self.dst_string()
        shift_str = f', {self.shift}' if self.shift != 0 else ''
        if lhs_align is not None and lhs_align < len(insn) + len(dst_str):
            raise ValueError(
                f'Cannot align left-hand operand to the index {lhs_align}; '
                f'need {len(insn) + len(dst_str)} characters for '
                f'{insn + dst_str}')
        if lhs_align is None:
            padding = ''
        else:
            padding = ' ' * (lhs_align - len(insn) - len(dst_str))
        return f'{insn}{padding}{dst_str}{self.lhs}, {self.rhs}{shift_str}'


def make_operand(regs: List[str], limb_index: int) -> str:
    '''Get the string representing the requested limb of the bignum.

    Assumes LIMB_SIZE-bit limbs and WDR_SIZE-bit registers.
    '''
    reg = regs[limb_index // LIMBS_PER_WDR]
    index = limb_index % LIMBS_PER_WDR
    return f'{reg}.{index}'


def make_half_word_writeback(regs: List[str], limb_index: int) -> str:
    '''Get the string representing for the half-word of the result where the
    given limb starts.'''
    reg = regs[limb_index // LIMBS_PER_WDR]
    index = limb_index % LIMBS_PER_WDR
    if index < LIMBS_PER_WDR / 2:
        return f'{reg}.L'
    return f'{reg}.U'


def mask(n: int) -> int:
    '''Return an n-bit mask.'''
    return (1 << n) - 1


def make_block(lhs_bits: int, lhs_regs: List[str], rhs_bits: int,
               rhs_regs: List[str], result_regs: List[str]) -> List[Line]:
    '''Print out the bn.mulqacc block for the bignum product.

    Requires that:
      - The `lhs_regs` list has enough registers to store `lhs_bits` bits
      - The `rhs_regs` list has enough registers to store `rhs_bits` bits
      - The `result_regs` list has enough registers to store the result

    Returns the result as a list, each representing a single line of assembly.
    '''
    # Determine the number of limbs for bn.mulqacc.
    lhs_nlimbs = math.ceil(lhs_bits / LIMB_SIZE)
    rhs_nlimbs = math.ceil(rhs_bits / LIMB_SIZE)

    # Determine the position of the most significant partial product. Note that
    # this product is up to 128 bits, so the actual number of 64-bit limbs in
    # the result is likely 1 higher than this number.

    # Determine the number of "terms" (partial products with the same weight)
    # in the result. For example, if the LHS and RHS each have 2 limbs, the
    # result (a0b0 + (a0b1 + a1b0) << 64 + a1b1 << 128) has 3 terms.
    result_nterms = lhs_nlimbs + rhs_nlimbs - 1

    # Determine the number of full-size words in the result.
    result_nwords = math.ceil((lhs_bits + rhs_bits) / WDR_SIZE)

    # Create the main body of the loop, with half-word writebacks at every 128
    # bits of the result. Track the bounds of the accumulator.
    lines = []
    acc_max = 0
    for i in range(result_nterms):
        shift_bits = (i % 2) * 64
        is_final_word = (i // LIMBS_PER_WDR) == result_nwords - 1

        # Gather the partial products for this limb of the result.
        partial_products = []
        for j in range(i + 1):
            if j < lhs_nlimbs and (i - j) < rhs_nlimbs:
                partial_products.append((j, i - j))
        # Sanity check.
        if len(partial_products) == 0:
            raise RuntimeError(f'No partial products for limb {i}!')

        for lhs_idx, rhs_idx in partial_products:
            # Compute the upper bound of the partial product.
            lhs_max = mask(LIMB_SIZE)
            if lhs_idx == lhs_nlimbs - 1 and lhs_bits % LIMB_SIZE != 0:
                lhs_max = mask(lhs_bits % LIMB_SIZE)
            rhs_max = mask(LIMB_SIZE)
            if rhs_idx == rhs_nlimbs - 1 and rhs_bits % LIMB_SIZE != 0:
                rhs_max = mask(rhs_bits % LIMB_SIZE)
            # Update the accumulator bounds and append the multiplication.
            acc_max += ((lhs_max * rhs_max) << shift_bits)
            lines.append(
                Line([], None, make_operand(lhs_regs, lhs_idx),
                     make_operand(rhs_regs, rhs_idx), shift_bits))

        if acc_max > mask(256):
            raise ValueError(
                f'Accumulator overflow while computing limb {i} of result!')

        # Every second limb (except for the final word), write back a half-word
        # of the result as part of the last instruction.
        if not is_final_word and i % 2 == 1:
            lines[-1].modifiers.append('so')
            lines[-1].dst = make_half_word_writeback(result_regs, i)
            # Update the accumulator bounds to reflect a half-word shift right.
            acc_max = mask(max(0, int.bit_length(acc_max) - (WDR_SIZE // 2)))

    # It's possible that the last result term crosses a word boundary, so we
    # never will have hit the "final word" condition. If so, we need to add an
    # instruction that just writes back the accumulator to the final word
    # without modifying it. We assume here that w31 = 0.
    if not is_final_word:
        lines.append(Line([], None, 'w31.0', 'w31.0', 0))

    # Write back the final word of the result by adding a `.wo` modifier to the
    # last instruction.
    lines[-1].modifiers.append('wo')
    lines[-1].dst = result_regs[result_nwords - 1]

    # Ensure that the first instruction zeroes the accumulator.
    lines[0].modifiers.append('z')

    return lines


def format_block(lines: List[Line]) -> List[str]:
    '''Format the lines so that the left-hand operands are all aligned.'''
    align = max([line.min_lhs_align() for line in lines])
    return [line.pretty(align) for line in lines]


def is_wdr(reg_name: str) -> bool:
    if not reg_name.startswith('w'):
        return False
    if len(reg_name) < 1:
        return False
    reg_num = reg_name[1:]
    if not reg_num.isnumeric():
        return False
    reg_idx = int(reg_num)
    return (0 <= reg_idx) and (reg_idx <= 31)


def check_regs(regs: List[str], bits: int) -> None:
    '''Checks that register names are valid and there are enough registers to
       store the given number of bits.'''
    for reg_name in regs:
        if not is_wdr(reg_name):
            raise ValueError(f'{reg_name} is not a valid wide data register.')
    if len(regs) * WDR_SIZE < bits:
        raise ValueError(
            f'Register list {regs} does not have enough space to store a '
            f'{bits}-bit number.')
    return


def main() -> int:
    parser = argparse.ArgumentParser(description=(
        'Print out a bn.mulqacc block for the given bignum product.'))
    parser.add_argument('--lhs-bits',
                        type=int,
                        required=True,
                        help='Number of bits in the left-hand operand.')
    parser.add_argument(
        '--lhs-regs',
        nargs='+',
        type=str,
        required=True,
        help='Registers for left-hand operand (little-endian order).')
    parser.add_argument('--rhs-bits',
                        type=int,
                        required=True,
                        help='Number of bits in the right-hand operand.')
    parser.add_argument(
        '--rhs-regs',
        nargs='+',
        type=str,
        required=True,
        help='Registers for the right-hand operand (little-endian order).')
    parser.add_argument(
        '--result-regs',
        nargs='+',
        type=str,
        required=True,
        help=('Registers in which to store the result. Must not overlap with '
              'operand registers.'))
    args = parser.parse_args()

    check_regs(args.lhs_regs, args.lhs_bits)
    check_regs(args.rhs_regs, args.rhs_bits)
    check_regs(args.result_regs, args.lhs_bits + args.rhs_bits)

    # Check that result registers don't overlap with operand registers.
    for reg in args.result_regs:
        if reg in args.lhs_regs or reg in args.rhs_regs:
            raise ValueError(
                f'Result register {reg} also appears in one of the operands.')

    # Check that neither operand is 0 bits.
    if args.lhs_bits == 0 or args.rhs_bits == 0:
        raise ValueError('Both operands must have at least 1 bit.')

    block = make_block(args.lhs_bits, args.lhs_regs, args.rhs_bits,
                       args.rhs_regs, args.result_regs)
    for line in format_block(block):
        print(line)

    return 0


if __name__ == "__main__":
    sys.exit(main())
