Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion src/openai/lib/_tools.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,43 @@
from __future__ import annotations

from typing import Any, Dict, cast
from typing import Any, Dict, Iterable, List, cast

import pydantic

from ._pydantic import to_strict_json_schema
from .._types import Omit
from .._utils import is_given
from ..types.chat import ChatCompletionFunctionToolParam
from ..types.shared_params import FunctionDefinition
from ..types.responses.function_tool_param import FunctionToolParam as ResponsesFunctionToolParam

_WEB_SEARCH_TOOL_TYPES = frozenset(
{"web_search", "web_search_2025_08_26", "web_search_preview", "web_search_preview_2025_03_11"}
)


def _apply_web_search_default_location_tools(
tools: Iterable[Any] | Omit,
) -> Iterable[Any] | Omit:
"""For web_search tools that lack user_location, inject user_location with type='approximate'.

This prevents the server from defaulting to a US-based location when no
user_location is specified, which is unexpected behavior for developers
outside the US.
"""
if not is_given(tools):
return tools

result: List[Any] = []
changed = False
for tool in tools:
if isinstance(tool, dict) and tool.get("type") in _WEB_SEARCH_TOOL_TYPES and "user_location" not in tool:
tool = {**tool, "user_location": {"type": "approximate"}}
changed = True
result.append(tool)

return result if changed else tools
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Return collected tools to avoid exhausting iterators

If tools is a one-shot iterable (e.g., generator/iterator), this function consumes it in the for loop and then returns the original iterable when changed is false. In that case the caller (all responses.* create/stream paths now using this helper) receives an already-exhausted iterable, so no tools are sent at all. This regression appears whenever users pass non-reiterable Iterable[ToolParam] values that either contain no web-search tool or only web-search tools that already have user_location.

Useful? React with 👍 / 👎.



