#!/usr/bin/env python3
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Runs test within Spike, Qemu and Renode simulators."""
import argparse
import os
import re
import sys
import tempfile

import io
import pexpect


parser = argparse.ArgumentParser(
    description="Run a springbok test on an simulator.")

parser.add_argument("simulator",
                    help="Select a simulator",
                    choices=["renode", "qemu", "spike"])
parser.add_argument("elf",
                    help="Elf to execute on a simulator")
parser.add_argument("--renode-path",
                    help="Path to renode simulator")
parser.add_argument("--trace-output",
                    help="Path to trace output file")
parser.add_argument("--qemu-path",
                    help="Path to qemu simulator")
parser.add_argument("--spike-path",
                    help="Path to spike simulator")
parser.add_argument("--timeout", type=int,
                    help="Timeout for test", default=1000)
parser.add_argument("--quick_test",
                    help="allow quickest test time", action="store_true")

args = parser.parse_args()

class Simulation: # pylint: disable=too-few-public-methods
    """ Base class for simulation """
    def __init__(self, simulator_cmd):
        self.simulator_cmd = simulator_cmd
        self.buffer = io.StringIO()
        self.child = None
        self.termination_strings = [
            "main returned",
            "Exception occurred",
            "ReadByte from non existing peripheral",
            "File does not exist"
        ]
        # stats collected command for renode
        self.renode_end_command = None

    def run(self, timeout=1000):
        """ Run the simulation command and quit the simulation."""
        self.child = pexpect.spawn(self.simulator_cmd, encoding="utf-8")
        self.child.logfile = self.buffer
        try:
            self.child.expect(self.termination_strings, timeout=timeout)
        except pexpect.exceptions.EOF as run_hit_eof:
            self.buffer.seek(0)
            message = ("Runner reach EOF with the execution log: \n\n" +
                       cleanup_message(self.buffer.read()))
            exc = pexpect.exceptions.EOF(message)
            exc.__cause__ = None
            raise exc from run_hit_eof
        except pexpect.exceptions.TIMEOUT as run_hit_timeout:
            self.buffer.seek(0)
            message = ("Runner times out with the execution log: \n\n" +
                       cleanup_message(self.buffer.read()))
            exc = pexpect.exceptions.EOF(message)
            exc.__cause__ = None
            raise exc from run_hit_timeout
        if self.renode_end_command:
            self.child.send(self.renode_end_command)
            self.child.expect("(springbok)", timeout=timeout)
        self.child.send("\nq\n")
        self.child.expect(pexpect.EOF, timeout=timeout)
        self.child.close()
        self.buffer.seek(0)
        return self.buffer.read()

class QemuSimulation(Simulation): # pylint: disable=too-few-public-methods
    """ Qemu simulation """
    def __init__(self, path, elf):
        self.qemu_simulator_cmd = (
            "%(sim)s -M springbok -nographic -d springbok "
            "-device loader,file=%(elf)s")
        self.sim_params = {"sim": path, "elf": elf}
        super().__init__(self.qemu_simulator_cmd % self.sim_params)


class RenodeSimulation(Simulation): # pylint: disable=too-few-public-methods
    """ Renode Simulation """
    def __init__(self, path, elf):
        # Get the ROOTDIR path if it exists
        self.rootdir = os.environ.get("ROOTDIR", default=None)
        if self.rootdir is None:
            parser.error("ROOTDIR environment variable not set.")
        renode_script = """
$bin=@%(elf)s
path set @%(rootdir)s
include @sim/config/springbok.resc"""

        if args.quick_test:
            renode_script += """
sysbus.cpu2 PerformanceInMips 2000
emulation SetGlobalQuantum "1" """

        if args.trace_output:
            renode_script += """
sysbus.cpu2 EnableExecutionTracing @%(trace_file)s PCAndOpcode """

        renode_script += """
sysbus.cpu2 EnableExternalWindowMmu false
sysbus.vec_controlblock WriteDoubleWord 0xc 0
start"""

        trace_file = ""
        if args.trace_output:
            trace_file = os.path.realpath(args.trace_output)
        self.script_params = {
            "elf": os.path.realpath(elf),
            "rootdir": self.rootdir,
            "trace_file": trace_file
        }
        self.renode_script = renode_script % self.script_params
        self.renode_args = [
            f"{path}",
            "--disable-xwt",
            " --console",
            "--plain",
        ]
        self.renode_simulator_cmd = " ".join(self.renode_args)
        super().__init__(self.renode_simulator_cmd)
        self.renode_end_command = "\nsysbus.cpu2 ExecutedInstructions\n"

    def run(self, timeout=120):
        file_desc, script_path = tempfile.mkstemp(suffix=".resc")
        try:
            with os.fdopen(file_desc, "w") as tmp:
                tmp.write(self.renode_script)
                tmp.flush()
            self.simulator_cmd += f" {script_path}"
            test_output = super().run(timeout=timeout)
        finally:
            os.remove(script_path)
        return test_output

class SpikeSimulation(Simulation): # pylint: disable=too-few-public-methods
    """ Spike Simulation """
    def __init__(self, path, elf):
        trace_file = ""
        if args.trace_output:
            trace_file = os.path.realpath(args.trace_output)
        self.sim_params = {
                "path": path,
                "elf": elf,
                "trace_file": trace_file
                }
        self.spike_simulator_cmd = (
            "%(path)s -m0x34000000:0x1000000 "
            "--pc=0x34000000 ")

        if args.trace_output:
            self.spike_simulator_cmd += " -l --log=%(trace_file)s "

        self.spike_simulator_cmd += " %(elf)s"
        super().__init__(self.spike_simulator_cmd % self.sim_params)

Simulators = {
    "qemu": QemuSimulation,
    "renode": RenodeSimulation,
    "spike": SpikeSimulation,
}

simulators_paths = {
    "renode": args.renode_path,
    "qemu": args.qemu_path,
    "spike": args.spike_path,
}

def cleanup_message(message: str) -> str:
    """ Clean up the message generated by Mono.

    The non-ascii code generated by Mono.

    Convert the opcode count from hex to decimal.
    """
    ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
    output = ansi_escape.sub("", message)
    op_count_out = re.compile(
        r"(?P<op_count>0x[0-9A-Fa-f]+)\r\r\r\n\(springbok\)")
    op_count = op_count_out.search(output)
    if op_count:
        op_count = int(op_count.group(1), 16)
        output = op_count_out.sub(
            f"Renode total instruction count: {op_count}\n", output)
    return output

def main():
    """ Run a test and check for Pass or Fail """
    simulator_path = simulators_paths[args.simulator]
    if simulator_path is None:
        parser.error(
            f"Must provide path to simulator {args.simulator}, "
            f"use argument --{args.simulator}-path")

    simulator_class = Simulators[args.simulator]
    simulator = simulator_class(simulator_path, args.elf)
    output = simulator.run(timeout=args.timeout)
    output = cleanup_message(output)
    print(output)
    failure_strings = [
        "FAILED",
        "Exception occurred",
        "ReadByte from non existing peripheral",
        "File does not exist"
    ]
    if any(x in output for x in failure_strings):
        sys.exit(1)
    # Grab the return code from the output string with regex
    # Syntax: "main returned: ", <code> (<hex_code>)
    return_string = re.compile(
        r"\"main returned:\s\",(?P<ret_code>\s[0-9]+\s*)")
    code = return_string.search(output)
    sys.exit(int(code.group(1)))


if __name__ == "__main__":
    main()
