Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ do that. (There's also a more fundamental reason why this
won't work: since each commit is a separate PR, you have to
resolve conflicts in *each* PR, not just for the entire stack.)

**What if the repository default branch changed?** ghstack caches
repository metadata in `.git/ghstack-repo-info.json` for the local
checkout. If ghstack is still using an old default branch name,
delete that file and rerun ghstack; it will query GitHub again.

**How do I start a new feature?** Just checkout main on a new
branch, and start working on a fresh branch.

Expand Down
33 changes: 31 additions & 2 deletions src/ghstack/github_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#!/usr/bin/env python3

import json
import logging
import os
import re
from typing import Optional

Expand Down Expand Up @@ -60,6 +63,11 @@ def get_github_repo_name_with_owner(
)


def _repo_info_cache_path(sh: ghstack.shell.Shell) -> str:
git_dir = sh.abspath(sh.git("rev-parse", "--git-dir"))
return os.path.join(git_dir, "ghstack-repo-info.json")


def get_github_repo_info(
*,
github: ghstack.github.GitHubEndpoint,
Expand All @@ -78,7 +86,20 @@ def get_github_repo_info(
else:
name_with_owner = {"owner": repo_owner, "name": repo_name}

# TODO: Cache this guy
cache_path = _repo_info_cache_path(sh)
try:
with open(cache_path) as f:
cached = json.load(f)
if (
cached.get("name_with_owner") == name_with_owner
and cached.get("id")
and cached.get("default_branch")
):
logging.debug("Using cached repo info from %s", cache_path)
return cached
except (OSError, json.JSONDecodeError, KeyError):
pass

try:
repo = github.graphql(
"""
Expand Down Expand Up @@ -117,13 +138,21 @@ def get_github_repo_info(
# Re-raise the original error if it's not the repository access issue
raise

return {
result: GitHubRepoInfo = {
"name_with_owner": name_with_owner,
"id": repo["id"],
"is_fork": repo["isFork"],
"default_branch": repo["defaultBranchRef"]["name"],
}

try:
with open(cache_path, "w") as f:
json.dump(result, f)
except OSError:
pass

return result


RE_PR_URL = re.compile(
r"^https://(?P<github_url>[^/]+)/(?P<owner>[^/]+)/(?P<name>[^/]+)/pull/(?P<number>[0-9]+)/?$"
Expand Down
1 change: 1 addition & 0 deletions src/ghstack/test_prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(self, direct: bool) -> None:
local_dir = tempfile.mkdtemp()
self.sh = ghstack.shell.Shell(cwd=local_dir, testing=True)
self.sh.git("clone", upstream_dir, ".")
self.sh.git("fetch", "origin", "+refs/heads/*:refs/remotes/origin/*")
self.direct = direct

def cleanup(self) -> None:
Expand Down
10 changes: 10 additions & 0 deletions test/land/default_branch_change.py.test
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from ghstack.test_prelude import *

init_test()
Expand All @@ -15,6 +16,13 @@ get_github().patch(
name="pytorch",
default_branch="release",
)
# invalidate repo info cache since default branch changed
cache_path = os.path.join(
get_sh().abspath(get_sh().git("rev-parse", "--git-dir")),
"ghstack-repo-info.json",
)
if os.path.exists(cache_path):
os.remove(cache_path)

assert_github_state(
"""\
Expand Down Expand Up @@ -63,6 +71,8 @@ get_github().patch(
name="pytorch",
default_branch="main",
)
if os.path.exists(cache_path):
os.remove(cache_path)

assert_github_state(
"""\
Expand Down
Loading