Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions tests/test_multiturn_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,3 +459,98 @@ async def test_responses_stored_in_state(self, mock_multiturn_env):
for response in state["responses"]:
assert hasattr(response, "choices")
assert len(response.choices) > 0

@pytest.mark.asyncio
async def test_strip_think_string_content_preserves_tail_and_tools(
self, mock_openai_client, sample_chat_dataset
):
"""Ensure only text up to </think> is removed; tool_calls and tool messages remain."""
from tests.conftest import SimpleMultiTurnEnv

env = SimpleMultiTurnEnv(
client=mock_openai_client,
model="test-model",
dataset=sample_chat_dataset,
parser=Parser(),
rubric=Rubric(),
exclude_think=True,
)

prompt = [{"role": "user", "content": "What is 2+2?"}]
state = await env.init_state(
prompt=prompt,
completion=[],
answer="",
task="default",
info={},
example_id=0,
)

assistant_msg = {
"role": "assistant",
"content": "<think>\nprivate reasoning</think>\n\nCall tool A",
"tool_calls": [
{
"id": "id1",
"type": "function",
"function": {"name": "toolA", "arguments": "{}"},
}
],
}
tool_msg = {"role": "tool", "content": "resultA", "tool_call_id": "id1"}
state["completion"].extend([assistant_msg, tool_msg])

ctx = await env.get_context_messages(state)
assert isinstance(ctx, list)

assert ctx[0] == prompt[0]
assert ctx[1]["role"] == "assistant"
assert ctx[1]["content"] == "Call tool A"
assert ctx[1].get("tool_calls") == assistant_msg["tool_calls"]
assert ctx[2] == tool_msg

@pytest.mark.asyncio
async def test_no_think_content_is_passthrough(
self, mock_openai_client, sample_chat_dataset
):
"""If no </think> present, assistant content remains unchanged."""
from tests.conftest import SimpleMultiTurnEnv

env = SimpleMultiTurnEnv(
client=mock_openai_client,
model="test-model",
dataset=sample_chat_dataset,
parser=Parser(),
rubric=Rubric(),
exclude_think=True,
)

prompt = [{"role": "user", "content": "Q"}]
state = await env.init_state(
prompt=prompt,
completion=[],
answer="",
task="default",
info={},
example_id=0,
)

assistant_msg = {
"role": "assistant",
"content": "No CoT here, proceed to tool",
"tool_calls": [
{
"id": "id3",
"type": "function",
"function": {"name": "toolC", "arguments": "{}"},
}
],
}
tool_msg = {"role": "tool", "content": "resultC", "tool_call_id": "id3"}
state["completion"].extend([assistant_msg, tool_msg])

ctx = await env.get_context_messages(state)
assert isinstance(ctx, list)
assert ctx[1]["content"] == assistant_msg["content"]
assert ctx[1].get("tool_calls") == assistant_msg["tool_calls"]
assert ctx[2] == tool_msg
47 changes: 45 additions & 2 deletions verifiers/envs/multiturn_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,15 @@


class MultiTurnEnv(Environment):
def __init__(self, max_turns: int = -1, **kwargs):
def __init__(
self,
max_turns: int = -1,
exclude_think: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.max_turns = max_turns
self.exclude_think = exclude_think

async def prompt_too_long(self, state: State) -> bool:
return state.get("prompt_too_long", False)
Expand All @@ -49,8 +55,45 @@ async def env_response(
"""
pass

@staticmethod
def _process_assistant_message(msg: ChatMessage) -> ChatMessage:
import re
from copy import deepcopy

def _strip_prefix_up_to_close(text: str) -> str:
return re.sub(r"(?s)^.*</think>", "", text).lstrip()

new_msg: ChatMessage = deepcopy(msg)
new_msg["role"] = msg.get("role", "assistant")

content = msg.get("content")
if content is None:
new_msg["content"] = ""
return new_msg

if "</think>" in content:
new_msg["content"] = _strip_prefix_up_to_close(content)
else:
new_msg["content"] = content

return new_msg

async def get_context_messages(self, state: State) -> Messages:
return state["prompt"] + state["completion"]
if not self.exclude_think:
return state["prompt"] + state["completion"]

prompt_msgs = state["prompt"]
completion_msgs = state["completion"]

processed_completion: list[ChatMessage] = []
for m in completion_msgs:
role = m.get("role")
if role == "assistant":
processed_completion.append(self._process_assistant_message(m))
else:
processed_completion.append(m)

return prompt_msgs + processed_completion

async def rollout(
self,
Expand Down
Loading