Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 20 additions & 17 deletions src/cachier/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta
from functools import wraps
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, ParamSpec, TypeVar, Union
from warnings import warn

from ._types import RedisClient, S3Client
Expand All @@ -31,6 +31,9 @@
from .metrics import CacheMetrics, MetricsContext
from .util import parse_bytes

_P = ParamSpec("_P")
_R = TypeVar("_R")

MAX_WORKERS_ENVAR_NAME = "CACHIER_MAX_WORKERS"
DEFAULT_MAX_WORKERS = 8
ZERO_TIMEDELTA = timedelta(seconds=0)
Expand Down Expand Up @@ -221,7 +224,7 @@ def cachier(
allow_non_static_methods: Optional[bool] = None,
enable_metrics: bool = False,
metrics_sampling_rate: float = 1.0,
):
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
"""Wrap as a persistent, stale-free memoization decorator.

The positional and keyword arguments to the wrapped function must be
Expand Down Expand Up @@ -400,7 +403,7 @@ def cachier(
else:
raise ValueError("specified an invalid core: %s" % backend)

def _cachier_decorator(func):
def _cachier_decorator(func: Callable[_P, _R]) -> Callable[_P, _R]:
core.set_func(func)

# Guard: raise TypeError when decorating an instance method unless
Expand Down Expand Up @@ -513,7 +516,7 @@ def _call(*args, max_age: Optional[timedelta] = None, **kwds):
from .config import _global_params

if ignore_cache or not _global_params.caching_enabled:
return func(args[0], **kwargs) if core.func_is_method else func(**kwargs)
return func(args[0], **kwargs) if core.func_is_method else func(**kwargs) # type: ignore[call-arg]

with MetricsContext(cache_metrics) as _mctx:
key, entry = core.get_entry((), kwargs)
Expand Down Expand Up @@ -629,7 +632,7 @@ async def _call_async(*args, max_age: Optional[timedelta] = None, **kwds):
from .config import _global_params

if ignore_cache or not _global_params.caching_enabled:
return await func(args[0], **kwargs) if core.func_is_method else await func(**kwargs)
return await func(args[0], **kwargs) if core.func_is_method else await func(**kwargs) # type: ignore[call-arg,misc]

with MetricsContext(cache_metrics) as _mctx:
key, entry = await core.aget_entry((), kwargs)
Expand Down Expand Up @@ -699,14 +702,14 @@ async def _call_async(*args, max_age: Optional[timedelta] = None, **kwds):
if is_coroutine:

@wraps(func)
async def func_wrapper(*args, **kwargs):
return await _call_async(*args, **kwargs)
async def func_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
return await _call_async(*args, **kwargs) # type: ignore[arg-type]

else:

@wraps(func)
def func_wrapper(*args, **kwargs):
return _call(*args, **kwargs)
def func_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
return _call(*args, **kwargs) # type: ignore[arg-type]

def _clear_cache():
"""Clear the cache."""
Expand Down Expand Up @@ -751,13 +754,13 @@ def _precache_value(*args, value_to_cache, **kwds):
kwargs = _convert_args_kwargs(func, _is_method=core.func_is_method, args=args, kwds=kwds)
return core.precache_value((), kwargs, value_to_cache)

func_wrapper.clear_cache = _clear_cache
func_wrapper.clear_being_calculated = _clear_being_calculated
func_wrapper.aclear_cache = _aclear_cache
func_wrapper.aclear_being_calculated = _aclear_being_calculated
func_wrapper.cache_dpath = _cache_dpath
func_wrapper.precache_value = _precache_value
func_wrapper.metrics = cache_metrics # Expose metrics object
return func_wrapper
func_wrapper.clear_cache = _clear_cache # type: ignore[attr-defined]
func_wrapper.clear_being_calculated = _clear_being_calculated # type: ignore[attr-defined]
func_wrapper.aclear_cache = _aclear_cache # type: ignore[attr-defined]
func_wrapper.aclear_being_calculated = _aclear_being_calculated # type: ignore[attr-defined]
func_wrapper.cache_dpath = _cache_dpath # type: ignore[attr-defined]
func_wrapper.precache_value = _precache_value # type: ignore[attr-defined]
func_wrapper.metrics = cache_metrics # type: ignore[attr-defined]
return func_wrapper # type: ignore[return-value]

return _cachier_decorator
2 changes: 2 additions & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ pytest-rerunfailures # for retrying flaky tests
coverage
pytest-cov
birch
# type checking
mypy
# to be able to run `python setup.py checkdocs`
collective.checkdocs
pygments
Expand Down
234 changes: 234 additions & 0 deletions tests/test_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
"""Tests that the @cachier decorator preserves function type signatures.

These tests invoke mypy programmatically and assert that decorated functions retain their parameter types and return
types as seen by static analysis.

"""

import textwrap

import pytest

mypy_api = pytest.importorskip("mypy.api", reason="mypy is required for typing tests")


def _run_mypy(code: str) -> tuple[list[str], list[str]]:
"""Run mypy on a code snippet and return (notes, errors).

