1212from _tilebox .grpc .aio .producer_consumer import async_producer_consumer
1313from _tilebox .grpc .error import ArgumentError , NotFoundError
1414from tilebox .datasets .aio .pagination import with_progressbar , with_time_progress_callback , with_time_progressbar
15- from tilebox .datasets .data .collection import CollectionInfo
15+ from tilebox .datasets .data .collection import Collection , CollectionInfo
1616from tilebox .datasets .data .data_access import QueryFilters , SpatialFilter , SpatialFilterLike
1717from tilebox .datasets .data .datapoint import QueryResultPage
1818from tilebox .datasets .data .datasets import Dataset
@@ -139,6 +139,122 @@ async def delete_collection(self, collection: "str | UUID | CollectionClient") -
139139
140140 await self ._service .delete_collection (self ._dataset .id , collection_id )
141141
142+ async def find (
143+ self ,
144+ datapoint_id : str | UUID ,
145+ collections : "list[str] | list[UUID] | list[Collection] | list[CollectionInfo] | list[CollectionClient] | None" = None ,
146+ skip_data : bool = False ,
147+ ) -> xr .Dataset :
148+ """
149+ Find a specific datapoint in one of the specified collections by its id.
150+
151+ Args:
152+ datapoint_id: The id of the datapoint to find.
153+ collections: The collections to search in. Supports collection names, ids or collection objects.
154+ If not specified, all collections in the dataset are searched.
155+ skip_data: Whether to skip the actual data of the datapoint. If True, only
156+ datapoint metadata is returned.
157+
158+ Returns:
159+ The datapoint as an xarray dataset.
160+ """
161+ collection_ids = await self ._collection_ids (collections )
162+ try :
163+ datapoint = await self ._service .query_by_id (
164+ self ._dataset .id ,
165+ collection_ids ,
166+ as_uuid (datapoint_id ),
167+ skip_data ,
168+ )
169+ except ArgumentError :
170+ raise ValueError (f"Invalid datapoint id: { datapoint_id } is not a valid UUID" ) from None
171+ except NotFoundError :
172+ raise NotFoundError (f"No such datapoint { datapoint_id } " ) from None
173+
174+ message_type = get_message_type (datapoint .type_url )
175+ data = message_type .FromString (datapoint .value )
176+
177+ converter = MessageToXarrayConverter (initial_capacity = 1 )
178+ converter .convert (data )
179+ return converter .finalize ("time" , skip_empty_fields = skip_data ).isel (time = 0 )
180+
181+ async def query (
182+ self ,
183+ * ,
184+ collections : "list[str] | list[UUID] | list[Collection] | list[CollectionInfo] | list[CollectionClient] | dict[str, CollectionClient] | None" ,
185+ temporal_extent : TimeIntervalLike ,
186+ spatial_extent : SpatialFilterLike | None = None ,
187+ skip_data : bool = False ,
188+ show_progress : bool | ProgressCallback = False ,
189+ ) -> xr .Dataset :
190+ """
191+ Query datapoints in the specified collections and temporal extent.
192+
193+ Args:
194+ collections: The collections to query in. Supports collection names, ids or collection objects.
195+ If not specified, all collections in the dataset are queried.
196+ temporal_extent: The temporal extent to query data for. (Required)
197+ spatial_extent: The spatial extent to query data in. (Optional)
198+ skip_data: Whether to skip the actual data of the datapoint. If True, only
199+ datapoint metadata is returned.
200+ show_progress: Whether to show a progress bar while loading the data.
201+ If a callable is specified it is used as callback to report progress percentages.
202+
203+ Returns:
204+ Matching datapoints in the given temporal and spatial extent as an xarray dataset.
205+ """
206+ if temporal_extent is None :
207+ raise ValueError ("A temporal_extent for your query must be specified" )
208+
209+ collection_ids = await self ._collection_ids (collections )
210+ pages = _iter_query_pages (
211+ self ._service ,
212+ self ._dataset .id ,
213+ collection_ids ,
214+ temporal_extent ,
215+ spatial_extent ,
216+ skip_data ,
217+ dataset_name = self .name ,
218+ show_progress = show_progress ,
219+ )
220+ return await _convert_to_dataset (pages , skip_empty_fields = skip_data )
221+
222+ async def _collection_id (self , collection : "UUID | Collection | CollectionInfo | CollectionClient" ) -> UUID :
223+ if isinstance (collection , CollectionClient ):
224+ return collection ._collection .id
225+ if isinstance (collection , CollectionInfo ):
226+ return collection .collection .id
227+ if isinstance (collection , Collection ):
228+ return collection .id
229+ return collection
230+
231+ async def _collection_ids (
232+ self ,
233+ collections : "list[str] | list[UUID] | list[Collection] | list[CollectionInfo] | list[CollectionClient] | dict[str, CollectionClient] | None" ,
234+ ) -> list [UUID ]:
235+ if collections is None :
236+ return []
237+
238+ all_collections : list [CollectionInfo ] = await self ._service .get_collections (self ._dataset .id , True , True )
239+ # find all valid collection names and ids
240+ collections_by_name = {c .collection .name : c .collection .id for c in all_collections }
241+ valid_collection_ids = {c .collection .id for c in all_collections }
242+
243+ collection_ids : list [UUID ] = []
244+ for collection in collections :
245+ if isinstance (collection , str ):
246+ try :
247+ collection_ids .append (collections_by_name [collection ])
248+ except KeyError :
249+ raise ValueError (f"Collection { collection } not found in dataset { self .name } " ) from None
250+ else :
251+ collection_id = await self ._collection_id (collection )
252+ if collection_id not in valid_collection_ids :
253+ raise ValueError (f"Collection { collection_id } is not part of the dataset { self .name } " )
254+ collection_ids .append (collection_id )
255+
256+ return collection_ids
257+
142258 def __repr__ (self ) -> str :
143259 return f"{ self .name } [Timeseries Dataset]: { self ._dataset .summary } "
144260
@@ -221,7 +337,7 @@ async def find(self, datapoint_id: str | UUID, skip_data: bool = False) -> xr.Da
221337 """
222338 try :
223339 datapoint = await self ._dataset ._service .query_by_id (
224- [self ._collection .id ], as_uuid (datapoint_id ), skip_data
340+ self . _dataset . _dataset . id , [self ._collection .id ], as_uuid (datapoint_id ), skip_data
225341 )
226342 except ArgumentError :
227343 raise ValueError (f"Invalid datapoint id: { datapoint_id } is not a valid UUID" ) from None
@@ -259,8 +375,14 @@ async def _find_interval(
259375 filters = QueryFilters (temporal_extent = IDInterval .parse (datapoint_id_interval , end_inclusive = end_inclusive ))
260376
261377 async def request (page : PaginationProtocol ) -> QueryResultPage :
262- query_page = Pagination (page .limit , page .starting_after )
263- return await self ._dataset ._service .query ([self ._collection .id ], filters , skip_data , query_page )
378+ return await _query_page (
379+ self ._dataset ._service ,
380+ self ._dataset ._dataset .id ,
381+ [self ._collection .id ],
382+ filters ,
383+ skip_data ,
384+ page ,
385+ )
264386
265387 initial_page = Pagination ()
266388 pages = paginated_request (request , initial_page )
@@ -350,7 +472,16 @@ async def query(
350472 if temporal_extent is None :
351473 raise ValueError ("A temporal_extent for your query must be specified" )
352474
353- pages = self ._iter_pages (temporal_extent , spatial_extent , skip_data , show_progress = show_progress )
475+ pages = _iter_query_pages (
476+ self ._dataset ._service ,
477+ self ._dataset ._dataset .id ,
478+ [self ._collection .id ],
479+ temporal_extent ,
480+ spatial_extent ,
481+ skip_data ,
482+ dataset_name = self ._dataset .name ,
483+ show_progress = show_progress ,
484+ )
354485 return await _convert_to_dataset (pages , skip_empty_fields = skip_data )
355486
356487 async def _iter_pages (
@@ -361,29 +492,19 @@ async def _iter_pages(
361492 show_progress : bool | ProgressCallback = False ,
362493 page_size : int | None = None ,
363494 ) -> AsyncIterator [QueryResultPage ]:
364- time_interval = TimeInterval .parse (temporal_extent )
365- filters = QueryFilters (time_interval , SpatialFilter .parse (spatial_extent ) if spatial_extent else None )
366-
367- request = partial (self ._query_page , filters , skip_data )
368-
369- initial_page = Pagination (limit = page_size )
370- pages = paginated_request (request , initial_page )
371-
372- if callable (show_progress ):
373- pages = with_time_progress_callback (pages , time_interval , show_progress )
374- elif show_progress :
375- message = f"Fetching { self ._dataset .name } "
376- pages = with_time_progressbar (pages , time_interval , message )
377-
378- async for page in pages :
495+ async for page in _iter_query_pages (
496+ self ._dataset ._service ,
497+ self ._dataset ._dataset .id ,
498+ [self ._collection .id ],
499+ temporal_extent ,
500+ spatial_extent ,
501+ skip_data ,
502+ dataset_name = self ._dataset .name ,
503+ show_progress = show_progress ,
504+ page_size = page_size ,
505+ ):
379506 yield page
380507
381- async def _query_page (
382- self , filters : QueryFilters , skip_data : bool , page : PaginationProtocol | None = None
383- ) -> QueryResultPage :
384- query_page = Pagination (page .limit , page .starting_after ) if page else Pagination ()
385- return await self ._dataset ._service .query ([self ._collection .id ], filters , skip_data , query_page )
386-
387508 async def ingest (
388509 self ,
389510 data : IngestionData ,
@@ -477,6 +598,47 @@ async def delete(self, datapoints: DatapointIDs, *, show_progress: bool | Progre
477598 return num_deleted
478599
479600
601+ async def _query_page ( # noqa: PLR0913
602+ service : TileboxDatasetService ,
603+ dataset_id : UUID ,
604+ collection_ids : list [UUID ] | None ,
605+ filters : QueryFilters ,
606+ skip_data : bool ,
607+ page : PaginationProtocol | None = None ,
608+ ) -> QueryResultPage :
609+ query_page = Pagination (page .limit , page .starting_after ) if page else Pagination ()
610+ return await service .query (dataset_id , collection_ids or [], filters , skip_data , query_page )
611+
612+
613+ async def _iter_query_pages ( # noqa: PLR0913
614+ service : TileboxDatasetService ,
615+ dataset_id : UUID ,
616+ collection_ids : list [UUID ] | None ,
617+ temporal_extent : TimeIntervalLike ,
618+ spatial_extent : SpatialFilterLike | None = None ,
619+ skip_data : bool = False ,
620+ * ,
621+ dataset_name : str ,
622+ show_progress : bool | ProgressCallback = False ,
623+ page_size : int | None = None ,
624+ ) -> AsyncIterator [QueryResultPage ]:
625+ time_interval = TimeInterval .parse (temporal_extent )
626+ filters = QueryFilters (time_interval , SpatialFilter .parse (spatial_extent ) if spatial_extent else None )
627+
628+ request = partial (_query_page , service , dataset_id , collection_ids , filters , skip_data )
629+
630+ initial_page = Pagination (limit = page_size )
631+ pages = paginated_request (request , initial_page )
632+
633+ if callable (show_progress ):
634+ pages = with_time_progress_callback (pages , time_interval , show_progress )
635+ elif show_progress :
636+ pages = with_time_progressbar (pages , time_interval , f"Fetching { dataset_name } " )
637+
638+ async for page in pages :
639+ yield page
640+
641+
480642async def _convert_to_dataset (pages : AsyncIterator [QueryResultPage ], skip_empty_fields : bool = False ) -> xr .Dataset :
481643 """
482644 Convert an async iterator of QueryResultPages into a single xarray Dataset
0 commit comments