Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,15 @@ def __init__(
self._last_emission_time: float = 0.0
self._timer = Timer()

# Partitioned stream status tracking for progress estimation.
# These counters are per-sync only and intentionally NOT restored from persisted state
# (_set_initial_state does not read them back). On resume, they reset to 0.
self._num_partitions_completed: int = 0
self._is_partition_discovery_complete: bool = False
# Tracks partition keys for which observe() has been called (worker produced at least one record).
# Only len() is used in state emission; the set itself is never serialized.
self._partitions_observed: set[str] = set()

self._set_initial_state(stream_state)

# FIXME this is a temporary field the time of the migration from declarative cursors to concurrent ones
Expand Down Expand Up @@ -217,6 +226,13 @@ def state(self) -> MutableMapping[str, Any]:
state["lookback_window"] = self._lookback_window
if self._parent_state is not None:
state["parent_state"] = self._parent_state
num_observed = len(self._partitions_observed)
state["partitioned_stream_status"] = {
"num_partitions_in_progress": max(0, num_observed - self._num_partitions_completed),
"num_partitions_completed": self._num_partitions_completed,
"num_partitions_expected": self._generated_partitions_count,
"is_partition_discovery_complete": self._is_partition_discovery_complete,
}
Comment thread
aaronsteers marked this conversation as resolved.
Comment thread
aaronsteers marked this conversation as resolved.
return state

def close_partition(self, partition: Partition) -> None:
Expand Down Expand Up @@ -322,6 +338,8 @@ def stream_slices(self) -> Iterable[StreamSlice]:
slices, self._partition_router.get_stream_state
):
yield from self._generate_slices_from_partition(partition, parent_state)
with self._lock:
self._is_partition_discovery_complete = True