Parameters
----------
code : str
Python source code to type-check.

Returns
-------
tuple[list[str], list[str]]
A tuple of (note lines, error lines) from mypy output.

"""
result = mypy_api.run(
[
"-c",
textwrap.dedent(code),
"--no-error-summary",
"--hide-error-context",
]
)
stdout = result[0]
notes = []
errors = []
for line in stdout.splitlines():
if ": note:" in line:
notes.append(line)
elif ": error:" in line:
errors.append(line)
return notes, errors


class TestSyncTyping:
"""Verify that synchronous decorated functions preserve types."""

def test_return_type_preserved(self) -> None:
"""Mypy should infer the original return type through @cachier."""
notes, errors = _run_mypy("""
from cachier import cachier

@cachier()
def my_func(x: int) -> str:
return str(x)

reveal_type(my_func(5))
""")
assert not errors
assert any('"str"' in n for n in notes)

def test_param_types_preserved(self) -> None:
"""Mypy should see the original parameter types through @cachier."""
notes, errors = _run_mypy("""
from cachier import cachier

@cachier()
def my_func(x: int, y: str) -> list[str]:
return [y] * x

reveal_type(my_func)
""")
assert not errors
assert any("int" in n and "str" in n for n in notes)

def test_wrong_arg_type_is_error(self) -> None:
"""Mypy should reject calls with wrong argument types."""
_notes, errors = _run_mypy("""
from cachier import cachier

@cachier()
def add(a: int, b: int) -> int:
return a + b

add("not", "ints")
""")
assert errors

def test_return_type_mismatch_is_error(self) -> None:
"""Mypy should catch assigning the result to an incompatible type."""
_notes, errors = _run_mypy("""
from cachier import cachier

@cachier()
def get_name() -> str:
return "hello"

x: int = get_name()
""")
assert errors


class TestAsyncTyping:
"""Verify that async decorated functions preserve types."""

def test_async_return_type_preserved(self) -> None:
"""Mypy should infer the awaited return type for async functions."""
notes, errors = _run_mypy("""
import asyncio
from cachier import cachier

@cachier()
async def fetch(url: str) -> bytes:
return b"data"

async def main() -> None:
result = await fetch("http://example.com")
reveal_type(result)

asyncio.run(main())
""")
assert not errors
assert any('"bytes"' in n for n in notes)

def test_async_signature_preserved(self) -> None:
"""Mypy should see the async function as a coroutine."""
notes, errors = _run_mypy("""
from cachier import cachier

@cachier()
async def fetch(url: str) -> bytes:
return b"data"

reveal_type(fetch)
""")
assert not errors
assert any("Coroutine" in n for n in notes)

def test_async_wrong_arg_type_is_error(self) -> None:
"""Mypy should reject calls with wrong argument types for async."""
_notes, errors = _run_mypy("""
from cachier import cachier

@cachier()
async def fetch(url: str) -> bytes:
return b"data"

async def main() -> None:
await fetch(123)
""")
assert errors


class TestComplexSignatures:
"""Verify preservation of more complex type signatures."""

def test_optional_params(self) -> None:
"""Mypy should preserve Optional parameter types."""
notes, errors = _run_mypy("""
from typing import Optional
from cachier import cachier

@cachier()
def greet(name: str, greeting: Optional[str] = None) -> str:
return f"{greeting or 'Hello'}, {name}"

reveal_type(greet)
""")
assert not errors
assert any("str" in n for n in notes)

def test_generic_return_type(self) -> None:
"""Mypy should preserve generic return types like dict."""
notes, errors = _run_mypy("""
from cachier import cachier

@cachier()
def make_mapping(keys: list[str], value: int) -> dict[str, int]:
return {k: value for k in keys}

reveal_type(make_mapping(["a"], 1))
""")
assert not errors
assert any("dict[str, int]" in n for n in notes)

def test_none_return_type(self) -> None:
"""Mypy should preserve None return type."""
notes, errors = _run_mypy("""
from cachier import cachier

@cachier()
def side_effect(x: int) -> None:
pass

reveal_type(side_effect(1))
""")
assert not errors
assert any('"None"' in n for n in notes)


class TestDecoratorWithArgs:
"""Verify typing works with various decorator arguments."""

def test_with_backend_arg(self) -> None:
"""Type preservation should work with explicit backend selection."""
notes, errors = _run_mypy("""
from cachier import cachier

@cachier(backend="memory")
def compute(x: float) -> float:
return x * 2.0

reveal_type(compute(1.0))
""")
assert not errors
assert any('"float"' in n for n in notes)

def test_with_stale_after_arg(self) -> None:
"""Type preservation should work with stale_after parameter."""
notes, errors = _run_mypy("""
from datetime import timedelta
from cachier import cachier

@cachier(stale_after=timedelta(hours=1))
def lookup(key: str) -> list[int]:
return [1, 2, 3]

reveal_type(lookup("x"))
""")
assert not errors
assert any("list[int]" in n for n in notes)