Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -147,19 +147,8 @@ public interface BigQueryAnomalyDetection {
regexes = {"^[a-zA-Z0-9_-]+:[a-zA-Z0-9_]+\\.[a-zA-Z0-9_]+$"})
String getSinkTable();

@TemplateParameter.Integer(
order = 13,
optional = true,
name = "decompress_shards",
description = "Decompress Shards",
helpText =
"Number of shards for CDC Arrow batch decompression fan-out. "
+ "Spreads decompression CPU across workers. "
+ "0 disables fan-out (decode inline). Default: 400.")
Integer getDecompressShards();

@TemplateParameter.Text(
order = 14,
order = 13,
optional = true,
name = "fanout_strategy",
description = "Fanout Strategy",
Expand All @@ -170,7 +159,7 @@ public interface BigQueryAnomalyDetection {
String getFanoutStrategy();

@TemplateParameter.Integer(
order = 15,
order = 14,
optional = true,
name = "fanout",
description = "Fanout Shards",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
import dataclasses
import datetime
import logging
import random
import sys
import time
import uuid
Expand All @@ -65,8 +64,6 @@
from apache_beam.io.watermark_estimators import ManualWatermarkEstimator
from apache_beam.metrics import Metrics
from apache_beam.transforms.core import WatermarkEstimatorProvider
from apache_beam.transforms import trigger as beam_trigger
from apache_beam.transforms.window import GlobalWindows
from apache_beam.transforms.window import TimestampedValue
from apache_beam.utils import retry
from apache_beam.utils.timestamp import MAX_TIMESTAMP
Expand Down Expand Up @@ -784,10 +781,14 @@ def _split_all_streams(self, stream_names: Tuple[str, ...],
rounds of doubling.
"""
result = list(stream_names)
no_split = set()
for round_num in range(1, max_split_rounds + 1):
new_result = []
made_progress = False
for name in result:
if name in no_split:
new_result.append(name)
continue
response = self._storage_client.split_read_stream(
request=bq_storage.types.SplitReadStreamRequest(
name=name, fraction=0.5))
Expand All @@ -798,6 +799,7 @@ def _split_all_streams(self, stream_names: Tuple[str, ...],
made_progress = True
else:
new_result.append(name)
no_split.add(name)
result = new_result
_LOGGER.info(
'[Read] _split_all_streams round %d/%d: %d streams '
Expand Down Expand Up @@ -1063,7 +1065,7 @@ def _read_stream_raw(
class _DecompressArrowBatchesFn(beam.DoFn):
"""Decompress and convert raw Arrow batches to timestamped row dicts.

Receives GBK output: (shard_key, Iterable[(schema_bytes, batch_bytes)])
Receives individual (schema_bytes, batch_bytes) tuples after Reshuffle
and converts each batch to individual row dicts with event timestamps
extracted from the change_timestamp column.
"""
Expand All @@ -1072,25 +1074,24 @@ def __init__(self, change_timestamp_column: str = 'change_timestamp') -> None:

def process(
self,
element: Tuple[int, Iterable[Tuple[bytes, bytes]]]
element: Tuple[bytes, bytes]
) -> Iterable[Dict[str, Any]]:
_, batches = element
for schema_bytes, batch_bytes in batches:
schema = pyarrow.ipc.read_schema(pyarrow.py_buffer(schema_bytes))
batch = pyarrow.ipc.read_record_batch(
pyarrow.py_buffer(batch_bytes), schema)

rows = batch.to_pylist()
for row in rows:
ts = row.get(self._change_timestamp_column)
if ts is None:
raise ValueError(
'Row missing %r column. Row keys: %s' %
(self._change_timestamp_column, list(row.keys())))
if isinstance(ts, datetime.datetime):
ts = Timestamp.from_utc_datetime(ts)
yield TimestampedValue(row, ts)
Metrics.counter('BigQueryChangeHistory', 'rows_emitted').inc(len(rows))
schema_bytes, batch_bytes = element
schema = pyarrow.ipc.read_schema(pyarrow.py_buffer(schema_bytes))
batch = pyarrow.ipc.read_record_batch(
pyarrow.py_buffer(batch_bytes), schema)

rows = batch.to_pylist()
for row in rows:
ts = row.get(self._change_timestamp_column)
if ts is None:
raise ValueError(
'Row missing %r column. Row keys: %s' %
(self._change_timestamp_column, list(row.keys())))
if isinstance(ts, datetime.datetime):
ts = Timestamp.from_utc_datetime(ts)
yield TimestampedValue(row, ts)
Metrics.counter('BigQueryChangeHistory', 'rows_emitted').inc(len(rows))


# =============================================================================
Expand Down Expand Up @@ -1215,12 +1216,12 @@ class ReadBigQueryChangeHistory(beam.PTransform):
1 (one round of splitting). Set 0 to disable splitting
entirely. Set higher for very large tables where more
parallelism is needed.
decompress_shards: If set to a positive integer, the Read SDF
emits raw compressed Arrow batches instead of decoded rows.
The batches are reshuffled for fan-out and then decoded in a
separate DoFn. This spreads decompression and Arrow-to-Python
conversion CPU across more workers. If None (default), rows
are decoded inline within the Read SDF.
reshuffle_decompress: If True (default), the Read SDF emits raw
compressed Arrow batches instead of decoded rows. The batches
are reshuffled for fan-out and then decoded in a separate DoFn.
This spreads decompression and Arrow-to-Python conversion CPU
across more workers. Set to False to decode rows inline within
the Read SDF.
"""
def __init__(
self,
Expand All @@ -1239,7 +1240,7 @@ def __init__(
row_filter: Optional[str] = None,
batch_arrow_read: bool = True,
max_split_rounds: int = 1,
decompress_shards: Optional[int] = None) -> None:
reshuffle_decompress: bool = True) -> None:
super().__init__()
if bq_storage is None:
raise ImportError(
Expand Down Expand Up @@ -1274,7 +1275,7 @@ def __init__(
self._row_filter = row_filter
self._batch_arrow_read = batch_arrow_read
self._max_split_rounds = max_split_rounds
self._decompress_shards = decompress_shards
self._reshuffle_decompress = reshuffle_decompress

def expand(self, pbegin: beam.pvalue.PBegin) -> beam.PCollection:
project = self._project
Expand Down Expand Up @@ -1354,7 +1355,7 @@ def expand(self, pbegin: beam.pvalue.PBegin) -> beam.PCollection:
row_filter=self._row_filter))
| 'CommitQueryResults' >> beam.Reshuffle())

emit_raw = self._decompress_shards is not None
emit_raw = self._reshuffle_decompress

read_sdf = beam.ParDo(
_ReadStorageStreamsSDF(
Expand All @@ -1377,22 +1378,11 @@ def expand(self, pbegin: beam.pvalue.PBegin) -> beam.PCollection:
| 'CleanupTempTables' >> beam.ParDo(_CleanupTempTablesFn()))

if emit_raw:
# Fan out raw Arrow batches across decompress_shards workers
# via GBK, then decompress and convert to timestamped row dicts.
# Uses a discarding trigger so GBK fires per-element without
# waiting for the GlobalWindow to close.
num_shards = self._decompress_shards
# Reshuffle raw Arrow batches for fan-out, then decompress and
# convert to timestamped row dicts in a separate DoFn.
rows = (
read_outputs['rows']
| 'ShardBatches' >> beam.WithKeys(
lambda _, n=num_shards: random.randint(0, n - 1))
| 'WindowForGBK' >> beam.WindowInto(
GlobalWindows(),
trigger=beam_trigger.Repeatedly(
beam_trigger.AfterCount(1)),
accumulation_mode=(
beam_trigger.AccumulationMode.DISCARDING))
| 'GroupByShardKey' >> beam.GroupByKey()
| 'ReshuffleForFanout' >> beam.Reshuffle()
| 'DecompressBatches' >> beam.ParDo(
_DecompressArrowBatchesFn(
change_timestamp_column=(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -533,13 +533,6 @@ def _add_argparse_args(cls, parser):
help='BigQuery table to write all anomaly detection results to. '
'Format: project:dataset.table. If unset, results are not written '
'to BigQuery.')
parser.add_argument(
'--decompress_shards',
type=int,
default=1200,
help='Number of shards for CDC Arrow batch decompression fan-out. '
'Spreads decompression CPU across workers. '
'0 disables fan-out (decode inline). Default: 1200.')
parser.add_argument(
'--fanout_strategy',
default='sharded',
Expand Down Expand Up @@ -996,10 +989,7 @@ def build_pipeline(pipeline, options, metric_spec, detector):
buffer_sec=options.buffer_sec,
columns=columns,
change_type_column=change_type_col,
change_timestamp_column=change_ts_col,
decompress_shards=(
options.decompress_shards if options.decompress_shards > 0
else None))
change_timestamp_column=change_ts_col)
if stop_time is not None:
cdc_kwargs['stop_time'] = stop_time
if options.temp_dataset:
Expand Down
Loading