Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 8 additions & 64 deletions api/src/feeds/impl/feeds_api_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
from typing import List, Union, TypeVar, Optional

from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import joinedload, contains_eager, selectinload, Session
from sqlalchemy.orm import contains_eager, selectinload, Session
from sqlalchemy.orm.query import Query

from feeds.impl.datasets_api_impl import DatasetsApiImpl
from feeds.impl.error_handling import raise_http_error, raise_http_validation_error, convert_exception
from shared.db_models.entity_type_enum import EntityType
from shared.db_models.feed_impl import FeedImpl
from shared.db_models.gbfs_feed_impl import GbfsFeedImpl
from shared.db_models.gtfs_feed_impl import GtfsFeedImpl
Expand All @@ -23,7 +21,7 @@
from shared.common.db_utils import (
get_gtfs_feeds_query,
get_gtfs_rt_feeds_query,
get_joinedload_options,
get_selectinload_options,
add_official_filter,
get_gbfs_feeds_query,
)
Expand All @@ -41,13 +39,10 @@
Gtfsdataset,
Gtfsfeed,
Gtfsrealtimefeed,
Location,
Entitytype,
)
from shared.feed_filters.feed_filter import FeedFilter
from shared.feed_filters.gtfs_dataset_filter import GtfsDatasetFilter
from shared.feed_filters.gtfs_feed_filter import LocationFilter
from shared.feed_filters.gtfs_rt_feed_filter import GtfsRtFeedFilter, EntityTypeFilter
from shared.feed_filters.gtfs_rt_feed_filter import GtfsRtFeedFilter
from utils.date_utils import valid_iso_date
from utils.logger import get_logger

Expand Down Expand Up @@ -120,7 +115,7 @@ def get_feeds(
# Results are sorted by provider
feed_query = feed_query.order_by(FeedOrm.provider, FeedOrm.stable_id)
# Ensure license relationship is available to the model conversion without extra queries
feed_query = feed_query.options(*get_joinedload_options(), selectinload(FeedOrm.license))
feed_query = feed_query.options(*get_selectinload_options(), selectinload(FeedOrm.license))
if limit is not None:
feed_query = feed_query.limit(limit)
if offset is not None:
Expand Down Expand Up @@ -251,11 +246,10 @@ def get_gtfs_rt_feed(self, id: str, db_session: Session) -> GtfsRTFeed:
not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
)
)
.outerjoin(Location, Gtfsrealtimefeed.locations)
.options(
joinedload(Gtfsrealtimefeed.entitytypes),
joinedload(Gtfsrealtimefeed.gtfs_feeds),
*get_joinedload_options(),
selectinload(Gtfsrealtimefeed.entitytypes),
selectinload(Gtfsrealtimefeed.gtfs_feeds),
*get_selectinload_options(),
)
).all()

Expand Down Expand Up @@ -299,61 +293,11 @@ def get_gtfs_rt_feeds(

return self._get_response(feed_query, GtfsRTFeedImpl)

entity_types_list = entity_types.split(",") if entity_types else None
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this PR, but we had all this dead code...


# Validate entity types using the EntityType enum
if entity_types_list:
try:
entity_types_list = [EntityType(et.strip()).value for et in entity_types_list]
except ValueError:
raise_http_validation_error(
"Entity types must be the value 'vp,' 'sa,' or 'tu,'. "
"When provided a list values must be separated by commas."
)

gtfs_rt_feed_filter = GtfsRtFeedFilter(
stable_id=None,
provider__ilike=provider,
producer_url__ilike=producer_url,
entity_types=EntityTypeFilter(name__in=entity_types_list),
location=LocationFilter(
country_code=country_code,
subdivision_name__ilike=subdivision_name,
municipality__ilike=municipality,
),
)
subquery = gtfs_rt_feed_filter.filter(
select(Gtfsrealtimefeed.id)
.join(Location, Gtfsrealtimefeed.locations)
.join(Entitytype, Gtfsrealtimefeed.entitytypes)
).subquery()
feed_query = (
db_session.query(Gtfsrealtimefeed)
.filter(Gtfsrealtimefeed.id.in_(subquery))
.filter(
or_(
Gtfsrealtimefeed.operational_status == "published",
not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
)
)
.options(
joinedload(Gtfsrealtimefeed.entitytypes),
joinedload(Gtfsrealtimefeed.gtfs_feeds),
*get_joinedload_options(),
)
.order_by(Gtfsrealtimefeed.provider, Gtfsrealtimefeed.stable_id)
)
feed_query = add_official_filter(feed_query, is_official)

feed_query = feed_query.limit(limit).offset(offset)
return self._get_response(feed_query, GtfsRTFeedImpl)

@staticmethod
def _get_response(feed_query: Query, impl_cls: type[T]) -> List[T]:
"""Get the response for the feed query."""
results = feed_query.all()
response = [impl_cls.from_orm(feed) for feed in results]
return list({feed.id: feed for feed in response}.values())
return [impl_cls.from_orm(feed) for feed in results]
Comment on lines -355 to +300
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This optimizes memory handling within the function.


@with_db_session
def get_gtfs_feed_gtfs_rt_feeds(self, id: str, db_session: Session) -> List[GtfsRTFeed]:
Expand Down
20 changes: 12 additions & 8 deletions api/src/shared/common/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,15 @@ def get_gtfs_feeds_query(

if include_options_for_joinedload:
feed_query = feed_query.options(
joinedload(Gtfsfeed.latest_dataset)
.joinedload(Gtfsdataset.validation_reports)
.joinedload(Validationreport.features),
joinedload(Gtfsfeed.visualization_dataset),
*get_joinedload_options(),
# Use selectinload for all collection relationships to avoid a cartesian-product row
# explosion when multiple one-to-many associations are loaded simultaneously.
# joinedload on collections multiplies rows (N feeds × M locations × F features …);
# selectinload issues a separate IN-query per relationship, keeping rows at N per query.
selectinload(Gtfsfeed.latest_dataset)
.selectinload(Gtfsdataset.validation_reports)
.selectinload(Validationreport.features),
joinedload(Gtfsfeed.visualization_dataset), # scalar (many-to-one) — joinedload is safe
*get_selectinload_options(),
).order_by(Gtfsfeed.provider, Gtfsfeed.stable_id)

feed_query = feed_query.limit(limit).offset(offset)
Expand Down Expand Up @@ -274,9 +278,9 @@ def get_gtfs_rt_feeds_query(
feed_query = feed_query.filter(Gtfsrealtimefeed.operational_status == "published")

feed_query = feed_query.options(
joinedload(Gtfsrealtimefeed.entitytypes),
joinedload(Gtfsrealtimefeed.gtfs_feeds),
*get_joinedload_options(),
selectinload(Gtfsrealtimefeed.entitytypes),
selectinload(Gtfsrealtimefeed.gtfs_feeds),
*get_selectinload_options(),
)
feed_query = add_official_filter(feed_query, is_official)

Expand Down
Loading