Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 229 additions & 0 deletions aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,32 @@
log = logging.getLogger(__name__)
escaper = TokenEscaper()

# Minimum redis-py version for hash field expiration support
_HASH_FIELD_EXPIRATION_MIN_VERSION = (5, 1, 0)


def supports_hash_field_expiration() -> bool:
"""
Check if the installed redis-py version supports hash field expiration commands.

Hash field expiration (HEXPIRE, HTTL, HPERSIST, etc.) was added in redis-py 5.1.0
and requires Redis server 7.4+.

Returns:
True if redis-py >= 5.1.0 and has the hexpire method, False otherwise.
"""
try:
import redis as redis_lib

version_str = getattr(redis_lib, "__version__", "0.0.0")
version_parts = tuple(int(x) for x in version_str.split(".")[:3])
if version_parts >= _HASH_FIELD_EXPIRATION_MIN_VERSION:
# Also check that the method actually exists
return hasattr(redis_lib.asyncio.Redis, "hexpire")
return False
except (ValueError, AttributeError):
return False


def convert_datetime_to_timestamp(obj):
"""Convert datetime objects to Unix timestamps for storage."""
Expand Down Expand Up @@ -1879,13 +1905,15 @@ def __init__(self, default: Any = ..., **kwargs: Any) -> None:
index = kwargs.pop("index", None)
full_text_search = kwargs.pop("full_text_search", None)
vector_options = kwargs.pop("vector_options", None)
expire = kwargs.pop("expire", None)
super().__init__(default=default, **kwargs)
self.primary_key = primary_key
self.sortable = sortable
self.case_sensitive = case_sensitive
self.index = index
self.full_text_search = full_text_search
self.vector_options = vector_options
self.expire = expire


class RelationshipInfo(Representation):
Expand Down Expand Up @@ -1996,8 +2024,27 @@ def Field(
index: Union[bool, UndefinedType] = Undefined,
full_text_search: Union[bool, UndefinedType] = Undefined,
vector_options: Optional[VectorFieldOptions] = None,
expire: Optional[int] = None,
**kwargs: Unpack[_FromFieldInfoInputs],
) -> Any:
"""
Create a field with Redis OM specific options.

Args:
default: Default value for the field.
primary_key: Whether this field is the primary key.
sortable: Whether this field should be sortable in queries.
case_sensitive: Whether string matching should be case-sensitive.
index: Whether this field should be indexed for searching.
full_text_search: Whether to enable full-text search on this field.
vector_options: Vector field configuration for similarity search.
expire: TTL in seconds for this field (HashModel only, requires Redis 7.4+).
When set, the field will automatically expire after save().
**kwargs: Additional Pydantic field options.

Returns:
FieldInfo instance with the configured options.
"""
field_info = FieldInfo(
**kwargs,
default=default,
Expand All @@ -2007,6 +2054,7 @@ def Field(
index=index,
full_text_search=full_text_search,
vector_options=vector_options,
expire=expire,
)
return field_info

Expand Down Expand Up @@ -2631,12 +2679,62 @@ def __init_subclass__(cls, **kwargs):
f"HashModels cannot index dataclass fields. Field: {name}"
)

def _get_field_expirations(
self, field_expirations: Optional[Dict[str, int]] = None
) -> Dict[str, int]:
"""
Collect field expirations from Field(expire=N) defaults and overrides.

Args:
field_expirations: Optional dict of {field_name: ttl_seconds} to override defaults.

Returns:
Dict of field names to TTL in seconds.
"""
expirations: Dict[str, int] = {}

# Collect default expirations from Field(expire=N)
for name, field in self.model_fields.items():
field_info = field
# Handle metadata-wrapped FieldInfo
if (
not isinstance(field, FieldInfo)
and hasattr(field, "metadata")
and len(field.metadata) > 0
and isinstance(field.metadata[0], FieldInfo)
):
field_info = field.metadata[0]

expire = getattr(field_info, "expire", None)
if expire is not None:
expirations[name] = expire

# Override with explicit field_expirations
if field_expirations:
expirations.update(field_expirations)

return expirations

