Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 87cd807

Browse files
committed
Fix Anthropic FIM with muxing.
In the context of muxing, the code determining which mapper to use when receiving requests to be routed towards Anthropic was relying in `is_fim_request` only, and was not taking into account if the actual endpoint receiving the request was the legacy one (i.e. `/completions`) or the current one (i.e. `/chat/completions`). This caused the use of the wrong mapper, which led to an empty text content for the FIM request. A better way to determine which mapper to use is looking at the effective type, since that's the real source of truth for the translation.
1 parent dc2ceb0 commit 87cd807

9 files changed

Lines changed: 120 additions & 56 deletions

File tree

src/codegate/muxing/router.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,15 @@ async def route_to_dest_provider(
138138
# TODO this should be improved
139139
match model_route.endpoint.provider_type:
140140
case ProviderType.anthropic:
141-
if is_fim_request:
141+
# Note: despite `is_fim_request` being true, our
142+
# integration tests query the `/chat/completions`
143+
# endpoint, which causes the
144+
# `anthropic_from_legacy_openai` to incorrectly
145+
# populate the struct.
146+
#
147+
# Checking for the actual type is a much more
148+
# reliable way of determining the right mapper.
149+
if isinstance(parsed, openai.LegacyCompletionRequest):
142150
completion_function = anthropic.acompletion
143151
from_openai = anthropic_from_legacy_openai
144152
to_openai = anthropic_to_legacy_openai

src/codegate/providers/anthropic/provider.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,15 @@
1111
from codegate.providers.anthropic.completion_handler import AnthropicCompletion
1212
from codegate.providers.base import BaseProvider, ModelFetchError
1313
from codegate.providers.fim_analyzer import FIMAnalyzer
14-
from codegate.types.anthropic import ChatCompletionRequest, stream_generator
14+
from codegate.types.anthropic import (
15+
ChatCompletionRequest,
16+
single_message,
17+
single_response,
18+
stream_generator,
19+
)
20+
from codegate.types.generators import (
21+
completion_handler_replacement,
22+
)
1523

1624
logger = structlog.get_logger("codegate")
1725

@@ -118,18 +126,29 @@ async def create_message(
118126
body = await request.body()
119127

120128
if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None:
121-
print(f"{create_message.__name__}: {body}")
129+
print(f"{body.decode('utf-8')}")
122130

123131
req = ChatCompletionRequest.model_validate_json(body)
124132
is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, req)
125133

126-
return await self.process_request(
127-
req,
128-
x_api_key,
129-
self.base_url,
130-
is_fim_request,
131-
request.state.detected_client,
132-
)
134+
if req.stream:
135+
return await self.process_request(
136+
req,
137+
x_api_key,
138+
self.base_url,
139+
is_fim_request,
140+
request.state.detected_client,
141+
)
142+
else:
143+
return await self.process_request(
144+
req,
145+
x_api_key,
146+
self.base_url,
147+
is_fim_request,
148+
request.state.detected_client,
149+
completion_handler=completion_handler_replacement(single_message),
150+
stream_generator=single_response,
151+
)
133152

134153

135154
async def dumper(stream):

src/codegate/providers/ollama/completion_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ async def _ollama_dispatcher( # noqa: C901
7373
stream = openai_stream_generator(prepend(first, stream))
7474

7575
if isinstance(first, OpenAIChatCompletion):
76-
stream = openai_single_response_generator(first, stream)
76+
stream = openai_single_response_generator(first)
7777

7878
async for item in stream:
7979
yield item

src/codegate/types/anthropic/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from ._generators import (
22
acompletion,
33
message_wrapper,
4+
single_message,
5+
single_response,
46
stream_generator,
57
)
68
from ._request_models import (
@@ -49,6 +51,8 @@
4951
__all__ = [
5052
"acompletion",
5153
"message_wrapper",
54+
"single_message",
55+
"single_response",
5256
"stream_generator",
5357
"AssistantMessage",
5458
"CacheControl",

src/codegate/types/anthropic/_generators.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
ContentBlockDelta,
1313
ContentBlockStart,
1414
ContentBlockStop,
15+
Message,
1516
MessageDelta,
1617
MessageError,
1718
MessagePing,
@@ -27,7 +28,7 @@ async def stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]:
2728
try:
2829
async for chunk in stream:
2930
try:
30-
body = chunk.json(exclude_defaults=True, exclude_unset=True)
31+
body = chunk.json(exclude_unset=True)
3132
except Exception as e:
3233
logger.error("failed serializing payload", exc_info=e)
3334
err = MessageError(
@@ -37,7 +38,7 @@ async def stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]:
3738
message=str(e),
3839
),
3940
)
40-
body = err.json(exclude_defaults=True, exclude_unset=True)
41+
body = err.json(exclude_unset=True)
4142
yield f"event: error\ndata: {body}\n\n"
4243

4344
data = f"event: {chunk.type}\ndata: {body}\n\n"
@@ -55,10 +56,60 @@ async def stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]:
5556
message=str(e),
5657
),
5758
)
58-
body = err.json(exclude_defaults=True, exclude_unset=True)
59+
body = err.json(exclude_unset=True)
5960
yield f"event: error\ndata: {body}\n\n"
6061

6162

