Skip to content

Commit 5e7e701

Browse files
committed
feat: add alternative download methods to resolver API
Extend the resolver API with alternative download URLs. Resolvers can now return download links to alternative locations or retrieval methods. The `PyPIProvider` now accepts a `override_download_url` parameter. The value overwrites the default PyPI download link. The string can contain a `{version}` format variable. The GitHub and Gitlab tag providers can return git clone URLs for `https` and `ssh` transport. The URLs uses pip's VCS syntax like `git+https://host/repo.git@tag`. The new enum `RetrieveMethod` has a `from_url()` constructor that parses an URL and splits it into method, url, and git ref. Signed-off-by: Christian Heimes <cheimes@redhat.com>
1 parent 81d7024 commit 5e7e701

File tree

2 files changed

+210
-8
lines changed

2 files changed

+210
-8
lines changed

src/fromager/resolver.py

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from __future__ import annotations
77

88
import datetime
9+
import enum
910
import functools
1011
import logging
1112
import os
@@ -14,7 +15,7 @@
1415
from collections.abc import Iterable
1516
from operator import attrgetter
1617
from platform import python_version
17-
from urllib.parse import quote, unquote, urljoin, urlparse
18+
from urllib.parse import quote, unquote, urljoin, urlparse, urlsplit, urlunsplit
1819

