Skip to content
1 change: 1 addition & 0 deletions sdks/python/apache_beam/runners/worker/operations.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ cdef class DoOperation(Operation):
cdef dict timer_specs
cdef public object input_info
cdef object fn
cdef readonly object scoped_timer_processing_state


cdef class SdfProcessSizedElements(DoOperation):
Expand Down
32 changes: 23 additions & 9 deletions sdks/python/apache_beam/runners/worker/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from apache_beam.runners.worker import opcounters
from apache_beam.runners.worker import operation_specs
from apache_beam.runners.worker import sideinputs
from apache_beam.runners.worker import statesampler
from apache_beam.runners.worker.data_sampler import DataSampler
from apache_beam.transforms import sideinputs as apache_sideinputs
from apache_beam.transforms import combiners
Expand Down Expand Up @@ -808,8 +809,14 @@ def __init__(
self.user_state_context = user_state_context
self.tagged_receivers = None # type: Optional[_TaggedReceivers]
# A mapping of timer tags to the input "PCollections" they come in on.
# Force clean rebuild
self.input_info = None # type: Optional[OpInputInfo]

self.scoped_timer_processing_state = statesampler.NOOP_SCOPED_STATE
if self.state_sampler:
self.scoped_timer_processing_state = self.state_sampler.scoped_state(
self.name_context,
'process-timers',
metrics_container=self.metrics_container)
# See fn_data in dataflow_runner.py
# TODO: Store all the items from spec?
self.fn, _, _, _, _ = (pickler.loads(self.spec.serialized_fn))
Expand Down Expand Up @@ -971,14 +978,21 @@ def add_timer_info(self, timer_family_id, timer_info):
self.user_state_context.add_timer_info(timer_family_id, timer_info)

def process_timer(self, tag, timer_data):
timer_spec = self.timer_specs[tag]
self.dofn_runner.process_user_timer(
timer_spec,
timer_data.user_key,
timer_data.windows[0],
timer_data.fire_timestamp,
timer_data.paneinfo,
timer_data.dynamic_timer_tag)
def process_timer_logic():
timer_spec = self.timer_specs[tag]
self.dofn_runner.process_user_timer(
timer_spec,
timer_data.user_key,
timer_data.windows[0],
timer_data.fire_timestamp,
timer_data.paneinfo,
timer_data.dynamic_timer_tag)

if self.scoped_timer_processing_state:
with self.scoped_timer_processing_state:
process_timer_logic()
else:
process_timer_logic()

def finish(self):
# type: () -> None
Expand Down
20 changes: 17 additions & 3 deletions sdks/python/apache_beam/runners/worker/statesampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def scoped_state(
name_context: Union[str, 'common.NameContext'],
state_name: str,
io_target=None,
metrics_container: Optional['MetricsContainer'] = None
) -> statesampler_impl.ScopedState:
metrics_container: Optional['MetricsContainer'] = None,
suffix: str = '-msecs') -> statesampler_impl.ScopedState:
"""Returns a ScopedState object associated to a Step and a State.

Args:
Expand All @@ -152,7 +152,7 @@ def scoped_state(
name_context = common.NameContext(name_context)

counter_name = CounterName(
state_name + '-msecs',
state_name + suffix,
stage_name=self._prefix,
step_name=name_context.metrics_name(),
io_target=io_target)
Expand All @@ -170,3 +170,17 @@ def commit_counters(self) -> None:
for state in self._states_by_name.values():
state_msecs = int(1e-6 * state.nsecs)
state.counter.update(state_msecs - state.counter.value())


class NoOpScopedState:
def __enter__(self):
pass

def __exit__(self, exc_type, exc_val, exc_tb):
pass

def sampled_msecs_int(self):
return 0


NOOP_SCOPED_STATE = NoOpScopedState()
116 changes: 116 additions & 0 deletions sdks/python/apache_beam/runners/worker/statesampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
from apache_beam.runners.worker import statesampler
from apache_beam.utils.counters import CounterFactory
from apache_beam.utils.counters import CounterName
from apache_beam.runners.worker import operation_specs
from apache_beam.runners.worker import operations
from apache_beam.internal import pickler
from apache_beam.transforms import core

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -127,6 +131,118 @@ def test_sampler_transition_overhead(self):
# debug mode).
self.assertLess(overhead_us, 20.0)

@retry(reraise=True, stop=stop_after_attempt(3))
def test_timer_sampler(self):
# Set up state sampler.
counter_factory = CounterFactory()
sampler = statesampler.StateSampler(
'timer', counter_factory, sampling_period_ms=1)

# Duration of the timer processing.
state_duration_ms = 100
margin_of_error = 0.25

sampler.start()
with sampler.scoped_state('step1', 'process-timers'):
time.sleep(state_duration_ms / 1000)
sampler.stop()
sampler.commit_counters()

if not statesampler.FAST_SAMPLER:
# The slow sampler does not implement sampling, so we won't test it.
return

# Test that sampled state timings are close to their expected values.
c = CounterName(
'process-timers-msecs', step_name='step1', stage_name='timer')
expected_counter_values = {
c: state_duration_ms,
}
for counter in counter_factory.get_counters():
self.assertIn(counter.name, expected_counter_values)
expected_value = expected_counter_values[counter.name]
actual_value = counter.value()
deviation = float(abs(actual_value - expected_value)) / expected_value
_LOGGER.info('Sampling deviation from expectation: %f', deviation)
self.assertGreater(actual_value, expected_value * (1.0 - margin_of_error))
self.assertLess(actual_value, expected_value * (1.0 + margin_of_error))

@retry(reraise=True, stop=stop_after_attempt(3))
def test_process_timers_metric_is_recorded(self):
"""
Tests that the 'process-timers-msecs' metric is correctly recorded
when a state sampler is active.
"""
# Set up a real state sampler and counter factory.
counter_factory = CounterFactory()
sampler = statesampler.StateSampler(
'test_stage', counter_factory, sampling_period_ms=1)

state_duration_ms = 100
margin_of_error = 0.25

# Run a workload inside the 'process-timers' scoped state.
sampler.start()
with sampler.scoped_state('test_step', 'process-timers'):
time.sleep(state_duration_ms / 1000.0)
sampler.stop()
sampler.commit_counters()

if not statesampler.FAST_SAMPLER:
return

# Verify that the counter was created with the correct name and value.
expected_counter_name = CounterName(
'process-timers-msecs', step_name='test_step', stage_name='test_stage')

# Find the specific counter we are looking for.
found_counter = None
for counter in counter_factory.get_counters():
if counter.name == expected_counter_name:
found_counter = counter
break

self.assertIsNotNone(
found_counter,
f"The expected counter '{expected_counter_name}' was not created.")

# Check that its value is approximately correct.
actual_value = found_counter.value()
expected_value = state_duration_ms
self.assertGreater(
actual_value,
expected_value * (1.0 - margin_of_error),
"The timer metric was lower than expected.")
self.assertLess(
actual_value,
expected_value * (1.0 + margin_of_error),
"The timer metric was higher than expected.")

def test_do_operation_with_sampler(self):
"""
Tests that a DoOperation with an active state_sampler correctly
creates a real ScopedState object for timer processing.
"""
mock_spec = operation_specs.WorkerDoFn(
serialized_fn=pickler.dumps((core.DoFn(), None, None, None, None)),
output_tags=[],
input=None,
side_inputs=[],
output_coders=[])

sampler = statesampler.StateSampler(
'test_stage', CounterFactory(), sampling_period_ms=1)

# 1. Create the operation WITHOUT the unexpected keyword argument
op = operations.create_operation(
name_context='test_op',
spec=mock_spec,
counter_factory=CounterFactory(),
state_sampler=sampler)

self.assertIsNot(
op.scoped_timer_processing_state, statesampler.NOOP_SCOPED_STATE)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
Loading