Skip to content

Commit d972466

Browse files
Add query support for multiple collections
1 parent c4330c2 commit d972466

14 files changed

Lines changed: 530 additions & 167 deletions

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Added
11+
12+
- `tilebox-datasets`: Added dataset-level `find` and `query` methods on both sync and async `DatasetClient` to query
13+
across multiple collections.
14+
15+
1016
## [0.49.0] - 2026-02-19
1117

1218
### Added

tilebox-datasets/tests/test_timeseries.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,68 @@ def test_timeseries_dataset_collection_find_not_found() -> None:
198198
mocked.collection.find("14eb91a2-a42f-421f-9397-1dab577f05a9")
199199

200200

201+
@settings(max_examples=1)
202+
@given(example_datapoints(generated_fields=True, missing_fields=True))
203+
def test_timeseries_dataset_find_multiple_collections(expected_datapoint: ExampleDatapoint) -> None:
204+
"""Test that DatasetClient.find() supports querying by mixed collection reference types."""
205+
dataset, service = _mocked_dataset()
206+
207+
named_collection = CollectionInfo(Collection(uuid4(), "named-collection"), None, None)
208+
other_collection = CollectionInfo(Collection(uuid4(), "other-collection"), None, None)
209+
210+
service.get_collections.return_value = Promise.resolve([named_collection, other_collection])
211+
message = AnyMessage(example_dataset_type_url(), expected_datapoint.SerializeToString())
212+
service.query_by_id.return_value = Promise.resolve(message)
213+
214+
datapoint_id = uuid_message_to_uuid(expected_datapoint.id)
215+
datapoint = dataset.find(
216+
datapoint_id,
217+
[
218+
named_collection.collection.name,
219+
],
220+
)
221+
222+
assert isinstance(datapoint, xr.Dataset)
223+
service.get_collections.assert_called_once_with(dataset._dataset.id, True, True)
224+
service.query_by_id.assert_called_once_with(
225+
dataset._dataset.id,
226+
[
227+
named_collection.collection.id,
228+
],
229+
datapoint_id,
230+
False,
231+
)
232+
233+
234+
@settings(max_examples=1)
235+
@given(pages=paginated_query_results())
236+
def test_timeseries_dataset_query_multiple_collections(pages: list[QueryResultPage]) -> None:
237+
"""Test that DatasetClient.query() forwards all selected collection ids to the backend query endpoint."""
238+
dataset, service = _mocked_dataset()
239+
240+
named_collection = CollectionInfo(Collection(uuid4(), "named-collection"), None, None)
241+
other_collection = CollectionInfo(Collection(uuid4(), "other-collection"), None, None)
242+
243+
service.get_collections.return_value = Promise.resolve([named_collection, other_collection])
244+
service.query.side_effect = [Promise.resolve(page) for page in pages]
245+
246+
interval = TimeInterval(datetime.now(), datetime.now() + timedelta(days=1))
247+
queried = dataset.query(
248+
collections=[
249+
named_collection.collection.name,
250+
],
251+
temporal_extent=interval,
252+
)
253+
254+
_assert_datapoints_match(queried, pages)
255+
service.get_collections.assert_called_once_with(dataset._dataset.id, True, True)
256+
first_call_args = service.query.call_args_list[0][0]
257+
assert first_call_args[0] == dataset._dataset.id
258+
assert first_call_args[1] == [
259+
named_collection.collection.id,
260+
]
261+
262+
201263
@patch("tilebox.datasets.sync.pagination.tqdm")
202264
@patch("tilebox.datasets.progress.tqdm")
203265
@settings(deadline=1000, max_examples=3) # increase deadline to 1s to not timeout because of the progress bar
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:5234a799245c656e37e16e249f6212b0c7a48021d9a2fbc7a672375d7354a57c
3-
size 10004
2+
oid sha256:7de1b3958bcc1c8aecd8170b71433153cf5b705d0b0c7ea1f424bd78b9e8f66f
3+
size 10056
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:ef04121a4e99f85a25513d0933823b1c271a367b575dece8959be080c85e456d
3-
size 903956
2+
oid sha256:a964dcaf44e3ef38283e4906269781d287e313e8ebd6834a3dda22ef76539e43
3+
size 904012
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:a5fee0adc442645fb623f35d133681f34003690f582d532f486ecd91f59cc67b
3-
size 873354
2+
oid sha256:5ffb02e44b4e9a312cdb3edc8881e5bcffaf21ac28d4df973e1f20ba882c7dfa
3+
size 875182
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:7428cc2c23ba142cdc148053f5edf9aa5d7f630cf0139c69c13701c516f78037
3-
size 8560
2+
oid sha256:da90443beb49a4e0d84fd06f08f0affab5ea703f4ce7e65585c59f3ce421a101
3+
size 8584

