Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 31 additions & 7 deletions sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from collections.abc import Callable
from collections.abc import Mapping
from typing import Any
Expand All @@ -30,6 +31,8 @@
QueryFn = Callable[[beam.Row], str]
ConditionValueFn = Callable[[beam.Row], list[Any]]

_LOGGER = logging.getLogger(__name__)


def _validate_bigquery_metadata(
table_name, row_restriction_template, fields, condition_value_fn, query_fn):
Expand Down Expand Up @@ -87,6 +90,7 @@ def __init__(
query_fn: Optional[QueryFn] = None,
min_batch_size: int = 1,
max_batch_size: int = 10000,
throw_exception_on_empty_results: bool = True,
**kwargs,
):
"""
Expand Down Expand Up @@ -145,6 +149,7 @@ def __init__(
self.query_template = (
"SELECT %s FROM %s WHERE %s" %
(self.select_fields, self.table_name, self.row_restriction_template))
self.throw_exception_on_empty_results = throw_exception_on_empty_results
self.kwargs = kwargs
self._batching_kwargs = {}
if not query_fn:
Expand All @@ -157,10 +162,13 @@ def __enter__(self):
def _execute_query(self, query: str):
try:
results = self.client.query(query=query).result()
row_list = [dict(row.items()) for row in results]
if not row_list:
return None
if self._batching_kwargs:
return [dict(row.items()) for row in results]
return row_list
else:
return [dict(row.items()) for row in results][0]
return row_list[0]
except BadRequest as e:
raise BadRequest(
f'Could not execute the query: {query}. Please check if '
Expand Down Expand Up @@ -204,11 +212,21 @@ def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs):
query = raw_query.format(*values)

responses_dict = self._execute_query(query)
for response in responses_dict:
response_row = beam.Row(**response)
response_key = self.create_row_key(response_row)
if response_key in requests_map:
responses.append((requests_map[response_key], response_row))
unmatched_requests = requests_map.copy()
if responses_dict:
for response in responses_dict:
response_row = beam.Row(**response)
response_key = self.create_row_key(response_row)
if response_key in unmatched_requests:
req = unmatched_requests.pop(response_key)
responses.append((req, response_row))
if unmatched_requests:
if self.throw_exception_on_empty_results:
raise ValueError(f"no matching row found for query: {query}")
else:
_LOGGER.warning('no matching row found for query: %s', query)
for req in unmatched_requests.values():
responses.append((req, beam.Row()))
return responses
else:
request_dict = request._asdict()
Expand All @@ -223,6 +241,12 @@ def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs):
# construct the query.
query = self.query_template.format(*values)
response_dict = self._execute_query(query)
if response_dict is None:
if self.throw_exception_on_empty_results:
raise ValueError(f"no matching row found for query: {query}")
else:
_LOGGER.warning('no matching row found for query: %s', query)
return request, beam.Row()
return request, beam.Row(**response_dict)

def __exit__(self, exc_type, exc_val, exc_tb):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,147 @@ def test_bigquery_enrichment_with_redis(self):
assert_that(pcoll_cached, equal_to(expected_rows))
BigQueryEnrichmentHandler.__call__ = actual

def test_bigquery_enrichment_no_results_throws_exception(self):
requests = [
beam.Row(id=999, name='X'), # This ID does not exist
]
handler = BigQueryEnrichmentHandler(
project=self.project,
row_restriction_template="id = {}",
table_name=self.table_name,
fields=['id'],
throw_exception_on_empty_results=True,
)

with self.assertRaisesRegex(ValueError, "no matching row found for query"):
with TestPipeline(is_integration_test=True) as test_pipeline:
_ = (test_pipeline | beam.Create(requests) | Enrichment(handler))

def test_bigquery_enrichment_no_results_graceful(self):
requests = [
beam.Row(id=999, name='X'), # This ID does not exist
beam.Row(id=1000, name='Y'), # This ID does not exist
]
# When no results are found and not throwing, Enrichment yields original.
expected_rows = requests

handler = BigQueryEnrichmentHandler(
project=self.project,
row_restriction_template="id = {}",
table_name=self.table_name,
fields=['id'],
min_batch_size=1,
max_batch_size=100,
throw_exception_on_empty_results=False,
)

with TestPipeline(is_integration_test=True) as test_pipeline:
pcoll = (test_pipeline | beam.Create(requests) | Enrichment(handler))
assert_that(pcoll, equal_to(expected_rows))

def test_bigquery_enrichment_no_results_partial_graceful_batched(self):
requests = [
beam.Row(id=1, name='A'), # This ID exists
beam.Row(id=1000, name='Y'), # This ID does not exist
]
# When no results are found and not throwing, Enrichment yields original.
expected_rows = [
beam.Row(id=1, name='A', quantity=2, distribution_center_id=3),
beam.Row(id=1000,
name='Y'), # This ID does not exist so remains unchanged
]

handler = BigQueryEnrichmentHandler(
project=self.project,
row_restriction_template="id = {}",
table_name=self.table_name,
fields=['id'],
min_batch_size=2,
max_batch_size=100,
throw_exception_on_empty_results=False,
)

with TestPipeline(is_integration_test=True) as test_pipeline:
pcoll = (test_pipeline | beam.Create(requests) | Enrichment(handler))
assert_that(pcoll, equal_to(expected_rows))

def test_bigquery_enrichment_no_results_graceful_batched(self):
requests = [
beam.Row(id=999, name='X'), # This ID does not exist
beam.Row(id=1000, name='Y'), # This ID does not exist
]
# When no results are found and not throwing, Enrichment yields original.
expected_rows = requests

handler = BigQueryEnrichmentHandler(
project=self.project,
row_restriction_template="id = {}",
table_name=self.table_name,
fields=['id'],
min_batch_size=2,
max_batch_size=100,
throw_exception_on_empty_results=False,
)

with TestPipeline(is_integration_test=True) as test_pipeline:
pcoll = (test_pipeline | beam.Create(requests) | Enrichment(handler))
assert_that(pcoll, equal_to(expected_rows))

def test_bigquery_enrichment_no_results_with_query_fn_throws_exception(self):
requests = [
beam.Row(id=999, name='X'), # This ID does not exist
]
# This query_fn will return no results
fn = functools.partial(query_fn, self.table_name)
handler = BigQueryEnrichmentHandler(
project=self.project,
query_fn=fn,
throw_exception_on_empty_results=True,
)

with self.assertRaisesRegex(ValueError, "no matching row found for query"):
with TestPipeline(is_integration_test=True) as test_pipeline:
_ = (test_pipeline | beam.Create(requests) | Enrichment(handler))

def test_bigquery_enrichment_no_results_with_query_fn_graceful(self):
requests = [
beam.Row(id=999, name='X'), # This ID does not exist
beam.Row(id=1000, name='Y'), # This ID does not exist
]
# When no results are found and not throwing, Enrichment yields original.
expected_rows = requests

# This query_fn will return no results
fn = functools.partial(query_fn, self.table_name)
handler = BigQueryEnrichmentHandler(
project=self.project,
query_fn=fn,
throw_exception_on_empty_results=False,
)

with TestPipeline(is_integration_test=True) as test_pipeline:
pcoll = (test_pipeline | beam.Create(requests) | Enrichment(handler))
assert_that(pcoll, equal_to(expected_rows))

def test_bigquery_enrichment_partial_results_throws_exception_batched(self):
requests = [
beam.Row(id=1, name='A'), # This ID exists
beam.Row(id=1000, name='Y'), # This ID does not exist
]
handler = BigQueryEnrichmentHandler(
project=self.project,
row_restriction_template="id = {}",
table_name=self.table_name,
fields=['id'],
min_batch_size=2,
max_batch_size=100,
throw_exception_on_empty_results=True,
)

with self.assertRaisesRegex(ValueError, "no matching row found for query"):
with TestPipeline(is_integration_test=True) as test_pipeline:
_ = (test_pipeline | beam.Create(requests) | Enrichment(handler))


if __name__ == '__main__':
unittest.main()
Loading