Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions src/ghstack/github.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3

import asyncio
from abc import ABCMeta, abstractmethod
from typing import Any, Sequence

Expand Down Expand Up @@ -64,29 +65,59 @@ def get(self, path: str, **kwargs: Any) -> Any:

Returns: parsed JSON response
"""
return self.rest("get", path, **kwargs)
return self._run_async(self.aget(path, **kwargs))

def post(self, path: str, **kwargs: Any) -> Any:
"""
Send a POST request to endpoint 'path'.

Returns: parsed JSON response
"""
return self.rest("post", path, **kwargs)
return self._run_async(self.apost(path, **kwargs))

def patch(self, path: str, **kwargs: Any) -> Any:
"""
Send a PATCH request to endpoint 'path'.

Returns: parsed JSON response
"""
return self.rest("patch", path, **kwargs)
return self._run_async(self.apatch(path, **kwargs))

@abstractmethod
def rest(self, method: str, path: str, **kwargs: Any) -> Any:
"""
Send a 'method' request to endpoint 'path'.

Args:
method: 'GET', 'POST', etc.
path: relative URL path to access on endpoint
**kwargs: dictionary of JSON payload to send

Returns: parsed JSON response
"""
return self._run_async(self.arest(method, path, **kwargs))

@staticmethod
def _run_async(coro: Any) -> Any:
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(coro)
finally:
loop.close()

async def aget(self, path: str, **kwargs: Any) -> Any:
return await self.arest("get", path, **kwargs)

async def apost(self, path: str, **kwargs: Any) -> Any:
return await self.arest("post", path, **kwargs)

async def apatch(self, path: str, **kwargs: Any) -> Any:
return await self.arest("patch", path, **kwargs)

@abstractmethod
async def arest(self, method: str, path: str, **kwargs: Any) -> Any:
"""
Send an async 'method' request to endpoint 'path'.

Args:
method: 'GET', 'POST', etc.
path: relative URL path to access on endpoint
Expand Down
49 changes: 33 additions & 16 deletions src/ghstack/github_fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,19 @@ def pullRequests(self, info: GraphQLResolveInfo) -> "PullRequestConnection":
# TODO: This should take which repository the ref is in
# This only works if you have upstream_sh
def _make_ref(self, state: GitHubState, refName: str) -> "Ref":
return ghstack.github.GitHubEndpoint._run_async(
self._make_ref_async(state, refName)
)

