# Copyright 2021 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import (
    Any,
    Callable,
    Dict,
    Generic,
    List,
    Optional,
    Sequence,
    Tuple,
    TypeVar,
    Union,
)
import pathlib
import dataclasses
import json
import urllib.parse
import markdown_strings as md
import math

from common import benchmark_definition, benchmark_thresholds
from common.benchmark_thresholds import (
    BENCHMARK_THRESHOLDS,
    COMPILATION_TIME_THRESHOLDS,
    TOTAL_ARTIFACT_SIZE_THRESHOLDS,
    TOTAL_DISPATCH_SIZE_THRESHOLDS,
    BenchmarkThreshold,
    ThresholdUnit,
)

GetMetricFunc = Callable[[Any], Tuple[int, Optional[int]]]

PERFBOARD_SERIES_PREFIX = "https://perf.iree.dev/serie?IREE?"
BENCHMARK_RESULTS_HEADERS = [
    "Benchmark Name",
    "Average Latency (ms)",
    "Median Latency (ms)",
    "Latency Standard Deviation (ms)",
]
# Since We don't have a structural way to store metric data yet, each metric is
# assigned with a fixed id generated from uuid.uuid4(), to identify the series.
COMPILATION_TIME_METRIC_ID = "e54cd682-c079-4c42-b4ad-d92c4bedea13"
COMPILATION_TIME_SERIES_SUFFIX = "compilation:module:compilation-time"
TOTAL_DISPATCH_SIZE_METRIC_ID = "9e15f7e6-383c-47ec-bd38-ecba55a5f10a"
TOTAL_DISPATCH_SIZE_SERIES_SUFFIX = (
    "compilation:module:component-size:total-dispatch-size"
)
TOTAL_ARTIFACT_SIZE_METRIC_ID = "2c8a9198-c01c-45b9-a7da-69c82cf749f7"
TOTAL_ARTIFACT_SIZE_SERIES_SUFFIX = "compilation:module:total-artifact-size"
STREAM_IR_DISPATCH_COUNT_METRIC_ID = "7b72cd9e-43ed-4078-b6d3-20b810f9e4ad"
STREAM_IR_DISPATCH_COUNT_SERIES_SUFFIX = "compilation:ir:stream-dispatch-count"


@dataclass
class AggregateBenchmarkLatency:
    """An object for describing aggregate latency numbers for a benchmark."""

    name: str
    benchmark_info: benchmark_definition.BenchmarkInfo
    mean_time: int
    median_time: int
    stddev_time: int
    # The average latency time for the base commit to compare against.
    base_mean_time: Optional[int] = None

    def __str__(self) -> str:
        return self.name


@dataclass(frozen=True)
class CompilationMetrics:
    """An object for describing the summary of statistics and the reference."""

    name: str
    compilation_info: benchmark_definition.CompilationInfo
    compilation_time_ms: int
    total_dispatch_component_bytes: int
    total_artifact_bytes: int
    stream_ir_dispatch_count: int
    base_compilation_time_ms: Optional[int] = None
    base_total_artifact_bytes: Optional[int] = None
    base_total_dispatch_component_bytes: Optional[int] = None
    base_stream_ir_dispatch_count: Optional[int] = None

    def __str__(self) -> str:
        return self.name


T = TypeVar("T")


class MetricsToTableMapper(ABC, Generic[T]):
    """Abstract class to help map benchmark metrics to table.

    It contains a set of methods to help table generator get the required
    information for a metric. For example, extract the current and base metric
    value, the metric thresholds, the table header of the metrics, ...
    """

    @abstractmethod
    def update_base_value(self, obj: T, base_value: Any) -> T:
        """Sets the base value and returns the updated metric object."""
        raise NotImplementedError()

    @abstractmethod
    def get_current_and_base_value(self, obj: T) -> Tuple[int, Optional[int]]:
        """Returns the current and base (can be None) value."""
        raise NotImplementedError()

    def get_series_id(self, benchmark_id: str) -> str:
        """Returns the dashboard series id."""
        return f"{benchmark_id}-{self.get_metric_id()}"

    @abstractmethod
    def get_metric_id(self) -> str:
        """Returns the dashboard series id."""
        raise NotImplementedError()

    @abstractmethod
    def get_series_name(self, name: str) -> str:
        """Returns the dashboard series name."""
        raise NotImplementedError()

    @abstractmethod
    def get_unit(self) -> str:
        """Returns the unit of the metric value."""
        raise NotImplementedError()

    @abstractmethod
    def get_table_header(self) -> str:
        """Returns the header of the table."""
        raise NotImplementedError()

    @staticmethod
    @abstractmethod
    def get_metric_thresholds() -> Sequence[BenchmarkThreshold]:
        raise NotImplementedError()

    @staticmethod
    @abstractmethod
    def get_table_title() -> str:
        raise NotImplementedError()


