Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions scripts/benchmark_numpy_hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Benchmark default Cachier hashing against xxhash for large NumPy arrays."""

from __future__ import annotations

import argparse
import pickle
import statistics
import time
from typing import Any, Callable, Dict, List

import numpy as np

from cachier.config import _default_hash_func


def _xxhash_numpy_hash(args: tuple[Any, ...], kwds: dict[str, Any]) -> str:
"""Hash call arguments with xxhash, optimized for NumPy arrays.

Parameters
----------
args : tuple[Any, ...]
Positional arguments.
kwds : dict[str, Any]
Keyword arguments.

Returns
-------
str
xxhash hex digest.

"""
import xxhash

hasher = xxhash.xxh64()
hasher.update(b"args")
for value in args:
if isinstance(value, np.ndarray):
hasher.update(value.dtype.str.encode("utf-8"))
hasher.update(str(value.shape).encode("utf-8"))
hasher.update(value.tobytes(order="C"))
else:
hasher.update(pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL))

hasher.update(b"kwds")
for key, value in sorted(kwds.items()):
hasher.update(pickle.dumps(key, protocol=pickle.HIGHEST_PROTOCOL))
if isinstance(value, np.ndarray):
hasher.update(value.dtype.str.encode("utf-8"))
hasher.update(str(value.shape).encode("utf-8"))
hasher.update(value.tobytes(order="C"))
else:
hasher.update(pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL))

return hasher.hexdigest()


def _benchmark(hash_func: Callable[[tuple[Any, ...], dict[str, Any]], str], args: tuple[Any, ...], runs: int) -> float:
durations: List[float] = []
for _ in range(runs):
start = time.perf_counter()
hash_func(args, {})
durations.append(time.perf_counter() - start)
return statistics.median(durations)


def main() -> None:
"""Run benchmark comparing cachier default hashing with xxhash."""
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--elements",
type=int,
default=10_000_000,
help="Number of float64 elements in the benchmark array",
)
parser.add_argument("--runs", type=int, default=7, help="Number of benchmark runs")
parsed = parser.parse_args()

try:
import xxhash # noqa: F401
except ImportError as error:
raise SystemExit("Missing dependency: xxhash. Install with `pip install xxhash`.") from error

array = np.arange(parsed.elements, dtype=np.float64)
args = (array,)

results: Dict[str, float] = {
"cachier_default": _benchmark(_default_hash_func, args, parsed.runs),
"xxhash_reference": _benchmark(_xxhash_numpy_hash, args, parsed.runs),
}

ratio = results["cachier_default"] / results["xxhash_reference"]

print(f"Array elements: {parsed.elements:,}")
print(f"Array bytes: {array.nbytes:,}")
print(f"Runs: {parsed.runs}")
print(f"cachier_default median: {results['cachier_default']:.6f}s")
print(f"xxhash_reference median: {results['xxhash_reference']:.6f}s")
print(f"ratio (cachier_default / xxhash_reference): {ratio:.2f}x")


if __name__ == "__main__":
main()
98 changes: 92 additions & 6 deletions src/cachier/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,99 @@
from ._types import Backend, HashFunc, Mongetter


def _is_numpy_array(value: Any) -> bool:
"""Check whether a value is a NumPy ndarray without importing NumPy eagerly.

Parameters
----------
value : Any
The value to inspect.

Returns
-------
bool
True when ``value`` is a NumPy ndarray instance.

"""
return type(value).__module__ == "numpy" and type(value).__name__ == "ndarray"


def _hash_numpy_array(hasher: "hashlib._Hash", value: Any) -> None:
"""Update hasher with NumPy array metadata and buffer content.

Parameters
----------
hasher : hashlib._Hash
The hasher to update.
value : Any
A NumPy ndarray instance.

"""
hasher.update(b"numpy.ndarray")
hasher.update(value.dtype.str.encode("utf-8"))
hasher.update(str(value.shape).encode("utf-8"))
hasher.update(value.tobytes(order="C"))
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tobytes(order="C") call creates a copy if the array is not already C-contiguous, which is correct for ensuring consistent hashing. However, consider documenting this behavior in the function docstring, as it has performance implications for large non-contiguous arrays (e.g., sliced views, transposed arrays). This is working as intended but worth noting for users.

Copilot uses AI. Check for mistakes.


def _update_hash_for_value(hasher: "hashlib._Hash", value: Any) -> None:
"""Update hasher with a stable representation of a Python value.

Parameters
----------
hasher : hashlib._Hash
The hasher to update.
value : Any
Value to encode.

