Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def __init__(
max_batch_duration_secs: Optional[int] = None,
max_batch_weight: Optional[int] = None,
element_size_fn: Optional[Callable[[Any], int]] = None,
length_fn: Optional[Callable[[Any], int]] = None,
bucket_boundaries: Optional[list[int]] = None,
large_model: bool = False,
model_copies: Optional[int] = None,
**kwargs):
Expand All @@ -190,6 +192,11 @@ def __init__(
before emitting; used in streaming contexts.
max_batch_weight: the maximum weight of a batch. Requires element_size_fn.
element_size_fn: a function that returns the size (weight) of an element.
length_fn: a callable mapping an element to its length. When set with
max_batch_duration_secs, enables length-aware bucketed keying so
elements of similar length are batched together.
bucket_boundaries: sorted list of positive boundary values for length
bucketing. Requires length_fn.
large_model: set to true if your model is large enough to run into
memory pressure if you load multiple copies.
model_copies: The exact number of models that you would like loaded
Expand All @@ -209,6 +216,10 @@ def __init__(
self._batching_kwargs['max_batch_weight'] = max_batch_weight
if element_size_fn is not None:
self._batching_kwargs['element_size_fn'] = element_size_fn
if length_fn is not None:
self._batching_kwargs['length_fn'] = length_fn
if bucket_boundaries is not None:
self._batching_kwargs['bucket_boundaries'] = bucket_boundaries
self._large_model = large_model
self._model_copies = model_copies
self._share_across_processes = large_model or (model_copies is not None)
Expand Down
37 changes: 37 additions & 0 deletions sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2278,6 +2278,43 @@ def test_max_batch_duration_secs_only(self):

self.assertEqual(kwargs, {'max_batch_duration_secs': 60})

def test_length_fn_and_bucket_boundaries(self):
"""length_fn and bucket_boundaries are passed through to kwargs."""
handler = FakeModelHandlerForBatching(
length_fn=len, bucket_boundaries=[16, 32, 64])
kwargs = handler.batch_elements_kwargs()

self.assertIs(kwargs['length_fn'], len)
self.assertEqual(kwargs['bucket_boundaries'], [16, 32, 64])

def test_length_fn_only(self):
"""length_fn alone is passed through without bucket_boundaries."""
handler = FakeModelHandlerForBatching(length_fn=len)
kwargs = handler.batch_elements_kwargs()

self.assertIs(kwargs['length_fn'], len)
self.assertNotIn('bucket_boundaries', kwargs)

def test_bucket_boundaries_without_length_fn(self):
"""Passing bucket_boundaries without length_fn should fail in BatchElements.

Note: ModelHandler.__init__ doesn't validate this; the error is raised
by BatchElements when batch_elements_kwargs are used."""
handler = FakeModelHandlerForBatching(bucket_boundaries=[10, 20])
kwargs = handler.batch_elements_kwargs()
# The kwargs are stored, but BatchElements will reject them
self.assertEqual(kwargs['bucket_boundaries'], [10, 20])
self.assertNotIn('length_fn', kwargs)

def test_batching_kwargs_none_values_omitted(self):
"""None values for length_fn and bucket_boundaries are not in kwargs."""
handler = FakeModelHandlerForBatching(
min_batch_size=5, length_fn=None, bucket_boundaries=None)
kwargs = handler.batch_elements_kwargs()
self.assertNotIn('length_fn', kwargs)
self.assertNotIn('bucket_boundaries', kwargs)
self.assertEqual(kwargs['min_batch_size'], 5)


class SimpleFakeModelHandler(base.ModelHandler[int, int, FakeModel]):
def load_model(self):
Expand Down
58 changes: 56 additions & 2 deletions sdks/python/apache_beam/transforms/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

# pytype: skip-file

import bisect
import collections
import contextlib
import hashlib
Expand Down Expand Up @@ -1209,6 +1210,28 @@ def process(self, element):
yield (self.key, element)


class WithLengthBucketKey(DoFn):
"""Keys elements with (worker_uuid, length_bucket) for length-aware
stateful batching. Elements of similar length are routed to the same
state partition, reducing padding waste."""
def __init__(self, length_fn, bucket_boundaries):
self.shared_handle = shared.Shared()
self._length_fn = length_fn
self._bucket_boundaries = bucket_boundaries

def setup(self):
self.key = self.shared_handle.acquire(
load_shared_key, "WithLengthBucketKey").key

def _get_bucket(self, length):
return bisect.bisect_left(self._bucket_boundaries, length)

def process(self, element):
length = self._length_fn(element)
bucket = self._get_bucket(length)
yield ((self.key, bucket), element)


@typehints.with_input_types(T)
@typehints.with_output_types(list[T])
class BatchElements(PTransform):
Expand Down Expand Up @@ -1268,7 +1291,18 @@ class BatchElements(PTransform):
donwstream operations (mostly for testing)
record_metrics: (optional) whether or not to record beam metrics on
distributions of the batch size. Defaults to True.
length_fn: (optional) a callable mapping an element to its length (int).
When set together with max_batch_duration_secs, enables length-aware
bucketed keying on the stateful path so that elements of similar length
are routed to the same batch, reducing padding waste.
bucket_boundaries: (optional) a sorted list of positive boundary values
for length bucketing. Elements with length < boundaries[i] go to
bucket i; overflow goes to bucket len(boundaries). Defaults to
[16, 32, 64, 128, 256, 512] when length_fn is set. Requires
length_fn.
"""
_DEFAULT_BUCKET_BOUNDARIES = [16, 32, 64, 128, 256, 512]

def __init__(
self,
min_batch_size=1,
Expand All @@ -1281,7 +1315,17 @@ def __init__(
element_size_fn=lambda x: 1,
variance=0.25,
clock=time.time,
record_metrics=True):
record_metrics=True,
length_fn=None,
bucket_boundaries=None):
if bucket_boundaries is not None and length_fn is None:
raise ValueError('bucket_boundaries requires length_fn to be set.')
if bucket_boundaries is not None:
if (not bucket_boundaries or any(b <= 0 for b in bucket_boundaries) or
bucket_boundaries != sorted(bucket_boundaries)):
raise ValueError(
'bucket_boundaries must be a non-empty sorted list of '
'positive values.')
self._batch_size_estimator = _BatchSizeEstimator(
min_batch_size=min_batch_size,
max_batch_size=max_batch_size,
Expand All @@ -1295,13 +1339,23 @@ def __init__(
self._element_size_fn = element_size_fn
self._max_batch_dur = max_batch_duration_secs
self._clock = clock
self._length_fn = length_fn
if length_fn is not None and bucket_boundaries is None:
self._bucket_boundaries = self._DEFAULT_BUCKET_BOUNDARIES
else:
self._bucket_boundaries = bucket_boundaries

def expand(self, pcoll):
if getattr(pcoll.pipeline.runner, 'is_streaming', False):
raise NotImplementedError("Requires stateful processing (BEAM-2687)")
elif self._max_batch_dur is not None:
coder = coders.registry.get_coder(pcoll)
return pcoll | ParDo(WithSharedKey()) | ParDo(
if self._length_fn is not None:
keying_dofn = WithLengthBucketKey(
self._length_fn, self._bucket_boundaries)
else:
keying_dofn = WithSharedKey()
return pcoll | ParDo(keying_dofn) | ParDo(
_pardo_stateful_batch_elements(
coder,
self._batch_size_estimator,
Expand Down
Loading
Loading