class CompilationTimeToTable(MetricsToTableMapper[CompilationMetrics]):
    """Helper to map CompilationMetrics to compilation time column."""

    def update_base_value(
        self, compile_metrics: CompilationMetrics, base_value: Any
    ) -> CompilationMetrics:
        return dataclasses.replace(compile_metrics, base_compilation_time_ms=base_value)

    def get_current_and_base_value(
        self, compile_metrics: CompilationMetrics
    ) -> Tuple[int, Optional[int]]:
        return (
            compile_metrics.compilation_time_ms,
            compile_metrics.base_compilation_time_ms,
        )

    def get_metric_id(self) -> str:
        return COMPILATION_TIME_METRIC_ID

    def get_series_name(self, name: str) -> str:
        return f"{name} [{COMPILATION_TIME_SERIES_SUFFIX}]"

    def get_unit(self) -> str:
        return "ms"

    def get_table_header(self) -> str:
        return f"Compilation Time ({self.get_unit()})"

    @staticmethod
    def get_metric_thresholds() -> Sequence[BenchmarkThreshold]:
        return COMPILATION_TIME_THRESHOLDS

    @staticmethod
    def get_table_title() -> str:
        return "Compilation Times"


class TotalDispatchSizeToTable(MetricsToTableMapper[CompilationMetrics]):
    """Helper to map CompilationMetrics to total dispatch size column."""

    def update_base_value(
        self, compile_metrics: CompilationMetrics, base_value: Any
    ) -> CompilationMetrics:
        return dataclasses.replace(
            compile_metrics, base_total_dispatch_component_bytes=base_value
        )

    def get_current_and_base_value(
        self, compile_metrics: CompilationMetrics
    ) -> Tuple[int, Optional[int]]:
        return (
            compile_metrics.total_dispatch_component_bytes,
            compile_metrics.base_total_dispatch_component_bytes,
        )

    def get_metric_id(self) -> str:
        return TOTAL_DISPATCH_SIZE_METRIC_ID

    def get_series_name(self, name: str) -> str:
        return f"{name} [{TOTAL_DISPATCH_SIZE_SERIES_SUFFIX}]"

    def get_unit(self) -> str:
        return "bytes"

    def get_table_header(self) -> str:
        return f"Total Dispatch Size ({self.get_unit()})"

    @staticmethod
    def get_metric_thresholds() -> Sequence[BenchmarkThreshold]:
        return TOTAL_DISPATCH_SIZE_THRESHOLDS

    @staticmethod
    def get_table_title() -> str:
        return "Total Dispatch Sizes"


class TotalArtifactSizeToTable(MetricsToTableMapper[CompilationMetrics]):
    """Helper to map CompilationMetrics to total artifact size column."""

    def update_base_value(
        self, compile_metrics: CompilationMetrics, base_value: Any
    ) -> CompilationMetrics:
        return dataclasses.replace(
            compile_metrics, base_total_artifact_bytes=base_value
        )

    def get_current_and_base_value(
        self, compile_metrics: CompilationMetrics
    ) -> Tuple[int, Optional[int]]:
        return (
            compile_metrics.total_artifact_bytes,
            compile_metrics.base_total_artifact_bytes,
        )

    def get_metric_id(self) -> str:
        return TOTAL_ARTIFACT_SIZE_METRIC_ID

    def get_series_name(self, name: str) -> str:
        return f"{name} [{TOTAL_ARTIFACT_SIZE_SERIES_SUFFIX}]"

    def get_unit(self) -> str:
        return "bytes"

    def get_table_header(self) -> str:
        return f"Total Artifact Size ({self.get_unit()})"

    @staticmethod
    def get_metric_thresholds() -> Sequence[BenchmarkThreshold]:
        return TOTAL_ARTIFACT_SIZE_THRESHOLDS

    @staticmethod
    def get_table_title() -> str:
        return "Total Artifact Sizes"


