Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
slice_logger: SliceLogger,
message_repository: MessageRepository,
partition_reader: PartitionReader,
max_concurrent_partition_generators: Optional[int] = None,
):
"""
This class is responsible for handling items from a concurrent stream read process.
Expand All @@ -55,15 +56,29 @@ def __init__(
:param slice_logger: SliceLogger instance
:param message_repository: MessageRepository instance
:param partition_reader: PartitionReader instance
:param max_concurrent_partition_generators: Maximum number of partition generators allowed
to run concurrently. None means no limit. When set, should be less than the number of
workers in multi-worker mode so at least one worker slot is always available for
partition reading, preventing thread pool starvation. In single-threaded mode
(num_workers=1) the value may equal num_workers; ConcurrentSource.create() handles
this distinction. ConcurrentSource.read() passes this value explicitly.
"""
self._stream_name_to_instance = {s.name: s for s in stream_instances_to_read_from}
self._record_counter = {}
self._streams_to_running_partitions: Dict[str, Set[Partition]] = {}
for stream in stream_instances_to_read_from:
self._streams_to_running_partitions[stream.name] = set()
self._record_counter[stream.name] = 0
if (
max_concurrent_partition_generators is not None
and max_concurrent_partition_generators < 1
):
raise ValueError(
f"max_concurrent_partition_generators must be >= 1 or None, got {max_concurrent_partition_generators}"
)
self._thread_pool_manager = thread_pool_manager
self._partition_enqueuer = partition_enqueuer
self._max_concurrent_partition_generators = max_concurrent_partition_generators
self._stream_instances_to_start_partition_generation = stream_instances_to_read_from
self._streams_currently_generating_partitions: List[str] = []
self._logger = logger
Expand Down Expand Up @@ -255,6 +270,20 @@ def start_next_partition_generator(self) -> Optional[AirbyteMessage]:
if not self._stream_instances_to_start_partition_generation:
return None

# Enforce the concurrent generator cap so at least one worker slot is always available
# for partition reading. Recovery is guaranteed: on_partition_generation_completed
# decrements the count before calling here, so the guard always passes there.
if (
self._max_concurrent_partition_generators is not None
and len(self._streams_currently_generating_partitions)
>= self._max_concurrent_partition_generators
):
self._logger.debug(
f"Concurrent partition generator cap ({self._max_concurrent_partition_generators}) reached "
f"({len(self._streams_currently_generating_partitions)} active). Deferring next generator start."
)
return None

# Remember initial queue size to avoid infinite loops if all streams are blocked
max_attempts = len(self._stream_instances_to_start_partition_generation)
attempts = 0
Expand Down
5 changes: 5 additions & 0 deletions airbyte_cdk/sources/concurrent_source/concurrent_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def create(
queue: Optional[Queue[QueueItem]] = None,
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
) -> "ConcurrentSource":
if initial_number_of_partitions_to_generate < 1:
raise ValueError(
f"initial_number_of_partitions_to_generate must be >= 1, got {initial_number_of_partitions_to_generate}"
)
is_single_threaded = initial_number_of_partitions_to_generate == 1 and num_workers == 1
too_many_generator = (
not is_single_threaded and initial_number_of_partitions_to_generate >= num_workers
Expand Down Expand Up @@ -117,6 +121,7 @@ def read(
self._queue,
PartitionLogger(self._slice_logger, self._logger, self._message_repository),
),
max_concurrent_partition_generators=self._initial_number_partitions_to_generate,
)

# Enqueue initial partition generation tasks
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,71 @@ def test_start_next_partition_generator(self):
self._partition_enqueuer.generate_partitions, self._stream
)

def test_invalid_max_concurrent_partition_generators_raises(self):
for invalid in (0, -1):
with self.assertRaises(ValueError):
ConcurrentReadProcessor(
[self._stream],
self._partition_enqueuer,
self._thread_pool_manager,
self._logger,
self._slice_logger,
self._message_repository,
self._partition_reader,
max_concurrent_partition_generators=invalid,
)

def test_start_next_partition_generator_respects_concurrent_limit(self):
stream_instances_to_read_from = [self._stream]
handler = ConcurrentReadProcessor(
stream_instances_to_read_from,
self._partition_enqueuer,
self._thread_pool_manager,
self._logger,
self._slice_logger,
self._message_repository,
self._partition_reader,
max_concurrent_partition_generators=1,
)
handler._streams_currently_generating_partitions.append(_STREAM_NAME)

status_message = handler.start_next_partition_generator()

assert status_message is None
assert (
handler._stream_instances_to_start_partition_generation == stream_instances_to_read_from
)
self._thread_pool_manager.submit.assert_not_called()

def test_start_next_partition_generator_starts_when_below_limit(self):
other_stream = Mock(spec=AbstractStream)
other_stream.name = "other_stream"
other_stream.block_simultaneous_read = ""
other_stream.as_airbyte_stream.return_value = AirbyteStream(
name="other_stream",
json_schema={},
supported_sync_modes=[SyncMode.full_refresh],
)
handler = ConcurrentReadProcessor(
[other_stream],
self._partition_enqueuer,
self._thread_pool_manager,
self._logger,
self._slice_logger,
self._message_repository,
self._partition_reader,
max_concurrent_partition_generators=2,
)
handler._streams_currently_generating_partitions.append(_STREAM_NAME)

status_message = handler.start_next_partition_generator()

assert status_message is not None
assert "other_stream" in handler._streams_currently_generating_partitions
self._thread_pool_manager.submit.assert_called_with(
self._partition_enqueuer.generate_partitions, other_stream
)


class TestBlockSimultaneousRead(unittest.TestCase):
"""Tests for block_simultaneous_read functionality"""
Expand Down
Loading