Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/cachekit/config/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ def local_function():
integrity_checking: Enable checksums for corruption detection (default: True)
All serializers use xxHash3-64 (8 bytes).
Set to False for @cache.minimal (speed-first, no integrity guarantee)
key: Custom key function for complex types. Receives (*args, **kwargs) and returns str.
Use for numpy arrays, DataFrames, or cross-language cache sharing.
Example: @cache(key=lambda arr: hashlib.blake2b(arr.tobytes()).hexdigest())
refresh_ttl_on_get: Extend TTL on cache hit
ttl_refresh_threshold: Minimum remaining TTL fraction (0.0-1.0) to trigger refresh
backend: L2 backend (RedisBackend, HTTPBackend, None for L1-only)
Expand All @@ -183,12 +186,13 @@ def local_function():
encryption: Client-side encryption configuration
"""

# Core settings (5 fields)
# Core settings (6 fields)
ttl: int | None = None
namespace: str | None = None
serializer: Union[str, SerializerProtocol] = "default" # type: ignore[assignment] # String name or protocol instance
safe_mode: bool = False
integrity_checking: bool = True # Checksums for corruption detection (xxHash3-64 for all serializers)
key: Callable[..., str] | None = None # Custom key function (escape hatch for complex types)

# Performance (2 fields)
refresh_ttl_on_get: bool = False
Expand Down Expand Up @@ -251,6 +255,7 @@ def to_dict(self) -> dict[str, object]:
"namespace": self.namespace,
"serializer": self.serializer,
"safe_mode": self.safe_mode,
"key": self.key,
"refresh_ttl_on_get": self.refresh_ttl_on_get,
"ttl_refresh_threshold": self.ttl_refresh_threshold,
"backend": self.backend,
Expand Down
28 changes: 24 additions & 4 deletions src/cachekit/decorators/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,15 @@ def create_cache_wrapper(
deployment_uuid = config.encryption.deployment_uuid
master_key = config.encryption.master_key

# Custom key function (escape hatch for complex types)
custom_key_func = config.key
else:
custom_key_func = None

# Re-scope custom_key_func for closure
if "custom_key_func" not in dir():
custom_key_func = None

# Fast mode: Disable monitoring overhead, keep performance features
use_circuit_breaker = circuit_breaker and not fast_mode
use_adaptive_timeout = adaptive_timeout and not fast_mode
Expand Down Expand Up @@ -541,7 +550,13 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: PLR0912

# Key generation - needed for both L1-only and L1+L2 modes
try:
if fast_mode:
# Custom key function takes priority (escape hatch for complex types)
if custom_key_func is not None:
custom_key = custom_key_func(*args, **kwargs)
if not isinstance(custom_key, str):
raise TypeError(f"key function must return str, got {type(custom_key).__name__}")
cache_key = f"{namespace or 'default'}:{custom_key}"
elif fast_mode:
# Minimal key generation - no string formatting overhead
from ..hash_utils import cache_key_hash

Expand Down Expand Up @@ -878,12 +893,17 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
cache_key = None
func_start_time: float | None = None # Initialize for exception handlers
try:
# Fast key generation path (for simple types)
if fast_mode:
# Custom key function takes priority (escape hatch for complex types)
if custom_key_func is not None:
custom_key = custom_key_func(*args, **kwargs)
if not isinstance(custom_key, str):
raise TypeError(f"key function must return str, got {type(custom_key).__name__}")
cache_key = f"{namespace or 'default'}:{custom_key}"
elif fast_mode:
# Ultra-fast key generation for hot paths (10-50μs savings)
from ..hash_utils import cache_key_hash

cache_namespace = namespace or namespace or "default"
cache_namespace = namespace or "default"
args_kwargs_str = str(args) + str(kwargs)
cache_key = cache_namespace + ":" + func_hash + ":" + cache_key_hash(args_kwargs_str)
else:
Expand Down
216 changes: 203 additions & 13 deletions src/cachekit/key_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,25 @@
from __future__ import annotations

import hashlib
from typing import Any, Callable, cast
import sys
from datetime import datetime
from decimal import Decimal
from enum import Enum
from pathlib import Path, PurePath
from typing import TYPE_CHECKING, Any, Callable, NoReturn, cast
from uuid import UUID

import msgpack

if TYPE_CHECKING:
pass

# Constants for constrained array support (per round-table review 2025-12-18)
ARRAY_MAX_BYTES = 100_000 # 100KB per array
ARRAY_AGGREGATE_MAX = 5_000_000 # 5MB total across all args
SUPPORTED_ARRAY_DTYPES = {"int32", "int64", "float32", "float64"}
DTYPE_MAP = {"int32": "i32", "int64": "i64", "float32": "f32", "float64": "f64"}


class CacheKeyGenerator:
"""Generates consistent cache keys from function calls.
Expand Down Expand Up @@ -96,9 +111,12 @@ def _blake2b_hash(self, args: tuple, kwargs: dict) -> str:
Raises:
TypeError: If args/kwargs contain unsupported types (custom objects, numpy arrays, etc.)
"""
# Track aggregate array bytes for DoS prevention
array_bytes_seen: list[int] = [0]

# Step 1: Normalize recursively
normalized_args = [self._normalize(arg) for arg in args]
normalized_kwargs = {k: self._normalize(v) for k, v in sorted(kwargs.items())}
normalized_args = [self._normalize(arg, array_bytes_seen) for arg in args]
normalized_kwargs = {k: self._normalize(v, array_bytes_seen) for k, v in sorted(kwargs.items())}

# Step 2: Serialize with MessagePack
try:
Expand All @@ -112,27 +130,199 @@ def _blake2b_hash(self, args: tuple, kwargs: dict) -> str:
# Step 3: Hash with Blake2b-256
return hashlib.blake2b(msgpack_bytes, digest_size=32).hexdigest()

def _normalize(self, obj: Any) -> Any:
def _normalize(self, obj: Any, _array_bytes_seen: list[int] | None = None) -> Any:
"""Normalize object for deterministic MessagePack encoding.

