Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 43 additions & 5 deletions packages/nvidia_nat_core/src/nat/data_models/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import abc
import datetime
import json
import typing
import uuid
from abc import abstractmethod
Expand Down Expand Up @@ -520,15 +521,52 @@ def get_stream_data(self) -> str:


class ResponsePayloadOutput(BaseModel, ResponseSerializable):
"""SSE wrapper for the workflow's output on the ``/generate/full`` endpoint.

payload: typing.Any
The wire contract for ``/generate/full`` is::

def get_stream_data(self) -> str:
data: {"value": "<answer string>"}

if (isinstance(self.payload, BaseModel)):
return f"data: {self.payload.model_dump_json()}\n\n"
paired with ``intermediate_data: ...`` lines for step events. Consumers
(notably ``nat.plugins.eval.runtime.remote_workflow``) rely on the
``value`` key to extract the workflow's final answer; see the test
fixture in ``packages/nvidia_nat_eval/tests/eval/test_remote_evaluate.py``
which simulates the server emitting exactly this shape.

``payload`` accepts whatever the workflow yields from its single-output
function or stream chunks: a primitive (``str``, ``int``, ...), a
Pydantic ``BaseModel`` (e.g. ``ChatResponseChunk`` for ReAct-style
workflows that opt into token streaming), or any JSON-serializable
object. ``get_stream_data`` normalizes that to a string answer and
emits the canonical ``{"value": ...}`` envelope so the wire shape stays
stable across workflow types.
"""

return f"data: {self.payload}\n\n"
payload: typing.Any

def get_stream_data(self) -> str:
payload = self.payload

if isinstance(payload, ChatResponseChunk):
content = payload.choices[0].delta.content or "" if payload.choices else ""
return f"data: {json.dumps({'value': content})}\n\n"

if isinstance(payload, ChatResponse):
content = payload.choices[0].message.content or "" if payload.choices else ""
return f"data: {json.dumps({'value': content})}\n\n"

if isinstance(payload, BaseModel):
# If the payload already exposes a ``value`` field it is already in
# the canonical envelope shape (e.g. NAT's auto-derived
# ``OutputArgsSchema`` from ``DecomposedType.get_pydantic_schema``,
# or any user model that opts into the wire contract). Emit the
# model's dump as-is so we don't double-wrap into
# ``{"value": "{\"value\": ...}"}``.
if "value" in type(payload).model_fields:
return f"data: {payload.model_dump_json()}\n\n"
return f"data: {json.dumps({'value': payload.model_dump_json()})}\n\n"

return f"data: {json.dumps({'value': str(payload)})}\n\n"


class ResponseATIFStep(BaseModel, ResponseSerializable):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,105 @@ async def test_chat_response_chunk_to_websocket_message():
assert isinstance(nat_chat_repsonse_chunk_to_system_response, WebSocketSystemResponseTokenMessage)


# -- ResponsePayloadOutput.get_stream_data() wire-format contract -------------
#
# Pins the SSE shape ``/generate/full`` emits. The eval client at
# ``nat.plugins.eval.runtime.remote_workflow`` extracts the final answer via
# ``chunk_data.get("value")``, so every payload type wrapped in
# ``ResponsePayloadOutput`` must serialize to ``data: {"value": <str>}\n\n``.
# This regressed when ``react_agent`` gained a ``_stream_fn`` (PR #1851):
# ``ChatResponseChunk`` payloads were emitted as raw OpenAI envelopes with no
# top-level ``value`` field, so the eval client extracted ``None`` and every
# eval run scored zero.


@pytest.mark.parametrize(
"payload, expected_value",
[
pytest.param("21", "21", id="str"),
pytest.param(21, "21", id="int-stringified"),
pytest.param(ChatResponseChunk.create_streaming_chunk("21"), "21", id="chat-response-chunk"),
pytest.param(
ChatResponseChunk.create_streaming_chunk("", finish_reason="stop"),
"",
id="chat-response-chunk-keepalive",
),
pytest.param(
ChatResponse(id="x",
object="chat.completion",
created=datetime.datetime.now(datetime.UTC),
choices=[
ChatResponseChoice(
index=0, message=ChoiceMessage(content="42", role="assistant"), finish_reason="stop")
],
usage=Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0)),
"42",
id="chat-response",
),
],
)
def test_response_payload_output_emits_value_envelope(payload, expected_value):
"""Each supported payload type serializes to ``data: {"value": <str>}\\n\\n``.

Covers strings, primitives (stringified), ``ChatResponseChunk`` (the
streaming-agent regression case introduced by PR #1851, including
role-only/finish-only keepalive chunks), and ``ChatResponse``.
"""
sse = ResponsePayloadOutput(payload=payload).get_stream_data()
assert sse.startswith("data: ") and sse.endswith("\n\n")
assert json.loads(sse[len("data: "):-2]) == {"value": expected_value}