"""
if _is_numpy_array(value):
_hash_numpy_array(hasher, value)
return

if isinstance(value, tuple):
hasher.update(b"tuple")
for item in value:
_update_hash_for_value(hasher, item)
return

if isinstance(value, list):
hasher.update(b"list")
for item in value:
_update_hash_for_value(hasher, item)
return

if isinstance(value, dict):
hasher.update(b"dict")
for dict_key in sorted(value):
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling sorted(value) on a dict with non-comparable keys will raise a TypeError. While dict keys in kwargs are typically strings (and thus sortable), users could pass dicts with arbitrary key types as function arguments. Consider wrapping the sort in a try-except block that falls back to sorting by the string representation of keys, or by using sorted(value.items(), key=lambda x: str(x[0])) to ensure robustness.

Copilot uses AI. Check for mistakes.
_update_hash_for_value(hasher, dict_key)
_update_hash_for_value(hasher, value[dict_key])
return

Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hash function does not handle set or frozenset types explicitly. When these types are passed to cached functions, they'll be pickled via the fallback at line 80. Consider adding explicit handling for sets (converting to sorted tuples) to ensure deterministic hashing, similar to how dicts are handled with sorted keys. This would improve hash consistency and performance for set arguments.

Suggested change
if isinstance(value, (set, frozenset)):
# Use a deterministic ordering of elements for hashing.
hasher.update(b"frozenset" if isinstance(value, frozenset) else b"set")
try:
# Fast path: works for homogeneous, orderable element types.
iterable = sorted(value)
except TypeError:
# Fallback: impose a deterministic order based on type name and repr.
iterable = sorted(value, key=lambda item: (type(item).__name__, repr(item)))
for item in iterable:
_update_hash_for_value(hasher, item)
return

Copilot uses AI. Check for mistakes.
hasher.update(pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL))
Comment on lines +46 to +80
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The recursive implementation of _update_hash_for_value could cause a stack overflow with deeply nested data structures (e.g., nested lists/tuples/dicts). Consider adding a maximum recursion depth check or using an iterative approach with an explicit stack to prevent potential stack overflow errors when users pass deeply nested arguments to cached functions.

Copilot uses AI. Check for mistakes.
Comment on lines +46 to +80
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new hash function adds type prefixes (e.g., b"tuple", b"list", b"dict") to distinguish between different container types. This is good for correctness, but it means that hash_func((1,), {}) and hash_func([1], {}) will produce different hashes (as expected). However, this is a semantic change from the old implementation where both might have been pickled to similar byte sequences. Ensure this behavior is desired and documented, particularly that tuple (1,) and list [1] are now guaranteed to have different cache keys.

Copilot uses AI. Check for mistakes.


def _default_hash_func(args, kwds):
# Sort the kwargs to ensure consistent ordering
sorted_kwargs = sorted(kwds.items())
# Serialize args and sorted_kwargs using pickle or similar
serialized = pickle.dumps((args, sorted_kwargs))
# Create a hash of the serialized data
return hashlib.sha256(serialized).hexdigest()
"""Compute a stable hash key for function arguments.

Parameters
----------
args : tuple
Positional arguments.
kwds : dict
Keyword arguments.

Returns
-------
str
A hex digest representing the call arguments.

"""
hasher = hashlib.blake2b(digest_size=32)
hasher.update(b"args")
_update_hash_for_value(hasher, args)
hasher.update(b"kwds")
_update_hash_for_value(hasher, dict(sorted(kwds.items())))
return hasher.hexdigest()
Comment on lines 83 to +104
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing the hash algorithm from SHA256 to blake2b is a breaking change that will invalidate all existing caches. When users upgrade to this version, their cached function results will not be found because the cache keys will be different. Consider documenting this breaking change in the PR description or adding a migration guide. Alternatively, consider versioning the hash function or providing a compatibility mode that can read old cache entries.

Copilot uses AI. Check for mistakes.


def _default_cache_dir():
Expand Down
43 changes: 43 additions & 0 deletions tests/test_numpy_hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Tests for NumPy-aware default hash behavior."""

from datetime import timedelta

import pytest

from cachier import cachier

np = pytest.importorskip("numpy")


@pytest.mark.parametrize("backend", ["memory", "pickle"])
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test should be decorated with either @pytest.mark.memory or @pytest.mark.pickle for each backend being tested. According to the codebase conventions, tests should be marked with the appropriate backend marker so they can be selectively run in the CI matrix. Consider splitting this into two separate test functions, one for each backend with the appropriate marker, or using pytest.mark.parametrize with marks on each parameter value as demonstrated in tests/test_defaults.py lines 188-194.

Suggested change
@pytest.mark.parametrize("backend", ["memory", "pickle"])
@pytest.mark.parametrize(
"backend",
[
pytest.param("memory", marks=pytest.mark.memory),
pytest.param("pickle", marks=pytest.mark.pickle),
],
)

Copilot uses AI. Check for mistakes.
def test_default_hash_func_uses_array_content_for_cache_keys(backend, tmp_path):
"""Verify equal arrays map to a cache hit and different arrays miss."""
call_count = 0

decorator_kwargs = {"backend": backend, "stale_after": timedelta(seconds=120)}
if backend == "pickle":
decorator_kwargs["cache_dir"] = tmp_path

@cachier(**decorator_kwargs)
def array_sum(values):
nonlocal call_count
call_count += 1
return int(values.sum())

arr = np.arange(100_000, dtype=np.int64)
arr_copy = arr.copy()
changed = arr.copy()
changed[-1] = -1

first = array_sum(arr)
assert call_count == 1

second = array_sum(arr_copy)
assert second == first
assert call_count == 1

third = array_sum(changed)
assert third != first
assert call_count == 2

array_sum.clear_cache()