def _generate_slices_from_partition(
self, partition: StreamSlice, parent_state: Mapping[str, Any]
Expand Down Expand Up @@ -537,11 +555,11 @@ def observe(self, record: Record) -> None:
return

self._synced_some_data = True
partition_key = self._to_partition_key(record.associated_slice.partition)
self._partitions_observed.add(partition_key)
self._update_global_cursor(record_cursor)
if not self._use_global_cursor:
self._cursor_per_partition[
self._to_partition_key(record.associated_slice.partition)
].observe(record)
self._cursor_per_partition[partition_key].observe(record)

def _update_global_cursor(self, value: Any) -> None:
if (
Expand All @@ -566,6 +584,9 @@ def _cleanup_if_done(self, partition_key: str) -> None:

seq = self._partition_key_to_index.pop(partition_key)
self._processing_partitions_indexes.remove(seq)
# Ensure completed partitions are counted as observed (handles partitions with no records)
self._partitions_observed.add(partition_key)
self._num_partitions_completed += 1

logger.debug(f"Partition {partition_key} fully processed and cleaned up.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,20 @@
import requests_mock


def _strip_partitioned_stream_status(state_dict: dict) -> dict:
"""Recursively strip partitioned_stream_status from state dicts in-place, including nested parent_state.

Mutates and returns the same dict. Callers that need to preserve the original should pass a copy.
Only traverses parent_state nesting; extend if the emitted shape gains more nesting layers.
"""
state_dict.pop("partitioned_stream_status", None)
if "parent_state" in state_dict and isinstance(state_dict["parent_state"], dict):
for value in state_dict["parent_state"].values():
if isinstance(value, dict):
_strip_partitioned_stream_status(value)
return state_dict


def run_mocked_test(
mock_requests,
manifest,
Expand Down Expand Up @@ -380,7 +394,25 @@ def run_mocked_test(

# Verify state
final_state = output.state_messages[-1].state.stream.stream_state
assert final_state.__dict__ == expected_state
final_state_dict = final_state.__dict__
# Validate partitioned_stream_status exists and has correct shape, then remove for comparison
partitioned_status = final_state_dict.get("partitioned_stream_status")
assert partitioned_status is not None, (
"partitioned_stream_status must always be present in state"
)
assert "num_partitions_in_progress" in partitioned_status
assert "num_partitions_completed" in partitioned_status
assert "num_partitions_expected" in partitioned_status
assert "is_partition_discovery_complete" in partitioned_status
assert partitioned_status["num_partitions_in_progress"] >= 0
assert partitioned_status["num_partitions_completed"] >= 0
assert (
partitioned_status["num_partitions_expected"]
>= partitioned_status["num_partitions_in_progress"]
+ partitioned_status["num_partitions_completed"]
)
_strip_partitioned_stream_status(final_state_dict)
assert final_state_dict == expected_state
Comment thread
coderabbitai[bot] marked this conversation as resolved.

# Verify that each request was made exactly once
for url, _ in mock_requests:
Expand Down Expand Up @@ -1107,7 +1139,9 @@ def run_incremental_parent_state_test(
final_states = [] # To store the final state after each read

# Store the final state after the initial read
final_states.append(output.state_messages[-1].state.stream.stream_state.__dict__)
initial_final_state = output.state_messages[-1].state.stream.stream_state.__dict__.copy()
_strip_partitioned_stream_status(initial_final_state)
final_states.append(initial_final_state)

for message in output.records_and_state_messages:
if message.type.value == "RECORD":
Expand All @@ -1122,10 +1156,11 @@ def run_incremental_parent_state_test(
# Assert that the number of intermediate states is as expected
assert len(intermediate_states) - 1 == num_intermediate_states
# Assert that ensure_at_least_one_state_emitted is called before yielding the last record from the last slice
assert (
intermediate_states[-1][0].stream.stream_state.__dict__["parent_state"]
== intermediate_states[-2][0].stream.stream_state.__dict__["parent_state"]
)
last_state_dict = intermediate_states[-1][0].stream.stream_state.__dict__.copy()
_strip_partitioned_stream_status(last_state_dict)
prev_state_dict = intermediate_states[-2][0].stream.stream_state.__dict__.copy()
_strip_partitioned_stream_status(prev_state_dict)
assert last_state_dict["parent_state"] == prev_state_dict["parent_state"]

# For each intermediate state, perform another read starting from that state
for state, records_before_state in intermediate_states[:-1]:
Expand All @@ -1151,10 +1186,11 @@ def run_incremental_parent_state_test(
)

# Store the final state after each intermediate read
final_state_intermediate = [
message.state.stream.stream_state.__dict__
for message in output_intermediate.state_messages
]
final_state_intermediate = []
for message in output_intermediate.state_messages:
state_dict = message.state.stream.stream_state.__dict__.copy()
_strip_partitioned_stream_status(state_dict)
final_state_intermediate.append(state_dict)
final_states.append(final_state_intermediate[-1])

# Assert that the final state matches the expected state for all runs
Expand Down Expand Up @@ -3654,7 +3690,20 @@ def test_given_no_partitions_processed_when_close_partition_then_no_state_update
)
)

assert cursor.state == {
state = cursor.state
partitioned_status = state.pop("partitioned_stream_status", None)
assert partitioned_status is not None
assert partitioned_status["num_partitions_in_progress"] == 0
assert partitioned_status["num_partitions_completed"] == 0
assert partitioned_status["num_partitions_expected"] == 0
assert partitioned_status["is_partition_discovery_complete"] is True
# Invariant: in_progress + completed <= expected
assert (
partitioned_status["num_partitions_in_progress"]
+ partitioned_status["num_partitions_completed"]
<= partitioned_status["num_partitions_expected"]
)
assert state == {
"use_global_cursor": False,
"lookback_window": 0,
"states": [],
Expand Down Expand Up @@ -3742,6 +3791,13 @@ def test_given_unfinished_first_parent_partition_no_parent_state_update():
cursor.ensure_at_least_one_state_emitted()

state = cursor.state
partitioned_status = state.pop("partitioned_stream_status", None)
assert partitioned_status is not None
# observe() not called in this test, so in_progress comes only from _cleanup_if_done adding to observed
assert partitioned_status["num_partitions_in_progress"] == 0
assert partitioned_status["num_partitions_completed"] == 1
assert partitioned_status["num_partitions_expected"] == 2
assert partitioned_status["is_partition_discovery_complete"] is True
assert state == {
"use_global_cursor": False,
"states": [
Expand Down Expand Up @@ -3838,6 +3894,13 @@ def test_given_unfinished_last_parent_partition_with_partial_parent_state_update
cursor.ensure_at_least_one_state_emitted()

state = cursor.state
partitioned_status = state.pop("partitioned_stream_status", None)
assert partitioned_status is not None
# observe() not called in this test, so in_progress comes only from _cleanup_if_done adding to observed
assert partitioned_status["num_partitions_in_progress"] == 0
assert partitioned_status["num_partitions_completed"] == 1
assert partitioned_status["num_partitions_expected"] == 2
assert partitioned_status["is_partition_discovery_complete"] is True
assert state == {
"use_global_cursor": False,
"states": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,12 @@ def test_given_record_for_partition_when_read_then_update_state(caplog):
"cursor": {CURSOR_FIELD: "2022-01-01"},
},
],
"partitioned_stream_status": {
"num_partitions_in_progress": 0,
"num_partitions_completed": 2,
"num_partitions_expected": 2,
"is_partition_discovery_complete": True,
},
}


Expand Down Expand Up @@ -581,6 +587,12 @@ def test_perpartition_with_fallback(caplog):
"use_global_cursor": True,
"state": {"cursor_field": "2022-02-19"},
"lookback_window": 1,
"partitioned_stream_status": {
"num_partitions_in_progress": 0,
"num_partitions_completed": 6,
"num_partitions_expected": 6,
"is_partition_discovery_complete": True,
},
}


Expand Down Expand Up @@ -763,6 +775,12 @@ def test_per_partition_cursor_within_limit(caplog):
"cursor": {CURSOR_FIELD: "2022-03-29"},
},
],
"partitioned_stream_status": {
"num_partitions_in_progress": 0,
"num_partitions_completed": 3,
"num_partitions_expected": 3,
"is_partition_discovery_complete": True,
},
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1327,7 +1327,9 @@ def test_stream_with_incremental_and_async_retriever_with_partition_router(use_l
assert isinstance(retriever, AsyncRetriever)
stream_slicer = retriever.stream_slicer.stream_slicer
assert isinstance(stream_slicer, ConcurrentPerPartitionCursor)
assert stream_slicer.state == stream_state
actual_state = stream_slicer.state
actual_state.pop("partitioned_stream_status", None)
assert actual_state == stream_state
import json

cursor_perpartition = stream_slicer._cursor_per_partition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@
InMemoryPartition,
)


def _strip_partitioned_stream_status(state_dict: dict) -> dict:
"""Recursively strip partitioned_stream_status from state dicts (mutates in place)."""
state_dict.pop("partitioned_stream_status", None)
for value in state_dict.values():
if isinstance(value, dict):
_strip_partitioned_stream_status(value)
return state_dict


parent_records = [{"id": 1, "data": "data1"}, {"id": 2, "data": "data2"}]
more_records = [
{"id": 10, "data": "data10", "slice": "second_parent"},
Expand Down Expand Up @@ -639,6 +649,7 @@ def test_substream_slicer_parent_state_update_with_cursor(parent_stream_config,

# Check if the parent state has been updated correctly
parent_state = partition_router.get_stream_state()
_strip_partitioned_stream_status(parent_state)
assert parent_state == expected_state


Expand Down
Loading