def test_response_payload_output_other_basemodel_serializes_to_json_string():
"""Arbitrary ``BaseModel`` payloads are JSON-encoded into ``value`` so the
wire shape stays uniformly ``{"value": <str>}`` while preserving the full
payload — consumers that need the structured form ``json.loads(value)``.
"""

class Custom(BaseModel):
answer: int
explanation: str

sse = ResponsePayloadOutput(payload=Custom(answer=21, explanation="3*7")).get_stream_data()
decoded = json.loads(sse[len("data: "):-2])
assert set(decoded.keys()) == {"value"}
assert json.loads(decoded["value"]) == {"answer": 21, "explanation": "3*7"}


def test_response_payload_output_basemodel_with_value_field_passes_through():
"""``BaseModel`` payloads that already expose a ``value`` field are emitted
as-is rather than re-wrapped. Covers NAT's auto-derived
``OutputArgsSchema`` (synthesized from ``str``/scalar workflow returns)
and any user model that already matches the canonical envelope shape —
re-wrapping would produce ``{"value": "{\\"value\\": ...}"}`` and break
SSE consumers reading ``sse.json()["value"]``.
"""

class AlreadyCanonical(BaseModel):
value: str

sse = ResponsePayloadOutput(payload=AlreadyCanonical(value="a")).get_stream_data()
decoded = json.loads(sse[len("data: "):-2])
assert decoded == {"value": "a"}


def test_response_payload_output_basemodel_with_value_and_extras_preserves_extras():
"""A ``BaseModel`` whose top-level shape includes ``value`` plus other
fields (e.g. a workflow that returns ``{"value": "21", "tokens": 42}``)
is forwarded unchanged. The eval client still extracts ``value`` correctly
via ``chunk_data.get("value")``, and richer consumers can read sibling
fields.
"""

class Rich(BaseModel):
value: str
tokens: int

sse = ResponsePayloadOutput(payload=Rich(value="21", tokens=42)).get_stream_data()
decoded = json.loads(sse[len("data: "):-2])
assert decoded == {"value": "21", "tokens": 42}