async def save(
self: "Model",
pipeline: Optional[redis.client.Pipeline] = None,
nx: bool = False,
xx: bool = False,
field_expirations: Optional[Dict[str, int]] = None,
) -> Optional["Model"]:
"""
Save the model to Redis.

Args:
pipeline: Optional Redis pipeline for batching commands.
nx: Only save if the key doesn't exist.
xx: Only save if the key already exists.
field_expirations: Dict of {field_name: ttl_seconds} to set field expirations.
Overrides any Field(expire=N) defaults. Requires Redis 7.4+.

Returns:
The saved model, or None if nx/xx conditions weren't met.
"""
if nx and xx:
raise ValueError("Cannot specify both nx and xx")
if pipeline and (nx or xx):
Expand Down Expand Up @@ -2666,6 +2764,12 @@ async def save(

key = self.key()

# Collect field expirations
expirations = self._get_field_expirations(field_expirations)

# Check if we're using a pipeline (pipelines don't support TTL preservation)
is_pipeline = pipeline is not None

async def _do_save(conn):
# Check nx/xx conditions (HSET doesn't support these natively)
if nx or xx:
Expand All @@ -2675,7 +2779,37 @@ async def _do_save(conn):
if xx and not exists:
return None # Key doesn't exist, xx means only update existing

# Preserve existing field TTLs before HSET (HSET removes field-level TTLs)
# See issue #753: .save() conflicts with TTL on unrelated field
# Note: TTL preservation is skipped when using pipelines because
# pipeline commands return futures, not actual values
preserved_ttls: Dict[str, int] = {}
if supports_hash_field_expiration() and not is_pipeline:
fields_to_check = [f for f in document.keys() if f != "pk"]
if fields_to_check:
current_ttls = await conn.httl(key, *fields_to_check)
if current_ttls:
for i, field_name in enumerate(fields_to_check):
if current_ttls[i] > 0: # Has a TTL
preserved_ttls[field_name] = current_ttls[i]

await conn.hset(key, mapping=document)

# Apply field expirations after HSET (requires Redis 7.4+)
# When using pipelines, we can still apply default expirations but
# can't preserve manually-set TTLs
if supports_hash_field_expiration():
for field_name in document.keys():
if field_name == "pk":
continue
# Priority: preserved TTL > explicit field_expirations > Field(expire=N) default
if field_name in preserved_ttls:
# Restore the TTL that was removed by HSET
await conn.hexpire(key, preserved_ttls[field_name], field_name)
elif field_name in expirations:
# Apply new expiration (from Field(expire=N) or field_expirations param)
await conn.hexpire(key, expirations[field_name], field_name)

return self

# TODO: Wrap any Redis response errors in a custom exception?
Expand Down Expand Up @@ -2861,6 +2995,101 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):

return schema

# =========================================================================
# Hash Field Expiration Methods (Redis 7.4+)
# =========================================================================

async def expire_field(
self,
field_name: str,
seconds: int,
nx: bool = False,
xx: bool = False,
gt: bool = False,
lt: bool = False,
) -> int:
"""
Set a TTL on a specific hash field.

Requires Redis 7.4+ and redis-py >= 5.1.0.

Args:
field_name: The name of the field to expire.
seconds: TTL in seconds.
nx: Only set expiry if field has no expiry.
xx: Only set expiry if field already has an expiry.
gt: Only set expiry if new expiry is greater than current.
lt: Only set expiry if new expiry is less than current.

Returns:
1 if expiry was set, -2 if field doesn't exist, 0 if conditions not met.

Raises:
NotImplementedError: If redis-py version doesn't support HEXPIRE.
"""
if not supports_hash_field_expiration():
raise NotImplementedError(
"Hash field expiration requires redis-py >= 5.1.0 and Redis 7.4+"
)

db = self.db()
key = self.key()
result = await db.hexpire(key, seconds, field_name, nx=nx, xx=xx, gt=gt, lt=lt)
# hexpire returns a list of results, one per field
return result[0] if result else -2

async def field_ttl(self, field_name: str) -> int:
"""
Get the remaining TTL of a hash field in seconds.

Requires Redis 7.4+ and redis-py >= 5.1.0.

Args:
field_name: The name of the field.

Returns:
TTL in seconds, -1 if no expiry, -2 if field doesn't exist.

Raises:
NotImplementedError: If redis-py version doesn't support HTTL.
"""
if not supports_hash_field_expiration():
raise NotImplementedError(
"Hash field expiration requires redis-py >= 5.1.0 and Redis 7.4+"
)

db = self.db()
key = self.key()
result = await db.httl(key, field_name)
# httl returns a list of results, one per field
return result[0] if result else -2

async def persist_field(self, field_name: str) -> int:
"""
Remove the expiration from a hash field.

Requires Redis 7.4+ and redis-py >= 5.1.0.

Args:
field_name: The name of the field.

Returns:
1 if expiry was removed, -1 if no expiry, -2 if field doesn't exist.

Raises:
NotImplementedError: If redis-py version doesn't support HPERSIST.
"""
if not supports_hash_field_expiration():
raise NotImplementedError(
"Hash field expiration requires redis-py >= 5.1.0 and Redis 7.4+"
)

db = self.db()
key = self.key()
result = await db.hpersist(key, field_name)
# hpersist returns a list of results, one per field
return result[0] if result else -2


class JsonModel(RedisModel, abc.ABC):
def __init_subclass__(cls, **kwargs):
Expand Down
Loading
Loading