blob: a6e880b4b02f9571f2ae600eb8560cdfa33fa41e [file] [log] [blame]
#!/usr/bin/env python3
# Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""
Posts benchmark results to gist and comments on pull requests.
Requires the environment variables:
- GITHUB_TOKEN: token from GitHub action that has write access on issues. See
https://docs.github.com/en/actions/security-guides/automatic-token-authentication#permissions-for-the-github_token
- COMMENT_BOT_USER: user name that posts the comment. Note this can be different
from the user creates the gist.
- GIST_BOT_TOKEN: token that has write access to gist. Gist will be posted as
the owner of the token. See
https://docs.github.com/en/rest/overview/permissions-required-for-fine-grained-personal-access-tokens#gists
"""
import sys
import pathlib
# Add build_tools python dir to the search path.
sys.path.insert(0, str(pathlib.Path(__file__).parent.with_name("python")))
import argparse
import http.client
import json
import os
import requests
from typing import Any, Optional
from reporting import benchmark_comment
GITHUB_IREE_API_PREFIX = "https://api.github.com/repos/iree-org/iree"
GITHUB_GIST_API = "https://api.github.com/gists"
GITHUB_API_VERSION = "2022-11-28"
class APIRequester(object):
"""REST API client that injects proper GitHub authentication headers."""
def __init__(self, github_token: str):
self._api_headers = {
"Accept": "application/vnd.github+json",
"Authorization": f"token {github_token}",
"X-GitHub-Api-Version": GITHUB_API_VERSION,
}
self._session = requests.session()
def get(self, endpoint: str, payload: Any = {}) -> requests.Response:
return self._session.get(
endpoint, data=json.dumps(payload), headers=self._api_headers
)
def post(self, endpoint: str, payload: Any = {}) -> requests.Response:
return self._session.post(
endpoint, data=json.dumps(payload), headers=self._api_headers
)
def patch(self, endpoint: str, payload: Any = {}) -> requests.Response:
return self._session.patch(
endpoint, data=json.dumps(payload), headers=self._api_headers
)
class GithubClient(object):
"""Helper to call Github REST APIs."""
def __init__(self, requester: APIRequester):
self._requester = requester
def post_to_gist(self, filename: str, content: str, verbose: bool = False) -> str:
"""Posts the given content to a new GitHub Gist and returns the URL to it."""
response = self._requester.post(
endpoint=GITHUB_GIST_API,
payload={"public": True, "files": {filename: {"content": content}}},
)
if response.status_code != http.client.CREATED:
raise RuntimeError(
f"Failed to create on gist; error code: {response.status_code} - {response.text}"
)
response = response.json()
if verbose:
print(f"Gist posting response: {response}")
if response["truncated"]:
raise RuntimeError(f"Content is too large and was truncated")
return response["html_url"]
def get_previous_comment_on_pr(
self,
pr_number: int,
comment_bot_user: str,
comment_type_id: str,
query_comment_per_page: int = 100,
max_pages_to_search: int = 10,
verbose: bool = False,
) -> Optional[int]:
"""Gets the previous comment's id from GitHub."""
for page in range(1, max_pages_to_search + 1):
response = self._requester.get(
endpoint=f"{GITHUB_IREE_API_PREFIX}/issues/{pr_number}/comments",
payload={
"per_page": query_comment_per_page,
"page": page,
"sort": "updated",
"direction": "desc",
},
)
if response.status_code != http.client.OK:
raise RuntimeError(
f"Failed to get PR comments from GitHub; error code: {response.status_code} - {response.text}"
)
comments = response.json()
if verbose:
print(f"Previous comment query response on page {page}: {comments}")
# Find the most recently updated comment that matches.
for comment in comments:
if (
comment["user"]["login"] == comment_bot_user
and comment_type_id in comment["body"]
):
return comment["id"]
if len(comments) < query_comment_per_page:
break
return None
def update_comment_on_pr(self, comment_id: int, content: str):
"""Updates the content of the given comment id."""
response = self._requester.patch(
endpoint=f"{GITHUB_IREE_API_PREFIX}/issues/comments/{comment_id}",
payload={"body": content},
)
if response.status_code != http.client.OK:
raise RuntimeError(
f"Failed to comment on GitHub; error code: {response.status_code} - {response.text}"
)
def create_comment_on_pr(self, pr_number: int, content: str):
"""Posts the given content as comments to the current pull request."""
response = self._requester.post(
endpoint=f"{GITHUB_IREE_API_PREFIX}/issues/{pr_number}/comments",
payload={"body": content},
)
if response.status_code != http.client.CREATED:
raise RuntimeError(
f"Failed to comment on GitHub; error code: {response.status_code} - {response.text}"
)
def get_pull_request_head_commit(self, pr_number: int) -> str:
"""Get pull request head commit SHA."""
response = self._requester.get(
endpoint=f"{GITHUB_IREE_API_PREFIX}/pulls/{pr_number}"
)
if response.status_code != http.client.OK:
raise RuntimeError(
f"Failed to fetch the pull request: {pr_number}; "
f"error code: {response.status_code} - {response.text}"
)
return response.json()["head"]["sha"]
def _parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("comment_json", type=pathlib.Path)
parser.add_argument("--verbose", action="store_true")
verification_parser = parser.add_mutually_exclusive_group(required=True)
verification_parser.add_argument("--github_event_json", type=pathlib.Path)
return parser.parse_args()
def main(args: argparse.Namespace):
github_token = os.environ.get("GITHUB_TOKEN")
if github_token is None:
raise ValueError("GITHUB_TOKEN must be set.")
comment_bot_user = os.environ.get("COMMENT_BOT_USER")
if comment_bot_user is None:
raise ValueError("COMMENT_BOT_USER must be set.")
gist_bot_token = os.environ.get("GIST_BOT_TOKEN")
if gist_bot_token is None:
raise ValueError("GIST_BOT_TOKEN must be set.")
comment_data = benchmark_comment.CommentData(
**json.loads(args.comment_json.read_text())
)
# Sanitize the pr number to make sure it is an integer.
pr_number = int(comment_data.unverified_pr_number)
pr_client = GithubClient(requester=APIRequester(github_token=github_token))
if args.github_event_json is None:
github_event = None
else:
github_event = json.loads(args.github_event_json.read_text())
workflow_run_sha = github_event["workflow_run"]["head_sha"]
pr_head_sha = pr_client.get_pull_request_head_commit(pr_number=pr_number)
# We can't get the trusted PR number of a workflow run from GitHub API. So we
# take the untrusted PR number from presubmit workflow and verify if the PR's
# current head SHA matches the commit SHA in the workflow run. It assumes
# that to generate the malicious comment data, attacker must modify the code
# and has a new commit SHA. So if the PR head commit matches the workflow
# run with attacker's commit, either the PR is created by the attacker or
# other's PR has the malicious commit. In both cases posting malicious
# comment is acceptable.
#
# Note that the collision of a target SHA1 is possible but GitHub has some
# protections (https://github.blog/2017-03-20-sha-1-collision-detection-on-github-com/).
# The assumption also only holds if files in GCS can't be overwritten (so the
# comment data can't be modified without changing the code).
# The check will also fail if the PR author pushes the new commit after the
# workflow is triggered. But pushing the new commit means to cancel the
# current CI run including the benchmarking. So it will unlikely fail for
# that reason.
if workflow_run_sha != pr_head_sha:
raise ValueError(
f"Workflow run SHA: {workflow_run_sha} does not match "
f"the head SHA: {pr_head_sha} of the pull request: {pr_number}."
)
gist_client = GithubClient(requester=APIRequester(github_token=gist_bot_token))
gist_url = gist_client.post_to_gist(
filename=f"iree-full-benchmark-results-{pr_number}.md",
content=comment_data.full_md,
verbose=args.verbose,
)
previous_comment_id = pr_client.get_previous_comment_on_pr(
pr_number=pr_number,
comment_bot_user=comment_bot_user,
comment_type_id=comment_data.type_id,
verbose=args.verbose,
)
abbr_md = comment_data.abbr_md.replace(
benchmark_comment.GIST_LINK_PLACEHORDER, gist_url
)
if github_event is not None:
abbr_md += (
f'\n\n[Source Workflow Run]({github_event["workflow_run"]["html_url"]})'
)
if previous_comment_id is not None:
pr_client.update_comment_on_pr(comment_id=previous_comment_id, content=abbr_md)
else:
pr_client.create_comment_on_pr(pr_number=pr_number, content=abbr_md)
if __name__ == "__main__":
main(_parse_arguments())