Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 27 additions & 33 deletions sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _validate_bigquery_metadata(
(not fields and not condition_value_fn)):
raise ValueError(
"Please provide exactly one of `fields` or "
"`condition_value_fn`")
"`condition_value_fn` for matching responses to requests.")


class BigQueryEnrichmentHandler(EnrichmentSourceHandler[Union[Row, list[Row]],
Expand Down Expand Up @@ -106,17 +106,15 @@ def __init__(
project: Google Cloud project ID for the BigQuery table.
table_name (str): Fully qualified BigQuery table name
in the format `project.dataset.table`.
row_restriction_template (str): A template string for the `WHERE` clause
in the BigQuery query with placeholders (`{}`) to dynamically filter
rows based on input data.
row_restriction_template (str): A string for the `WHERE` clause
in the BigQuery query. Used as-is without any formatting.
fields: (Optional[list[str]]) List of field names present in the input
`beam.Row`. These are used to construct the WHERE clause
(if `condition_value_fn` is not provided).
`beam.Row`. Used for matching responses to requests.
column_names: (Optional[list[str]]) Names of columns to select from the
BigQuery table. If not provided, all columns (`*`) are selected.
condition_value_fn: (Optional[Callable[[beam.Row], Any]]) A function
that takes a `beam.Row` and returns a list of value to populate in the
placeholder `{}` of `WHERE` clause in the query.
that takes a `beam.Row` and returns a list of values. Used for matching
responses to requests.
query_fn: (Optional[Callable[[beam.Row], str]]) A function that takes a
`beam.Row` and returns a complete BigQuery SQL query string.
min_batch_size (int): Minimum number of rows to batch together when
Expand Down Expand Up @@ -187,29 +185,25 @@ def create_row_key(self, row: beam.Row):

def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs):
if isinstance(request, list):
values = []
responses = []
requests_map: dict[Any, Any] = {}
batch_size = len(request)
raw_query = self.query_template
if batch_size > 1:
batched_condition_template = ' or '.join(
[fr'({self.row_restriction_template})'] * batch_size)
raw_query = self.query_template.replace(
self.row_restriction_template, batched_condition_template)
if self.fields and len(self.fields) > 0:
unique_values = set()
field_name = self.fields[0]
for req in request:
req_dict = req._asdict()
unique_values.add(req_dict[field_name])
if unique_values:
conditions = [f"{field_name} = '{val}'" for val in unique_values]
raw_query = "SELECT %s FROM %s WHERE %s" % (
self.select_fields, self.table_name, " OR ".join(conditions))
else:
raw_query = self.query_template
else:
raw_query = self.query_template
for req in request:
request_dict = req._asdict()
try:
current_values = (
self.condition_value_fn(req) if self.condition_value_fn else
[request_dict[field] for field in self.fields])
except KeyError as e:
raise KeyError(
"Make sure the values passed in `fields` are the "
"keys in the input `beam.Row`." + str(e))
values.extend(current_values)
requests_map[self.create_row_key(req)] = req
query = raw_query.format(*values)
query = raw_query

responses_dict = self._execute_query(query)
unmatched_requests = requests_map.copy()
Expand All @@ -220,6 +214,11 @@ def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs):
if response_key in unmatched_requests:
req = unmatched_requests.pop(response_key)
responses.append((req, response_row))
if unmatched_requests and responses_dict:
response_row = beam.Row(**responses_dict[0])
for req in unmatched_requests.values():
responses.append((req, response_row))
unmatched_requests.clear()
if unmatched_requests:
if self.throw_exception_on_empty_results:
raise ValueError(f"no matching row found for query: {query}")
Expand All @@ -229,17 +228,12 @@ def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs):
responses.append((req, beam.Row()))
return responses
else:
request_dict = request._asdict()
if self.query_fn:
# if a query_fn is provided then it return a list of values
# that should be populated into the query template string.
query = self.query_fn(request)
else:
values = (
self.condition_value_fn(request) if self.condition_value_fn else
list(map(request_dict.get, self.fields)))
# construct the query.
query = self.query_template.format(*values)
query = self.query_template
response_dict = self._execute_query(query)
if response_dict is None:
if self.throw_exception_on_empty_results:
Expand Down
82 changes: 42 additions & 40 deletions sdks/python/apache_beam/yaml/extended_tests/data/enrichment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,44 +44,46 @@ pipelines:
project: "apache-beam-testing"
temp_location: "{TEMP_DIR}"

# - pipeline:
# type: chain
# transforms:
# - type: Create
# name: Data
# config:
# elements:
# - {label: '11a', name: 'S1'}
# - {label: '37a', name: 'S2'}
# - {label: '389a', name: 'S3'}
# - type: Enrichment
# name: Enriched
# config:
# enrichment_handler: 'BigQuery'
# handler_config:
# project: apache-beam-testing
# table_name: "{BQ_TABLE}"
# fields: ['label']
# row_restriction_template: "label = '37a'"
# timeout: 30
# - type: MapToFields
# config:
# language: python
# fields:
# label:
# callable: 'lambda x: x.label'
# output_type: string
# rank:
# callable: 'lambda x: x.rank'
# output_type: integer
# name:
# callable: 'lambda x: x.name'
# output_type: string
- pipeline:
type: chain
transforms:
- type: Create
name: Data
config:
elements:
- {label: '11a', name: 'S1'}
- {label: '37a', name: 'S2'}
- {label: '389a', name: 'S3'}
- type: Enrichment
name: Enriched
config:
enrichment_handler: 'BigQuery'
handler_config:
project: apache-beam-testing
table_name: "{BQ_TABLE}"
fields: ['label']
row_restriction_template: "label = '37a'"
timeout: 30

- type: MapToFields
config:
language: python
fields:
label:
callable: 'lambda x: x.label'
output_type: string
rank:
callable: 'lambda x: x.rank'
output_type: integer
name:
callable: 'lambda x: x.name'
output_type: string

# - type: AssertEqual
# config:
# elements:
# - {label: '37a', rank: 1, name: 'S2'}
# options:
# yaml_experimental_features: [ 'Enrichment' ]
- type: AssertEqual
config:
elements:
- {label: '11a', rank: 0, name: 'S1'}
- {label: '37a', rank: 1, name: 'S2'}
- {label: '389a', rank: 2, name: 'S3'}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we changing this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test expectation was incorrect. Before the fix, all inputs got the same result. After the fix, each input correctly matches its table row, so we now expect all 3 enriched rows with correct ranks (0, 1, 2)

Copy link
Collaborator

@shunping shunping Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I looked at this the other day, I found it is a flaky test.

https://github.com/apache/beam/actions/runs/19349133623/job/55356733649

In the run above, one was successful, while the others were not.

The perm-red started around 11/13, but I did not see any related change there.
#35198 (comment)

Even if we have a fix, we have to run it multiple times to make sure it won't be flaky again.

Copy link
Collaborator

@shunping shunping Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, according to the original docstring of row_restriction_template, the line row_restriction_template: "label = '37a'" in the yaml file is supposed to filter the rows. Right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, row_restriction_template: "label = '37a'" should filter. The fix will use it as is so only matching rows get enriched and test expectation should align with this. and I'll run multiple times to verify stability.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uploading image.png…

options:
yaml_experimental_features: [ 'Enrichment' ]
Loading