tilebox-datasets/tilebox/datasets/aio/dataset.py

Lines changed: 188 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from _tilebox.grpc.aio.producer_consumer import async_producer_consumer
1313
from _tilebox.grpc.error import ArgumentError, NotFoundError
1414
from tilebox.datasets.aio.pagination import with_progressbar, with_time_progress_callback, with_time_progressbar
15-
from tilebox.datasets.data.collection import CollectionInfo
15+
from tilebox.datasets.data.collection import Collection, CollectionInfo
1616
from tilebox.datasets.data.data_access import QueryFilters, SpatialFilter, SpatialFilterLike
1717
from tilebox.datasets.data.datapoint import QueryResultPage
1818
from tilebox.datasets.data.datasets import Dataset
@@ -139,6 +139,122 @@ async def delete_collection(self, collection: "str | UUID | CollectionClient") -
139139

140140
await self._service.delete_collection(self._dataset.id, collection_id)
141141

142+
async def find(
143+
self,
144+
datapoint_id: str | UUID,
145+
collections: "list[str] | list[UUID] | list[Collection] | list[CollectionInfo] | list[CollectionClient] | None" = None,
146+
skip_data: bool = False,
147+
) -> xr.Dataset:
148+
"""
149+
Find a specific datapoint in one of the specified collections by its id.
150+
151+
Args:
152+
datapoint_id: The id of the datapoint to find.
153+
collections: The collections to search in. Supports collection names, ids or collection objects.
154+
If not specified, all collections in the dataset are searched.
155+
skip_data: Whether to skip the actual data of the datapoint. If True, only
156+
datapoint metadata is returned.
157+
158+
Returns:
159+
The datapoint as an xarray dataset.
160+
"""
161+
collection_ids = await self._collection_ids(collections)
162+
try:
163+
datapoint = await self._service.query_by_id(
164+
self._dataset.id,
165+
collection_ids,
166+
as_uuid(datapoint_id),
167+
skip_data,
168+
)
169+
except ArgumentError:
170+
raise ValueError(f"Invalid datapoint id: {datapoint_id} is not a valid UUID") from None
171+
except NotFoundError:
172+
raise NotFoundError(f"No such datapoint {datapoint_id}") from None
173+
174+
message_type = get_message_type(datapoint.type_url)
175+
data = message_type.FromString(datapoint.value)
176+
177+
converter = MessageToXarrayConverter(initial_capacity=1)
178+
converter.convert(data)
179+
return converter.finalize("time", skip_empty_fields=skip_data).isel(time=0)
180+
181+
async def query(
182+
self,
183+
*,
184+
collections: "list[str] | list[UUID] | list[Collection] | list[CollectionInfo] | list[CollectionClient] | dict[str, CollectionClient] | None",
185+
temporal_extent: TimeIntervalLike,
186+
spatial_extent: SpatialFilterLike | None = None,
187+
skip_data: bool = False,
188+
show_progress: bool | ProgressCallback = False,
189+
) -> xr.Dataset:
190+
"""
191+
Query datapoints in the specified collections and temporal extent.
192+
193+
Args:
194+
collections: The collections to query in. Supports collection names, ids or collection objects.
195+
If not specified, all collections in the dataset are queried.
196+
temporal_extent: The temporal extent to query data for. (Required)
197+
spatial_extent: The spatial extent to query data in. (Optional)
198+
skip_data: Whether to skip the actual data of the datapoint. If True, only
199+
datapoint metadata is returned.
200+
show_progress: Whether to show a progress bar while loading the data.
201+
If a callable is specified it is used as callback to report progress percentages.
202+
203+
Returns:
204+
Matching datapoints in the given temporal and spatial extent as an xarray dataset.
205+
"""
206+
if temporal_extent is None:
207+
raise ValueError("A temporal_extent for your query must be specified")
208+
209+
collection_ids = await self._collection_ids(collections)
210+
pages = _iter_query_pages(
211+
self._service,
212+
self._dataset.id,
213+
collection_ids,
214+
temporal_extent,
215+
spatial_extent,
216+
skip_data,
217+
dataset_name=self.name,
218+
show_progress=show_progress,
219+
)
220+
return await _convert_to_dataset(pages, skip_empty_fields=skip_data)
221+
222+
async def _collection_id(self, collection: "UUID | Collection | CollectionInfo | CollectionClient") -> UUID:
223+
if isinstance(collection, CollectionClient):
224+
return collection._collection.id
225+
if isinstance(collection, CollectionInfo):
226+
return collection.collection.id
227+
if isinstance(collection, Collection):
228+
return collection.id
229+
return collection
230+
231+
async def _collection_ids(
232+
self,
233+
collections: "list[str] | list[UUID] | list[Collection] | list[CollectionInfo] | list[CollectionClient] | dict[str, CollectionClient] | None",
234+
) -> list[UUID]:
235+
if collections is None:
236+
return []
237+
238+
all_collections: list[CollectionInfo] = await self._service.get_collections(self._dataset.id, True, True)
239+
# find all valid collection names and ids
240+
collections_by_name = {c.collection.name: c.collection.id for c in all_collections}
241+
valid_collection_ids = {c.collection.id for c in all_collections}
242+
243+
collection_ids: list[UUID] = []
244+
for collection in collections:
245+
if isinstance(collection, str):
246+
try:
247+
collection_ids.append(collections_by_name[collection])
248+
except KeyError:
249+
raise ValueError(f"Collection {collection} not found in dataset {self.name}") from None
250+
else:
251+
collection_id = await self._collection_id(collection)
252+
if collection_id not in valid_collection_ids:
253+
raise ValueError(f"Collection {collection_id} is not part of the dataset {self.name}")
254+
collection_ids.append(collection_id)
255+
256+
return collection_ids
257+
142258
def __repr__(self) -> str:
143259
return f"{self.name} [Timeseries Dataset]: {self._dataset.summary}"
144260