async def test_nat_intermediate_step_to_websocket_message():
"""Tests ResponseIntermediateStep can be converted to a WebSocketSystemIntermediateStepMessage"""
message_validator = MessageValidator()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,14 @@ async def run_workflow_remote_single(self, session: aiohttp.ClientSession, item:

response.raise_for_status()

final_response: str | None = None
# Workflows that opt into token streaming (e.g. ``react_agent``'s
# ``_stream_fn`` after PR #1851) emit one ``data: {"value": "<token>"}``
# line per chunk. Accumulating into a list and joining at the end
# reconstructs the full answer regardless of how the producer
# chunks the response — a single ``data: {"value": "<full>"}`` line
# (the NAT 1.6 / single-fn workflow shape) still works correctly
# because the list contains exactly one element.
response_chunks: list[str] = []
intermediate_steps: list[IntermediateStep] = []

async for line in response.content:
Expand All @@ -76,8 +83,9 @@ async def run_workflow_remote_single(self, session: aiohttp.ClientSession, item:
if line.startswith(DATA_PREFIX):
try:
chunk_data: dict = json.loads(line[len(DATA_PREFIX):])
if chunk_data.get("value"):
final_response = chunk_data.get("value")
value = chunk_data.get("value")
if value is not None:
response_chunks.append(value)
except json.JSONDecodeError:
logger.exception("Failed to parse generate response chunk")
continue
Expand All @@ -99,7 +107,7 @@ async def run_workflow_remote_single(self, session: aiohttp.ClientSession, item:
logger.exception("Failed to parse intermediate step")
continue

item.output_obj = final_response
item.output_obj = "".join(response_chunks) if response_chunks else None
item.trajectory = intermediate_steps
return

Expand Down
121 changes: 121 additions & 0 deletions packages/nvidia_nat_eval/tests/eval/test_remote_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from aiohttp.test_utils import TestClient
from aiohttp.test_utils import TestServer

from nat.data_models.api_server import ChatResponseChunk
from nat.data_models.api_server import ResponseIntermediateStep
from nat.data_models.api_server import ResponsePayloadOutput
from nat.data_models.evaluate_runtime import EndpointRetryConfig
from nat.data_models.evaluate_runtime import EvaluationRunConfig
from nat.plugins.eval.runtime.remote_workflow import EvaluationRemoteWorkflowHandler
Expand Down Expand Up @@ -414,3 +416,122 @@ async def test_max_retries_lower_bound_validation(invalid_value: int) -> None:
error = exc_info.value.errors()[0]
assert error["type"] == "greater_than_equal"
assert error["loc"] == ("max_retries", )


# -- Producer/consumer wire-format compatibility ------------------------------
#
# Two paired fixes that ride together so ``nat eval --endpoint`` produces
# correct output regardless of how a workflow chunks its response:
#
# 1. **Producer side** (``ResponsePayloadOutput.get_stream_data()``) emits
# ``data: {"value": "<str>"}\n\n`` for every payload type — the canonical
# shape ``test_remote_evaluate.py``'s server fixture has always documented.
# Before this fix, ``ChatResponseChunk`` payloads (yielded by
# ``react_agent``'s ``_stream_fn`` after PR #1851) were emitted as the
# chunk's full OpenAI envelope with no top-level ``value`` field, so the
# eval client extracted ``None`` for every line.
#
# 2. **Consumer side** (``EvaluationRemoteWorkflowHandler``) accumulates
# ``value`` chunks into a list and joins at the end instead of overwriting
# on each chunk. Pre-fix, the handler kept only the *last* ``value``,
# which silently truncated multi-chunk responses (e.g. NAT 1.7's
# per-token ``react_agent`` streams) to the final fragment.
#
# Together these restore NAT 1.6 end-to-end behavior for both single-chunk
# (``FunctionInfo.from_fn``-derived) and multi-chunk (token-streaming)
# workflows.


@pytest.mark.parametrize("payload_factory",
[
pytest.param(lambda answer: answer, id="str-payload"),
pytest.param(ChatResponseChunk.create_streaming_chunk, id="chat-response-chunk-payload"),
])
async def test_remote_eval_consumes_response_payload_output(rag_eval_input, payload_factory):
"""Production ``ResponsePayloadOutput.get_stream_data()`` lines round-trip
through the eval client. Covers bare-string and ``ChatResponseChunk``
payloads (the streaming-agent regression path from PR #1851).
"""
item = rag_eval_input.eval_input_items[0]
expected_answer = "the answer is 21"

async def stream_response(request):
resp = web.StreamResponse(status=200, headers={"Content-Type": "text/event-stream"})
await resp.prepare(request)
sse_line = ResponsePayloadOutput(payload=payload_factory(expected_answer)).get_stream_data()
await resp.write(sse_line.encode("utf-8"))
await resp.write_eof()
return resp

app = web.Application()
app.router.add_post("/generate/full", stream_response)
server = TestServer(app)
await server.start_server()
client = TestClient(server)
await client.start_server()

handler = EvaluationRemoteWorkflowHandler(
config=EvaluationRunConfig(endpoint=str(server.make_url("")).rstrip("/"),
endpoint_timeout=5,
config_file=Path(__file__),
dataset=None,
result_json_path="",
skip_workflow=False,
skip_completed_entries=False,
reps=1),
max_concurrency=2,
)

async with client.session as session:
await handler.run_workflow_remote_single(session, item)

await client.close()
await server.close()

assert item.output_obj == expected_answer


async def test_remote_eval_accumulates_multi_chunk_value_stream(rag_eval_input):
"""Multi-chunk ``data: {"value": "<token>"}`` streams must be reconstructed
by joining all ``value`` fields in order. The empty-string token also
exercises the presence check (``value is not None`` vs. ``if value:``).
"""
item = rag_eval_input.eval_input_items[0]
tokens = ["The ", "answer ", "", "is ", "21", "."]
expected_answer = "".join(tokens)

async def stream_response(request):
resp = web.StreamResponse(status=200, headers={"Content-Type": "text/event-stream"})
await resp.prepare(request)
for token in tokens:
chunk = ChatResponseChunk.create_streaming_chunk(token)
await resp.write(ResponsePayloadOutput(payload=chunk).get_stream_data().encode("utf-8"))
await resp.write_eof()
return resp

app = web.Application()
app.router.add_post("/generate/full", stream_response)
server = TestServer(app)
await server.start_server()
client = TestClient(server)
await client.start_server()

handler = EvaluationRemoteWorkflowHandler(
config=EvaluationRunConfig(endpoint=str(server.make_url("")).rstrip("/"),
endpoint_timeout=5,
config_file=Path(__file__),
dataset=None,
result_json_path="",
skip_workflow=False,
skip_completed_entries=False,
reps=1),
max_concurrency=2,
)

async with client.session as session:
await handler.run_workflow_remote_single(session, item)

await client.close()
await server.close()

assert item.output_obj == expected_answer
Loading