Skip to content
3 changes: 2 additions & 1 deletion mellea/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,7 +1015,8 @@ async def post_processing(
mot._meta["hf_output"] = full_output

# The ModelOutputThunk must be computed by this point.
assert mot.value is not None
if mot.value is None:
return

# Store KV cache in LRU separately (not in mot._meta) to enable proper cleanup on eviction.
# This prevents GPU memory from being held by ModelOutputThunk references.
Expand Down
9 changes: 6 additions & 3 deletions mellea/backends/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,11 @@ async def post_processing(
# OpenAI-like streamed responses potentially give you chunks of tool calls.
# As a result, we have to store data between calls and only then
# check for complete tool calls in the post_processing step.
tool_chunk = extract_model_tool_requests(
tools, mot._meta["litellm_chat_response"]
litellm_response = mot._meta.get("litellm_chat_response")
tool_chunk = (
extract_model_tool_requests(tools, litellm_response)
if litellm_response is not None
else None
)
if tool_chunk is not None:
if mot.tool_calls is None:
Expand All @@ -457,7 +460,7 @@ async def post_processing(
generate_log.backend = f"litellm::{self.model_id!s}"
generate_log.model_options = mot._model_options
generate_log.date = datetime.datetime.now()
generate_log.model_output = mot._meta["litellm_chat_response"]
generate_log.model_output = mot._meta.get("litellm_chat_response")
generate_log.extra = {
"format": _format,
"tools_available": tools,
Expand Down
2 changes: 1 addition & 1 deletion mellea/backends/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ async def post_processing(
generate_log.backend = f"ollama::{self._get_ollama_model_id()}"
generate_log.model_options = mot._model_options
generate_log.date = datetime.datetime.now()
generate_log.model_output = mot._meta["chat_response"]
generate_log.model_output = mot._meta.get("chat_response")
generate_log.extra = {
"format": _format,
"thinking": mot._model_options.get(ModelOption.THINKING, None),
Expand Down
23 changes: 14 additions & 9 deletions mellea/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,9 +575,13 @@ async def post_processing(
# check for complete tool calls in the post_processing step.
# Use the choice format for tool extraction (backward compatibility)
choice_response = mot._meta.get(
"oai_chat_response_choice", mot._meta["oai_chat_response"]
"oai_chat_response_choice", mot._meta.get("oai_chat_response")
)
tool_chunk = (
extract_model_tool_requests(tools, choice_response)
if choice_response is not None
else None
)
tool_chunk = extract_model_tool_requests(tools, choice_response)
if tool_chunk is not None:
if mot.tool_calls is None:
mot.tool_calls = {}
Expand All @@ -592,7 +596,7 @@ async def post_processing(
generate_log.model_options = mot._model_options
generate_log.date = datetime.datetime.now()
# Store the full response (includes usage info)
generate_log.model_output = mot._meta["oai_chat_response"]
generate_log.model_output = mot._meta.get("oai_chat_response")
generate_log.extra = {
"format": _format,
"thinking": thinking,
Expand All @@ -613,12 +617,13 @@ async def post_processing(
record_token_usage,
)

response = mot._meta["oai_chat_response"]
# response is a dict from model_dump(), extract usage if present
usage = response.get("usage") if isinstance(response, dict) else None
if usage:
record_token_usage(span, usage)
record_response_metadata(span, response)
response = mot._meta.get("oai_chat_response")
if response is not None:
# response is a dict from model_dump(), extract usage if present
usage = response.get("usage") if isinstance(response, dict) else None
if usage:
record_token_usage(span, usage)
record_response_metadata(span, response)
# Close the span now that async operation is complete
end_backend_span(span)
# Clean up the span reference
Expand Down
3 changes: 2 additions & 1 deletion mellea/backends/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,8 @@ async def post_processing(
):
"""Called when generation is done."""
# The ModelOutputThunk must be computed by this point.
assert mot.value is not None
if mot.value is None:
return

# Only scan for tools if we are not doing structured output and tool calls were provided to the model.
if _format is None and tool_calls:
Expand Down
9 changes: 7 additions & 2 deletions mellea/backends/watsonx.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,12 @@ async def post_processing(
# OpenAI streamed responses give you chunks of tool calls.
# As a result, we have to store data between calls and only then
# check for complete tool calls in the post_processing step.
tool_chunk = extract_model_tool_requests(tools, mot._meta["oai_chat_response"])
oai_response = mot._meta.get("oai_chat_response")
tool_chunk = (
extract_model_tool_requests(tools, oai_response)
if oai_response is not None
else None
)
if tool_chunk is not None:
if mot.tool_calls is None:
mot.tool_calls = {}
Expand Down Expand Up @@ -509,7 +514,7 @@ async def post_processing(
generate_log.backend = f"watsonx::{self.model_id!s}"
generate_log.model_options = mot._model_options
generate_log.date = datetime.datetime.now()
generate_log.model_output = mot._meta["oai_chat_response"]
generate_log.model_output = mot._meta.get("oai_chat_response")
generate_log.extra = {
"format": _format,
"tools_available": tools,
Expand Down
8 changes: 6 additions & 2 deletions mellea/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,10 +325,14 @@ async def astream(self) -> str:
elif isinstance(chunks[-1], Exception):
# Mark as computed so post_process runs in finally block
self._computed = True
# Store exception to re-raise after cleanup
exception_to_raise = chunks[-1]
# Remove the exception from chunks so _process doesn't receive it
exception_to_raise = chunks.pop()

for chunk in chunks:
# Belt-and-suspenders: skip non-chunk objects that should
# have been removed above (exceptions, sentinel None).
if chunk is None or isinstance(chunk, Exception):
continue
assert self._process is not None
await self._process(self, chunk)

Expand Down
94 changes: 94 additions & 0 deletions test/core/test_astream_exception_handling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Regression tests for astream() exception handling.

When a backend error occurs during streaming, the Exception object lands in the
async queue. Before the fix, astream() would either pass it to _process (crash)
or post_processing would hit a KeyError on _meta keys that were never set.

These tests verify that astream() cleanly propagates the original exception
after running _post_process for telemetry cleanup.
"""

import asyncio

import pytest

from mellea.core.base import CBlock, GenerateType, ModelOutputThunk


def _make_streaming_mot():
"""Create a ModelOutputThunk wired up for streaming with stub callbacks."""
mot = ModelOutputThunk(value=None)
mot._generate_type = GenerateType.ASYNC
mot._chunk_size = 1

process_calls: list = []

async def _process(mot, chunk):
process_calls.append(chunk)
text = chunk if isinstance(chunk, str) else str(chunk)
if mot._underlying_value is None:
mot._underlying_value = text
else:
mot._underlying_value += text

post_process_called = asyncio.Event()

async def _post_process(mot):
post_process_called.set()

mot._process = _process
mot._post_process = _post_process

return mot, process_calls, post_process_called


async def test_astream_propagates_exception_from_queue():
"""Exception in the queue is re-raised after cleanup, not passed to _process."""
mot, process_calls, post_process_called = _make_streaming_mot()

original_error = RuntimeError("backend connection lost")
await mot._async_queue.put(original_error)

with pytest.raises(RuntimeError, match="backend connection lost"):
await mot.astream()

# _process must never have seen the Exception object
assert original_error not in process_calls
# _post_process ran for telemetry cleanup
assert post_process_called.is_set()


async def test_astream_propagates_exception_after_valid_chunks():
"""Valid chunks before the exception are processed; exception still raised."""
mot, process_calls, post_process_called = _make_streaming_mot()

await mot._async_queue.put("hello ")
await mot._async_queue.put("world")
await mot._async_queue.put(ValueError("mid-stream failure"))

with pytest.raises(ValueError, match="mid-stream failure"):
await mot.astream()

# Valid chunks were processed
assert process_calls == ["hello ", "world"]
# Accumulated value reflects only the valid chunks
assert mot._underlying_value == "hello world"
# Cleanup still ran
assert post_process_called.is_set()


async def test_astream_skips_none_and_exception_in_chunk_loop():
"""Belt-and-suspenders: stray None/Exception objects in the middle of the
chunk list are skipped rather than passed to _process."""
mot, process_calls, _ = _make_streaming_mot()

await mot._async_queue.put("good chunk")
await mot._async_queue.put(None)

mot._action = CBlock("test")

result = await mot.astream()

assert process_calls == ["good chunk"]
assert mot.is_computed()
assert result is not None