Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/scverse_misc/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,31 @@ def override(self, **kwargs: object) -> Generator[None]:
for argname, argval in reversed(oldsettings.items()):
setattr(self, argname, argval)

def reset(self, *args: str) -> AbstractContextManager[frozenset[str]]:
"""Reset passed settings to their default values.

Can be used as a context manager to make the resets temporary.
On `__enter__`, the context manager returns the settings that have been changed.
"""
prev_values = {arg: getattr(self, arg) for arg in args if arg in self.model_fields_set}

# since we want to allow using this method imperatively,
# eagerly do the reset here instead of returning a context manager with a lazy `__enter__`.
for arg in prev_values:
default = type(self).model_fields[arg].get_default()
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_default optionally takes validated_data – are we supposed to just pass type(self).model_dump(self) here?

setattr(self, arg, default)
self.model_fields_set.remove(arg)

class Cm(AbstractContextManager[frozenset[str]]):
def __enter__(_self) -> frozenset[str]:
return frozenset(prev_values)

def __exit__(_self, *_: object) -> None:
for arg, value in prev_values.items():
setattr(self, arg, value)

return Cm()

@classmethod
def __pydantic_init_subclass__( # type: ignore[override]
subcls: type[Self], *, exported_object_name: str, docstring_style: Literal["google", "numpy", "scverse"]
Expand Down
13 changes: 13 additions & 0 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import inspect
from contextlib import nullcontext
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Literal, cast

Expand Down Expand Up @@ -87,13 +88,25 @@ def test_override(settings: DummySettings) -> None:
assert settings.field_bool is True
assert settings.field_bool is False


def test_override_error(settings: DummySettings) -> None:
with pytest.raises(ValidationError):
with settings.override(field_int_range=3, field_no_docstring=1.1):
pass
assert settings.field_no_docstring == 42
assert settings.field_int_range == 1


@pytest.mark.parametrize("temp", [True, False], ids=["temporary", "permanent"])
def test_reset(settings: DummySettings, temp: bool) -> None:
default = settings.field_bool
settings.field_bool = not default
undo_reset = settings.reset("field_bool")
with undo_reset if temp else nullcontext():
assert settings.field_bool is default
assert settings.field_bool is (not default if temp else default)


@pytest.mark.parametrize("docstring_style", ["google", "numpy", "scverse"], indirect=True)
def test_docs(docstring_style: Literal["google", "numpy"], settings: DummySettings) -> None:
parser = GoogleDocstring if docstring_style == "google" else NumpyDocstring
Expand Down
Loading