CRITICAL: Ensures identical serialization across Python, TypeScript, Go, PHP.
CRITICAL: Cross-language compatible types ONLY per Protocol v1.1.

Supported types (per round-table review 2025-12-18):
- Primitives: int, str, bytes, bool, None, float
- Collections: dict (sorted keys), list, tuple
- Extended: Path, UUID, Decimal, Enum, datetime (UTC only)
- Arrays: numpy.ndarray (1D, ≤100KB, i32/i64/f32/f64)

Args:
obj: Object to normalize
_array_bytes_seen: Internal tracker for aggregate array size (DoS prevention)

Returns:
Normalized object safe for MessagePack serialization

Raises:
TypeError: For unsupported types with helpful guidance
"""
# Initialize aggregate tracker if not provided
if _array_bytes_seen is None:
_array_bytes_seen = [0]

# === COLLECTIONS (recursive) ===
if isinstance(obj, dict):
# Recursively normalize dict with sorted keys
return {k: self._normalize(v) for k, v in sorted(obj.items())}
return {k: self._normalize(v, _array_bytes_seen) for k, v in sorted(obj.items())}

elif isinstance(obj, (list, tuple)):
# Recursively normalize collections (tuple→list)
return [self._normalize(x) for x in obj]
if isinstance(obj, (list, tuple)):
return [self._normalize(x, _array_bytes_seen) for x in obj]

elif isinstance(obj, float):
# === FLOAT (cross-language compat) ===
if isinstance(obj, float):
# CRITICAL: Normalize -0.0 → 0.0 for cross-language compatibility
return 0.0 if obj == 0.0 else obj

else:
# Primitives (int, str, bytes, bool, None) pass through unchanged
# === EXTENDED TYPES ===

# Path: normalize to POSIX format for cross-platform consistency
if isinstance(obj, (Path, PurePath)):
return obj.as_posix()

# UUID: standard string format
if isinstance(obj, UUID):
return str(obj)

# Decimal: exact string representation
if isinstance(obj, Decimal):
return str(obj)

# Enum: use value (recursively normalize in case value is complex)
if isinstance(obj, Enum):
return self._normalize(obj.value, _array_bytes_seen)

# datetime: UTC only, reject naive datetimes
if isinstance(obj, datetime):
if obj.tzinfo is None:
raise TypeError(
"Naive datetime not allowed in cache keys (timezone ambiguity). "
"Use timezone-aware datetime: datetime(..., tzinfo=timezone.utc)"
)
return obj.isoformat()

# === NUMPY ARRAY (constrained support) ===
if self._is_numpy_array(obj):
return self._normalize_array(obj, _array_bytes_seen)

# === PRIMITIVES (pass through) ===
if isinstance(obj, (int, str, bytes, bool, type(None))):
return obj

# === UNSUPPORTED: Fail fast with helpful message ===
return self._raise_unsupported_type(obj)

def _is_numpy_array(self, obj: Any) -> bool:
"""Check if object is numpy array without importing numpy."""
return type(obj).__module__ == "numpy" and type(obj).__name__ == "ndarray"

def _normalize_array(self, arr: Any, _array_bytes_seen: list[int]) -> list[Any]:
"""Normalize numpy array with strict constraints.