class StreamIRDispatchCountToTable(MetricsToTableMapper[CompilationMetrics]):
    """Helper to map CompilationMetrics to Stream IR Dispatch Count column."""

    def update_base_value(
        self, compile_metrics: CompilationMetrics, base_value: Any
    ) -> CompilationMetrics:
        return dataclasses.replace(
            compile_metrics, base_stream_ir_dispatch_count=base_value
        )

    def get_current_and_base_value(
        self, compile_metrics: CompilationMetrics
    ) -> Tuple[int, Optional[int]]:
        return (
            compile_metrics.stream_ir_dispatch_count,
            compile_metrics.base_stream_ir_dispatch_count,
        )

    def get_metric_id(self) -> str:
        return STREAM_IR_DISPATCH_COUNT_METRIC_ID

    def get_series_name(self, name: str) -> str:
        return f"{name} [{STREAM_IR_DISPATCH_COUNT_SERIES_SUFFIX}]"

    def get_unit(self) -> str:
        return "number"

    def get_table_header(self) -> str:
        return f"Stream IR Dispatch Count (# of cmd.dispatch ops)"

    @staticmethod
    def get_metric_thresholds() -> Sequence[BenchmarkThreshold]:
        return benchmark_thresholds.STREAM_IR_DISPATCH_COUNT_THRESHOLDS

    @staticmethod
    def get_table_title() -> str:
        return "Stream IR Dispatch Count (# of cmd.dispatch ops)"


COMPILATION_METRICS_TO_TABLE_MAPPERS: List[MetricsToTableMapper[CompilationMetrics]] = [
    CompilationTimeToTable(),
    TotalDispatchSizeToTable(),
    TotalArtifactSizeToTable(),
    StreamIRDispatchCountToTable(),
]


def aggregate_all_benchmarks(
    benchmark_files: Sequence[pathlib.Path], expected_pr_commit: Optional[str] = None
) -> Dict[str, AggregateBenchmarkLatency]:
    """Aggregates all benchmarks in the given files.

    Args:
    - benchmark_files: A list of JSON files, each can be decoded as a
      BenchmarkResults.
    - expected_pr_commit: An optional Git commit SHA to match against.

    Returns:
    - A dict of benchmark names to AggregateBenchmarkLatency numbers.
    """

    aggregate_results = {}
    benchmark_names = set()
    for benchmark_file in benchmark_files:
        file_results = benchmark_definition.BenchmarkResults.from_json_str(
            benchmark_file.read_text()
        )

        if (expected_pr_commit is not None) and (
            file_results.commit != expected_pr_commit
        ):
            raise ValueError("Inconsistent pull request commit")

        for benchmark_index in range(len(file_results.benchmarks)):
            benchmark_run = file_results.benchmarks[benchmark_index]

            series_name = str(benchmark_run.info)
            # Make sure each benchmark has a unique name.
            if series_name in benchmark_names:
                raise ValueError(f"Duplicated benchmark name: {series_name}")
            benchmark_names.add(series_name)

            series_id = benchmark_run.info.run_config_id
            if series_id in aggregate_results:
                raise ValueError(f"Duplicated benchmark id: {series_id}")

            aggregate_results[series_id] = AggregateBenchmarkLatency(
                name=series_name,
                benchmark_info=benchmark_run.info,
                mean_time=benchmark_run.metrics.real_time.mean,
                median_time=benchmark_run.metrics.real_time.median,
                stddev_time=benchmark_run.metrics.real_time.stddev,
            )

    return aggregate_results


