Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tests/test_prefetching.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,32 @@ async def test_prefetch_m2m_to_attr(db):
assert list(event.to_attr_participants_2) == [team_second]


@pytest.mark.asyncio
async def test_prefetch_m2m_annotate(db):
tournament = await Tournament.create(name="tournament")
team = await Team.create(name="1")
event = await Event.create(name="First", tournament=tournament)
await event.participants.add(team)
event = await Event.first().prefetch_related(
Prefetch("participants", Team.annotate(count_events=Count("events")))
)
for team in event.participants:
assert team.count_events == 1


@pytest.mark.asyncio
async def test_prefetch_m2m_select_related(db):
tournament = await Tournament.create(name="tournament")
team = await Team.create(name="1")
event = await Event.create(name="First", tournament=tournament)
await team.events.add(event)
team = await Team.first().prefetch_related(
Prefetch("events", Event.all().select_related("tournament"))
)
for event in team.events:
assert event.tournament == tournament


@pytest.mark.asyncio
async def test_prefetch_o2o_to_attr(db):
tournament = await Tournament.create(name="tournament")
Expand Down
109 changes: 38 additions & 71 deletions tortoise/backends/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from copy import copy
from typing import TYPE_CHECKING, Any, cast

from pypika_tortoise import JoinType, Parameter, Table
from pypika_tortoise.queries import QueryBuilder
from pypika_tortoise.queries import QueryBuilder, Table
from pypika_tortoise.terms import Parameter

from tortoise.exceptions import OperationalError, UnSupportedError
from tortoise.expressions import Expression, ResolveContext
Expand All @@ -19,7 +19,6 @@
ManyToManyFieldInstance,
RelationalField,
)
from tortoise.query_utils import QueryModifier

if TYPE_CHECKING: # pragma: nocoverage
from tortoise.backends.base.client import BaseDBAsyncClient
Expand Down Expand Up @@ -602,85 +601,53 @@ async def _prefetch_m2m_relation(
field: str,
related_query: tuple[str | None, QuerySet],
) -> Iterable[Model]:
to_attr, related_query = related_query
instance_id_set: set = {
instance._meta.pk.to_db_value(instance.pk, instance) for instance in instance_list
}
to_attr, queryset = related_query

field_object: ManyToManyFieldInstance = self.model._meta.fields_map[field] # type: ignore

through_table = Table(field_object.through, schema=field_object.through_schema)
model_pk = self.model._meta.pk
instance_pks = [model_pk.to_db_value(instance.pk, instance) for instance in instance_list]

subquery = (
self.db.query_class.from_(through_table)
.select(
through_table[field_object.backward_key].as_("_backward_relation_key"),
through_table[field_object.forward_key].as_("_forward_relation_key"),
)
.where(through_table[field_object.backward_key].isin(instance_id_set))
related_objects = await queryset.filter(
**{f"{field_object.related_name}__in": instance_pks}
)

related_query_table = related_query.model._meta.basetable
related_pk_field = related_query.model._meta.db_pk_column
related_query.resolve_ordering(related_query.model, related_query_table, [], {})
query = (
related_query.query.join(subquery)
.on(subquery._forward_relation_key == related_query_table[related_pk_field])
.select(
subquery._backward_relation_key.as_("_backward_relation_key"),
*[related_query_table[field].as_(field) for field in related_query.fields],
relation_map: dict = {}
if related_objects:
related_pk_map: dict = {obj.pk: obj for obj in related_objects}
related_model_pk = queryset.model._meta.pk
related_pks = [related_model_pk.to_db_value(pk, None) for pk in related_pk_map]
through_table = Table(field_object.through, schema=field_object.through_schema)
backward_field = through_table[field_object.backward_key]
forward_field = through_table[field_object.forward_key]

_, (_, through_rows) = await asyncio.gather(
self.__class__(
model=queryset.model, db=self.db, prefetch_map=queryset._prefetch_map
)._execute_prefetch_queries(related_objects),
self.db.execute_query(
*(
self.db.query_class.from_(through_table)
.select(backward_field, forward_field)
.where(backward_field.isin(instance_pks))
.where(forward_field.isin(related_pks))
.get_parameterized_sql()
)
),
)
)

if related_query._q_objects:
joined_tables: list[Table] = []
modifier = QueryModifier()
for node in related_query._q_objects:
modifier &= node.resolve(
ResolveContext(
model=related_query.model,
table=related_query_table,
annotations=related_query._annotations,
custom_filters=related_query._custom_filters,
)
for row in through_rows:
backward_key_value = model_pk.to_python_value(row[field_object.backward_key])
related_object = related_pk_map.get(
related_model_pk.to_python_value(row[field_object.forward_key])
)

for join in modifier.joins:
if join[0] not in joined_tables:
query = query.join(join[0], how=JoinType.left_outer).on(join[1])
joined_tables.append(join[0])

if modifier.where_criterion:
query = query.where(modifier.where_criterion)

if modifier.having_criterion:
query = query.having(modifier.having_criterion)

_, raw_results = await self.db.execute_query(*query.get_parameterized_sql())
relations: list[tuple[Any, Any]] = []
related_object_list: list[Model] = []
model_pk, related_pk = self.model._meta.pk, field_object.related_model._meta.pk
for e in raw_results:
pk_values: tuple[Any, Any] = (
model_pk.to_python_value(e["_backward_relation_key"]),
related_pk.to_python_value(e[related_pk_field]),
)
relations.append(pk_values)
related_object_list.append(related_query.model._init_from_db(**e))
await self.__class__(
model=related_query.model, db=self.db, prefetch_map=related_query._prefetch_map
)._execute_prefetch_queries(related_object_list)
related_object_map = {e.pk: e for e in related_object_list}
relation_map: dict[str, list] = {}

for object_id, related_object_id in relations:
if object_id not in relation_map:
relation_map[object_id] = []
relation_map[object_id].append(related_object_map[related_object_id])
if related_object is not None:
relation_map.setdefault(backward_key_value, []).append(related_object)

for instance in instance_list:
relation_container = getattr(instance, field)
relation_container._set_result_for_query(relation_map.get(instance.pk, []), to_attr)
getattr(instance, field)._set_result_for_query(
relation_map.get(instance.pk, []), to_attr
)
return instance_list

async def _prefetch_direct_relation(
Expand Down
Loading