Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions renderers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
build_trajectory_step,
create_renderer,
create_renderer_pool,
extract_message_tool_names,
is_multimodal,
reject_assistant_in_extension,
trim_to_turn_close,
Expand Down Expand Up @@ -168,6 +169,7 @@ def __dir__() -> list[str]:
"config_from_name",
"create_renderer",
"create_renderer_pool",
"extract_message_tool_names",
"is_multimodal",
"reject_assistant_in_extension",
"trim_to_turn_close",
Expand Down
90 changes: 90 additions & 0 deletions renderers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import queue
import threading
from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import (
Expand Down Expand Up @@ -117,6 +118,68 @@ class Message(TypedDict, total=False):
reasoning_content: str


def extract_message_tool_names(messages: list[Message]) -> list[str | None]:
"""Per-message tool function names parallel to ``message_roles``.

Returns one entry per message: the function name for ``role="tool"``
messages, ``None`` for every other message. Length matches the
input list.

For tool messages the name is taken from ``msg["name"]`` when set
(caller-provided), otherwise recovered by joining
``msg["tool_call_id"]`` against any prior assistant's
``tool_calls[i].function.name`` in the same list. Tool messages
whose issuing assistant lives outside the provided list (e.g. on
a :meth:`Renderer.bridge_to_next_turn` call where ``new_messages``
covers only the new turn) resolve to ``None``.

Pure metadata: this never mutates the caller's messages and has
no effect on the rendered token stream. It runs independently of
the render path so the renderer can populate the field on
:class:`RenderedTokens` without breaking HF byte parity for tool
messages that carry no ``name``. Callers who *also* want the
function name to appear in the rendered scaffold (e.g. GPT-OSS
Harmony's ``functions.{name}`` prefix) must attach ``name`` to
their tool messages before calling :meth:`Renderer.render`
themselves — renderers don't synthesize ``name`` into the input,
only into this metadata field.

Trainers join this list with :attr:`RenderedTokens.message_indices`
to recover per-token tool attribution — the canonical use case is
SFT on tool response bodies while RL acts only on assistant tokens
(tool body tokens get a constant positive advantage so the model
learns to anticipate tool outputs without learning to emit
``<|tool_response>`` itself).

Per-message rather than per-token because the data is naturally
per-message — storing it per-token would duplicate the same
string across every body token of the same tool message.
"""
lookup: dict[str, str] = {}
for m in messages:
if not isinstance(m, Mapping) or m.get("role") != "assistant":
continue
for tc in m.get("tool_calls") or []:
if not isinstance(tc, Mapping):
continue
tc_id = tc.get("id")
fn = tc.get("function")
tc_name = fn.get("name") if isinstance(fn, Mapping) else None
if isinstance(tc_id, str) and isinstance(tc_name, str):
lookup[tc_id] = tc_name
out: list[str | None] = []
for m in messages:
if not isinstance(m, Mapping) or m.get("role") != "tool":
out.append(None)
continue
name = m.get("name")
if not (isinstance(name, str) and name):
tc_id = m.get("tool_call_id")
name = lookup.get(tc_id) if isinstance(tc_id, str) else None
out.append(name if isinstance(name, str) and name else None)
return out


# ---------------------------------------------------------------------------
# Renderer data types
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -208,6 +271,32 @@ class RenderedTokens:
renderer doesn't provide the signal. ``DefaultRenderer`` leaves it
empty for the same reason.

``message_tool_names`` is the per-message tool function name list,
parallel to ``message_roles`` (same length). For tool-role
messages it carries the function name — either taken from
``msg["name"]`` (caller-provided) or recovered by joining
``msg["tool_call_id"]`` against a prior assistant's
``tool_calls[i].function.name`` in the rendered slice. Every
other message is ``None``, as are tool messages whose issuing
assistant lives outside the rendered slice (e.g. on a
:meth:`Renderer.bridge_to_next_turn` call where ``new_messages``
covers only the new turn).

This is pure metadata, computed by :func:`extract_message_tool_names`
independently of the render path: populating it never touches the
rendered token stream, so HF chat-template byte parity is
preserved for tool messages carrying no ``name``. Callers who
*also* want the function name to appear in the rendered scaffold
(e.g. GPT-OSS Harmony's ``functions.{name}`` prefix) must attach
``name`` to their tool messages before calling
:meth:`Renderer.render` themselves.

Trainers join this with ``message_indices`` to build per-tool
selective loss masks (SFT on tool response bodies of a specific
tool while RL acts on assistant tokens). Empty
``message_tool_names`` (``[]``) means the renderer doesn't
provide the signal.

``multi_modal_data`` is populated by multimodal renderers (e.g.
``Qwen3VLRenderer``) when image / video content parts are present;
text-only renderers leave it as ``None``.
Expand All @@ -218,6 +307,7 @@ class RenderedTokens:
sampled_mask: list[bool] = field(default_factory=list)
is_content: list[bool] = field(default_factory=list)
message_roles: list[str] = field(default_factory=list)
message_tool_names: list[str | None] = field(default_factory=list)
multi_modal_data: "MultiModalData | None" = None

def tokens_per_message(
Expand Down
3 changes: 3 additions & 0 deletions renderers/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
RenderedTokens,
ToolSpec,
attribute_text_segments,
extract_message_tool_names,
reject_assistant_in_extension,
trim_to_turn_close,
)
Expand Down Expand Up @@ -247,6 +248,7 @@ def emit_text_segments(
sampled_mask=sampled,
is_content=content_mask,
message_roles=[m.get("role") or "" for m in messages],
message_tool_names=extract_message_tool_names(messages),
)

def render_ids(
Expand Down Expand Up @@ -390,6 +392,7 @@ def emit_text(
sampled_mask=[False] * total_len,
is_content=[False] * len(previous_ids) + ext_content,
message_roles=[m.get("role") or "" for m in new_messages],
message_tool_names=extract_message_tool_names(new_messages),
)

# ------------------------------------------------------------------
Expand Down
2 changes: 2 additions & 0 deletions renderers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ParsedResponse,
RenderedTokens,
ToolSpec,
extract_message_tool_names,
)
from renderers.configs import DefaultRendererConfig
from renderers.parsers import (
Expand Down Expand Up @@ -141,6 +142,7 @@ def render(
token_ids=token_ids,
message_indices=message_indices,
message_roles=message_roles,
message_tool_names=extract_message_tool_names(messages),
)

def _apply(self, messages, *, tools=None, add_generation_prompt=False) -> list[int]:
Expand Down
3 changes: 3 additions & 0 deletions renderers/glm45.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
RenderedTokens,
ToolSpec,
attribute_text_segments,
extract_message_tool_names,
reject_assistant_in_extension,
should_preserve_past_thinking,
)
Expand Down Expand Up @@ -265,6 +266,7 @@ def emit_text_segments(
sampled_mask=sampled,
is_content=content_mask,
message_roles=[m.get("role") or "" for m in messages],
message_tool_names=extract_message_tool_names(messages),
)

def render_ids(
Expand Down Expand Up @@ -445,6 +447,7 @@ def emit_text_segments(
sampled_mask=[False] * total_len,
is_content=[False] * len(previous_ids) + ext_content,
message_roles=[m.get("role") or "" for m in new_messages],
message_tool_names=extract_message_tool_names(new_messages),
)

def _render_assistant(
Expand Down
3 changes: 3 additions & 0 deletions renderers/glm5.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
RenderedTokens,
ToolSpec,
attribute_text_segments,
extract_message_tool_names,
reject_assistant_in_extension,
should_preserve_past_thinking,
)
Expand Down Expand Up @@ -281,6 +282,7 @@ def emit_text_segments(
sampled_mask=sampled,
is_content=content_mask,
message_roles=[m.get("role") or "" for m in messages],
message_tool_names=extract_message_tool_names(messages),
)

def render_ids(
Expand Down Expand Up @@ -456,6 +458,7 @@ def emit_text_segments(
sampled_mask=[False] * total_len,
is_content=[False] * len(previous_ids) + ext_content,
message_roles=[m.get("role") or "" for m in new_messages],
message_tool_names=extract_message_tool_names(new_messages),
)

def _render_assistant(
Expand Down
3 changes: 3 additions & 0 deletions renderers/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
ParsedResponse,
RenderedTokens,
ToolSpec,
extract_message_tool_names,
reject_assistant_in_extension,
should_preserve_past_thinking,
trim_to_turn_close,
Expand Down Expand Up @@ -465,6 +466,7 @@ def emit_harmony_message(
sampled_mask=sampled,
is_content=content_mask,
message_roles=[m.get("role") or "" for m in messages],
message_tool_names=extract_message_tool_names(messages),
)

def render_ids(
Expand Down Expand Up @@ -594,6 +596,7 @@ def bridge_to_next_turn(
sampled_mask=[False] * total_len,
is_content=[False] * len(previous_ids) + ext_content,
message_roles=[m.get("role") or "" for m in new_messages],
message_tool_names=extract_message_tool_names(new_messages),
)

# ── message conversion ───────────────────────────────────────────────────
Expand Down
3 changes: 3 additions & 0 deletions renderers/kimi_k2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ParsedResponse,
RenderedTokens,
ToolSpec,
extract_message_tool_names,
reject_assistant_in_extension,
trim_to_turn_close,
)
Expand Down Expand Up @@ -305,6 +306,7 @@ def emit_text(
sampled_mask=sampled,
is_content=content_mask,
message_roles=[m.get("role") or "" for m in caller_messages],
message_tool_names=extract_message_tool_names(caller_messages),
)

def render_ids(
Expand Down Expand Up @@ -454,6 +456,7 @@ def emit_text(
sampled_mask=[False] * total_len,
is_content=[False] * len(previous_ids) + ext_content,
message_roles=[m.get("role") or "" for m in new_messages],
message_tool_names=extract_message_tool_names(new_messages),
)

def _render_assistant(
Expand Down
5 changes: 5 additions & 0 deletions renderers/kimi_k25.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
RenderedTokens,
ToolCallParseStatus,
ToolSpec,
extract_message_tool_names,
reject_assistant_in_extension,
should_preserve_past_thinking,
trim_to_turn_close,
Expand Down Expand Up @@ -946,6 +947,7 @@ def emit_image(
sampled_mask=sampled,
is_content=content_mask,
message_roles=[m.get("role") or "" for m in messages],
message_tool_names=extract_message_tool_names(messages),
multi_modal_data=mm_data,
)

Expand Down Expand Up @@ -1188,13 +1190,15 @@ def emit_image(
merged_items.setdefault(modality, []).extend(vals)

bridge_roles = [m.get("role") or "" for m in new_messages]
bridge_tool_names = extract_message_tool_names(new_messages)
if not (merged_hashes or merged_placeholders or merged_items):
return RenderedTokens(
token_ids=tokens,
message_indices=indices,
sampled_mask=sampled,
is_content=content_mask,
message_roles=bridge_roles,
message_tool_names=bridge_tool_names,
)

mm_data = MultiModalData(
Expand All @@ -1208,6 +1212,7 @@ def emit_image(
sampled_mask=sampled,
is_content=content_mask,
message_roles=bridge_roles,
message_tool_names=bridge_tool_names,
multi_modal_data=mm_data,
)

Expand Down
3 changes: 3 additions & 0 deletions renderers/laguna_xs2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
RenderedTokens,
ToolSpec,
attribute_text_segments,
extract_message_tool_names,
reject_assistant_in_extension,
)
from renderers.configs import LagunaXS2RendererConfig
Expand Down Expand Up @@ -275,6 +276,7 @@ def emit_text_segments(
sampled_mask=sampled,
is_content=content_mask,
message_roles=[m.get("role") or "" for m in messages],
message_tool_names=extract_message_tool_names(messages),
)

def render_ids(
Expand Down Expand Up @@ -426,6 +428,7 @@ def emit_text_segments(
sampled_mask=[False] * total_len,
is_content=[False] * len(previous_ids) + ext_content,
message_roles=[m.get("role") or "" for m in new_messages],
message_tool_names=extract_message_tool_names(new_messages),
)

def _render_assistant(
Expand Down
3 changes: 3 additions & 0 deletions renderers/minimax_m2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
RenderedTokens,
ToolSpec,
attribute_text_segments,
extract_message_tool_names,
reject_assistant_in_extension,
should_preserve_past_thinking,
trim_to_turn_close,
Expand Down Expand Up @@ -278,6 +279,7 @@ def emit_token_overlap_body(
sampled_mask=sampled,
is_content=content_mask,
message_roles=[m.get("role") or "" for m in messages],
message_tool_names=extract_message_tool_names(messages),
)

def render_ids(
Expand Down Expand Up @@ -459,6 +461,7 @@ def emit_token_overlap_body(
sampled_mask=[False] * total_len,
is_content=[False] * len(previous_ids) + ext_content,
message_roles=[m.get("role") or "" for m in new_messages],
message_tool_names=extract_message_tool_names(new_messages),
)

def _render_assistant(
Expand Down
3 changes: 3 additions & 0 deletions renderers/nemotron3.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
RenderedTokens,
ToolSpec,
attribute_text_segments,
extract_message_tool_names,
reject_assistant_in_extension,
should_preserve_past_thinking,
trim_to_turn_close,
Expand Down Expand Up @@ -411,6 +412,7 @@ def emit_text_segments(
sampled_mask=sampled,
is_content=content_mask,
message_roles=[m.get("role") or "" for m in original_messages],
message_tool_names=extract_message_tool_names(original_messages),
)

def render_ids(
Expand Down Expand Up @@ -581,6 +583,7 @@ def emit_text_segments(
sampled_mask=[False] * total_len,
is_content=[False] * len(previous_ids) + ext_content,
message_roles=[m.get("role") or "" for m in new_messages],
message_tool_names=extract_message_tool_names(new_messages),
)

# ------------------------------------------------------------------
Expand Down
Loading
Loading