Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions src/agents/models/openai_chatcompletions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, Any, Literal, cast, overload

from openai import AsyncOpenAI, AsyncStream, Omit, omit
from openai import AsyncOpenAI, AsyncStream, NotGiven, Omit, omit
from openai.types import ChatModel
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice
Expand Down Expand Up @@ -45,6 +45,20 @@
from ..model_settings import ModelSettings


def _is_openai_omitted_value(value: Any) -> bool:
return isinstance(value, Omit | NotGiven)


# Keys whose first-class create_kwargs entry is reserved by the SDK and must
# never be overridden via ModelSettings.extra_args, even when the SDK passes
# the value as an OpenAI omit sentinel. ``stream`` is the canonical example:
# get_response() pins it to ``omit`` and then expects a non-streaming
# ChatCompletion, so allowing extra_args to flip it to True would cause the
# OpenAI client to return an async stream that the non-streaming code path
# cannot consume.
_RESERVED_CHAT_COMPLETIONS_KEYS = frozenset({"stream"})


class OpenAIChatCompletionsModel(Model):
_OFFICIAL_OPENAI_SUPPORTED_INPUT_CONTENT_TYPES = frozenset(
{"input_text", "input_image", "input_audio", "input_file"}
Expand Down Expand Up @@ -423,8 +437,15 @@ async def _fetch_response(
"extra_body": model_settings.extra_body,
"metadata": self._non_null_or_omit(model_settings.metadata),
}
extra_args = model_settings.extra_args or {}
duplicate_extra_arg_keys = sorted(
set(create_kwargs).intersection(model_settings.extra_args or {})
k
for k in extra_args
if k in create_kwargs
and (
k in _RESERVED_CHAT_COMPLETIONS_KEYS
or not _is_openai_omitted_value(create_kwargs[k])
)
)
if duplicate_extra_arg_keys:
if len(duplicate_extra_arg_keys) == 1:
Expand All @@ -436,7 +457,7 @@ async def _fetch_response(
raise TypeError(
f"chat.completions.create() got multiple values for keyword arguments {keys}"
)
create_kwargs.update(model_settings.extra_args or {})
create_kwargs.update(extra_args)

ret = await self._get_client().chat.completions.create(**create_kwargs)

Expand Down
135 changes: 135 additions & 0 deletions tests/models/test_openai_chatcompletions.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,3 +770,138 @@ def __init__(self):
assert ChatCmplHelpers.get_store_param(client, model_settings) is True, (
"Should respect explicitly set store=True"
)


def _build_chat_completions_dummy_client() -> tuple[Any, Any]:
class DummyCompletions:
def __init__(self) -> None:
self.kwargs: dict[str, Any] = {}

async def create(self, **kwargs: Any) -> Any:
self.kwargs = kwargs
msg = ChatCompletionMessage(role="assistant", content="ok")
choice = Choice(index=0, finish_reason="stop", message=msg)
return ChatCompletion(
id="resp-id",
created=0,
model="fake",
object="chat.completion",
choices=[choice],
)

class DummyClient:
def __init__(self, completions: DummyCompletions) -> None:
self.chat = type("_Chat", (), {"completions": completions})()
self.base_url = httpx.URL("http://fake")

completions = DummyCompletions()
return completions, DummyClient(completions)


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_fetch_response_chat_completions_allows_extra_arg_when_explicit_arg_is_omitted() -> (
None
):
"""An extra_args key must not collide with a create_kwargs entry whose
value is the OpenAI omit sentinel — the user simply has not set the
first-class field, so there is no real duplicate.
"""

completions, dummy_client = _build_chat_completions_dummy_client()
model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=cast(AsyncOpenAI, dummy_client))
with generation_span(disabled=True) as span:
await model._fetch_response(
system_instructions=None,
input="hi",
model_settings=ModelSettings(extra_args={"reasoning_effort": "high"}),
tools=[],
output_schema=None,
handoffs=[],
span=span,
tracing=ModelTracing.DISABLED,
stream=False,
)

assert completions.kwargs["reasoning_effort"] == "high"


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_fetch_response_chat_completions_rejects_duplicate_extra_args_keys() -> None:
"""When the same key is supplied through both first-class settings and
extra_args, the duplicate must still be reported.
"""

_completions, dummy_client = _build_chat_completions_dummy_client()
model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=cast(AsyncOpenAI, dummy_client))
with generation_span(disabled=True) as span:
with pytest.raises(TypeError, match="multiple values.*temperature"):
await model._fetch_response(
system_instructions=None,
input="hi",
model_settings=ModelSettings(temperature=0.5, extra_args={"temperature": 0.7}),
tools=[],
output_schema=None,
handoffs=[],
span=span,
tracing=ModelTracing.DISABLED,
stream=False,
)


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_fetch_response_chat_completions_rejects_stream_via_extra_args_in_non_streaming_call() -> ( # noqa: E501
None
):
"""``stream`` is reserved by the SDK. The non-streaming get_response path
pins create_kwargs["stream"] to the OpenAI omit sentinel and then expects
a ChatCompletion. Allowing ``extra_args={"stream": True}`` to slip
through the duplicate check would make the OpenAI client return an async
stream that the non-streaming code path cannot consume.
"""

_completions, dummy_client = _build_chat_completions_dummy_client()
model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=cast(AsyncOpenAI, dummy_client))
with generation_span(disabled=True) as span:
with pytest.raises(TypeError, match="multiple values.*stream"):
await model._fetch_response(
system_instructions=None,
input="hi",
model_settings=ModelSettings(extra_args={"stream": True}),
tools=[],
output_schema=None,
handoffs=[],
span=span,
tracing=ModelTracing.DISABLED,
stream=False,
)


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_fetch_response_chat_completions_rejects_stream_via_extra_args_in_streaming_call() -> ( # noqa: E501
None
):
"""``stream`` is also reserved on the streaming path. ``stream=True`` sets
create_kwargs["stream"] to ``True`` (a real value), so the original
intersection check would already catch this. Cover it explicitly so the
invariant is enforced for both directions.
"""

_completions, dummy_client = _build_chat_completions_dummy_client()
model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=cast(AsyncOpenAI, dummy_client))
with generation_span(disabled=True) as span:
with pytest.raises(TypeError, match="multiple values.*stream"):
await model._fetch_response(
system_instructions=None,
input="hi",
model_settings=ModelSettings(extra_args={"stream": False}),
tools=[],
output_schema=None,
handoffs=[],
span=span,
tracing=ModelTracing.DISABLED,
stream=True,
)
Loading