blob: 6564b0db7d1311ad639d28b119cc7b7da9bc3b7e [file] [log] [blame]
#!/usr/bin/env python3
# Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import http.client
import requests
import unittest
from unittest import mock
from typing import Any
import post_benchmark_comment
class GithubClientTest(unittest.TestCase):
def setUp(self):
self._mock_response = mock.create_autospec(requests.Response)
self._mock_requester = mock.create_autospec(
post_benchmark_comment.APIRequester)
self._mock_requester.get.return_value = self._mock_response
self._mock_requester.post.return_value = self._mock_response
self._mock_requester.patch.return_value = self._mock_response
def test_post_to_gist(self):
gist_url = "https://example.com/123455/1234.md"
self._mock_response.status_code = http.client.CREATED
self._mock_response.json.return_value = {
"html_url": gist_url,
"truncated": False
}
client = post_benchmark_comment.GithubClient(self._mock_requester)
url = client.post_to_gist(filename="1234.md", content="xyz")
self.assertEqual(url, gist_url)
self._mock_requester.post.assert_called_once_with(
endpoint=post_benchmark_comment.GITHUB_GIST_API,
payload={
"public": True,
"files": {
"1234.md": {
"content": "xyz"
}
}
})
def test_post_to_gist_truncated(self):
gist_url = "example.com/123455/1234.md"
self._mock_response.status_code = http.client.CREATED
self._mock_response.json.return_value = {
"html_url": gist_url,
"truncated": True
}
client = post_benchmark_comment.GithubClient(self._mock_requester)
with self.assertRaises(RuntimeError) as _:
client.post_to_gist(filename="1234.md", content="xyz")
def test_get_previous_comment_on_pr(self):
first_mock_response = mock.create_autospec(requests.Response)
first_mock_response.status_code = http.client.OK
first_mock_response.json.return_value = [{
"id": 1,
"user": {
"login": "bot"
},
"body": "comment id: abcd"
}, {
"id": 2,
"user": {
"login": "user"
},
"body": "comment id: 1234"
}]
second_mock_response = mock.create_autospec(requests.Response)
second_mock_response.status_code = http.client.OK
second_mock_response.json.return_value = [{
"id": 3,
"user": {
"login": "bot"
},
"body": "comment id: 1234"
}]
mock_requester = mock.create_autospec(post_benchmark_comment.APIRequester)
mock_requester.get.side_effect = [first_mock_response, second_mock_response]
client = post_benchmark_comment.GithubClient(mock_requester)
comment_id = client.get_previous_comment_on_pr(pr_number=23,
comment_bot_user="bot",
comment_type_id="1234",
query_comment_per_page=2,
max_pages_to_search=10)
self.assertEqual(comment_id, 3)
self.assertEqual(mock_requester.get.call_count, 2)
endpoint_url = f"{post_benchmark_comment.GITHUB_IREE_API_PREFIX}/issues/23/comments"
mock_requester.get.assert_any_call(endpoint=endpoint_url,
payload={
"per_page": 2,
"page": 1,
"sort": "updated",
"direction": "desc"
})
mock_requester.get.assert_any_call(endpoint=endpoint_url,
payload={
"per_page": 2,
"page": 2,
"sort": "updated",
"direction": "desc"
})
def test_get_previous_comment_on_pr_not_found(self):
mock_response = mock.create_autospec(requests.Response)
mock_response.status_code = http.client.OK
mock_response.json.return_value = [{
"id": 1,
"user": {
"login": "bot"
},
"body": "comment id: 5678"
}]
mock_requester = mock.create_autospec(post_benchmark_comment.APIRequester)
mock_requester.get.side_effect = [mock_response] * 10
client = post_benchmark_comment.GithubClient(mock_requester)
comment_id = client.get_previous_comment_on_pr(pr_number=23,
comment_bot_user="bot",
comment_type_id="1234",
query_comment_per_page=1,
max_pages_to_search=10)
self.assertIsNone(comment_id)
self.assertEqual(mock_requester.get.call_count, 10)
endpoint_url = f"{post_benchmark_comment.GITHUB_IREE_API_PREFIX}/issues/23/comments"
mock_requester.get.assert_any_call(endpoint=endpoint_url,
payload={
"per_page": 1,
"page": 1,
"sort": "updated",
"direction": "desc"
})
mock_requester.get.assert_any_call(endpoint=endpoint_url,
payload={
"per_page": 1,
"page": 10,
"sort": "updated",
"direction": "desc"
})
def test_update_comment_on_pr(self):
self._mock_response.status_code = http.client.OK
client = post_benchmark_comment.GithubClient(self._mock_requester)
client.update_comment_on_pr(comment_id=123, content="xyz")
self._mock_requester.patch.assert_called_once_with(
endpoint=
f"{post_benchmark_comment.GITHUB_IREE_API_PREFIX}/issues/comments/123",
payload={"body": "xyz"})
def test_create_comment_on_pr(self):
self._mock_response.status_code = http.client.CREATED
client = post_benchmark_comment.GithubClient(self._mock_requester)
client.create_comment_on_pr(pr_number=1234, content="xyz")
self._mock_requester.post.assert_called_once_with(
endpoint=
f"{post_benchmark_comment.GITHUB_IREE_API_PREFIX}/issues/1234/comments",
payload={"body": "xyz"})
def test_get_pull_request_head_commit(self):
self._mock_response.status_code = http.client.OK
self._mock_response.json.return_value = {"head": {"sha": "sha123"}}
client = post_benchmark_comment.GithubClient(self._mock_requester)
commit_sha = client.get_pull_request_head_commit(pr_number=123)
self.assertEqual(commit_sha, "sha123")
self._mock_requester.get.assert_called_once_with(
endpoint=f"{post_benchmark_comment.GITHUB_IREE_API_PREFIX}/pulls/123")
if __name__ == "__main__":
unittest.main()