diff --git a/revup/github/github.py b/revup/github/github.py index b82d222..b237dae 100644 --- a/revup/github/github.py +++ b/revup/github/github.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple from revup.forge import ( MAX_COMMENTS_TO_QUERY, @@ -10,151 +10,10 @@ PrUpdate, ) from revup.github.endpoint import GitHubEndpoint +from revup.github.graphql import GraphqlQuery, QueryGroup from revup.types import RevupForgeException - -def _get_args_dict(args: List[Any], prefix: str) -> Dict[str, Any]: - return {f"{prefix}{n}": arg for n, arg in enumerate(args)} - - -def _get_args_declaration(args: Dict[str, Any], typ: str) -> List[str]: - return [f"${var}: {typ}" for var in args] - - -def _get_result_args(num: int, prefix: str) -> List[str]: - return [f"{prefix}{n}" for n in range(num)] - - -def _zip_and_flatten(l1: Iterable[str], l2: Iterable[str]) -> List[str]: - ret: List[str] = [] - iter1 = iter(l1) - iter2 = iter(l2) - while True: - try: - ret.append(next(iter1)) - ret.append(next(iter2)) - except StopIteration: - break - return ret - - -class Github(Forge): - def __init__( - self, - endpoint: GitHubEndpoint, - repo_info: ForgeRepoInfo, - fork_info: ForgeRepoInfo, - ): - self.endpoint = endpoint - self.repo_info = repo_info - self.fork_info = fork_info - - @property - def repo_owner(self) -> str: - return self.fork_info.owner - - @property - def repo_name(self) -> str: - return self.repo_info.name - - @property - def is_fork(self) -> bool: - return self.fork_info.owner != self.repo_info.owner - - async def close(self) -> None: - await self.endpoint.close() - - async def query_everything( - self, - head_refs: List[str], - user_ids: List[str], - labels: List[str], - teams: List[Tuple[str, str]], - ) -> Tuple[ - str, - List[Optional[PrInfo]], - Dict[str, str], - Dict[str, str], - Dict[str, str], - Dict[str, str], - Dict[str, Optional[Set[str]]], - ]: - head_refs_args = _get_args_dict(head_refs, "pr") - user_id_args = _get_args_dict(user_ids, "user") - label_args = _get_args_dict(labels, "label") - team_org_args = _get_args_dict([t[0] for t in teams], "team_org") - team_slug_args = _get_args_dict([t[1] for t in teams], "team_slug") - - prs_out = _get_result_args(len(head_refs), "pr_out") - user_id_out = _get_result_args(len(user_ids), "user_out") - label_out = _get_result_args(len(labels), "label_out") - team_out = _get_result_args(len(teams), "team_out") - - arg_str = ", ".join( - _get_args_declaration(head_refs_args, "String!") - + _get_args_declaration(user_id_args, "String!") - + _get_args_declaration(label_args, "String!") - + _get_args_declaration(team_org_args, "String!") - + _get_args_declaration(team_slug_args, "String!") - ) - - # NOTE: There are possible limitations here because we depend on PRs being - # returned in order of OPEN prs, followed by MERGED prs in the order that - # they merged. github doesn't offer these options and it is excessively - # expensive to always fetch multiple prs and order them on this side. For now - # we hope that the most relevant PR will have the most recent update time. - request_str = "".join( - len(head_refs) - * [ - "{}: pullRequests (headRefName: ${}, states: [OPEN, MERGED], first: 1, " - "orderBy: {{direction: DESC, field:UPDATED_AT}}) {{" - "...PrResult" - "}}," - ] - ) - request_str = request_str.format(*_zip_and_flatten(prs_out, head_refs_args.keys())) - - user_str = "".join( - len(user_ids) * ["{}: assignableUsers (query: ${}, first: 25) {{...UserResult}},"] - ) - user_str = user_str.format(*_zip_and_flatten(user_id_out, user_id_args.keys())) - - label_str = "".join(len(labels) * ["{}: label (name: ${}) {{...LabelResult}},"]) - label_str = label_str.format(*_zip_and_flatten(label_out, label_args.keys())) - - team_str = "" - for i in range(len(teams)): - team_str += ( - f"{team_out[i]}: organization(login: ${list(team_org_args.keys())[i]}) " - f"{{team(slug: ${list(team_slug_args.keys())[i]}) " - f"{{id, members(first: 100) {{nodes {{login}}, totalCount}}}}}}," - ) - - multi_query_str = f""" - query GetPrResults($owner: String!, $name: String!, {arg_str}) {{ - repository(name: $name, owner: $owner) {{ - id - {request_str}{user_str}{label_str} - }} - {team_str} - }}""" - if user_str: - multi_query_str += """ - fragment UserResult on UserConnection { - nodes { - login - id - } - totalCount - }""" - if label_str: - multi_query_str += """ - fragment LabelResult on Label { - id - name - }""" - if request_str: - multi_query_str += f""" +PR_FRAGMENT = f""" fragment PrResult on PullRequestConnection {{ nodes {{ id @@ -257,175 +116,325 @@ async def query_everything( totalCount }}""" - pr_result = await self.endpoint.graphql( - multi_query_str, - owner=self.repo_info.owner, - name=self.repo_info.name, - **head_refs_args, - **user_id_args, - **label_args, - **team_org_args, - **team_slug_args, - ) +USER_FRAGMENT = """ + fragment UserResult on UserConnection { + nodes { + login + id + } + totalCount + }""" - prs: List[Optional[PrInfo]] = [] - for i, branch_name in enumerate(head_refs): - this_node = pr_result["data"]["repository"][prs_out[i]] - if len(this_node["nodes"]) == 1: - this_node = this_node["nodes"][0] - pr_labels: Set[str] = set() - pr_label_ids: Set[str] = set() - reviewers: Set[str] = set() - reviewer_ids: Set[str] = set() - reviewer_teams: Set[str] = set() - reviewer_team_ids: Set[str] = set() - assignees: Set[str] = set() - assignee_ids: Set[str] = set() - for label in this_node["labels"]["nodes"]: - pr_labels.add(label["name"]) - pr_label_ids.add(label["id"]) - for revs in this_node["reviewRequests"]["nodes"]: - requested = revs["requestedReviewer"] - if not requested: - continue - elif "slug" in requested: - reviewer_teams.add( - f"{requested['organization']['login']}/{requested['slug']}" - ) - reviewer_team_ids.add(requested["id"]) - elif "login" in requested: - reviewers.add(requested["login"]) - reviewer_ids.add(requested["id"]) - for revs in this_node["latestReviews"]["nodes"]: - # Ignore self reviews and bot reviews (without a login) - if not revs["viewerDidAuthor"] and "login" in revs["author"]: - reviewers.add(revs["author"]["login"]) - reviewer_ids.add(revs["author"]["id"]) - for user in this_node["assignees"]["nodes"]: - assignees.add(user["login"]) - assignee_ids.add(user["id"]) - - # The plain headRef and baseRef fields return the latest commit id associated with - # that branch name which may be newer than the PR itself if it was merged. We want - # the ids of the commits actually last associated with the PR, which we query from - # the commit list. This can also mean they are None if the PR has 0 commits. - headRefOid = ( - this_node["headCommit"]["nodes"][0]["commit"]["oid"] - if this_node["headCommit"]["nodes"] - else None +LABEL_FRAGMENT = """ + fragment LabelResult on Label { + id + name + }""" + + +def _make_pr_group(head_refs: List[str]) -> QueryGroup: + group = QueryGroup( + prefix="pr", + scope="repo", + field_template=( + "{}: pullRequests (headRefName: {}, states: [OPEN, MERGED], first: 1, " + "orderBy: {{direction: DESC, field:UPDATED_AT}}) {{...PrResult}}," + ), + var_types=["String!"], + fragment=PR_FRAGMENT, + ) + for ref in head_refs: + group.add(ref) + return group + + +def _make_user_group(user_ids: List[str]) -> QueryGroup: + group = QueryGroup( + prefix="user", + scope="repo", + field_template="{}: assignableUsers (query: {}, first: 25) {{...UserResult}},", + var_types=["String!"], + fragment=USER_FRAGMENT, + ) + for uid in user_ids: + group.add(uid) + return group + + +def _make_label_group(labels: List[str]) -> QueryGroup: + group = QueryGroup( + prefix="label", + scope="repo", + field_template="{}: label (name: {}) {{...LabelResult}},", + var_types=["String!"], + fragment=LABEL_FRAGMENT, + ) + for label in labels: + group.add(label) + return group + + +def _make_team_group(teams: List[Tuple[str, str]]) -> QueryGroup: + group = QueryGroup( + prefix="team", + scope="top", + field_template=( + "{}: organization(login: {}) " + "{{team(slug: {}) " + "{{id, members(first: 100) {{nodes {{login}}, totalCount}}}}}}," + ), + var_types=["String!", "String!"], + fragment="", + ) + for org, slug in teams: + group.add(org, slug) + return group + + +def _parse_prs(group: QueryGroup, result: Any, head_refs: List[str]) -> List[Optional[PrInfo]]: + raw = group.extract(result) + prs: List[Optional[PrInfo]] = [] + for i, branch_name in enumerate(head_refs): + this_node = raw[i] + if len(this_node["nodes"]) == 1: + this_node = this_node["nodes"][0] + pr_labels: Set[str] = set() + pr_label_ids: Set[str] = set() + reviewers: Set[str] = set() + reviewer_ids: Set[str] = set() + reviewer_teams: Set[str] = set() + reviewer_team_ids: Set[str] = set() + assignees: Set[str] = set() + assignee_ids: Set[str] = set() + for label in this_node["labels"]["nodes"]: + pr_labels.add(label["name"]) + pr_label_ids.add(label["id"]) + for revs in this_node["reviewRequests"]["nodes"]: + requested = revs["requestedReviewer"] + if not requested: + continue + elif "slug" in requested: + reviewer_teams.add(f"{requested['organization']['login']}/{requested['slug']}") + reviewer_team_ids.add(requested["id"]) + elif "login" in requested: + reviewers.add(requested["login"]) + reviewer_ids.add(requested["id"]) + for revs in this_node["latestReviews"]["nodes"]: + if not revs["viewerDidAuthor"] and "login" in revs["author"]: + reviewers.add(revs["author"]["login"]) + reviewer_ids.add(revs["author"]["id"]) + for user in this_node["assignees"]["nodes"]: + assignees.add(user["login"]) + assignee_ids.add(user["id"]) + + headRefOid = ( + this_node["headCommit"]["nodes"][0]["commit"]["oid"] + if this_node["headCommit"]["nodes"] + else None + ) + baseRefOid = ( + this_node["baseCommit"]["nodes"][0]["commit"]["parents"]["nodes"][0]["oid"] + if this_node["baseCommit"]["nodes"] + else None + ) + + comments = [] + for c in this_node["comments"]["nodes"]: + comments.append(PrComment(c["body"], c["id"])) + + removed_reviewers: Set[str] = set() + removed_reviewer_ids: Set[str] = set() + removed_assignees: Set[str] = set() + removed_assignee_ids: Set[str] = set() + for event in this_node["timelineItems"]["nodes"]: + rr = event.get("requestedReviewer") + if rr and "login" in rr and rr["login"] not in reviewers: + removed_reviewers.add(rr["login"]) + removed_reviewer_ids.add(rr["id"]) + assignee = event.get("assignee") + if assignee and "login" in assignee and assignee["login"] not in assignees: + removed_assignees.add(assignee["login"]) + removed_assignee_ids.add(assignee["id"]) + + prs.append( + PrInfo( + id=this_node["id"], + url=this_node["url"], + baseRef=this_node["baseRefName"], + headRef=branch_name, + baseRefOid=baseRefOid, + headRefOid=headRefOid, + body=this_node["body"], + title=this_node["title"], + reviewers=reviewers, + reviewer_ids=reviewer_ids, + reviewer_teams=reviewer_teams, + reviewer_team_ids=reviewer_team_ids, + assignees=assignees, + assignee_ids=assignee_ids, + labels=pr_labels, + label_ids=pr_label_ids, + removed_reviewers=removed_reviewers, + removed_reviewer_ids=removed_reviewer_ids, + removed_assignees=removed_assignees, + removed_assignee_ids=removed_assignee_ids, + is_draft=this_node["isDraft"], + state=this_node["state"], + comments=comments, ) - baseRefOid = ( - this_node["baseCommit"]["nodes"][0]["commit"]["parents"]["nodes"][0]["oid"] - if this_node["baseCommit"]["nodes"] - else None + ) + else: + prs.append(None) + return prs + + +def _parse_users( + group: QueryGroup, result: Any, user_ids: List[str] +) -> Tuple[Dict[str, str], Dict[str, str]]: + raw = group.extract(result) + names_to_ids: Dict[str, str] = {} + names_to_logins: Dict[str, str] = {} + for i, user_id in enumerate(user_ids): + this_node = raw[i] + if len(this_node["nodes"]) == 0: + logging.warning("No matching user found for {}".format(user_id)) + else: + if this_node["totalCount"] > len(this_node["nodes"]): + logging.warning( + "Too many matching users found for {}, try being more specific".format(user_id) ) - - comments = [] - for c in this_node["comments"]["nodes"]: - comments.append(PrComment(c["body"], c["id"])) - - removed_reviewers: Set[str] = set() - removed_reviewer_ids: Set[str] = set() - removed_assignees: Set[str] = set() - removed_assignee_ids: Set[str] = set() - for event in this_node["timelineItems"]["nodes"]: - rr = event.get("requestedReviewer") - if rr and "login" in rr and rr["login"] not in reviewers: - removed_reviewers.add(rr["login"]) - removed_reviewer_ids.add(rr["id"]) - assignee = event.get("assignee") - if assignee and "login" in assignee and assignee["login"] not in assignees: - removed_assignees.add(assignee["login"]) - removed_assignee_ids.add(assignee["id"]) - - prs.append( - PrInfo( - id=this_node["id"], - url=this_node["url"], - baseRef=this_node["baseRefName"], - headRef=branch_name, - baseRefOid=baseRefOid, - headRefOid=headRefOid, - body=this_node["body"], - title=this_node["title"], - reviewers=reviewers, - reviewer_ids=reviewer_ids, - reviewer_teams=reviewer_teams, - reviewer_team_ids=reviewer_team_ids, - assignees=assignees, - assignee_ids=assignee_ids, - labels=pr_labels, - label_ids=pr_label_ids, - removed_reviewers=removed_reviewers, - removed_reviewer_ids=removed_reviewer_ids, - removed_assignees=removed_assignees, - removed_assignee_ids=removed_assignee_ids, - is_draft=this_node["isDraft"], - state=this_node["state"], - comments=comments, + shortest_name = this_node["nodes"][0]["login"] + names_to_ids[user_id] = this_node["nodes"][0]["id"] + found_match = False + for user in this_node["nodes"]: + if len(user["login"]) <= len(shortest_name) and user["login"].startswith(user_id): + shortest_name = user["login"] + names_to_ids[user_id] = user["id"] + names_to_logins[user_id] = user["login"] + found_match = True + if not found_match: + logging.warning( + "Couldn't find a prefixed match for {}, going with {} instead".format( + user_id, shortest_name ) ) + return names_to_ids, names_to_logins + + +def _parse_labels(group: QueryGroup, result: Any, labels: List[str]) -> Dict[str, str]: + raw = group.extract(result) + labels_to_ids: Dict[str, str] = {} + for i, label in enumerate(labels): + this_node = raw[i] + if this_node is not None: + labels_to_ids[label] = this_node["id"] + else: + logging.warning("Couldn't find an existing label named {}".format(label)) + return labels_to_ids + + +def _parse_teams( + group: QueryGroup, result: Any, teams: List[Tuple[str, str]] +) -> Tuple[Dict[str, str], Dict[str, Optional[Set[str]]]]: + raw = group.extract(result) + teams_to_ids: Dict[str, str] = {} + teams_to_members: Dict[str, Optional[Set[str]]] = {} + for i, (org, slug) in enumerate(teams): + team_node = raw[i] + if team_node is not None and team_node["team"] is not None: + team_ref = f"{org}/{slug}" + teams_to_ids[team_ref] = team_node["team"]["id"] + members_node = team_node["team"]["members"] + member_logins = {m["login"] for m in members_node["nodes"]} + if members_node["totalCount"] > len(members_node["nodes"]): + teams_to_members[team_ref] = None else: - prs.append(None) - - names_to_ids: Dict[str, str] = {} - names_to_logins: Dict[str, str] = {} - for i, user_id in enumerate(user_ids): - this_node = pr_result["data"]["repository"][user_id_out[i]] - if len(this_node["nodes"]) == 0: - logging.warning("No matching user found for {}".format(user_id)) - else: - if this_node["totalCount"] > len(this_node["nodes"]): - logging.warning( - "Too many matching users found for {}, try being more specific".format( - user_id - ) - ) - shortest_name = this_node["nodes"][0]["login"] - names_to_ids[user_id] = this_node["nodes"][0]["id"] - found_match = False - for user in this_node["nodes"]: - if len(user["login"]) <= len(shortest_name) and user["login"].startswith( - user_id - ): - shortest_name = user["login"] - names_to_ids[user_id] = user["id"] - names_to_logins[user_id] = user["login"] - found_match = True - if not found_match: - logging.warning( - "Couldn't find a prefixed match for {}, going with {} instead".format( - user_id, shortest_name - ) - ) + teams_to_members[team_ref] = member_logins + else: + logging.warning("Couldn't find a team matching {}/{}".format(org, slug)) + return teams_to_ids, teams_to_members - labels_to_ids: Dict[str, str] = {} - for i, label in enumerate(labels): - this_node = pr_result["data"]["repository"][label_out[i]] - if this_node is not None: - labels_to_ids[label] = this_node["id"] - else: - logging.warning("Couldn't find an existing label named {}".format(label)) - - teams_to_ids: Dict[str, str] = {} - teams_to_members: Dict[str, Optional[Set[str]]] = {} - for i, (org, slug) in enumerate(teams): - team_node = pr_result["data"][team_out[i]] - if team_node is not None and team_node["team"] is not None: - team_ref = f"{org}/{slug}" - teams_to_ids[team_ref] = team_node["team"]["id"] - members_node = team_node["team"]["members"] - member_logins = {m["login"] for m in members_node["nodes"]} - if members_node["totalCount"] > len(members_node["nodes"]): - # Team has more members than we fetched; we can't check membership reliably. - teams_to_members[team_ref] = None - else: - teams_to_members[team_ref] = member_logins - else: - logging.warning("Couldn't find a team matching {}/{}".format(org, slug)) + +class Github(Forge): + def __init__( + self, + endpoint: GitHubEndpoint, + repo_info: ForgeRepoInfo, + fork_info: ForgeRepoInfo, + ): + self.endpoint = endpoint + self.repo_info = repo_info + self.fork_info = fork_info + + @property + def repo_owner(self) -> str: + return self.fork_info.owner + + @property + def repo_name(self) -> str: + return self.repo_info.name + + @property + def is_fork(self) -> bool: + return self.fork_info.owner != self.repo_info.owner + + async def close(self) -> None: + await self.endpoint.close() + + def _make_query_everything( + self, + head_refs: List[str], + user_ids: List[str], + labels: List[str], + teams: List[Tuple[str, str]], + ) -> Tuple[GraphqlQuery, QueryGroup, QueryGroup, QueryGroup, QueryGroup]: + q = GraphqlQuery(name="GetEverything") + q.add_fixed_var("owner", "String!", self.repo_info.owner) + q.add_fixed_var("name", "String!", self.repo_info.name) + q.fixed_repo_fields = "id\n" + + pr_group = _make_pr_group(head_refs) + user_group = _make_user_group(user_ids) + label_group = _make_label_group(labels) + team_group = _make_team_group(teams) + + q.add_group(pr_group) + q.add_group(user_group) + q.add_group(label_group) + q.add_group(team_group) + + return q, pr_group, user_group, label_group, team_group + + async def query_everything( + self, + head_refs: List[str], + user_ids: List[str], + labels: List[str], + teams: List[Tuple[str, str]], + ) -> Tuple[ + str, + List[Optional[PrInfo]], + Dict[str, str], + Dict[str, str], + Dict[str, str], + Dict[str, str], + Dict[str, Optional[Set[str]]], + ]: + q, pr_group, user_group, label_group, team_group = self._make_query_everything( + head_refs, user_ids, labels, teams + ) + + query_str, variables = q.build() + result = await self.endpoint.graphql(query_str, **variables) + + repo_id = result["data"]["repository"]["id"] + prs = _parse_prs(pr_group, result, head_refs) + names_to_ids, names_to_logins = _parse_users(user_group, result, user_ids) + labels_to_ids = _parse_labels(label_group, result, labels) + teams_to_ids, teams_to_members = _parse_teams(team_group, result, teams) return ( - pr_result["data"]["repository"]["id"], + repo_id, prs, names_to_ids, names_to_logins, @@ -451,39 +460,49 @@ async def create_pull_requests(self, repo_id: str, prs: List[PrInfo]) -> None: "title": pr.title, "draft": pr.is_draft, }) - inputs_args = _get_args_dict(inputs, "pr") - prs_out = _get_result_args(len(inputs), "pr_out") - - arg_str = ", ".join(_get_args_declaration(inputs_args, "CreatePullRequestInput!")) - request_str = "".join( - len(inputs) - * [ - """ - {}: createPullRequest(input: ${}) {{ + group = QueryGroup( + prefix="pr", + scope="mutation", + field_template=""" + {}: createPullRequest(input: {}) {{ pullRequest {{ id url }} - }},""" - ] + }},""", + var_types=["CreatePullRequestInput!"], ) - request_str = request_str.format(*_zip_and_flatten(prs_out, inputs_args.keys())) + for inp in inputs: + group.add(inp) - mutation_str = f""" - mutation ({arg_str}) {{ - {request_str} - }}""" + q = GraphqlQuery(operation="mutation") + q.add_group(group) + query_str, variables = q.build() - # Creating a pull request can fail if the branch is already merged. - pr_results = await self.endpoint.graphql(mutation_str, require_success=False, **inputs_args) + pr_results = await self.endpoint.graphql(query_str, **variables) + raw = group.extract(pr_results) for i, pr in enumerate(prs): - result = pr_results["data"][prs_out[i]]["pullRequest"] - if result is not None: - pr.id = result["id"] - pr.url = result["url"] + result_node = raw[i]["pullRequest"] + if result_node is not None: + pr.id = result_node["id"] + pr.url = result_node["url"] async def update_pull_requests(self, prs: List[PrUpdate]) -> None: + q = self._build_update_mutation(prs) + query_str, variables = q.build() + try: + await self.endpoint.graphql(query_str, **variables) + except RevupForgeException as e: + if "timeout" in e.message: + logging.warning( + "Github update request timed out! Most likely this is a false alarm and changes" + " actually succeeded. You may want to rerun this command to verify." + ) + else: + raise + + def _build_update_mutation(self, prs: List[PrUpdate]) -> GraphqlQuery: inputs = [] labels = [] reviewers = [] @@ -511,7 +530,6 @@ async def update_pull_requests(self, prs: List[PrUpdate]) -> None: "clientMutationId": "revup", "labelableId": pr.id, }) - if pr.reviewer_ids or pr.reviewer_team_ids: reviewers.append({ "userIds": list(pr.reviewer_ids), @@ -526,7 +544,6 @@ async def update_pull_requests(self, prs: List[PrUpdate]) -> None: "clientMutationId": "revup", "assignableId": pr.id, }) - if pr.is_draft is not None: if pr.is_draft: convert_to_draft.append({ @@ -538,7 +555,6 @@ async def update_pull_requests(self, prs: List[PrUpdate]) -> None: "clientMutationId": "revup", "pullRequestId": pr.id, }) - for c in pr.comments: if c.id: edit_comments.append({ @@ -553,165 +569,112 @@ async def update_pull_requests(self, prs: List[PrUpdate]) -> None: "subjectId": pr.id, }) - inputs_args = _get_args_dict(inputs, "pr") - prs_out = _get_result_args(len(inputs), "pr_out") - - labels_args = _get_args_dict(labels, "label") - labels_out = _get_result_args(len(labels), "label_out") - - reviewers_args = _get_args_dict(reviewers, "rev") - reviewers_out = _get_result_args(len(reviewers), "rev_out") - - assignees_args = _get_args_dict(assignees, "asn") - assignees_out = _get_result_args(len(assignees), "asn_out") - - to_draft_args = _get_args_dict(convert_to_draft, "to_d") - to_draft_out = _get_result_args(len(convert_to_draft), "to_d_out") - - from_draft_args = _get_args_dict(convert_from_draft, "from_d") - from_draft_out = _get_result_args(len(convert_from_draft), "from_d_out") - - comments_args = _get_args_dict(comments, "com") - comments_out = _get_result_args(len(comments), "com_out") - - edit_comments_args = _get_args_dict(edit_comments, "edit_com") - edit_comments_out = _get_result_args(len(edit_comments), "edit_com_out") - - arg_str = ", ".join( - _get_args_declaration(inputs_args, "UpdatePullRequestInput!") - + _get_args_declaration(labels_args, "AddLabelsToLabelableInput!") - + _get_args_declaration(reviewers_args, "RequestReviewsInput!") - + _get_args_declaration(assignees_args, "AddAssigneesToAssignableInput!") - + _get_args_declaration(to_draft_args, "ConvertPullRequestToDraftInput!") - + _get_args_declaration(from_draft_args, "MarkPullRequestReadyForReviewInput!") - + _get_args_declaration(comments_args, "AddCommentInput!") - + _get_args_declaration(edit_comments_args, "UpdateIssueCommentInput!") - ) - - update_str = "".join( - len(inputs) - * [ - """ - {}: updatePullRequest(input: ${}) {{ + update_group = QueryGroup( + prefix="pr", + scope="mutation", + field_template=""" + {}: updatePullRequest(input: {}) {{ clientMutationId - }},""" - ] + }},""", + var_types=["UpdatePullRequestInput!"], ) - update_str = update_str.format(*_zip_and_flatten(prs_out, inputs_args.keys())) - - request_reviewers_str = "".join( - len(reviewers_args) - * [ - """ - {}: requestReviews(input: ${}) {{ + for inp in inputs: + update_group.add(inp) + + label_group = QueryGroup( + prefix="label", + scope="mutation", + field_template=""" + {}: addLabelsToLabelable(input: {}) {{ clientMutationId - }},""" - ] - ) - request_reviewers_str = request_reviewers_str.format( - *_zip_and_flatten(reviewers_out, reviewers_args.keys()) + }},""", + var_types=["AddLabelsToLabelableInput!"], ) - assignees_str = "".join( - len(assignees_args) - * [ - """ - {}: addAssigneesToAssignable(input: ${}) {{ + for inp in labels: + label_group.add(inp) + + reviewer_group = QueryGroup( + prefix="rev", + scope="mutation", + field_template=""" + {}: requestReviews(input: {}) {{ clientMutationId - }},""" - ] + }},""", + var_types=["RequestReviewsInput!"], ) - assignees_str = assignees_str.format( - *_zip_and_flatten(assignees_out, assignees_args.keys()) - ) - - add_labels_str = "".join( - len(labels_args) - * [ - """ - {}: addLabelsToLabelable(input: ${}) {{ + for inp in reviewers: + reviewer_group.add(inp) + + assignee_group = QueryGroup( + prefix="asn", + scope="mutation", + field_template=""" + {}: addAssigneesToAssignable(input: {}) {{ clientMutationId - }},""" - ] + }},""", + var_types=["AddAssigneesToAssignableInput!"], ) - add_labels_str = add_labels_str.format(*_zip_and_flatten(labels_out, labels_args.keys())) - - to_draft_str = "".join( - len(convert_to_draft) - * [ - """ - {}: convertPullRequestToDraft(input: ${}) {{ + for inp in assignees: + assignee_group.add(inp) + + to_draft_group = QueryGroup( + prefix="to_d", + scope="mutation", + field_template=""" + {}: convertPullRequestToDraft(input: {}) {{ clientMutationId - }},""" - ] + }},""", + var_types=["ConvertPullRequestToDraftInput!"], ) - to_draft_str = to_draft_str.format(*_zip_and_flatten(to_draft_out, to_draft_args.keys())) - - from_draft_str = "".join( - len(convert_from_draft) - * [ - """ - {}: markPullRequestReadyForReview(input: ${}) {{ + for inp in convert_to_draft: + to_draft_group.add(inp) + + from_draft_group = QueryGroup( + prefix="from_d", + scope="mutation", + field_template=""" + {}: markPullRequestReadyForReview(input: {}) {{ clientMutationId - }},""" - ] - ) - from_draft_str = from_draft_str.format( - *_zip_and_flatten(from_draft_out, from_draft_args.keys()) + }},""", + var_types=["MarkPullRequestReadyForReviewInput!"], ) - - add_comments_str = "".join( - len(comments_args) - * [ - """ - {}: addComment(input: ${}) {{ + for inp in convert_from_draft: + from_draft_group.add(inp) + + comment_group = QueryGroup( + prefix="com", + scope="mutation", + field_template=""" + {}: addComment(input: {}) {{ clientMutationId - }},""" - ] + }},""", + var_types=["AddCommentInput!"], ) - add_comments_str = add_comments_str.format( - *_zip_and_flatten(comments_out, comments_args.keys()) - ) - - edit_comments_str = "".join( - len(edit_comments_args) - * [ - """ - {}: updateIssueComment(input: ${}) {{ + for inp in comments: + comment_group.add(inp) + + edit_comment_group = QueryGroup( + prefix="edit_com", + scope="mutation", + field_template=""" + {}: updateIssueComment(input: {}) {{ clientMutationId - }},""" - ] - ) - edit_comments_str = edit_comments_str.format( - *_zip_and_flatten(edit_comments_out, edit_comments_args.keys()) + }},""", + var_types=["UpdateIssueCommentInput!"], ) - - # Add comment mutations first to ensure comments are at the top of the PR - mutation_str = f""" - mutation ({arg_str}) {{ - {add_comments_str}{update_str}{request_reviewers_str}{assignees_str}{add_labels_str}\ -{to_draft_str}{from_draft_str}{edit_comments_str} - }}""" - - try: - await self.endpoint.graphql( - mutation_str, - **comments_args, - **inputs_args, - **reviewers_args, - **assignees_args, - **labels_args, - **to_draft_args, - **from_draft_args, - **edit_comments_args, - ) - except RevupForgeException as e: - if "timeout" in e.message: - logging.warning( - "Github update request timed out! Most likely this is a false alarm and changes" - " actually succeeded. You may want to rerun this command to verify." - ) - else: - raise e + for inp in edit_comments: + edit_comment_group.add(inp) + + q = GraphqlQuery(operation="mutation") + q.add_group(comment_group) + q.add_group(update_group) + q.add_group(reviewer_group) + q.add_group(assignee_group) + q.add_group(label_group) + q.add_group(to_draft_group) + q.add_group(from_draft_group) + q.add_group(edit_comment_group) + return q async def query_pr_by_number(self, owner: str, name: str, number: int) -> Tuple[str, str]: result = await self.endpoint.graphql( diff --git a/revup/github/graphql.py b/revup/github/graphql.py new file mode 100644 index 0000000..0271010 --- /dev/null +++ b/revup/github/graphql.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Tuple + + +@dataclass +class QueryGroup: + """A batch of homogeneous aliased fields in a GraphQL operation. + + All items share the same field_template and variable types. + Supports slicing to split items across multiple queries while + preserving original alias indices (so merged results don't collide). + """ + + prefix: str + scope: str # "repo" | "top" | "mutation" + field_template: str + var_types: List[str] + fragment: str = "" + values: List[List[Any]] = field(default_factory=list) + _offset: int = 0 + + def add(self, *values: Any) -> str: + assert len(values) == len(self.var_types) + idx = len(self.values) + self.values.append(list(values)) + return self.alias(idx) + + def alias(self, idx: int) -> str: + return f"{self.prefix}_out{self._offset + idx}" + + @property + def aliases(self) -> List[str]: + return [self.alias(i) for i in range(len(self.values))] + + def var_name(self, item_idx: int, var_idx: int) -> str: + actual = self._offset + item_idx + if len(self.var_types) == 1: + return f"{self.prefix}{actual}" + return f"{self.prefix}{actual}_{var_idx}" + + def __len__(self) -> int: + return len(self.values) + + def _render_fields(self) -> str: + parts: List[str] = [] + for i in range(len(self.values)): + var_names = [f"${self.var_name(i, j)}" for j in range(len(self.var_types))] + parts.append(self.field_template.format(self.alias(i), *var_names)) + return "".join(parts) + + def _render_declarations(self) -> List[str]: + decls: List[str] = [] + for i in range(len(self.values)): + for j, vtype in enumerate(self.var_types): + decls.append(f"${self.var_name(i, j)}: {vtype}") + return decls + + def _render_variables(self) -> Dict[str, Any]: + variables: Dict[str, Any] = {} + for i, vals in enumerate(self.values): + for j, val in enumerate(vals): + variables[self.var_name(i, j)] = val + return variables + + def extract(self, result: Any) -> List[Any]: + if self.scope == "repo": + repo = result["data"]["repository"] + return [repo[self.alias(i)] for i in range(len(self.values))] + else: + data = result["data"] + return [data[self.alias(i)] for i in range(len(self.values))] + + def slice(self, start: int, end: int) -> QueryGroup: + g = QueryGroup( + prefix=self.prefix, + scope=self.scope, + field_template=self.field_template, + var_types=list(self.var_types), + fragment=self.fragment, + _offset=self._offset + start, + ) + g.values = self.values[start:end] + return g + + +class GraphqlQuery: + """Builds a GraphQL query/mutation from composable, sliceable groups.""" + + def __init__(self, operation: str = "query", name: str = ""): + self.operation = operation + self.name = name + self.fixed_vars: List[Tuple[str, str, Any]] = [] + self.fixed_repo_fields: str = "" + self.groups: List[QueryGroup] = [] + + def add_fixed_var(self, name: str, gql_type: str, value: Any) -> None: + self.fixed_vars.append((name, gql_type, value)) + + def add_group(self, group: QueryGroup) -> None: + self.groups.append(group) + + def total_items(self) -> int: + return sum(len(g) for g in self.groups) + + def build(self) -> Tuple[str, Dict[str, Any]]: + all_decls: List[str] = [] + variables: Dict[str, Any] = {} + + for name, gql_type, value in self.fixed_vars: + all_decls.append(f"${name}: {gql_type}") + variables[name] = value + + for group in self.groups: + all_decls.extend(group._render_declarations()) + variables.update(group._render_variables()) + + decl_str = ", ".join(all_decls) + name_str = f" {self.name}" if self.name else "" + + repo_fields = self.fixed_repo_fields + top_fields = "" + mutation_fields = "" + for group in self.groups: + rendered = group._render_fields() + if group.scope == "repo": + repo_fields += rendered + elif group.scope == "top": + top_fields += rendered + else: + mutation_fields += rendered + + if self.operation == "query": + body = "" + if repo_fields: + body += f""" + repository(name: $name, owner: $owner) {{ + {repo_fields} + }}""" + body += top_fields + query_str = f""" + {self.operation}{name_str} ({decl_str}) {{{body} + }}""" + else: + query_str = f""" + {self.operation}{name_str} ({decl_str}) {{ + {mutation_fields} + }}""" + + fragments = "" + seen: set = set() + for group in self.groups: + if group.fragment and group.fragment not in seen and len(group) > 0: + fragments += group.fragment + seen.add(group.fragment) + query_str += fragments + + return query_str, variables + + def split(self) -> Tuple[GraphqlQuery, GraphqlQuery]: + """Split into two queries by halving each group's items. + + Alias indices are preserved so merged results don't collide. + """ + left = GraphqlQuery(operation=self.operation, name=self.name) + right = GraphqlQuery(operation=self.operation, name=self.name) + + left.fixed_vars = list(self.fixed_vars) + right.fixed_vars = list(self.fixed_vars) + left.fixed_repo_fields = self.fixed_repo_fields + right.fixed_repo_fields = self.fixed_repo_fields + + for group in self.groups: + mid = len(group) // 2 + if mid == 0: + left.add_group(group.slice(0, len(group))) + right.add_group(group.slice(0, 0)) + else: + left.add_group(group.slice(0, mid)) + right.add_group(group.slice(mid, len(group))) + + return left, right