Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/OperationsAPI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ paths:
tags:
- "operations"
parameters:
- name: search_query
in: query
description: General search query to match against feed stable id, feed name and feed provider.
required: False
schema:
type: string
- name: operation_status
in: query
description: Filter feeds by operational status.
Expand Down
18 changes: 15 additions & 3 deletions functions-python/helpers/query_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime
from typing import Type

from sqlalchemy import and_, func
from sqlalchemy import and_, func, or_
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.orm.query import Query

Expand Down Expand Up @@ -75,6 +75,7 @@ def get_eager_loading_options(model: Type[Feed]):

def get_feeds_query(
db_session: Session,
search_query: str | None = None,
operation_status: str | None = None,
data_type: str | None = None,
limit: int | None = None,
Expand All @@ -86,6 +87,7 @@ def get_feeds_query(

Args:
db_session: SQLAlchemy session
search_query: Optional general search query
operation_status: Optional filter for operational status (wip or published)
data_type: Optional filter for feed type (gtfs or gtfs_rt)
limit: Maximum number of items to return
Expand All @@ -103,17 +105,27 @@ def get_feeds_query(
)
conditions = []

if data_type is None:
if data_type is None or len(data_type.strip()) == 0:
conditions.append(model.data_type.in_(["gtfs", "gtfs_rt"]))
logging.info("Added filter to exclude gbfs feeds")
else:
conditions.append(model.data_type == data_type)
logging.info("Added data_type filter: %s", data_type)

if operation_status:
if operation_status and operation_status.strip():
conditions.append(model.operational_status == operation_status)
logging.info("Added operational_status filter: %s", operation_status)

if search_query and search_query.strip():
search_pattern = f"%{search_query.strip()}%"
conditions.append(
or_(
model.stable_id.ilike(search_pattern),
model.feed_name.ilike(search_pattern),
model.provider.ilike(search_pattern),
)
)
logging.info("Added search_query filter: %s", search_query)
query = db_session.query(model)
logging.info("Created base query with model %s", model.__name__)

Expand Down
4 changes: 2 additions & 2 deletions functions-python/operations_api/function_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
}
],
"ingress_settings": "ALLOW_ALL",
"max_instance_request_concurrency": 1,
"max_instance_count": 5,
"max_instance_request_concurrency": 100,
"max_instance_count": 10,
"min_instance_count": 0,
"available_cpu": 1,
"build_settings": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,9 @@ def assert_no_existing_feed_url(producer_url: str, db_session: Session):
)