async def _make_ref_async(self, state: GitHubState, refName: str) -> "Ref":
# TODO: Probably should preserve object identity here when
# you call this with refName/oid that are the same
assert state.upstream_sh
gitObject = GitObject(
id=state.next_id(),
# TODO: this upstream_sh hardcode wrong, but ok for now
# because we only have one repo
oid=GitObjectID(state.upstream_sh.git("rev-parse", refName)),
oid=GitObjectID(await state.upstream_sh.agit("rev-parse", refName)),
_repository=self.id,
)
ref = Ref(
Expand Down Expand Up @@ -366,7 +371,7 @@ def push_hook(self, refNames: Sequence[str]) -> None:
def notify_merged(self, pr_resolved: ghstack.diff.PullRequestResolved) -> None:
self.state.notify_merged(pr_resolved)

def _create_pull(
async def _create_pull_async(
self, owner: str, name: str, input: CreatePullRequestInput
) -> CreatePullRequestPayload:
state = self.state
Expand All @@ -378,8 +383,8 @@ def _create_pull(
# TODO: When we support forks, this needs rewriting to stop
# hard coded the repo we opened the pull request on
if state.upstream_sh:
baseRef = repo._make_ref(state, input["base"])
headRef = repo._make_ref(state, input["head"])
baseRef = await repo._make_ref_async(state, input["base"])
headRef = await repo._make_ref_async(state, input["head"])
pr = PullRequest(
id=id,
_repository=repo.id,
Expand All @@ -403,7 +408,7 @@ def _create_pull(

# NB: This technically does have a payload, but we don't
# use it so I didn't bother constructing it.
def _update_pull(
async def _update_pull_async(
self, owner: str, name: str, number: GitHubNumber, input: UpdatePullRequestInput
) -> None:
state = self.state
Expand All @@ -415,11 +420,11 @@ def _update_pull(
pr.title = input["title"]
if "base" in input and input["base"] is not None:
pr.baseRefName = input["base"]
pr.baseRef = repo._make_ref(state, pr.baseRefName)
pr.baseRef = await repo._make_ref_async(state, pr.baseRefName)
if "body" in input and input["body"] is not None:
pr.body = input["body"]

def _create_issue_comment(
async def _create_issue_comment_async(
self, owner: str, name: str, comment_id: int, input: CreateIssueCommentInput
) -> CreateIssueCommentPayload:
state = self.state
Expand All @@ -439,7 +444,7 @@ def _create_issue_comment(
"id": comment_id,
}

def _update_issue_comment(
async def _update_issue_comment_async(
self, owner: str, name: str, comment_id: int, input: UpdateIssueCommentInput
) -> None:
state = self.state
Expand All @@ -450,14 +455,19 @@ def _update_issue_comment(

# NB: This may have a payload, but we don't
# use it so I didn't bother constructing it.
def _set_default_branch(
async def _set_default_branch_async(
self, owner: str, name: str, input: SetDefaultBranchInput
) -> None:
state = self.state
repo = state.repository(owner, name)
repo.defaultBranchRef = repo._make_ref(state, input["default_branch"])
repo.defaultBranchRef = await repo._make_ref_async(
state, input["default_branch"]
)

def rest(self, method: str, path: str, **kwargs: Any) -> Any:
async def arest(self, method: str, path: str, **kwargs: Any) -> Any:
return await self._arest_impl(method, path, **kwargs)

async def _arest_impl(self, method: str, path: str, **kwargs: Any) -> Any:
if method == "get":
m = re.match(r"^repos/([^/]+)/([^/]+)/branches/([^/]+)/protection", path)
if m:
Expand All @@ -472,6 +482,12 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any:
"state": "closed" if pr.closed else "open",
"title": pr.title,
"body": pr.body,
"head": {
"ref": pr.headRefName,
},
"base": {
"ref": pr.baseRefName,
},
}
if m := re.match(r"^repos/([^/]+)/([^/]+)/issues/comments/([^/]+)$", path):
state = self.state
Expand All @@ -484,11 +500,11 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any:

elif method == "post":
if m := re.match(r"^repos/([^/]+)/([^/]+)/pulls$", path):
return self._create_pull(
return await self._create_pull_async(
m.group(1), m.group(2), cast(CreatePullRequestInput, kwargs)
)
if m := re.match(r"^repos/([^/]+)/([^/]+)/issues/([^/]+)/comments", path):
return self._create_issue_comment(
return await self._create_issue_comment_async(
m.group(1),
m.group(2),
GitHubNumber(int(m.group(3))),
Expand Down Expand Up @@ -516,23 +532,24 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any:
if m := re.match(r"^repos/([^/]+)/([^/]+)(?:/pulls/([^/]+))?$", path):
owner, name, number = m.groups()
if number is not None:
return self._update_pull(
return await self._update_pull_async(
owner,
name,
GitHubNumber(int(number)),
cast(UpdatePullRequestInput, kwargs),
)
elif "default_branch" in kwargs:
return self._set_default_branch(
return await self._set_default_branch_async(
owner, name, cast(SetDefaultBranchInput, kwargs)
)
if m := re.match(r"^repos/([^/]+)/([^/]+)/issues/comments/([^/]+)$", path):
return self._update_issue_comment(
return await self._update_issue_comment_async(
m.group(1),
m.group(2),
int(m.group(3)),
cast(UpdateIssueCommentInput, kwargs),
)

raise NotImplementedError(
"FakeGitHubEndpoint REST {} {} not implemented".format(method.upper(), path)
)
Loading
Loading