Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions tests/test_tool_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,124 @@ def invalid_type_tool() -> list:

result = await env.call_tool("invalid_type_tool", {}, "call_0")
assert result["content"] == "[{'type': 'audio', 'data': 'base64data'}]"


class TestToolEnvNonDictArgs:
"""Tests for ToolEnv handling of non-dict tool arguments (issue #562).

When a model produces double-encoded JSON, json.loads succeeds but returns
a str instead of a dict. ToolEnv should handle this gracefully rather than
crashing the RL training run.
"""

@pytest.mark.asyncio
async def test_tool_env_non_dict_args_stops_when_configured(
self, mock_client, sample_chat_dataset, make_input
):
"""Test that ToolEnv stops rollout when tool args are not a dict and ValueError is in stop_errors."""

class StrictToolEnv(vf.ToolEnv):
def __init__(self, **kwargs):
super().__init__(
tools=[square_tool], stop_errors=[ValueError], **kwargs
)

env = StrictToolEnv(
client=mock_client,
model="test-model",
dataset=sample_chat_dataset,
parser=vf.Parser(),
rubric=vf.Rubric(),
)

# Create a tool call with double-encoded JSON arguments (json.loads
# succeeds but returns str instead of dict)
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
Function,
)

double_encoded_args = json.dumps(json.dumps({"x": 4}))
tool_call = ChatCompletionMessageToolCall(
id="call_0",
type="function",
function=Function(
name="square_tool",
arguments=double_encoded_args,
),
)

mock_client.add_response(
messages=[{"role": "user", "content": "Square 4"}],
response="Using tool",
tool_calls=[tool_call],
)

state = await env.rollout(
input=make_input(
prompt=[{"role": "user", "content": "Square 4"}], answer="", task=""
),
client=mock_client,
model="test-model",
)

assert state.get("error") is not None
assert isinstance(state["error"], vf.ToolParseError)
assert isinstance(state["error"], vf.ToolError)
assert state["is_completed"] is True
assert state["stop_condition"] == "has_error"

@pytest.mark.asyncio
async def test_tool_env_non_dict_args_graceful_fallback(
self, mock_client, sample_chat_dataset, make_input
):
"""Test that ToolEnv gracefully returns error message when tool args are
not a dict and no stop_errors are configured."""

env = vf.ToolEnv(
tools=[square_tool],
client=mock_client,
model="test-model",
dataset=sample_chat_dataset,
parser=vf.Parser(),
rubric=vf.Rubric(),
)

from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
Function,
)

# Simulate double-encoded JSON: json.loads returns a str, not a dict
double_encoded_args = json.dumps(json.dumps({"x": 4}))
tool_call_bad = ChatCompletionMessageToolCall(
id="call_0",
type="function",
function=Function(
name="square_tool",
arguments=double_encoded_args,
),
)

mock_client.add_response(
messages=[{"role": "user", "content": "Square 4"}],
response="Using tool",
tool_calls=[tool_call_bad],
)

# Second response (after error feedback) completes normally
mock_client.set_default_response("Done")

state = await env.rollout(
input=make_input(
prompt=[{"role": "user", "content": "Square 4"}], answer="", task=""
),
client=mock_client,
model="test-model",
)

# Should NOT crash - error is returned as a tool message, training continues
completion = state["completion"]
tool_messages = [m for m in completion if m.get("role") == "tool"]
assert len(tool_messages) >= 1
assert "Expected tool arguments to be a dict" in tool_messages[0]["content"]
7 changes: 6 additions & 1 deletion verifiers/envs/tool_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,12 @@ async def env_response(
tool_call_id: str = tool_call.id
try:
tool_name: str = tool_call.name
tool_args: dict = json.loads(tool_call.arguments)
parsed_args = json.loads(tool_call.arguments)
if not isinstance(parsed_args, dict):
raise ValueError(
f"Expected tool arguments to be a dict, got {type(parsed_args).__name__}: {parsed_args}"
)
tool_args: dict = parsed_args
except Exception as e:
if self._should_stop_for_error(e):
raise vf.ToolParseError from e
Expand Down