@@ -221,7 +337,7 @@ async def find(self, datapoint_id: str | UUID, skip_data: bool = False) -> xr.Da
221337
"""
222338
try:
223339
datapoint = await self._dataset._service.query_by_id(
224-
[self._collection.id], as_uuid(datapoint_id), skip_data
340+
self._dataset._dataset.id, [self._collection.id], as_uuid(datapoint_id), skip_data
225341
)
226342
except ArgumentError:
227343
raise ValueError(f"Invalid datapoint id: {datapoint_id} is not a valid UUID") from None
@@ -259,8 +375,14 @@ async def _find_interval(
259375
filters = QueryFilters(temporal_extent=IDInterval.parse(datapoint_id_interval, end_inclusive=end_inclusive))
260376

261377
async def request(page: PaginationProtocol) -> QueryResultPage:
262-
query_page = Pagination(page.limit, page.starting_after)
263-
return await self._dataset._service.query([self._collection.id], filters, skip_data, query_page)
378+
return await _query_page(
379+
self._dataset._service,
380+
self._dataset._dataset.id,
381+
[self._collection.id],
382+
filters,
383+
skip_data,
384+
page,
385+
)
264386

265387
initial_page = Pagination()
266388
pages = paginated_request(request, initial_page)
@@ -350,7 +472,16 @@ async def query(
350472
if temporal_extent is None:
351473
raise ValueError("A temporal_extent for your query must be specified")
352474

353-
pages = self._iter_pages(temporal_extent, spatial_extent, skip_data, show_progress=show_progress)
475+
pages = _iter_query_pages(
476+
self._dataset._service,
477+
self._dataset._dataset.id,
478+
[self._collection.id],
479+
temporal_extent,
480+
spatial_extent,
481+
skip_data,
482+
dataset_name=self._dataset.name,
483+
show_progress=show_progress,
484+
)
354485
return await _convert_to_dataset(pages, skip_empty_fields=skip_data)
355486

356487
async def _iter_pages(
@@ -361,29 +492,19 @@ async def _iter_pages(
361492
show_progress: bool | ProgressCallback = False,
362493
page_size: int | None = None,
363494
) -> AsyncIterator[QueryResultPage]:
364-
time_interval = TimeInterval.parse(temporal_extent)
365-
filters = QueryFilters(time_interval, SpatialFilter.parse(spatial_extent) if spatial_extent else None)
366-
367-
request = partial(self._query_page, filters, skip_data)
368-
369-
initial_page = Pagination(limit=page_size)
370-
pages = paginated_request(request, initial_page)
371-
372-
if callable(show_progress):
373-
pages = with_time_progress_callback(pages, time_interval, show_progress)
374-
elif show_progress:
375-
message = f"Fetching {self._dataset.name}"
376-
pages = with_time_progressbar(pages, time_interval, message)
377-
378-
async for page in pages:
495+
async for page in _iter_query_pages(
496+
self._dataset._service,
497+
self._dataset._dataset.id,
498+
[self._collection.id],
499+
temporal_extent,
500+
spatial_extent,
501+
skip_data,
502+
dataset_name=self._dataset.name,
503+
show_progress=show_progress,
504+
page_size=page_size,
505+
):
379506
yield page
380507

381-
async def _query_page(
382-
self, filters: QueryFilters, skip_data: bool, page: PaginationProtocol | None = None
383-
) -> QueryResultPage:
384-
query_page = Pagination(page.limit, page.starting_after) if page else Pagination()
385-
return await self._dataset._service.query([self._collection.id], filters, skip_data, query_page)
386-
387508
async def ingest(
388509
self,
389510
data: IngestionData,
@@ -477,6 +598,47 @@ async def delete(self, datapoints: DatapointIDs, *, show_progress: bool | Progre
477598
return num_deleted
478599

479600

601+
async def _query_page( # noqa: PLR0913
602+
service: TileboxDatasetService,
603+
dataset_id: UUID,
604+
collection_ids: list[UUID] | None,
605+
filters: QueryFilters,
606+
skip_data: bool,
607+
page: PaginationProtocol | None = None,
608+
) -> QueryResultPage:
609+
query_page = Pagination(page.limit, page.starting_after) if page else Pagination()
610+
return await service.query(dataset_id, collection_ids or [], filters, skip_data, query_page)
611+
612+
613+
async def _iter_query_pages( # noqa: PLR0913
614+
service: TileboxDatasetService,
615+
dataset_id: UUID,
616+
collection_ids: list[UUID] | None,
617+
temporal_extent: TimeIntervalLike,
618+
spatial_extent: SpatialFilterLike | None = None,
619+
skip_data: bool = False,
620+
*,
621+
dataset_name: str,
622+
show_progress: bool | ProgressCallback = False,
623+
page_size: int | None = None,
624+
) -> AsyncIterator[QueryResultPage]:
625+
time_interval = TimeInterval.parse(temporal_extent)
626+
filters = QueryFilters(time_interval, SpatialFilter.parse(spatial_extent) if spatial_extent else None)
627+
628+
request = partial(_query_page, service, dataset_id, collection_ids, filters, skip_data)
629+
630+
initial_page = Pagination(limit=page_size)
631+
pages = paginated_request(request, initial_page)
632+
633+
if callable(show_progress):
634+
pages = with_time_progress_callback(pages, time_interval, show_progress)
635+
elif show_progress:
636+
pages = with_time_progressbar(pages, time_interval, f"Fetching {dataset_name}")
637+
638+
async for page in pages:
639+
yield page
640+
641+
480642
async def _convert_to_dataset(pages: AsyncIterator[QueryResultPage], skip_empty_fields: bool = False) -> xr.Dataset:
481643
"""
482644
Convert an async iterator of QueryResultPages into a single xarray Dataset

0 commit comments

Comments
 (0)