Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions tests/test_openai_chat_completions_token_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,163 @@ async def fake_get_prompt_ids(self, state, prompt_messages, oai_tools): # noqa:
assert len(recording_client.calls) == 1
assert recording_client.calls[0]["path"] == "/chat/completions/tokens"
assert recording_client.calls[0]["body"]["tokens"] == [10, 20]


# ---------------------------------------------------------------------------
# dynamo_chat_nvext transport (Dynamo bis/dynamo-rl)
# ---------------------------------------------------------------------------


class _StubRenderer:
"""Renderer stand-in for the dynamo_chat_nvext transport tests.

Returns deterministic ids so we can assert on body shape without pulling
in a real HuggingFace tokenizer download. ``render_ids`` returns a
fixed sequence; ``get_stop_token_ids`` returns a marker pair.
"""

def __init__(self) -> None:
self.render_calls: list[dict[str, Any]] = []

def render_ids(
self,
messages,
*,
tools=None,
add_generation_prompt: bool = False,
) -> list[int]:
self.render_calls.append(
{
"messages": messages,
"tools": tools,
"add_generation_prompt": add_generation_prompt,
}
)
# Encode the call shape into ids so tests can disambiguate the two
# bridge tokenize calls without a real tokenizer.
return [42, len(messages), int(add_generation_prompt)]

def get_stop_token_ids(self) -> list[int]:
return [99, 100]


class _DynamoTestClient(OpenAIChatCompletionsTokenClient):
"""Dynamo-transport TITO client with a stubbed renderer.

Subclass override is the cleanest way to inject the stub without going
through ``ClientConfig`` (which would require a real ``api_base_url``
and ``setup_client`` to construct the AsyncOpenAI). The recording
client captures the eventual ``self.client.post(...)`` call.
"""

_stub_renderer: _StubRenderer

def __init__(self, recording_client) -> None:
super().__init__(recording_client)
self._stub_renderer = _StubRenderer()

@property
def renderer_transport(self) -> str: # type: ignore[override]
return "dynamo_chat_nvext"

def _get_renderer(self, model: str): # type: ignore[override]
return self._stub_renderer


@pytest.mark.asyncio
async def test_local_tokenize_uses_renderer_under_dynamo_transport():
"""Bridge tokenize must NOT hit any HTTP route under dynamo_chat_nvext.

Goes straight through ``_local_tokenize`` -> ``renderer.render_ids``.
The recording client would record any errant POST; we assert it sees
none.
"""
recording_client = _RecordingClient()
client = _DynamoTestClient(recording_client)

ids_full = await client.tokenize(
messages=[{"role": "user", "content": "u"}],
tools=None,
model="test-model",
)
ids_base = await client.tokenize(
messages=[{"role": "user", "content": "u"}],
tools=None,
model="test-model",
extra_kwargs={"add_generation_prompt": False},
)

# Both calls hit the renderer, neither hit the wire.
assert recording_client.calls == []
assert client._stub_renderer.render_calls[0]["add_generation_prompt"] is True
assert client._stub_renderer.render_calls[1]["add_generation_prompt"] is False
# And the stub encodes that into the returned ids' last element.
assert ids_full[-1] == 1
assert ids_base[-1] == 0


@pytest.mark.asyncio
async def test_get_native_response_uses_dynamo_chat_nvext_under_transport(
monkeypatch: pytest.MonkeyPatch,
):
"""Dynamo transport must POST to /chat/completions with nvext.token_data.

Mirrors test_get_native_response_uses_token_route_when_prompt_ids_available
but for the new transport.
"""
recording_client = _RecordingClient()
client = _DynamoTestClient(recording_client)

async def fake_get_prompt_ids(self, state, prompt_messages, oai_tools): # noqa: ANN001
return [10, 20, 30]

monkeypatch.setattr(
OpenAIChatCompletionsTokenClient, "get_prompt_ids", fake_get_prompt_ids
)

state = cast(
State,
{
"model": "test-model",
"trajectory": [
_make_step(
prompt=[{"role": "user", "content": "u1"}],
completion=[{"role": "assistant", "content": "a1"}],
prompt_ids=[1],
completion_ids=[2],
)
],
},
)
prompt = cast(Any, [{"role": "user", "content": "u2"}])

response = await client.get_native_response(
prompt=prompt,
model="test-model",
sampling_args={"max_completion_tokens": 16, "temperature": 0.5},
tools=None,
state=state,
)

assert response["ok"] is True
assert len(recording_client.calls) == 1
call = recording_client.calls[0]

# Wire-shape assertions: route, nvext.token_data, stop_token_ids,
# placeholder messages, sampling fields promoted.
assert call["path"] == "/chat/completions"
body = call["body"]
assert body["nvext"]["token_data"] == [10, 20, 30]
assert body["nvext"]["extra_fields"] == ["completion_token_ids"]
assert body["stop_token_ids"] == [99, 100]
assert body["messages"] == [{"role": "user", "content": "(token-in mode)"}]
assert body["max_completion_tokens"] == 16
assert body["temperature"] == 0.5
assert body["logprobs"] is True
assert body["stream"] is False

# No /chat/completions/tokens, no /tokenize for the dynamo transport.
assert all(
c["path"] != "/chat/completions/tokens" and not c["path"].endswith("/tokenize")
for c in recording_client.calls
)
Loading
Loading