Skip to content
This repository was archived by the owner on Sep 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/dispatch/case/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def get_cases(
expand: bool = Query(default=False),
):
"""Retrieves all cases."""
common["include_keys"] = include
pagination = search_filter_sort_paginate(model="Case", **common)

if expand:
Expand Down
33 changes: 10 additions & 23 deletions src/dispatch/database/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,8 @@ def apply_filters(query, filter_spec, model_cls=None, do_auto_join=True):
return query


def get_model_map(filters: dict) -> dict:
def apply_filter_specific_joins(model: Base, filter_spec: dict, query: orm.query):
"""Applies any model specific implicitly joins."""
# this is required because by default sqlalchemy-filter's auto-join
# knows nothing about how to join many-many relationships.
model_map = {
Expand Down Expand Up @@ -370,21 +371,19 @@ def get_model_map(filters: dict) -> dict:
(SignalInstance, "EntityType"): (SignalInstance.entities, True),
(Tag, "TagType"): (Tag.tag_type, False),
}
filters = build_filters(filter_spec)

# Replace mapping if looking for commander
if "Commander" in filters:
if "Commander" in str(filter_spec):
model_map.update({(Incident, "IndividualContact"): (Incident.commander, True)})
if "Assignee" in filters:
if "Assignee" in str(filter_spec):
model_map.update({(Case, "IndividualContact"): (Case.assignee, True)})
return model_map


def apply_model_specific_joins(model: Base, models: List[str], query: orm.query):
model_map = get_model_map(models)
filter_models = get_named_models(filters)
joined_models = []

for include_model in models:
if model_map.get((model, include_model)):
joined_model, is_outer = model_map[(model, include_model)]
for filter_model in filter_models:
if model_map.get((model, filter_model)):
joined_model, is_outer = model_map[(model, filter_model)]
try:
if joined_model not in joined_models:
query = query.join(joined_model, isouter=is_outer)
Expand All @@ -395,14 +394,6 @@ def apply_model_specific_joins(model: Base, models: List[str], query: orm.query)
return query


def apply_filter_specific_joins(model: Base, filter_spec: dict, query: orm.query):
"""Applies any model specific implicitly joins."""
filters = build_filters(filter_spec)
filter_models = get_named_models(filters)

return apply_model_specific_joins(model, filter_models, query)


def composite_search(*, db_session, query_str: str, models: List[Base], current_user: DispatchUser):
"""Perform a multi-table search based on the supplied query."""
s = CompositeSearch(db_session, models)
Expand Down Expand Up @@ -546,7 +537,6 @@ def search_filter_sort_paginate(
model,
query_str: str = None,
filter_spec: str | dict | None = None,
include_keys: List[str] = None,
page: int = 1,
items_per_page: int = 5,
sort_by: List[str] = None,
Expand Down Expand Up @@ -584,9 +574,6 @@ def search_filter_sort_paginate(
else:
query = apply_filters(query, filter_spec, model_cls)

if include_keys:
query = apply_model_specific_joins(model_cls, include_keys, query)

if model == "Incident":
query = query.intersect(query_restricted)
for filter in tag_all_filters:
Expand Down
26 changes: 0 additions & 26 deletions tests/database/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,32 +154,6 @@ def test_role_based_filtering(session, incidents, user, admin_user):
assert len(admin_result["items"]) >= len(member_result["items"])


def test_include_keys_functionality(session, case, admin_user):
"""Test functionality of include_keys parameter."""
from dispatch.common.utils.views import create_pydantic_include
from dispatch.case.models import CasePagination

result = search_filter_sort_paginate(
db_session=session,
model="Case",
include_keys=["tags"],
current_user=admin_user,
role=UserRoles.admin,
)

# make sure they are renderable
include_sets = create_pydantic_include(["tags", "title"])

include_fields = {
"items": {"__all__": include_sets},
"itemsPerPage": ...,
"page": ...,
"total": ...,
}
marshalled = json.loads(CasePagination(**result).json(include=include_fields))
assert "tags" in marshalled["items"][0].keys()


# Test restricted filters
def test_restricted_incident_filter_member(session, user):
"""Tests incident filtering for member role."""
Expand Down
Loading