def collect_all_compilation_metrics(
    compile_stats_files: Sequence[pathlib.Path],
    expected_pr_commit: Optional[str] = None,
) -> Dict[str, CompilationMetrics]:
    """Collects all compilation statistics in the given files.

    Args:
      compile_stats_files: A list of JSON files, each can be decoded as a
        CompilationResults.
      expected_pr_commit: An optional Git commit SHA to match against.

    Returns:
      A dict of benchmark names to CompilationMetrics.
    """
    compile_metrics = {}
    target_names = set()
    for compile_stats_file in compile_stats_files:
        with compile_stats_file.open("r") as f:
            file_results = benchmark_definition.CompilationResults.from_json_object(
                json.load(f)
            )

        if (expected_pr_commit is not None) and (
            file_results.commit != expected_pr_commit
        ):
            raise ValueError("Inconsistent pull request commit")

        for compile_stats in file_results.compilation_statistics:
            component_sizes = compile_stats.module_component_sizes
            stream_dispatch_count = compile_stats.ir_stats.stream_dispatch_count

            target_name = str(compile_stats.compilation_info)
            if target_name in target_names:
                raise ValueError(f"Duplicated target name: {target_name}")
            target_names.add(target_name)

            target_id = compile_stats.compilation_info.gen_config_id
            if target_id in compile_metrics:
                raise ValueError(f"Duplicated target id: {target_id}")

            compile_metrics[target_id] = CompilationMetrics(
                name=target_name,
                compilation_info=compile_stats.compilation_info,
                compilation_time_ms=compile_stats.compilation_time_ms,
                total_artifact_bytes=component_sizes.file_bytes,
                total_dispatch_component_bytes=component_sizes.total_dispatch_component_bytes,
                stream_ir_dispatch_count=stream_dispatch_count,
            )

    return compile_metrics


def _make_series_link(name: str, series_id: str) -> str:
    """Add link to the given benchmark name.

    Args:
      name: the text to show on the link.
      series_id: the dashboard series id.
    """
    url = PERFBOARD_SERIES_PREFIX + urllib.parse.quote(series_id, safe="()[]@,")
    return md.link(name, url)


def _add_header_and_get_markdown_table(
    headers: Sequence[str], rows: Sequence[Tuple], size_cut: Optional[int] = None
) -> str:
    """Generates a markdown table with headers.

    Args:
      headers: list of table headers.
      rows: list of rows. Each row is a tuple with the same length as headers.
      size_cut: If not None, only show the top N results for each table.
    """

    total_size = len(rows)
    if size_cut is not None:
        rows = rows[0:size_cut]

    columns = [[header] for header in headers]
    for row in rows:
        for column, item in zip(columns, row):
            column.append(item)

    table_str = md.table(columns)
    if size_cut is not None and size_cut < total_size:
        table_str += "\n\n"
        table_str += md.italics(f"[Top {size_cut} out of {total_size} results showed]")
    return table_str


T = TypeVar("T")


def _categorize_on_single_metric(
    metrics_map: Dict[str, T],
    metric_func: GetMetricFunc,
    thresholds: Sequence[BenchmarkThreshold],
    metric_unit: str,
) -> Tuple[Dict[str, T], Dict[str, T], Dict[str, T], Dict[str, T]]:
    """Categorize the metrics object into regressed, improved, similar, and the
    raw group (the group with no base to compare to).

    Args:
      metrics_map: map of (series_id, metrics object).
      metric_func: the function returns current and base value of the metric.
      thresholds: list of threshold settings to match for categorizing.
    Returns:
      A tuple of (regressed, improved, similar, raw) groups.
    """

    regressed_map = {}
    improved_map = {}
    similar_map = {}
    raw_map = {}
    for series_id, metrics_obj in metrics_map.items():
        current, base = metric_func(metrics_obj)
        if base is None:
            raw_map[series_id] = metrics_obj
            continue

        series_name = str(metrics_obj)
        similar_threshold = None
        for threshold in thresholds:
            if threshold.regex.match(series_name):
                similar_threshold = threshold
                break
        if similar_threshold is None:
            raise ValueError(f"No matched threshold setting for: {series_name}")

        if similar_threshold.unit == ThresholdUnit.PERCENTAGE:
            ratio = abs(current - base) / base * 100
        elif similar_threshold.unit.value == metric_unit:
            ratio = abs(current - base)
        else:
            raise ValueError(
                f"Mismatch between metric unit '{metric_unit}' and threshold unit '{similar_threshold.unit.value}'"
            )

        if ratio <= similar_threshold.threshold:
            similar_map[series_id] = metrics_obj
        elif current > base:
            regressed_map[series_id] = metrics_obj
        else:
            improved_map[series_id] = metrics_obj

    return (regressed_map, improved_map, similar_map, raw_map)


