Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion src/ghstack/github_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/usr/bin/env python3

import re
Expand Down Expand Up @@ -123,8 +123,19 @@
}


# Matches GitHub PR URLs like:
# https://github.com/owner/repo/pull/123
# https://github.com/owner/repo/pull/123/
# https://github.com/owner/repo/pull/123/files
# https://github.com/owner/repo/pull/123/commits
RE_PR_URL = re.compile(
r"^https://(?P<github_url>[^/]+)/(?P<owner>[^/]+)/(?P<name>[^/]+)/pull/(?P<number>[0-9]+)/?$"
r"^https://(?P<github_url>[^/]+)/(?P<owner>[^/]+)/(?P<name>[^/]+)/pull/(?P<number>[0-9]+)(?:/.*)?$"
)

# Matches PyTorch HUD URLs like:
# https://hud.pytorch.org/pr/169404
RE_PYTORCH_HUD_URL = re.compile(
r"^https://hud\.pytorch\.org/pr/(?P<number>[0-9]+)/?$"
)

GitHubPullRequestParams = TypedDict(
Expand All @@ -144,6 +155,17 @@
sh: Optional[ghstack.shell.Shell] = None,
remote_name: Optional[str] = None,
) -> GitHubPullRequestParams:
# Check for PyTorch HUD URL first (hud.pytorch.org/pr/NUMBER)
hud_match = RE_PYTORCH_HUD_URL.match(pull_request)
if hud_match:
number = int(hud_match.group("number"))
return {
"github_url": "github.com",
"owner": "pytorch",
"name": "pytorch",
"number": number,
}

m = RE_PR_URL.match(pull_request)
if not m:
# We can reconstruct the URL if just a PR number is passed
Expand Down
91 changes: 91 additions & 0 deletions test_github_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#!/usr/bin/env python3

import unittest

import ghstack.github_utils


class TestParsePullRequest(unittest.TestCase):
def test_github_url_basic(self) -> None:
result = ghstack.github_utils.parse_pull_request(
"https://github.com/pytorch/pytorch/pull/169404"
)
self.assertEqual(result["github_url"], "github.com")
self.assertEqual(result["owner"], "pytorch")
self.assertEqual(result["name"], "pytorch")
self.assertEqual(result["number"], 169404)

def test_github_url_trailing_slash(self) -> None:
result = ghstack.github_utils.parse_pull_request(
"https://github.com/pytorch/pytorch/pull/169404/"
)
self.assertEqual(result["github_url"], "github.com")
self.assertEqual(result["owner"], "pytorch")
self.assertEqual(result["name"], "pytorch")
self.assertEqual(result["number"], 169404)

def test_github_url_files_suffix(self) -> None:
result = ghstack.github_utils.parse_pull_request(
"https://github.com/pytorch/pytorch/pull/169404/files"
)
self.assertEqual(result["github_url"], "github.com")
self.assertEqual(result["owner"], "pytorch")
self.assertEqual(result["name"], "pytorch")
self.assertEqual(result["number"], 169404)

def test_github_url_commits_suffix(self) -> None:
result = ghstack.github_utils.parse_pull_request(
"https://github.com/pytorch/pytorch/pull/169404/commits"
)
self.assertEqual(result["github_url"], "github.com")
self.assertEqual(result["owner"], "pytorch")
self.assertEqual(result["name"], "pytorch")
self.assertEqual(result["number"], 169404)

def test_github_url_commits_with_sha(self) -> None:
result = ghstack.github_utils.parse_pull_request(
"https://github.com/pytorch/pytorch/pull/169404/commits/abc123def"
)
self.assertEqual(result["github_url"], "github.com")
self.assertEqual(result["owner"], "pytorch")
self.assertEqual(result["name"], "pytorch")
self.assertEqual(result["number"], 169404)

def test_pytorch_hud_url_basic(self) -> None:
result = ghstack.github_utils.parse_pull_request(
"https://hud.pytorch.org/pr/169404"
)
self.assertEqual(result["github_url"], "github.com")
self.assertEqual(result["owner"], "pytorch")
self.assertEqual(result["name"], "pytorch")
self.assertEqual(result["number"], 169404)

def test_pytorch_hud_url_trailing_slash(self) -> None:
result = ghstack.github_utils.parse_pull_request(
"https://hud.pytorch.org/pr/169404/"
)
self.assertEqual(result["github_url"], "github.com")
self.assertEqual(result["owner"], "pytorch")
self.assertEqual(result["name"], "pytorch")
self.assertEqual(result["number"], 169404)

def test_different_owner_repo(self) -> None:
result = ghstack.github_utils.parse_pull_request(
"https://github.com/facebook/react/pull/12345"
)
self.assertEqual(result["github_url"], "github.com")
self.assertEqual(result["owner"], "facebook")
self.assertEqual(result["name"], "react")
self.assertEqual(result["number"], 12345)

def test_invalid_url_raises(self) -> None:
with self.assertRaises(RuntimeError):
ghstack.github_utils.parse_pull_request("not-a-valid-url")

def test_invalid_hud_url_raises(self) -> None:
with self.assertRaises(RuntimeError):
ghstack.github_utils.parse_pull_request("https://hud.pytorch.org/not-pr/123")


if __name__ == "__main__":
unittest.main()
Loading