#!/usr/bin/env python3
# Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""
Posts benchmark results to gist and comments on pull requests.

Requires the environment variables:

- GITHUB_TOKEN: token from GitHub action that has write access on issues. See
  https://docs.github.com/en/actions/security-guides/automatic-token-authentication#permissions-for-the-github_token
- COMMENT_BOT_USER: user name that posts the comment. Note this can be different
    from the user creates the gist.
- GIST_BOT_TOKEN: token that has write access to gist. Gist will be posted as
  the owner of the token. See
  https://docs.github.com/en/rest/overview/permissions-required-for-fine-grained-personal-access-tokens#gists
"""

import sys
import pathlib

# Add build_tools python dir to the search path.
sys.path.insert(0, str(pathlib.Path(__file__).parent.with_name("python")))

import argparse
import http.client
import json
import os
import requests
from typing import Any, Optional

from reporting import benchmark_comment

GITHUB_IREE_API_PREFIX = "https://api.github.com/repos/iree-org/iree"
GITHUB_GIST_API = "https://api.github.com/gists"
GITHUB_API_VERSION = "2022-11-28"


class APIRequester(object):
    """REST API client that injects proper GitHub authentication headers."""

    def __init__(self, github_token: str):
        self._api_headers = {
            "Accept": "application/vnd.github+json",
            "Authorization": f"token {github_token}",
            "X-GitHub-Api-Version": GITHUB_API_VERSION,
        }
        self._session = requests.session()

    def get(self, endpoint: str, payload: Any = {}) -> requests.Response:
        return self._session.get(
            endpoint, data=json.dumps(payload), headers=self._api_headers
        )

    def post(self, endpoint: str, payload: Any = {}) -> requests.Response:
        return self._session.post(
            endpoint, data=json.dumps(payload), headers=self._api_headers
        )

    def patch(self, endpoint: str, payload: Any = {}) -> requests.Response:
        return self._session.patch(
            endpoint, data=json.dumps(payload), headers=self._api_headers
        )


class GithubClient(object):
    """Helper to call Github REST APIs."""

    def __init__(self, requester: APIRequester):
        self._requester = requester

    def post_to_gist(self, filename: str, content: str, verbose: bool = False) -> str:
        """Posts the given content to a new GitHub Gist and returns the URL to it."""

        response = self._requester.post(
            endpoint=GITHUB_GIST_API,
            payload={"public": True, "files": {filename: {"content": content}}},
        )
        if response.status_code != http.client.CREATED:
            raise RuntimeError(
                f"Failed to create on gist; error code: {response.status_code} - {response.text}"
            )

        response = response.json()
        if verbose:
            print(f"Gist posting response: {response}")

        if response["truncated"]:
            raise RuntimeError(f"Content is too large and was truncated")

        return response["html_url"]

    def get_previous_comment_on_pr(
        self,
        pr_number: int,
        comment_bot_user: str,
        comment_type_id: str,
        query_comment_per_page: int = 100,
        max_pages_to_search: int = 10,
        verbose: bool = False,
    ) -> Optional[int]:
        """Gets the previous comment's id from GitHub."""

        for page in range(1, max_pages_to_search + 1):
            response = self._requester.get(
                endpoint=f"{GITHUB_IREE_API_PREFIX}/issues/{pr_number}/comments",
                payload={
                    "per_page": query_comment_per_page,
                    "page": page,
                    "sort": "updated",
                    "direction": "desc",
                },
            )
            if response.status_code != http.client.OK:
                raise RuntimeError(
                    f"Failed to get PR comments from GitHub; error code: {response.status_code} - {response.text}"
                )

            comments = response.json()
            if verbose:
                print(f"Previous comment query response on page {page}: {comments}")

            # Find the most recently updated comment that matches.
            for comment in comments:
                if (
                    comment["user"]["login"] == comment_bot_user
                    and comment_type_id in comment["body"]
                ):
                    return comment["id"]

            if len(comments) < query_comment_per_page:
                break

        return None

    def update_comment_on_pr(self, comment_id: int, content: str):
        """Updates the content of the given comment id."""

        response = self._requester.patch(
            endpoint=f"{GITHUB_IREE_API_PREFIX}/issues/comments/{comment_id}",
            payload={"body": content},
        )
        if response.status_code != http.client.OK:
            raise RuntimeError(
                f"Failed to comment on GitHub; error code: {response.status_code} - {response.text}"
            )

    def create_comment_on_pr(self, pr_number: int, content: str):
        """Posts the given content as comments to the current pull request."""

        response = self._requester.post(
            endpoint=f"{GITHUB_IREE_API_PREFIX}/issues/{pr_number}/comments",
            payload={"body": content},
        )
        if response.status_code != http.client.CREATED:
            raise RuntimeError(
                f"Failed to comment on GitHub; error code: {response.status_code} - {response.text}"
            )

    def get_pull_request_head_commit(self, pr_number: int) -> str:
        """Get pull request head commit SHA."""

        response = self._requester.get(
            endpoint=f"{GITHUB_IREE_API_PREFIX}/pulls/{pr_number}"
        )
        if response.status_code != http.client.OK:
            raise RuntimeError(
                f"Failed to fetch the pull request: {pr_number}; "
                f"error code: {response.status_code} - {response.text}"
            )

        return response.json()["head"]["sha"]


