Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 93 additions & 45 deletions aries_cloudagent/messaging/models/base_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,27 @@
import json
import logging
import sys
import uuid
from datetime import datetime
from typing import Any, Mapping, Optional, Sequence, Type, TypeVar, Union

from marshmallow import fields
from ...utils.uuid_utils import uuid4

from ...cache.base import BaseCache
from ...config.settings import BaseSettings
from ...core.profile import ProfileSession
from ...storage.base import BaseStorage, StorageDuplicateError, StorageNotFoundError
from ...storage.base import (
DEFAULT_PAGE_SIZE,
BaseStorage,
StorageDuplicateError,
StorageNotFoundError,
)
from ...storage.record import StorageRecord
from ..util import datetime_to_str, time_now
from ..valid import INDY_ISO8601_DATETIME_EXAMPLE, INDY_ISO8601_DATETIME_VALIDATE
from ..valid import (
INDY_ISO8601_DATETIME_EXAMPLE,
INDY_ISO8601_DATETIME_VALIDATE,
)
from .base import BaseModel, BaseModelError, BaseModelSchema

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -46,8 +54,7 @@ def match_post_filter(
return (
positive
and all(
record.get(k) and record.get(k) in alts
for k, alts in post_filter.items()
record.get(k) and record.get(k) in alts for k, alts in post_filter.items()
)
) or (
(not positive)
Expand Down Expand Up @@ -224,11 +231,12 @@ async def retrieve_by_id(
Args:
session: The profile session to use
record_id: The ID of the record to find
for_update: Whether to lock the record for update
"""

storage = session.inject(BaseStorage)
result = await storage.get_record(
cls.RECORD_TYPE, record_id, {"forUpdate": for_update, "retrieveTags": False}
cls.RECORD_TYPE, record_id, options={"forUpdate": for_update}
)
vals = json.loads(result.value)
return cls.from_storage(record_id, vals)
Expand All @@ -238,24 +246,26 @@ async def retrieve_by_tag_filter(
cls: Type[RecordType],
session: ProfileSession,
tag_filter: dict,
post_filter: dict = None,
post_filter: Optional[dict] = None,
*,
for_update=False,
) -> RecordType:
"""Retrieve a record by tag filter.

Args:
cls: The record class
session: The profile session to use
tag_filter: The filter dictionary to apply
post_filter: Additional value filters to apply matching positively,
with sequence values specifying alternatives to match (hit any)
for_update: Whether to lock the record for update
"""

storage = session.inject(BaseStorage)
rows = await storage.find_all_records(
cls.RECORD_TYPE,
cls.prefix_tag_filter(tag_filter),
options={"forUpdate": for_update, "retrieveTags": False},
options={"forUpdate": for_update},
)
found = None
for record in rows:
Expand All @@ -282,65 +292,107 @@ async def retrieve_by_tag_filter(
async def query(
cls: Type[RecordType],
session: ProfileSession,
tag_filter: dict = None,
tag_filter: Optional[dict] = None,
*,
post_filter_positive: dict = None,
post_filter_negative: dict = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
order_by: Optional[str] = None,
descending: bool = False,
post_filter_positive: Optional[dict] = None,
post_filter_negative: Optional[dict] = None,
alt: bool = False,
) -> Sequence[RecordType]:
"""Query stored records.

Args:
session: The profile session to use
tag_filter: An optional dictionary of tag filter clauses
limit: The maximum number of records to retrieve
offset: The offset to start retrieving records from
order_by: An optional field by which to order the records.
descending: Whether to order the records in descending order.
post_filter_positive: Additional value filters to apply matching positively
post_filter_negative: Additional value filters to apply matching negatively
alt: set to match any (positive=True) value or miss all (positive=False)
values in post_filter
"""

storage = session.inject(BaseStorage)
rows = await storage.find_all_records(
cls.RECORD_TYPE,
cls.prefix_tag_filter(tag_filter),
options={"retrieveTags": False},
)

tag_query = cls.prefix_tag_filter(tag_filter)
post_filter = post_filter_positive or post_filter_negative

# set flag to indicate if pagination is requested or not, then set defaults
paginated = limit is not None or offset is not None
limit = limit or DEFAULT_PAGE_SIZE
offset = offset or 0

if not post_filter and paginated:
# Only fetch paginated records if post-filter is not being applied
rows = await storage.find_paginated_records(
type_filter=cls.RECORD_TYPE,
tag_query=tag_query,
limit=limit,
offset=offset,
order_by=order_by,
descending=descending,
)
else:
rows = await storage.find_all_records(
type_filter=cls.RECORD_TYPE,
tag_query=tag_query,
order_by=order_by,
descending=descending,
)

num_results_post_filter = 0 # used if applying pagination post-filter
num_records_to_match = limit + offset # ignored if not paginated

result = []
for record in rows:
vals = json.loads(record.value)
if match_post_filter(
vals,
post_filter_positive,
positive=True,
alt=alt,
) and match_post_filter(
vals,
post_filter_negative,
positive=False,
alt=alt,
):
try:
try:
vals = json.loads(record.value)
if not post_filter: # pagination would already be applied if requested
result.append(cls.from_storage(record.id, vals))
except BaseModelError as err:
raise BaseModelError(f"{err}, for record id {record.id}")
else:
continue_processing = (
not paginated or num_results_post_filter < num_records_to_match
)
if not continue_processing:
break

post_filter_match = match_post_filter(
vals, post_filter_positive, positive=True, alt=alt
) and match_post_filter(
vals, post_filter_negative, positive=False, alt=alt
)

if not post_filter_match:
continue

if num_results_post_filter >= offset: # append only after offset
result.append(cls.from_storage(record.id, vals))

num_results_post_filter += 1
except (BaseModelError, json.JSONDecodeError, TypeError) as err:
raise BaseModelError(f"{err}, for record id {record.id}")
return result

async def save(
self,
session: ProfileSession,
*,
reason: str = None,
reason: Optional[str] = None,
log_params: Mapping[str, Any] = None,
log_override: bool = False,
event: bool = None,
event: Optional[bool] = None,
) -> str:
"""Persist the record to storage.

Args:
session: The profile session to use
reason: A reason to add to the log
log_params: Additional parameters to log
override: Override configured logging regimen, print to stderr instead
log_override: Override configured logging regimen, print to stderr instead
event: Flag to override whether the event is sent
"""

Expand All @@ -355,7 +407,7 @@ async def save(
new_record = False
else:
if not self._id:
self._id = str(uuid.uuid4())
self._id = str(uuid4())
self.created_at = self.updated_at
await storage.add_record(self.storage_record)
new_record = True
Expand All @@ -380,7 +432,7 @@ async def post_save(
session: ProfileSession,
new_record: bool,
last_state: Optional[str],
event: bool = None,
event: Optional[bool] = None,
):
"""Perform post-save actions.

Expand Down Expand Up @@ -411,7 +463,7 @@ async def delete_record(self, session: ProfileSession):
await self.emit_event(session, self.serialize())
await storage.delete_record(self.storage_record)

async def emit_event(self, session: ProfileSession, payload: Any = None):
async def emit_event(self, session: ProfileSession, payload: Optional[Any] = None):
"""Emit an event.

Args:
Expand All @@ -436,12 +488,11 @@ async def emit_event(self, session: ProfileSession, payload: Any = None):
def log_state(
cls,
msg: str,
params: dict = None,
settings: BaseSettings = None,
params: Optional[dict] = None,
settings: Optional[BaseSettings] = None,
override: bool = False,
):
"""Print a message with increased visibility (for testing)."""

if override or (
cls.LOG_STATE_FLAG and settings and settings.get(cls.LOG_STATE_FLAG)
):
Expand All @@ -454,10 +505,7 @@ def log_state(
@classmethod
def strip_tag_prefix(cls, tags: dict):
"""Strip tilde from unencrypted tag names."""

return (
{(k[1:] if "~" in k else k): v for (k, v) in tags.items()} if tags else {}
)
return {(k[1:] if "~" in k else k): v for (k, v) in tags.items()} if tags else {}

@classmethod
def prefix_tag_filter(cls, tag_filter: dict):
Expand Down
Loading
Loading