1920
import pypi_simple
2021
import resolvelib
@@ -176,11 +177,42 @@ def resolve_from_provider(
176177
raise ValueError(f"Unable to resolve {req}")
177178

178179

180+
class RetrieveMethod(enum.StrEnum):
181+
tarball = "tarball"
182+
git_https = "git+https"
183+
git_ssh = "git+ssh"
184+
185+
@classmethod
186+
def from_url(cls, download_url) -> tuple[RetrieveMethod, str, str | None]:
187+
"""Parse a download URL into method, url, reference"""
188+
scheme, netloc, path, query, fragment = urlsplit(
189+
download_url, allow_fragments=False
190+
)
191+
match scheme:
192+
case "https":
193+
return RetrieveMethod.tarball, download_url, None
194+
case "git+https":
195+
method = RetrieveMethod.git_https
196+
case "git+ssh":
197+
method = RetrieveMethod.git_ssh
198+
case _:
199+
raise ValueError(f"unsupported download URL {download_url!r}")
200+
# remove git+
201+
scheme = scheme[4:]
202+
# split off @ revision
203+
if "@" not in path:
204+
raise ValueError(f"git download url {download_url!r} is missing '@ref'")
205+
path, ref = path.rsplit("@", 1)
206+
return method, urlunsplit((scheme, netloc, path, query, fragment)), ref
207+
208+
179209
def get_project_from_pypi(
180210
project: str,
181211
extras: typing.Iterable[str],
182212
sdist_server_url: str,
183213
ignore_platform: bool = False,
214+
*,
215+
override_download_url: str | None = None,
184216
) -> Candidates:
185217
"""Return candidates created from the project name and extras."""
186218
found_candidates: set[str] = set()
@@ -341,14 +373,19 @@ def get_project_from_pypi(
341373
ignored_candidates.add(dp.filename)
342374
continue
343375

376+
if override_download_url is None:
377+
url = dp.url
378+
else:
379+
url = override_download_url.format(version=version)
380+
344381
upload_time = dp.upload_time
345382
if upload_time is not None:
346383
upload_time = upload_time.astimezone(datetime.UTC)
347384

348385
c = Candidate(
349386
name=name,
350387
version=version,
351-
url=dp.url,
388+
url=url,
352389
extras=tuple(sorted(extras)),
353390
is_sdist=is_sdist,
354391
build_tag=build_tag,
@@ -569,6 +606,7 @@ def __init__(
569606
ignore_platform: bool = False,
570607
*,
571608
use_resolver_cache: bool = True,
609+
override_download_url: str | None = None,
572610
):
573611
super().__init__(
574612
constraints=constraints,
@@ -579,6 +617,7 @@ def __init__(
579617
self.include_wheels = include_wheels
580618
self.sdist_server_url = sdist_server_url
581619
self.ignore_platform = ignore_platform
620+
self.override_download_url = override_download_url
582621

583622
@property
584623
def cache_key(self) -> str:
@@ -591,9 +630,10 @@ def cache_key(self) -> str:
591630
def find_candidates(self, identifier: str) -> Candidates:
592631
return get_project_from_pypi(
593632
identifier,
594-
set(),
595-
self.sdist_server_url,
596-
self.ignore_platform,
633+
extras=set(),
634+
sdist_server_url=self.sdist_server_url,
635+
ignore_platform=self.ignore_platform,
636+
override_download_url=self.override_download_url,
597637
)
598638

599639
def validate_candidate(
@@ -764,6 +804,7 @@ def __init__(
764804
*,
765805
req_type: RequirementType | None = None,
766806
use_resolver_cache: bool = True,
807+
retrieve_method: RetrieveMethod = RetrieveMethod.tarball,
767808
):
768809
super().__init__(
769810
constraints=constraints,
@@ -774,6 +815,7 @@ def __init__(
774815
)
775816
self.organization = organization
776817
self.repo = repo
818+
self.retrieve_method = retrieve_method
777819

778820
@property
779821
def cache_key(self) -> str:
@@ -808,7 +850,14 @@ def _find_tags(
808850
logger.debug(f"{identifier}: match function ignores {tagname}")
809851
continue
810852
assert isinstance(version, Version)
811-
url = entry["tarball_url"]
853+
854+
match self.retrieve_method:
855+
case RetrieveMethod.tarball:
856+
url = entry["tarball_url"]
857+
case RetrieveMethod.git_https:
858+
url = f"git+https://{self.host}/{self.organization}/{self.repo}.git@{tagname}"
859+
case RetrieveMethod.git_ssh:
860+
url = f"git+ssh://git@{self.host}/{self.organization}/{self.repo}.git@{tagname}"
812861

813862
# Github tag API endpoint does not include commit date information.
814863
# It would be too expensive to query every commit API endpoint.
@@ -837,6 +886,7 @@ def __init__(
837886
*,
838887
req_type: RequirementType | None = None,
839888
use_resolver_cache: bool = True,
889+
retrieve_method: RetrieveMethod = RetrieveMethod.tarball,
840890
) -> None:
841891
super().__init__(
842892
constraints=constraints,
@@ -846,6 +896,7 @@ def __init__(
846896
matcher=matcher,
847897
)
848898
self.server_url = server_url.rstrip("/")
899+
self.server_hostname = urlparse(server_url).hostname
849900
self.project_path = project_path.lstrip("/")
850901
# URL-encode the project path as required by GitLab API.
851902
# The safe="" parameter tells quote() to encode ALL characters,
@@ -856,6 +907,7 @@ def __init__(
856907
self.api_url = (
857908
f"{self.server_url}/api/v4/projects/{encoded_path}/repository/tags"
858909
)
910+
self.retrieve_method = retrieve_method
859911

860912
@property
861913
def cache_key(self) -> str:
@@ -884,8 +936,14 @@ def _find_tags(
884936
continue
885937
assert isinstance(version, Version)
886938

887-
archive_path: str = f"{self.project_path}/-/archive/{tagname}/{self.project_path.split('/')[-1]}-{tagname}.tar.gz"
888-
url = urljoin(self.server_url, archive_path)
939+
match self.retrieve_method:
940+
case RetrieveMethod.tarball:
941+
archive_path: str = f"{self.project_path}/-/archive/{tagname}/{self.project_path.split('/')[-1]}-{tagname}.tar.gz"
942+
url = urljoin(self.server_url, archive_path)
943+
case RetrieveMethod.git_https:
944+
url = f"git+https://{self.server_hostname}/{self.project_path}.git@{tagname}"
945+
case RetrieveMethod.git_ssh:
946+
url = f"git+ssh://git@{self.server_hostname}/{self.project_path}.git@{tagname}"
889947

890948
# get tag creation time, fall back to commit creation time
891949
created_at_str: str | None = entry.get("created_at")

tests/test_resolver.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,26 @@ def test_provider_constraint_match() -> None:
370370
assert str(candidate.version) == "1.2.2"
371371

372372

373+
def test_provider_override_download_url() -> None:
374+
with requests_mock.Mocker() as r:
375+
r.get(
376+
"https://pypi.org/simple/hydra-core/",
377+
text=_hydra_core_simple_response,
378+
)
379+
380+
provider = resolver.PyPIProvider(
381+
override_download_url="https://server.test/hydr_core-{version}.tar.gz"
382+
)
383+
reporter: resolvelib.BaseReporter = resolvelib.BaseReporter()
384+
rslvr = resolvelib.Resolver(provider, reporter)
385+
386+
result = rslvr.resolve([Requirement("hydra-core")])
387+
assert "hydra-core" in result.mapping
388+
389+
candidate = result.mapping["hydra-core"]
390+
assert candidate.url == "https://server.test/hydr_core-1.3.2.tar.gz"
391+
392+
373393
_ignore_platform_simple_response = """
374394
<!DOCTYPE html>
375395
<html>
@@ -715,6 +735,51 @@ def test_resolve_github() -> None:
715735
)
716736

717737

738+
@pytest.mark.parametrize(
739+
["retrieve_method", "expected_url"],
740+
[
741+
(
742+
resolver.RetrieveMethod.tarball,
743+
"https://api.github.com/repos/python-wheel-build/fromager/tarball/refs/tags/0.9.0",
744+
),
745+
(
746+
resolver.RetrieveMethod.git_https,
747+
"git+https://github.com:443/python-wheel-build/fromager.git@0.9.0",
748+
),
749+
(
750+
resolver.RetrieveMethod.git_ssh,
751+
"git+ssh://git@github.com:443/python-wheel-build/fromager.git@0.9.0",
752+
),
753+
],
754+
)
755+
def test_resolve_github_retrieve_method(
756+
retrieve_method: resolver.RetrieveMethod, expected_url: str
757+
) -> None:
758+
with requests_mock.Mocker() as r:
759+
r.get(
760+
"https://api.github.com:443/repos/python-wheel-build/fromager",
761+
text=_github_fromager_repo_response,
762+
)
763+
r.get(
764+
"https://api.github.com:443/repos/python-wheel-build/fromager/tags",
765+
text=_github_fromager_tag_response,
766+
)
767+
768+
provider = resolver.GitHubTagProvider(
769+
organization="python-wheel-build",
770+
repo="fromager",
771+
retrieve_method=retrieve_method,
772+
)
773+
reporter: resolvelib.BaseReporter = resolvelib.BaseReporter()
774+
rslvr = resolvelib.Resolver(provider, reporter)
775+
776+
result = rslvr.resolve([Requirement("fromager")])
777+
assert "fromager" in result.mapping
778+
779+
candidate = result.mapping["fromager"]
780+
assert candidate.url == expected_url
781+
782+
718783
def test_github_constraint_mismatch() -> None:
719784
constraint = constraints.Constraints()
720785
constraint.add_constraint("fromager>=1.0")
@@ -922,6 +987,49 @@ def test_resolve_gitlab() -> None:
922987
)
923988

924989

990+
@pytest.mark.parametrize(
991+
["retrieve_method", "expected_url"],
992+
[
993+
(
994+
resolver.RetrieveMethod.tarball,
995+
"https://gitlab.com/mirrors/github/decile-team/submodlib/-/archive/v0.0.3/submodlib-v0.0.3.tar.gz",
996+
),
997+
(
998+
resolver.RetrieveMethod.git_https,
999+
"git+https://gitlab.com/mirrors/github/decile-team/submodlib.git@v0.0.3",
1000+
),
1001+
(
1002+
resolver.RetrieveMethod.git_ssh,
1003+
"git+ssh://git@gitlab.com/mirrors/github/decile-team/submodlib.git@v0.0.3",
1004+
),
1005+
],
1006+
)
1007+
def test_resolve_gitlab_retrieve_method(
1008+
retrieve_method: resolver.RetrieveMethod, expected_url: str
1009+
) -> None:
1010+
with requests_mock.Mocker() as r:
1011+
r.get(
1012+
"https://gitlab.com/api/v4/projects/mirrors%2Fgithub%2Fdecile-team%2Fsubmodlib/repository/tags",
1013+
text=_gitlab_submodlib_repo_response,
1014+
)
1015+
1016+
provider = resolver.GitLabTagProvider(
1017+
project_path="mirrors/github/decile-team/submodlib",
1018+
server_url="https://gitlab.com",
1019+
retrieve_method=retrieve_method,
1020+
)
1021+
reporter: resolvelib.BaseReporter = resolvelib.BaseReporter()
1022+
rslvr = resolvelib.Resolver(provider, reporter)
1023+
1024+
result = rslvr.resolve([Requirement("submodlib")])
1025+
assert "submodlib" in result.mapping
1026+
1027+
candidate = result.mapping["submodlib"]
1028+
assert str(candidate.version) == "0.0.3"
1029+
1030+
assert candidate.url == expected_url
1031+
1032+
9251033
def test_gitlab_constraint_mismatch() -> None:
9261034
constraint = constraints.Constraints()
9271035
constraint.add_constraint("submodlib>=1.0")
@@ -1042,3 +1150,39 @@ def test_pep592_support_constraint_mismatch() -> None:
10421150
def test_extract_filename_from_url(url, filename) -> None:
10431151
result = resolver.extract_filename_from_url(url)
10441152
assert result == filename
1153+
1154+
1155+
@pytest.mark.parametrize(
1156+
["download_url", "expected_method", "expected_url", "expected_ref"],
1157+
[
1158+
(
1159+
"https://api.github.com/repos/python-wheel-build/fromager/tarball/refs/tags/0.9.0",
1160+
resolver.RetrieveMethod.tarball,
1161+
"https://api.github.com/repos/python-wheel-build/fromager/tarball/refs/tags/0.9.0",
1162+
None,
1163+
),
1164+
(
1165+
"git+https://github.com:443/python-wheel-build/fromager.git@0.9.0",
1166+
resolver.RetrieveMethod.git_https,
1167+
"https://github.com:443/python-wheel-build/fromager.git",
1168+
"0.9.0",
1169+
),
1170+
(
1171+
"git+ssh://git@github.com:443/python-wheel-build/fromager.git@0.9.0",
1172+
resolver.RetrieveMethod.git_ssh,
1173+
"ssh://git@github.com:443/python-wheel-build/fromager.git",
1174+
"0.9.0",
1175+
),
1176+
],
1177+
)
1178+
def test_retrieve_method_from_url(
1179+
download_url: str,
1180+
expected_method: resolver.RetrieveMethod,
1181+
expected_url: str,
1182+
expected_ref: str | None,
1183+
) -> None:
1184+
assert resolver.RetrieveMethod.from_url(download_url) == (
1185+
expected_method,
1186+
expected_url,
1187+
expected_ref,
1188+
)

0 commit comments

Comments
 (0)