@with_db_session
async def get_feeds(
def handle_get_feeds(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is necessary because the with_db_session annotation doesn't properly close sessions when applied to async functions. I'll update the other endpoints in a separate PR to avoid polluting this one with "unrelated" changes.

self,
search_query: Optional[str] = None,
operation_status: Optional[str] = None,
data_type: Optional[str] = None,
offset: str = "0",
Expand All @@ -122,8 +123,21 @@ async def get_feeds(
limit_int = int(limit) if limit else 50
offset_int = int(offset) if offset else 0

# filtered but unpaginated for total
total_query = get_feeds_query(
db_session=db_session,
search_query=search_query,
operation_status=operation_status,
data_type=data_type,
limit=None,
offset=None,
model=Feed,
)
total = total_query.count()

query = get_feeds_query(
db_session=db_session,
search_query=search_query,
operation_status=operation_status,
data_type=data_type,
limit=limit_int,
Expand All @@ -133,14 +147,10 @@ async def get_feeds(

logging.info("Executing query with data_type: %s", data_type)

total = query.count()
feeds = query.all()
logging.info("Retrieved %d feeds from database", len(feeds))

feed_list = []
for feed in feeds:
feed_list.append(OperationFeedImpl.from_orm(feed))

feed_list = [OperationFeedImpl.from_orm(feed) for feed in feeds]
response = GetFeeds200Response(
total=total, offset=offset_int, limit=limit_int, feeds=feed_list
)
Expand All @@ -153,6 +163,20 @@ async def get_feeds(
status_code=500, detail=f"Internal server error: {str(e)}"
)

async def get_feeds(
self,
search_query: Optional[str] = None,
operation_status: Optional[str] = None,
data_type: Optional[str] = None,
offset: str = "0",
limit: str = "50",
db_session: Session = None,
) -> GetFeeds200Response:
"""Get a list of feeds with optional filtering and pagination."""
return self.handle_get_feeds(
search_query, operation_status, data_type, offset, limit
)

@with_db_session
async def get_gtfs_feed(
self,
Expand Down
8 changes: 5 additions & 3 deletions functions-python/operations_api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@
stable_id="mdb-41",
status="active",
feed_contact_email="feed_contact_email",
provider="provider",
provider="provider A",
entitytypes=[Entitytype(name="vp")],
operational_status="published",
)

feed_mdb_40 = Gtfsfeed(
Expand All @@ -56,7 +57,7 @@
stable_id="mdb-40",
status="active",
feed_contact_email="feed_contact_email",
provider="provider",
provider="provider B",
gtfs_rt_feeds=[feed_mdb_41],
operational_status="wip",
)
Expand All @@ -74,8 +75,9 @@
stable_id="mdb-400",
status="active",
feed_contact_email="feed_contact_email",
provider="provider",
provider="provider C",
gtfs_rt_feeds=[],
operational_status="published",
)

# Test license objects used by LicensesApiImpl tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,21 +87,21 @@ async def test_get_feeds_pagination():
api = OperationsApiImpl()

response = await api.get_feeds(limit=1)
assert response.total == 1
assert response.total == 3
assert response.limit == 1
assert response.offset == 0
assert len(response.feeds) == 1
first_feed = response.feeds[0]

response = await api.get_feeds(offset=1, limit=1)
assert response.total == 1
assert response.total == 3
assert response.limit == 1
assert response.offset == 1
assert len(response.feeds) == 1
assert response.feeds[0].stable_id != first_feed.stable_id

response = await api.get_feeds(offset=3)
assert response.total == 0
assert response.total == 3
assert response.limit == 50
assert response.offset == 3
assert len(response.feeds) == 0
Expand Down Expand Up @@ -149,26 +149,20 @@ async def test_get_feeds_combined_filters():

base_response = await api.get_feeds()
assert base_response is not None
print(f"\nTotal feeds in database: {len(base_response.feeds)}")

gtfs_response = await api.get_feeds(data_type="gtfs")
assert gtfs_response is not None
print(f"Total GTFS feeds: {len(gtfs_response.feeds)}")
for feed in gtfs_response.feeds:
print(f"GTFS Feed: {feed.stable_id}, status: {feed.operational_status}")

wip_response = await api.get_feeds(operation_status="wip")
assert wip_response is not None
print(f"Total WIP feeds: {len(wip_response.feeds)}")
for feed in wip_response.feeds:
print(f"WIP Feed: {feed.stable_id}, type: {feed.data_type}")

response = await api.get_feeds(data_type="gtfs", operation_status="wip")
response = await api.get_feeds(data_type="gtfs", operation_status="published")
assert response is not None
wip_gtfs_feeds = response.feeds
print(f"Total WIP GTFS feeds: {len(wip_gtfs_feeds)}")

assert len(wip_gtfs_feeds) == 0
assert len(wip_gtfs_feeds) == 1
assert wip_gtfs_feeds[0].data_type == "gtfs"
assert wip_gtfs_feeds[0].operational_status == "published"

response = await api.get_feeds(data_type="gtfs", limit=1, offset=1)
assert response is not None
Expand Down Expand Up @@ -231,3 +225,35 @@ async def test_get_feeds_unpublished_with_data_type():
for feed in rt_response.feeds:
assert feed.operational_status == "unpublished"
assert feed.data_type == "gtfs_rt"


@pytest.mark.asyncio
async def test_get_feeds_search_query():
"""
Test get_feeds endpoint with search query filter.
Should return only feeds matching the search query.
"""
api = OperationsApiImpl()

response = await api.get_feeds(search_query="RT")
assert response is not None
assert response.total == 1
assert len(response.feeds) == 1
assert response.feeds[0].feed_name == "London Transit Commission(RT"

response = await api.get_feeds(search_query=" Provider B ")
assert response is not None
assert response.total == 1
assert len(response.feeds) == 1
assert response.feeds[0].provider == "provider B"

response = await api.get_feeds(search_query="mdb-41")
assert response is not None
assert response.total == 1
assert len(response.feeds) == 1
assert response.feeds[0].stable_id == "mdb-41"

response = await api.get_feeds(search_query="mdb")
assert response is not None
assert response.total == 3
assert len(response.feeds) == 3
Loading