class PydanticFunctionTool(Dict[str, Any]):
"""Dictionary wrapper so we can pass the given base model
Expand Down
18 changes: 11 additions & 7 deletions src/openai/resources/responses/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@
AsyncInputItemsWithStreamingResponse,
)
from ..._streaming import Stream, AsyncStream
from ...lib._tools import PydanticFunctionTool, ResponsesPydanticFunctionTool
from ...lib._tools import (
PydanticFunctionTool,
ResponsesPydanticFunctionTool,
_apply_web_search_default_location_tools,
)
from .input_tokens import (
InputTokens,
AsyncInputTokens,
Expand Down Expand Up @@ -942,7 +946,7 @@ def create(
"temperature": temperature,
"text": text,
"tool_choice": tool_choice,
"tools": tools,
"tools": _apply_web_search_default_location_tools(tools),
"top_logprobs": top_logprobs,
"top_p": top_p,
"truncation": truncation,
Expand Down Expand Up @@ -1256,7 +1260,7 @@ def parser(raw_response: Response) -> ParsedResponse[TextFormatT]:
"temperature": temperature,
"text": text,
"tool_choice": tool_choice,
"tools": tools,
"tools": _apply_web_search_default_location_tools(tools),
"top_logprobs": top_logprobs,
"top_p": top_p,
"truncation": truncation,
Expand Down Expand Up @@ -2623,7 +2627,7 @@ async def create(
"temperature": temperature,
"text": text,
"tool_choice": tool_choice,
"tools": tools,
"tools": _apply_web_search_default_location_tools(tools),
"top_logprobs": top_logprobs,
"top_p": top_p,
"truncation": truncation,
Expand Down Expand Up @@ -2941,7 +2945,7 @@ def parser(raw_response: Response) -> ParsedResponse[TextFormatT]:
"temperature": temperature,
"text": text,
"tool_choice": tool_choice,
"tools": tools,
"tools": _apply_web_search_default_location_tools(tools),
"top_logprobs": top_logprobs,
"top_p": top_p,
"truncation": truncation,
Expand Down Expand Up @@ -4597,7 +4601,7 @@ def create(
"temperature": temperature,
"text": text,
"tool_choice": tool_choice,
"tools": tools,
"tools": _apply_web_search_default_location_tools(tools),
"top_logprobs": top_logprobs,
"top_p": top_p,
"truncation": truncation,
Expand Down Expand Up @@ -4677,7 +4681,7 @@ async def create(
"temperature": temperature,
"text": text,
"tool_choice": tool_choice,
"tools": tools,
"tools": _apply_web_search_default_location_tools(tools),
"top_logprobs": top_logprobs,
"top_p": top_p,
"truncation": truncation,
Expand Down
84 changes: 84 additions & 0 deletions tests/lib/test_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from __future__ import annotations

from openai._types import Omit, omit
from openai.lib._tools import _apply_web_search_default_location_tools


class TestApplyWebSearchDefaultLocationTools:
"""Tests for _apply_web_search_default_location_tools."""

def test_no_tools_returns_omit(self) -> None:
"""When tools is Omit (not given), return it unchanged."""
result = _apply_web_search_default_location_tools(omit)
assert result is omit

def test_empty_tools_list(self) -> None:
"""An empty list of tools should be returned unchanged."""
tools: list = []
result = _apply_web_search_default_location_tools(tools)
assert result is tools # same reference, no changes

def test_non_web_search_tools_unchanged(self) -> None:
"""Tools that are not web_search should not be modified."""
tools = [
{"type": "function", "function": {"name": "my_func"}},
{"type": "code_interpreter"},
]
result = _apply_web_search_default_location_tools(tools)
assert result is tools # same reference, no changes

def test_web_search_injects_user_location(self) -> None:
"""web_search without user_location should get one injected."""
tools = [{"type": "web_search"}]
result = _apply_web_search_default_location_tools(tools)
assert result is not tools # new list created
assert result[0]["user_location"] == {"type": "approximate"}
assert result[0]["type"] == "web_search"

def test_web_search_with_existing_user_location_unchanged(self) -> None:
"""web_search that already has user_location should not be overridden."""
existing_loc = {"type": "approximate", "city": "London", "country": "GB"}
tools = [{"type": "web_search", "user_location": existing_loc}]
result = _apply_web_search_default_location_tools(tools)
assert result is tools # same reference, no changes needed
assert result[0]["user_location"] is existing_loc

def test_web_search_2025_08_26_injects(self) -> None:
tools = [{"type": "web_search_2025_08_26"}]
result = _apply_web_search_default_location_tools(tools)
assert result[0]["user_location"] == {"type": "approximate"}

def test_web_search_preview_injects(self) -> None:
tools = [{"type": "web_search_preview"}]
result = _apply_web_search_default_location_tools(tools)
assert result[0]["user_location"] == {"type": "approximate"}

def test_web_search_preview_2025_03_11_injects(self) -> None:
tools = [{"type": "web_search_preview_2025_03_11"}]
result = _apply_web_search_default_location_tools(tools)
assert result[0]["user_location"] == {"type": "approximate"}

def test_mixed_tools_only_web_search_modified(self) -> None:
"""When mixing web_search and non-web-search tools, only web_search gets modified."""
func_tool = {"type": "function", "function": {"name": "foo"}}
ws_tool = {"type": "web_search"}
tools = [func_tool, ws_tool]
result = _apply_web_search_default_location_tools(tools)
# function tool is unchanged
assert result[0] is func_tool
# web_search tool is a new dict with user_location injected
assert result[1]["user_location"] == {"type": "approximate"}
assert result[1]["type"] == "web_search"

def test_web_search_preserves_other_keys(self) -> None:
"""Injecting user_location should not drop other keys on the tool dict."""
tools = [{"type": "web_search", "extra_key": "value"}]
result = _apply_web_search_default_location_tools(tools)
assert result[0]["extra_key"] == "value"
assert result[0]["user_location"] == {"type": "approximate"}

def test_non_dict_tool_not_modified(self) -> None:
"""Non-dict tools (e.g. string shorthand) should pass through."""
tools = ["web_search"] # string, not dict
result = _apply_web_search_default_location_tools(tools)
assert result is tools # unchanged