2525from itertools import islice
2626import logging
2727import queue
28+ import threading
2829import warnings
2930from typing import Any , Union , Optional , Callable , Generator , List
3031
@@ -134,6 +135,21 @@ def __init__(self):
134135 # be an atomic operation in the Python language definition (enforced by
135136 # the global interpreter lock).
136137 self .done = False
138+ # To assist with testing and understanding the behavior of the
139+ # download, use this object as shared state to track how many worker
140+ # threads have started and have gracefully shutdown.
141+ self ._started_workers_lock = threading .Lock ()
142+ self .started_workers = 0
143+ self ._finished_workers_lock = threading .Lock ()
144+ self .finished_workers = 0
145+
146+ def start (self ):
147+ with self ._started_workers_lock :
148+ self .started_workers += 1
149+
150+ def finish (self ):
151+ with self ._finished_workers_lock :
152+ self .finished_workers += 1
137153
138154
139155BQ_FIELD_TYPE_TO_ARROW_FIELD_METADATA = {
@@ -819,25 +835,35 @@ def _bqstorage_page_to_dataframe(column_names, dtypes, page):
819835def _download_table_bqstorage_stream (
820836 download_state , bqstorage_client , session , stream , worker_queue , page_to_item
821837):
822- reader = bqstorage_client .read_rows (stream .name )
838+ download_state .start ()
839+ try :
840+ reader = bqstorage_client .read_rows (stream .name )
823841
824- # Avoid deprecation warnings for passing in unnecessary read session.
825- # https://github.com/googleapis/python-bigquery-storage/issues/229
826- if _versions_helpers .BQ_STORAGE_VERSIONS .is_read_session_optional :
827- rowstream = reader .rows ()
828- else :
829- rowstream = reader .rows (session )
830-
831- for page in rowstream .pages :
832- item = page_to_item (page )
833- while True :
834- if download_state .done :
835- return
836- try :
837- worker_queue .put (item , timeout = _PROGRESS_INTERVAL )
838- break
839- except queue .Full : # pragma: NO COVER
840- continue
842+ # Avoid deprecation warnings for passing in unnecessary read session.
843+ # https://github.com/googleapis/python-bigquery-storage/issues/229
844+ if _versions_helpers .BQ_STORAGE_VERSIONS .is_read_session_optional :
845+ rowstream = reader .rows ()
846+ else :
847+ rowstream = reader .rows (session )
848+
849+ for page in rowstream .pages :
850+ item = page_to_item (page )
851+
852+ # Make sure we set a timeout on put() so that we give the worker
853+ # thread opportunities to shutdown gracefully, for example if the
854+ # parent thread shuts down or the parent generator object which
855+ # collects rows from all workers goes out of scope. See:
856+ # https://github.com/googleapis/python-bigquery/issues/2032
857+ while True :
858+ if download_state .done :
859+ return
860+ try :
861+ worker_queue .put (item , timeout = _PROGRESS_INTERVAL )
862+ break
863+ except queue .Full :
864+ continue
865+ finally :
866+ download_state .finish ()
841867
842868
843869def _nowait (futures ):
@@ -863,6 +889,7 @@ def _download_table_bqstorage(
863889 page_to_item : Optional [Callable ] = None ,
864890 max_queue_size : Any = _MAX_QUEUE_SIZE_DEFAULT ,
865891 max_stream_count : Optional [int ] = None ,
892+ download_state : Optional [_DownloadState ] = None ,
866893) -> Generator [Any , None , None ]:
867894 """Downloads a BigQuery table using the BigQuery Storage API.
868895
@@ -890,6 +917,9 @@ def _download_table_bqstorage(
890917 is True, the requested streams are limited to 1 regardless of the
891918 `max_stream_count` value. If 0 or None, then the number of
892919 requested streams will be unbounded. Defaults to None.
920+ download_state (Optional[_DownloadState]):
921+ A threadsafe state object which can be used to observe the
922+ behavior of the worker threads created by this method.
893923
894924 Yields:
895925 pandas.DataFrame: Pandas DataFrames, one for each chunk of data
@@ -948,7 +978,8 @@ def _download_table_bqstorage(
948978
949979 # Use _DownloadState to notify worker threads when to quit.
950980 # See: https://stackoverflow.com/a/29237343/101923
951- download_state = _DownloadState ()
981+ if download_state is None :
982+ download_state = _DownloadState ()
952983
953984 # Create a queue to collect frames as they are created in each thread.
954985 #
0 commit comments