63+
async def single_response(stream: AsyncIterator[Any]) -> AsyncIterator[str]:
64+
"""Wraps a single response object in an AsyncIterator. This is
65+
meant to be used for non-streaming responses.
66+
67+
"""
68+
resp = await anext(stream)
69+
yield resp.model_dump_json(exclude_unset=True)
70+
71+
72+
async def single_message(request, api_key, base_url, stream=None, is_fim_request=None):
73+
headers = {
74+
"anthropic-version": "2023-06-01",
75+
"x-api-key": api_key,
76+
"accept": "application/json",
77+
"content-type": "application/json",
78+
}
79+
payload = request.model_dump_json(exclude_unset=True)
80+
81+
if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None:
82+
print(payload)
83+
84+
client = httpx.AsyncClient()
85+
async with client.stream(
86+
"POST",
87+
f"{base_url}/v1/messages",
88+
headers=headers,
89+
content=payload,
90+
timeout=60, # TODO this should not be hardcoded
91+
) as resp:
92+
match resp.status_code:
93+
case 200:
94+
text = await resp.aread()
95+
if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None:
96+
print(text.decode("utf-8"))
97+
yield Message.model_validate_json(text)
98+
case 400 | 401 | 403 | 404 | 413 | 429:
99+
text = await resp.aread()
100+
if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None:
101+
print(text.decode("utf-8"))
102+
yield MessageError.model_validate_json(text)
103+
case 500 | 529:
104+
text = await resp.aread()
105+
if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None:
106+
print(text.decode("utf-8"))
107+
yield MessageError.model_validate_json(text)
108+
case _:
109+
logger.error(f"unexpected status code {resp.status_code}", provider="anthropic")
110+
raise ValueError(f"unexpected status code {resp.status_code}", provider="anthropic")
111+
112+
62113
async def acompletion(request, api_key, base_url):
63114
headers = {
64115
"anthropic-version": "2023-06-01",
@@ -86,9 +137,13 @@ async def acompletion(request, api_key, base_url):
86137
yield event
87138
case 400 | 401 | 403 | 404 | 413 | 429:
88139
text = await resp.aread()
140+
if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None:
141+
print(text.decode("utf-8"))
89142
yield MessageError.model_validate_json(text)
90143
case 500 | 529:
91144
text = await resp.aread()
145+
if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None:
146+
print(text.decode("utf-8"))
92147
yield MessageError.model_validate_json(text)
93148
case _:
94149
logger.error(f"unexpected status code {resp.status_code}", provider="anthropic")

src/codegate/types/anthropic/_request_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ class ToolDef(pydantic.BaseModel):
155155
Literal["auto"],
156156
Literal["any"],
157157
Literal["tool"],
158+
Literal["none"],
158159
]
159160

160161

src/codegate/types/generators.py

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,27 @@
1-
import os
21
from typing import (
3-
Any,
4-
AsyncIterator,
2+
Callable,
53
)
64

7-
import pydantic
85
import structlog
96

107
logger = structlog.get_logger("codegate")
118

129

13-
# Since different providers typically use one of these formats for streaming
14-
# responses, we have a single stream generator for each format that is then plugged
15-
# into the adapter.
10+
def completion_handler_replacement(
11+
completion_handler: Callable,
12+
):
13+
async def _inner(
14+
request,
15+
base_url,
16+
api_key,
17+
stream=None,
18+
is_fim_request=None,
19+
):
20+
# Execute e.g. acompletion from Anthropic types
21+
return completion_handler(
22+
request,
23+
api_key,
24+
base_url,
25+
)
1626

17-
18-
async def sse_stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]:
19-
"""OpenAI-style SSE format"""
20-
try:
21-
async for chunk in stream:
22-
if isinstance(chunk, pydantic.BaseModel):
23-
# alternatively we might want to just dump the whole object
24-
# this might even allow us to tighten the typing of the stream
25-
chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True)
26-
try:
27-
if os.getenv("CODEGATE_DEBUG_OPENAI") is not None:
28-
print(chunk)
29-
yield f"data: {chunk}\n\n"
30-
except Exception as e:
31-
logger.error("failed generating output payloads", exc_info=e)
32-
yield f"data: {str(e)}\n\n"
33-
except Exception as e:
34-
logger.error("failed generating output payloads", exc_info=e)
35-
yield f"data: {str(e)}\n\n"
36-
finally:
37-
yield "data: [DONE]\n\n"
27+
return _inner

src/codegate/types/ollama/_generators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ async def stream_generator(
2323
try:
2424
async for chunk in stream:
2525
try:
26-
body = chunk.model_dump_json(exclude_none=True, exclude_unset=True)
26+
body = chunk.model_dump_json(exclude_unset=True)
2727
data = f"{body}\n"
2828

2929
if os.getenv("CODEGATE_DEBUG_OLLAMA") is not None:

src/codegate/types/openai/_generators.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,26 +50,13 @@ async def stream_generator(stream: AsyncIterator[StreamingChatCompletion]) -> As
5050

5151
async def single_response_generator(
5252
first: ChatCompletion,
53-
stream: AsyncIterator[ChatCompletion],
5453
) -> AsyncIterator[ChatCompletion]:
5554
"""Wraps a single response object in an AsyncIterator. This is
5655
meant to be used for non-streaming responses.
5756
5857
"""
5958
yield first.model_dump_json(exclude_none=True, exclude_unset=True)
6059

61-
# Note: this async for loop is necessary to force Python to return
62-
# an AsyncIterator. This is necessary because of the wiring at the
63-
# Provider level expecting an AsyncIterator rather than a single
64-
# response payload.
65-
#
66-
# Refactoring this means adding a code path specific for when we
67-
# expect single response payloads rather than an SSE stream.
68-
async for item in stream:
69-
if item:
70-
logger.error("no further items were expected", item=item)
71-
yield item.model_dump_json(exclude_none=True, exclude_unset=True)
72-
7360

7461
async def completions_streaming(request, api_key, base_url):
7562
if base_url is None:

0 commit comments

Comments
 (0)