Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 28 additions & 16 deletions langfun/core/llms/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ class RateLimits(lf.ModelInfo.RateLimits):

@property
def max_tokens_per_minute(self) -> int:
return (self.max_input_tokens_per_minute
+ self.max_output_tokens_per_minute)
return (
self.max_input_tokens_per_minute + self.max_output_tokens_per_minute
)


SUPPORTED_MODELS = [
Expand Down Expand Up @@ -839,9 +840,7 @@ class Anthropic(rest.REST):
"""

model: pg.typing.Annotated[
pg.typing.Enum(
pg.MISSING_VALUE, [m.model_id for m in SUPPORTED_MODELS]
),
pg.typing.Enum(pg.MISSING_VALUE, [m.model_id for m in SUPPORTED_MODELS]),
'The name of the model to use.',
]

Expand All @@ -855,10 +854,7 @@ class Anthropic(rest.REST):

api_endpoint: str = 'https://api.anthropic.com/v1/messages'

api_version: Annotated[
str,
'Anthropic API version.'
] = '2023-06-01'
api_version: Annotated[str, 'Anthropic API version.'] = '2023-06-01'

thinking: Annotated[
bool | None,
Expand Down Expand Up @@ -912,9 +908,7 @@ def _use_adaptive_thinking(self) -> bool:
return self.model is not None and 'claude-opus-4-7' in self.model_id

def request(
self,
prompt: lf.Message,
sampling_options: lf.LMSamplingOptions
self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
) -> dict[str, Any]:
"""Returns the JSON input for a message."""
request = dict()
Expand Down Expand Up @@ -1022,14 +1016,27 @@ def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:

def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
message = lf.Message.from_value(json, format='anthropic')
input_tokens = json['usage']['input_tokens']
output_tokens = json['usage']['output_tokens']
usage = json.get('usage', {})
input_tokens = usage.get('input_tokens', 0)
output_tokens = usage.get('output_tokens', 0)
cache_read_tokens = usage.get('cache_read_input_tokens', 0)
cache_creation_tokens = usage.get('cache_creation_input_tokens', 0)

# Anthropic's input_tokens excludes cache hits. Total prompt tokens
# comprises both cached and uncached segments.
prompt_tokens = input_tokens + cache_read_tokens + cache_creation_tokens

return lf.LMSamplingResult(
[lf.LMSample(message)],
usage=lf.LMSamplingUsage(
prompt_tokens=input_tokens,
prompt_tokens=prompt_tokens,
completion_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
total_tokens=prompt_tokens + output_tokens,
cached_prompt_tokens=cache_read_tokens,
completion_tokens_details={
'cache_creation_input_tokens': cache_creation_tokens,
'cache_read_input_tokens': cache_read_tokens,
},
),
)

Expand Down Expand Up @@ -1118,21 +1125,25 @@ class Claude35(Anthropic):

class Claude35Sonnet(Claude35):
"""Claude 3.5 Sonnet model (latest)."""

model = 'claude-3-5-sonnet-latest'


class Claude35Sonnet_20241022(Claude35): # pylint: disable=invalid-name
"""Claude 3.5 Sonnet model (10/22/2024)."""

model = 'claude-3-5-sonnet-20241022'


class Claude35Haiku(Claude35):
"""Claude 3.5 Haiku model (latest)."""

model = 'claude-3-5-haiku-latest'


class Claude35Haiku_20241022(Claude35): # pylint: disable=invalid-name
"""Claude 3.5 Haiku model (10/22/2024)."""

model = 'claude-3-5-haiku-20241022'


Expand Down Expand Up @@ -1182,4 +1193,5 @@ def _register_anthropic_models():
if m.provider == 'Anthropic':
lf.LanguageModel.register(m.model_id, Anthropic)


_register_anthropic_models()
Loading