Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 36 additions & 11 deletions src/ghstack/checkout.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
#!/usr/bin/env python3

import asyncio
import logging
import re
from typing import Iterable

import ghstack.github
import ghstack.github_utils
import ghstack.shell


async def _fetch_refs(
sh: ghstack.shell.Shell, *, remote_name: str, refs: Iterable[str]
) -> None:
refspecs = [
f"+refs/heads/{ref}:refs/remotes/{remote_name}/{ref}"
for ref in sorted(set(refs))
]
await sh.agit("fetch", "--prune", remote_name, *refspecs)


async def main(
pull_request: str,
github: ghstack.github.GitHubEndpoint,
Expand All @@ -19,7 +31,24 @@ async def main(
params = await ghstack.github_utils.parse_pull_request(
pull_request, sh=sh, remote_name=remote_name
)
head_ref = await github.get_head_ref(**params)
head_ref_task = asyncio.ensure_future(github.get_head_ref(**params))

if same_base:
repo_info_task = asyncio.ensure_future(
ghstack.github_utils.get_github_repo_info(
github=github,
sh=sh,
repo_owner=params["owner"],
repo_name=params["name"],
github_url=params["github_url"],
remote_name=remote_name,
)
)
head_ref, repo_info = await asyncio.gather(head_ref_task, repo_info_task)
else:
head_ref = await head_ref_task
repo_info = None

orig_ref = re.sub(r"/head$", "/orig", head_ref)
if orig_ref == head_ref:
logging.warning(
Expand All @@ -30,15 +59,7 @@ async def main(

# If --same-base is specified, check if checkout would change the merge-base
if same_base:
# Get the default branch name from the repo
repo_info = await ghstack.github_utils.get_github_repo_info(
github=github,
sh=sh,
repo_owner=params["owner"],
repo_name=params["name"],
github_url=params["github_url"],
remote_name=remote_name,
)
assert repo_info is not None
default_branch = repo_info["default_branch"]
default_branch_ref = f"{remote_name}/{default_branch}"

Expand All @@ -48,7 +69,11 @@ async def main(
current_base = None
default_branch_ref = None

await sh.agit("fetch", "--prune", remote_name)
refs_to_fetch = [orig_ref]
if same_base:
assert repo_info is not None
refs_to_fetch.append(repo_info["default_branch"])
await _fetch_refs(sh, remote_name=remote_name, refs=refs_to_fetch)

# If --same-base is specified, check what the new merge-base would be
if same_base:
Expand Down
Loading