Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 85 additions & 1 deletion src/ghstack/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import logging
import os
import re
import shutil
import subprocess
from pathlib import Path
from typing import NamedTuple, Optional
from typing import NamedTuple, Optional, Tuple

import requests

Expand All @@ -15,6 +17,71 @@
DEFAULT_GHSTACKRC_PATH = Path.home() / ".ghstackrc"
GHSTACKRC_PATH_VAR = "GHSTACKRC_PATH"


def is_gh_cli_available() -> bool:
"""Check if the GitHub CLI (gh) is available in PATH."""
return shutil.which("gh") is not None


def get_gh_cli_credentials(
github_url: str = "github.com",
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""
Extract credentials from the GitHub CLI if available and authenticated.

Args:
github_url: The GitHub host to get credentials for.

Returns:
A tuple of (token, username, url) or (None, None, None) if unavailable.
"""
if not is_gh_cli_available():
return None, None, None

try:
# Check if gh is authenticated for this host
auth_status = subprocess.run(
["gh", "auth", "status", "-h", github_url],
capture_output=True,
text=True,
)
if auth_status.returncode != 0:
logging.debug(f"gh CLI not authenticated for {github_url}")
return None, None, None

# Get the token
token_result = subprocess.run(
["gh", "auth", "token", "-h", github_url],
capture_output=True,
text=True,
)
if token_result.returncode != 0:
logging.debug("Failed to get token from gh CLI")
return None, None, None
token = token_result.stdout.strip()
if not token:
return None, None, None

# Get the username using gh api
username_result = subprocess.run(
["gh", "api", "user", "-q", ".login", "--hostname", github_url],
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is annoying, it implies an extra network roundtrip at startup every time

capture_output=True,
text=True,
)
username = None
if username_result.returncode == 0:
username = username_result.stdout.strip()

logging.debug(
f"Successfully retrieved credentials from gh CLI for {github_url}"
)
return token, username, github_url

except Exception as e:
logging.debug(f"Error getting credentials from gh CLI: {e}")
return None, None, None


Config = NamedTuple(
"Config",
[
Expand Down Expand Up @@ -97,6 +164,7 @@ def read_config(
# Environment variable overrides config file
# This envvar is legacy from ghexport days
github_oauth = os.getenv("OAUTH_TOKEN")
gh_cli_username = None # Track username from gh CLI
if github_oauth is not None:
logging.warning(
"Deprecated OAUTH_TOKEN environment variable used to populate github_oauth--"
Expand All @@ -105,6 +173,17 @@ def read_config(
)
if github_oauth is None and config.has_option("ghstack", "github_oauth"):
github_oauth = config.get("ghstack", "github_oauth")

# Try GitHub CLI if available and no token found yet
if github_oauth is None and request_github_token:
gh_token, gh_username, _ = get_gh_cli_credentials(github_url)
if gh_token is not None:
print(f"Using GitHub credentials from gh CLI for {github_url}")
github_oauth = gh_token
gh_cli_username = gh_username
# Don't save gh CLI credentials to config - they may change/expire

# Fall back to device flow if still no token
if github_oauth is None and request_github_token:
print("Generating GitHub access token...")
CLIENT_ID = "89cc88ca50efbe86907a"
Expand Down Expand Up @@ -150,6 +229,11 @@ def read_config(
github_username = None
if config.has_option("ghstack", "github_username"):
github_username = config.get("ghstack", "github_username")
# Use username from gh CLI if we got it
if github_username is None and gh_cli_username is not None:
github_username = gh_cli_username
# Don't save gh CLI username to config - it comes from gh CLI
# Fall back to API lookup if we have a token but no username yet
if github_username is None and github_oauth is not None:
request_url: str
if github_url == "github.com":
Expand Down
14 changes: 13 additions & 1 deletion src/ghstack/github_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,20 +138,32 @@ def get_github_repo_info(
)


def _normalize_remote_url(remote_url: str) -> str:
"""Convert SSH remote URL to HTTPS format, strip .git suffix."""
# git@github.com:owner/repo.git -> https://github.com/owner/repo
m = re.match(r"^git@([^:]+):/?(.+?)(?:\.git)?$", remote_url)
if m:
return f"https://{m.group(1)}/{m.group(2)}"
return re.sub(r"\.git$", "", remote_url)


def parse_pull_request(
pull_request: str,
*,
sh: Optional[ghstack.shell.Shell] = None,
remote_name: Optional[str] = None,
) -> GitHubPullRequestParams:
pull_request = pull_request.lstrip("#")
m = RE_PR_URL.match(pull_request)
if not m:
# We can reconstruct the URL if just a PR number is passed
if sh is not None and remote_name is not None:
remote_url = sh.git("remote", "get-url", remote_name)
# Do not pass the shell to avoid infinite loop
try:
return parse_pull_request(remote_url + "/pull/" + pull_request)
return parse_pull_request(
_normalize_remote_url(remote_url) + "/pull/" + pull_request
)
except RuntimeError:
# Fall back on original error message
pass
Expand Down