Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 34 additions & 24 deletions src/confluent_kafka/src/Consumer.c
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,8 @@ Consumer_consume(Handle *self, PyObject *args, PyObject *kwargs) {
PyObject *msglist;
rd_kafka_queue_t *rkqu = self->u.Consumer.rkqu;
CallState cs;
Py_ssize_t i, n = 0;
Py_ssize_t i, msgs_received_count = 0;
Py_ssize_t chunk_msg_count;
const int CHUNK_TIMEOUT_MS = 200; /* 200ms chunks for signal checking */
int total_timeout_ms;
int chunk_timeout_ms;
Expand Down Expand Up @@ -1156,16 +1157,15 @@ Consumer_consume(Handle *self, PyObject *args, PyObject *kwargs) {
* ThreadPool. Only use wakeable poll for
* blocking calls that need to be interruptible. */
if (total_timeout_ms >= 0 && total_timeout_ms < CHUNK_TIMEOUT_MS) {
n = (Py_ssize_t)rd_kafka_consume_batch_queue(
msgs_received_count = (Py_ssize_t)rd_kafka_consume_batch_queue(
rkqu, total_timeout_ms, rkmessages, num_messages);

if (n < 0) {
/* Error - need to restore GIL before setting error */
PyEval_RestoreThread(cs.thread_state);
if (msgs_received_count < 0) {
if (CallState_end(self, &cs))
cfl_PyErr_Format(
rd_kafka_last_error(), "%s",
rd_kafka_err2str(rd_kafka_last_error()));
free(rkmessages);
cfl_PyErr_Format(
rd_kafka_last_error(), "%s",
rd_kafka_err2str(rd_kafka_last_error()));
return NULL;
}
} else {
Expand All @@ -1178,30 +1178,40 @@ Consumer_consume(Handle *self, PyObject *args, PyObject *kwargs) {
break;
}

/* Consume with chunk timeout */
n = (Py_ssize_t)rd_kafka_consume_batch_queue(
rkqu, chunk_timeout_ms, rkmessages, num_messages);

if (n < 0) {
/* Error - need to restore GIL before setting
* error */
PyEval_RestoreThread(cs.thread_state);
/* Consume with chunk timeout, appending after
* already-accumulated messages */
chunk_msg_count =
(Py_ssize_t)rd_kafka_consume_batch_queue(
rkqu, chunk_timeout_ms,
rkmessages + msgs_received_count,
num_messages -
(unsigned int)msgs_received_count);

if (chunk_msg_count < 0) {
for (i = 0; i < msgs_received_count; i++)
rd_kafka_message_destroy(rkmessages[i]);
if (CallState_end(self, &cs))
cfl_PyErr_Format(
rd_kafka_last_error(), "%s",
rd_kafka_err2str(
rd_kafka_last_error()));
free(rkmessages);
cfl_PyErr_Format(
rd_kafka_last_error(), "%s",
rd_kafka_err2str(rd_kafka_last_error()));
return NULL;
Comment on lines 1190 to 1199
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the chunked-loop error path (chunk_msg_count < 0), the code destroys accumulated messages and restores the GIL but returns without calling CallState_end() (or otherwise clearing the thread-local CallState). This can leave per-thread state behind. Ensure the CallState is properly ended/cleaned up before returning from this error path.

Copilot uses AI. Check for mistakes.
}

/* If we got messages, exit the loop */
if (n > 0) {
msgs_received_count += chunk_msg_count;

/* If we got all requested messages, exit the loop */
if (msgs_received_count >= (Py_ssize_t)num_messages) {
break;
}
Comment on lines 1181 to 1207
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When num_messages==0, the chunked loop still performs at least one rd_kafka_consume_batch_queue() call with a 200ms chunk timeout before exiting (because the >= num_messages check happens after the call). Since the API allows num_messages=0, this should return an empty list immediately (without chunking/allocating). Consider an early fast-path for num_messages==0 before entering CallState_begin()/the chunk loop.

Copilot uses AI. Check for mistakes.

chunk_count++;

/* Check for signals between chunks */
if (check_signals_between_chunks(self, &cs)) {
for (i = 0; i < msgs_received_count; i++)
rd_kafka_message_destroy(rkmessages[i]);
free(rkmessages);
return NULL;
}
Expand All @@ -1210,17 +1220,17 @@ Consumer_consume(Handle *self, PyObject *args, PyObject *kwargs) {

/* Final GIL restore and signal check */
if (!CallState_end(self, &cs)) {
for (i = 0; i < n; i++) {
for (i = 0; i < msgs_received_count; i++) {
rd_kafka_message_destroy(rkmessages[i]);
}
free(rkmessages);
return NULL;
}

/* Create Python list from messages */
msglist = PyList_New(n);
msglist = PyList_New(msgs_received_count);

for (i = 0; i < n; i++) {
for (i = 0; i < msgs_received_count; i++) {
PyObject *msgobj = Message_new0(self, rkmessages[i]);
#ifdef RD_KAFKA_V_HEADERS
/** Have to detach headers outside Message_new0 because it
Expand Down
5 changes: 4 additions & 1 deletion src/confluent_kafka/src/Producer.c
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,10 @@ static int Producer_poll0(Handle *self, int tmout) {
r = chunk_result;
break;
}
r += chunk_result; /* Accumulate events processed */
r += chunk_result;

if (chunk_result > 0)
break;

chunk_count++;

Expand Down
138 changes: 138 additions & 0 deletions tests/integration/consumer/test_consumer_wakeable_poll_consume.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,141 @@ def test_consume_message_delivery_with_wakeable_pattern(kafka_cluster):
assert msg.value() == expected_value, expected_msg

consumer.close()


def test_consume_accumulates_messages_across_chunks(kafka_cluster):
"""Test that consume() accumulates messages across 200ms chunks.
This verifies that consume() doesn't return early on the first chunk of messages
when using the wakeable pattern.
"""
topic = kafka_cluster.create_topic_and_wait_propogation('test-consume-accumulate-chunks')

# Produce 10 messages
producer = kafka_cluster.cimpl_producer()
num_produced = 10
for i in range(num_produced):
producer.produce(topic, value=f'msg-{i}'.encode())
producer.flush(timeout=5.0)

# Create consumer
consumer_conf = kafka_cluster.client_conf(
{
'group.id': 'test-consume-accumulate',
'socket.timeout.ms': 100,
'session.timeout.ms': 6000,
'auto.offset.reset': 'earliest',
}
)
consumer = TestConsumer(consumer_conf)
consumer.subscribe([topic])

# Wait for subscription and partition assignment
time.sleep(2.0)

# Consume with num_messages=10 and a generous timeout.
# Before the fix: would return < 10 (whatever arrived in the first 200ms chunk)
# After the fix: accumulates across chunks until 10 are collected
msglist = consumer.consume(num_messages=num_produced, timeout=10.0)

assert len(msglist) == num_produced, (
f"Expected {num_produced} messages but got {len(msglist)}. " f"consume() may not be accumulating across chunks."
)

for i, msg in enumerate(msglist):
assert not msg.error(), f"Message {i} has error: {msg.error()}"

consumer.close()


def test_consume_returns_partial_on_timeout(kafka_cluster):
"""Test that consume() returns partial results when timeout expires
before num_messages is reached."""
topic = kafka_cluster.create_topic_and_wait_propogation('test-consume-partial-timeout')

# Produce only 3 messages, but request 100
producer = kafka_cluster.cimpl_producer()
num_produced = 3
for i in range(num_produced):
producer.produce(topic, value=f'partial-{i}'.encode())
producer.flush(timeout=5.0)

consumer_conf = kafka_cluster.client_conf(
{
'group.id': 'test-consume-partial',
'socket.timeout.ms': 100,
'session.timeout.ms': 6000,
'auto.offset.reset': 'earliest',
}
)
consumer = TestConsumer(consumer_conf)
consumer.subscribe([topic])

time.sleep(2.0)

# Request 100 messages but only 3 exist — should return 3 after timeout
start = time.time()
msglist = consumer.consume(num_messages=100, timeout=3.0)
elapsed = time.time() - start

assert len(msglist) == num_produced, f"Expected {num_produced} messages (partial), got {len(msglist)}"
# Should have waited close to the full timeout since num_messages wasn't reached
assert elapsed >= 2.0, f"Should wait near full timeout for more messages, but returned in {elapsed:.2f}s"

for i, msg in enumerate(msglist):
assert not msg.error(), f"Message {i} has error: {msg.error()}"
assert msg.value() == f'partial-{i}'.encode()

consumer.close()


def test_consume_accumulates_messages_produced_in_waves(kafka_cluster):
"""Test that consume() accumulates messages that arrive in multiple waves.
This verifies that consume() doesn't return early on the first wave of messages
when using the wakeable pattern.
"""
import threading

topic = kafka_cluster.create_topic_and_wait_propogation('test-consume-waves')

producer = kafka_cluster.cimpl_producer()

def produce_in_waves():
"""Produce 3 waves of 4 messages each, with 1s gaps."""
for wave in range(3):
time.sleep(1.0)
for i in range(4):
msg_num = wave * 4 + i
producer.produce(topic, value=f'wave-{msg_num}'.encode())
producer.flush(timeout=5.0)

consumer_conf = kafka_cluster.client_conf(
{
'group.id': 'test-consume-waves',
'socket.timeout.ms': 100,
'session.timeout.ms': 6000,
'auto.offset.reset': 'earliest',
}
)
consumer = TestConsumer(consumer_conf)
consumer.subscribe([topic])

time.sleep(2.0)

# Start producing in background
producer_thread = threading.Thread(target=produce_in_waves, daemon=True)
producer_thread.start()

# Request 10 messages with a long timeout (waves take ~3s to complete)
msglist = consumer.consume(num_messages=10, timeout=10.0)

# Should have accumulated messages across multiple waves
assert len(msglist) == 10, (
f"Expected exactly 10 messages accumulated across waves, got {len(msglist)}. "
f"consume() may be returning early on the first wave."
)

for msg in msglist:
assert not msg.error(), f"Message has error: {msg.error()}"

producer_thread.join(timeout=5.0)
consumer.close()
35 changes: 35 additions & 0 deletions tests/integration/producer/test_producer_wakeable_poll_flush.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,41 @@ def delivery_callback(err, msg):
consumer.close()


def test_poll_returns_early_after_delivery_callback(kafka_cluster):
"""Test that poll() returns early after delivery callback fires."""
topic = kafka_cluster.create_topic_and_wait_propogation('test-poll-early-return')

delivery_called = []

def delivery_callback(err, msg):
delivery_called.append(time.time())

producer_conf = kafka_cluster.client_conf(
{
'socket.timeout.ms': 100,
'message.timeout.ms': 10000,
}
)
producer = kafka_cluster.cimpl_producer(producer_conf)

producer.produce(topic, value=b'early-return-test', on_delivery=delivery_callback)

# Poll with a long timeout — should return early once callback fires
poll_timeout = 5.0
start = time.time()
events = producer.poll(timeout=poll_timeout)
elapsed = time.time() - start

assert len(delivery_called) == 1, "Expected delivery callback to fire"
assert events > 0, "Expected at least 1 event served"
assert elapsed < poll_timeout - 1.0, (
f"poll({poll_timeout}) took {elapsed:.2f}s — should have returned "
f"early after delivery callback, not blocked for full timeout"
)

producer.close()


def test_flush_message_delivery_with_wakeable_pattern(kafka_cluster):
"""Test that flush() correctly delivers messages when using wakeable pattern.

Expand Down
Loading