def _get_fixed_point_str(value: Union[int, float], digits=3) -> str:
    if isinstance(value, int) or value.is_integer():
        return str(math.floor(value))
    return f"{{:.{digits}f}}".format(value)


def _get_compare_text(current: float, base: Optional[int]) -> str:
    """Generates the text of comparison between current and base value. Returns
    the current value if the base value is None.
    """
    # If base is None, don't need to do compare.
    if base is None:
        return f"{_get_fixed_point_str(current)}"

    ratio = abs(current - base) / base
    direction = "↑" if current > base else ("↓" if current < base else "")
    return f"{_get_fixed_point_str(current)} (vs. {_get_fixed_point_str(base)}, {ratio:.2%}{direction})"


def _sort_benchmarks_and_get_table(
    benchmarks: Dict[str, AggregateBenchmarkLatency], size_cut: Optional[int] = None
) -> str:
    """Sorts all benchmarks according to the improvement/regression ratio and
    returns a markdown table for it.

    Args:
      benchmarks_map: map of (series_id, benchmark object).
      size_cut: If not None, only show the top N results for each table.
    """
    sorted_rows = []
    for series_id, benchmark in benchmarks.items():
        current = benchmark.mean_time / 1e6
        base = benchmark.base_mean_time / 1e6
        ratio = abs(current - base) / base
        str_mean = _get_compare_text(current, base)
        clickable_name = _make_series_link(benchmark.name, series_id)
        sorted_rows.append(
            (
                ratio,
                (
                    clickable_name,
                    str_mean,
                    f"{_get_fixed_point_str(benchmark.median_time / 1e6)}",
                    f"{_get_fixed_point_str(benchmark.stddev_time / 1e6)}",
                ),
            )
        )
    sorted_rows.sort(key=lambda row: row[0], reverse=True)

    return _add_header_and_get_markdown_table(
        headers=BENCHMARK_RESULTS_HEADERS,
        rows=[row[1] for row in sorted_rows],
        size_cut=size_cut,
    )


def categorize_benchmarks_into_tables(
    benchmarks: Dict[str, AggregateBenchmarkLatency], size_cut: Optional[int] = None
) -> str:
    """Splits benchmarks into regressed/improved/similar/raw categories and
    returns their markdown tables.

    If size_cut is None, the table includes regressed/improved/similar/raw
    categories; otherwise, the table includes regressed/improved/raw categories.

    Args:
      benchmarks: A dictionary of benchmark names to its aggregate info.
      size_cut: If not None, only show the top N results for each table.
    """
    regressed, improved, similar, raw = _categorize_on_single_metric(
        benchmarks,
        lambda results: (results.mean_time, results.base_mean_time),
        BENCHMARK_THRESHOLDS,
        "ns",
    )

    tables = []
    if regressed:
        tables.append(md.header("Regressed Latencies 🚩", 3))
        tables.append(_sort_benchmarks_and_get_table(regressed, size_cut))
    if improved:
        tables.append(md.header("Improved Latencies 🎉", 3))
        tables.append(_sort_benchmarks_and_get_table(improved, size_cut))
    # If we want to abbreviate, similar results won't be interesting.
    if similar and size_cut is None:
        tables.append(md.header("Similar Latencies", 3))
        tables.append(_sort_benchmarks_and_get_table(similar, size_cut))
    if raw:
        tables.append(md.header("Raw Latencies", 3))
        raw_list = [
            (
                _make_series_link(name=v.name, series_id=k),
                f"{_get_fixed_point_str(v.mean_time / 1e6)}",
                f"{_get_fixed_point_str(v.median_time / 1e6)}",
                f"{_get_fixed_point_str(v.stddev_time / 1e6)}",
            )
            for k, v in raw.items()
        ]
        tables.append(
            _add_header_and_get_markdown_table(
                BENCHMARK_RESULTS_HEADERS, raw_list, size_cut=size_cut
            )
        )
    return "\n\n".join(tables)