def _parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument("comment_json", type=pathlib.Path)
    parser.add_argument("--verbose", action="store_true")
    verification_parser = parser.add_mutually_exclusive_group(required=True)
    verification_parser.add_argument("--github_event_json", type=pathlib.Path)
    return parser.parse_args()


def main(args: argparse.Namespace):
    github_token = os.environ.get("GITHUB_TOKEN")
    if github_token is None:
        raise ValueError("GITHUB_TOKEN must be set.")

    comment_bot_user = os.environ.get("COMMENT_BOT_USER")
    if comment_bot_user is None:
        raise ValueError("COMMENT_BOT_USER must be set.")

    gist_bot_token = os.environ.get("GIST_BOT_TOKEN")
    if gist_bot_token is None:
        raise ValueError("GIST_BOT_TOKEN must be set.")

    comment_data = benchmark_comment.CommentData(
        **json.loads(args.comment_json.read_text())
    )
    # Sanitize the pr number to make sure it is an integer.
    pr_number = int(comment_data.unverified_pr_number)

    pr_client = GithubClient(requester=APIRequester(github_token=github_token))
    if args.github_event_json is None:
        github_event = None
    else:
        github_event = json.loads(args.github_event_json.read_text())
        workflow_run_sha = github_event["workflow_run"]["head_sha"]
        pr_head_sha = pr_client.get_pull_request_head_commit(pr_number=pr_number)
        # We can't get the trusted PR number of a workflow run from GitHub API. So we
        # take the untrusted PR number from presubmit workflow and verify if the PR's
        # current head SHA matches the commit SHA in the workflow run. It assumes
        # that to generate the malicious comment data, attacker must modify the code
        # and has a new commit SHA. So if the PR head commit matches the workflow
        # run with attacker's commit, either the PR is created by the attacker or
        # other's PR has the malicious commit. In both cases posting malicious
        # comment is acceptable.
        #
        # Note that the collision of a target SHA1 is possible but GitHub has some
        # protections (https://github.blog/2017-03-20-sha-1-collision-detection-on-github-com/).
        # The assumption also only holds if files in GCS can't be overwritten (so the
        # comment data can't be modified without changing the code).
        # The check will also fail if the PR author pushes the new commit after the
        # workflow is triggered. But pushing the new commit means to cancel the
        # current CI run including the benchmarking. So it will unlikely fail for
        # that reason.
        if workflow_run_sha != pr_head_sha:
            raise ValueError(
                f"Workflow run SHA: {workflow_run_sha} does not match "
                f"the head SHA: {pr_head_sha} of the pull request: {pr_number}."
            )

    gist_client = GithubClient(requester=APIRequester(github_token=gist_bot_token))
    gist_url = gist_client.post_to_gist(
        filename=f"iree-full-benchmark-results-{pr_number}.md",
        content=comment_data.full_md,
        verbose=args.verbose,
    )

    previous_comment_id = pr_client.get_previous_comment_on_pr(
        pr_number=pr_number,
        comment_bot_user=comment_bot_user,
        comment_type_id=comment_data.type_id,
        verbose=args.verbose,
    )

    abbr_md = comment_data.abbr_md.replace(
        benchmark_comment.GIST_LINK_PLACEHORDER, gist_url
    )
    if github_event is not None:
        abbr_md += (
            f'\n\n[Source Workflow Run]({github_event["workflow_run"]["html_url"]})'
        )
    if previous_comment_id is not None:
        pr_client.update_comment_on_pr(comment_id=previous_comment_id, content=abbr_md)
    else:
        pr_client.create_comment_on_pr(pr_number=pr_number, content=abbr_md)


if __name__ == "__main__":
    main(_parse_arguments())