Constraints (per round-table review 2025-12-18):
- 1D only (cross-language simplicity)
- ≤100KB (memory safety)
- 4 dtypes: i32, i64, f32, f64 (cross-language compatibility)
- Little-endian byte order (platform determinism)
- 256-bit Blake2b hash (collision resistance)
- Version prefix for future protocol changes

Args:
arr: numpy.ndarray to normalize
_array_bytes_seen: Aggregate byte counter for DoS prevention

Returns:
List of ["__array_v1__", shape_list, dtype_str, content_hash]
(list format for MessagePack compatibility with strict_types=True)

Raises:
TypeError: If array doesn't meet constraints
"""
import numpy as np

# Constraint 1: Size limit per array
if arr.nbytes > ARRAY_MAX_BYTES:
raise TypeError(
f"Array too large ({arr.nbytes:,} bytes, max {ARRAY_MAX_BYTES:,}). Use key= parameter for large arrays."
)

# Constraint 2: Aggregate size limit (DoS prevention)
_array_bytes_seen[0] += arr.nbytes
if _array_bytes_seen[0] > ARRAY_AGGREGATE_MAX:
raise TypeError(
f"Total array size exceeds {ARRAY_AGGREGATE_MAX:,} bytes. Use key= parameter for batch array operations."
)

# Constraint 3: 1D only
if arr.ndim != 1:
raise TypeError(
f"Only 1D arrays supported in cache keys (got {arr.ndim}D). "
f"Use key= parameter for multidimensional arrays, or flatten with arr.ravel()."
)

# Constraint 4: Supported dtypes only
dtype_name = arr.dtype.name
if dtype_name not in SUPPORTED_ARRAY_DTYPES:
raise TypeError(
f"Unsupported array dtype '{dtype_name}'. "
f"Supported: {', '.join(sorted(SUPPORTED_ARRAY_DTYPES))}. "
f"Cast with arr.astype(np.float64) or use key= parameter."
)

# Ensure C-contiguous memory layout
arr = np.ascontiguousarray(arr)

# Force little-endian byte order for cross-platform determinism
if arr.dtype.byteorder not in ("=", "<", "|"):
arr = arr.astype(arr.dtype.newbyteorder("<"))
elif arr.dtype.byteorder == "=" and sys.byteorder == "big":
arr = arr.byteswap().newbyteorder("<")

# 256-bit Blake2b hash (per security review)
content_hash = hashlib.blake2b(arr.tobytes(), digest_size=32).hexdigest()

# Standardized dtype string for cross-language compatibility
dtype_str = DTYPE_MAP[dtype_name]

# Version prefix for protocol evolution
# Return as list (not tuple) for MessagePack compatibility with strict_types=True
# Shape converted to list as well
return ["__array_v1__", list(arr.shape), dtype_str, content_hash]

def _raise_unsupported_type(self, obj: Any) -> NoReturn:
"""Raise helpful TypeError for unsupported types.

Args:
obj: The unsupported object

Raises:
TypeError: Always, with guidance on how to handle the type
"""
type_name = type(obj).__module__ + "." + type(obj).__qualname__

# Specific guidance for numpy arrays that don't meet constraints
if "numpy" in type_name and "ndarray" in type_name:
raise TypeError(
"numpy array doesn't meet cache key constraints. "
"Requirements: 1D, ≤100KB, dtype in (i32, i64, f32, f64). "
"Use key= parameter for other arrays."
)

if "pandas" in type_name:
raise TypeError(
"pandas objects not supported as cache key arguments "
"(Parquet serialization is non-deterministic). "
"Recommended patterns:\n"
" 1. Pass identifier, return DataFrame: @cache def load(id: int) -> pd.DataFrame\n"
" 2. Use explicit key: @cache(key=lambda df: hashlib.blake2b(df.to_parquet()).hexdigest())"
)

if isinstance(obj, (set, frozenset)):
raise TypeError(
"set/frozenset not supported in cache keys (mixed-type sorting crashes). "
"Convert to sorted list: sorted(list(your_set))"
)

raise TypeError(
f"Unsupported type '{type_name}' for cache key. "
f"Supported: dict, list, tuple, int, float, str, bytes, bool, None, "
f"Path, UUID, Decimal, Enum, datetime (UTC), 1D numpy arrays (≤100KB, i32/i64/f32/f64). "
f"For custom types, use key= parameter."
)

def _normalize_key(self, key: str) -> str:
"""Normalize key to ensure it's valid for cache backends.

Expand Down
Loading
Loading