def _sort_metrics_objects_and_get_table(
    metrics_objs: Dict[str, T],
    mapper: MetricsToTableMapper[T],
    headers: Sequence[str],
    size_cut: Optional[int] = None,
) -> str:
    """Sorts all metrics objects according to the improvement/regression ratio and
    returns a markdown table for it.

    Args:
      metrics_objs: map of (target_id, CompilationMetrics). All objects must
        contain base value.
      mapper: MetricsToTableMapper for metrics_objs.
      headers: list of table headers.
      size_cut: If not None, only show the top N results for each table.
    """
    sorted_rows = []
    for target_id, metrics_obj in metrics_objs.items():
        current, base = mapper.get_current_and_base_value(metrics_obj)
        if base is None:
            raise AssertionError("Base can't be None for sorting.")
        ratio = abs(current - base) / base
        sorted_rows.append(
            (
                ratio,
                (
                    _make_series_link(
                        str(metrics_obj), mapper.get_series_id(target_id)
                    ),
                    _get_compare_text(current, base),
                ),
            )
        )
    sorted_rows.sort(key=lambda row: row[0], reverse=True)

    return _add_header_and_get_markdown_table(
        headers=headers, rows=[row[1] for row in sorted_rows], size_cut=size_cut
    )


def categorize_compilation_metrics_into_tables(
    compile_metrics_map: Dict[str, CompilationMetrics], size_cut: Optional[int] = None
) -> str:
    """Splits compilation metrics into regressed/improved/all categories
    and returns their markdown tables.

    If size_cut is None, the table includes regressed/improved/all categories;
    otherwise, the table includes regressed/improved categories.

    Args:
      compile_metrics_map: A dictionary of benchmark names to its compilation
        metrics.
      size_cut: If not None, only show the top N results for each table.
    """

    tables = []
    for mapper in COMPILATION_METRICS_TO_TABLE_MAPPERS:
        regressed, improved, _, _ = _categorize_on_single_metric(
            compile_metrics_map,
            mapper.get_current_and_base_value,
            mapper.get_metric_thresholds(),
            mapper.get_unit(),
        )

        table_title = mapper.get_table_title()
        table_header = mapper.get_table_header()
        if regressed:
            tables.append(md.header(f"Regressed {table_title} 🚩", 3))
            tables.append(
                _sort_metrics_objects_and_get_table(
                    metrics_objs=regressed,
                    mapper=mapper,
                    headers=["Benchmark Name", table_header],
                    size_cut=size_cut,
                )
            )
        if improved:
            tables.append(md.header(f"Improved {table_title} 🎉", 3))
            tables.append(
                _sort_metrics_objects_and_get_table(
                    metrics_objs=improved,
                    mapper=mapper,
                    headers=["Benchmark Name", table_header],
                    size_cut=size_cut,
                )
            )

    # If we want to abbreviate, similar results won't be interesting.
    if size_cut is None and compile_metrics_map:
        tables.append(md.header("All Compilation Metrics", 3))
        headers = ["Benchmark Name"] + [
            mapper.get_table_header() for mapper in COMPILATION_METRICS_TO_TABLE_MAPPERS
        ]
        rows = []
        for target_id, metrics in compile_metrics_map.items():
            row = [metrics.name]
            for mapper in COMPILATION_METRICS_TO_TABLE_MAPPERS:
                current, base = mapper.get_current_and_base_value(metrics)
                row.append(
                    _make_series_link(
                        _get_compare_text(current, base),
                        mapper.get_series_id(target_id),
                    )
                )
            rows.append(tuple(row))

        tables.append(
            _add_header_and_get_markdown_table(headers, rows, size_cut=size_cut)
        )

    return "\n\n".join(tables)
