Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 91 additions & 1 deletion policyengine_uk_data/tests/test_release_manifest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import hashlib
from io import BytesIO
from importlib import metadata
from pathlib import Path
from unittest.mock import MagicMock, patch

import pytest
from huggingface_hub import CommitOperationAdd
from huggingface_hub.errors import EntryNotFoundError

from policyengine_uk_data.utils.data_upload import upload_files_to_hf
from policyengine_uk_data.utils.data_upload import (
_get_model_package_version,
load_release_manifest_from_hf,
upload_files_to_hf,
)
from policyengine_uk_data.utils.release_manifest import (
RELEASE_MANIFEST_SCHEMA_VERSION,
build_release_manifest,
Expand Down Expand Up @@ -81,6 +88,89 @@ def test_build_release_manifest_tracks_uk_release_artifacts(tmp_path):
assert manifest["artifacts"]["local_authority_weights"]["kind"] == "weights"


def test_build_release_manifest_refreshes_compatible_model_packages_for_draft_retry(
tmp_path,
):
dataset_path = _write_file(
tmp_path / "enhanced_frs_2023_24.h5",
b"enhanced-frs",
)

manifest = build_release_manifest(
files_with_repo_paths=[(dataset_path, "enhanced_frs_2023_24.h5")],
version="1.40.4",
repo_id="policyengine/policyengine-uk-data-private",
model_package_version="9.99.9",
existing_manifest={
"schema_version": RELEASE_MANIFEST_SCHEMA_VERSION,
"data_package": {
"name": "policyengine-uk-data",
"version": "1.40.4",
},
"compatible_model_packages": [
{
"name": "policyengine-uk",
"specifier": "==1.0.0",
}
],
"default_datasets": {},
"created_at": "2026-04-10T12:00:00Z",
"artifacts": {},
},
)

assert manifest["compatible_model_packages"] == [
{"name": "policyengine-uk", "specifier": "==9.99.9"}
]


def test_load_release_manifest_from_hf_raises_non_missing_download_errors():
with patch(
"policyengine_uk_data.utils.data_upload.hf_hub_download",
side_effect=RuntimeError("boom"),
):
with pytest.raises(RuntimeError, match="boom"):
load_release_manifest_from_hf(version="1.40.4")


def test_load_release_manifest_from_hf_continues_on_missing_entry(tmp_path):
manifest_path = tmp_path / "release_manifest.json"
manifest_path.write_text('{"data_package": {"version": "1.40.4"}}')

with patch(
"policyengine_uk_data.utils.data_upload.hf_hub_download",
side_effect=[
EntryNotFoundError("missing"),
str(manifest_path),
],
):
manifest = load_release_manifest_from_hf(version="1.40.4")

assert manifest["data_package"]["version"] == "1.40.4"


def test_get_model_package_version_prefers_imported_checkout(tmp_path):
package_root = tmp_path / "policyengine_uk"
package_root.mkdir()
(package_root / "__init__.py").write_text("")
pyproject_path = tmp_path / "pyproject.toml"
pyproject_path.write_text(
'[project]\nname = "policyengine-uk"\nversion = "2.78.0"\n'
)
fake_spec = MagicMock(origin=str(package_root / "__init__.py"))

with (
patch(
"policyengine_uk_data.utils.data_upload.find_spec", return_value=fake_spec
),
patch(
"policyengine_uk_data.utils.data_upload.metadata.version",
side_effect=metadata.PackageNotFoundError,
),
):
assert _get_model_package_version() == "2.78.0"


def test_upload_files_to_hf_adds_uk_release_manifest_operations(tmp_path):
dataset_path = _write_file(
tmp_path / "enhanced_frs_2023_24.h5",
Expand Down
20 changes: 18 additions & 2 deletions policyengine_uk_data/utils/data_upload.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from io import BytesIO
from typing import Dict, List, Optional, Tuple
from huggingface_hub import HfApi, CommitOperationAdd, hf_hub_download
from huggingface_hub.errors import RevisionNotFoundError
from huggingface_hub.errors import EntryNotFoundError, RevisionNotFoundError
from google.cloud import storage
from pathlib import Path
from importlib import metadata
from importlib.util import find_spec
import google.auth
import json
import logging
import os
import tomllib

from policyengine_uk_data.utils.release_manifest import (
build_release_manifest,
Expand All @@ -21,6 +23,20 @@
def _get_model_package_version(
package_name: str = "policyengine-uk",
) -> Optional[str]:
module_name = package_name.replace("-", "_")
spec = find_spec(module_name)
module_origin = getattr(spec, "origin", None) if spec is not None else None
if module_origin is not None:
package_root = Path(module_origin).resolve().parent
for parent in [package_root, *package_root.parents]:
pyproject_path = parent / "pyproject.toml"
if not pyproject_path.exists():
continue
with open(pyproject_path, "rb") as f:
pyproject = tomllib.load(f)
project = pyproject.get("project", {})
if project.get("name") == package_name and project.get("version"):
return project["version"]
try:
return metadata.version(package_name)
except metadata.PackageNotFoundError:
Expand Down Expand Up @@ -87,7 +103,7 @@ def load_release_manifest_from_hf(
)
except RevisionNotFoundError:
raise
except Exception:
except EntryNotFoundError:
continue

with open(manifest_path) as f:
Expand Down