Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/strands/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ class AnthropicConfig(BaseModelConfig, total=False):
params: Additional model parameters (e.g., temperature).
For a complete list of supported parameters, see https://docs.anthropic.com/en/api/messages.
use_native_token_count: Whether to use the native Anthropic count_tokens API.
When True (default), count_tokens() calls the Anthropic API for accurate counts.
When False, skips the API call and uses the local estimator.
When True, count_tokens() calls the Anthropic API for accurate counts.
When False (default), skips the API call and uses the local estimator.
"""

max_tokens: Required[int]
Expand Down Expand Up @@ -398,7 +398,7 @@ async def count_tokens(
Returns:
Total input token count.
"""
if self.config.get("use_native_token_count") is False:
if self.config.get("use_native_token_count") is not True:
return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content)

try:
Expand Down
6 changes: 3 additions & 3 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ class BedrockConfig(BaseModelConfig, total=False):
temperature: Controls randomness in generation (higher = more random)
top_p: Controls diversity via nucleus sampling (alternative to temperature)
use_native_token_count: Whether to use the native Bedrock CountTokens API.
When True (default), count_tokens() calls the Bedrock API for accurate counts.
When False, skips the API call and uses the local estimator.
When True, count_tokens() calls the Bedrock API for accurate counts.
When False (default), skips the API call and uses the local estimator.
"""

additional_args: dict[str, Any] | None
Expand Down Expand Up @@ -798,7 +798,7 @@ async def count_tokens(
Returns:
Total input token count.
"""
if self.config.get("use_native_token_count") is False:
if self.config.get("use_native_token_count") is not True:
return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content)

model_id: str = self.config["model_id"]
Expand Down
6 changes: 3 additions & 3 deletions src/strands/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ class GeminiConfig(BaseModelConfig, total=False):
For a complete list of supported tools, see
https://ai.google.dev/api/caching#Tool
use_native_token_count: Whether to use the native Gemini count_tokens API.
When True (default), count_tokens() calls the Gemini API for accurate counts.
When False, skips the API call and uses the local estimator.
When True, count_tokens() calls the Gemini API for accurate counts.
When False (default), skips the API call and uses the local estimator.
"""

model_id: Required[str]
Expand Down Expand Up @@ -461,7 +461,7 @@ async def count_tokens(
Returns:
Total input token count.
"""
if self.config.get("use_native_token_count") is False:
if self.config.get("use_native_token_count") is not True:
return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content)

try:
Expand Down
6 changes: 3 additions & 3 deletions src/strands/models/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ class LlamaCppConfig(BaseModelConfig, total=False):
- slot_id: Slot ID for parallel inference
- samplers: Custom sampler order
use_native_token_count: Whether to use the native llama.cpp /tokenize endpoint.
When True (default), count_tokens() calls the server's tokenize endpoint for accurate counts.
When False, skips the API call and uses the local estimator.
When True, count_tokens() calls the server's tokenize endpoint for accurate counts.
When False (default), skips the API call and uses the local estimator.
"""

model_id: str
Expand Down Expand Up @@ -537,7 +537,7 @@ async def count_tokens(
Returns:
Total input token count.
"""
if self.config.get("use_native_token_count") is False:
if self.config.get("use_native_token_count") is not True:
return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content)

try:
Expand Down
6 changes: 3 additions & 3 deletions src/strands/models/openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ class OpenAIResponsesConfig(BaseModelConfig, total=False):
When True, the server stores conversation history and the client does not need to
send the full message history with each request. Defaults to False.
use_native_token_count: Whether to use the native OpenAI input_tokens.count API.
When True (default), count_tokens() calls the OpenAI API for accurate counts.
When False, skips the API call and uses the local estimator.
When True, count_tokens() calls the OpenAI API for accurate counts.
When False (default), skips the API call and uses the local estimator.
"""

model_id: str
Expand Down Expand Up @@ -242,7 +242,7 @@ async def count_tokens(
Returns:
Total input token count.
"""
if self.config.get("use_native_token_count") is False:
if self.config.get("use_native_token_count") is not True:
return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content)

try:
Expand Down
13 changes: 12 additions & 1 deletion tests/strands/models/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,7 @@ class TestCountTokens:
@pytest.fixture
def model_with_client(self, anthropic_client, model_id, max_tokens):
_ = anthropic_client
return AnthropicModel(model_id=model_id, max_tokens=max_tokens)
return AnthropicModel(model_id=model_id, max_tokens=max_tokens, use_native_token_count=True)

@pytest.fixture
def messages(self):
Expand Down Expand Up @@ -1175,3 +1175,14 @@ async def test_skip_native_api_when_use_native_token_count_false(
anthropic_client.messages.count_tokens.assert_not_called()
assert isinstance(result, int)
assert result >= 0

@pytest.mark.asyncio
async def test_skip_native_api_by_default(self, anthropic_client, model_id, max_tokens, messages):
_ = anthropic_client
model = AnthropicModel(model_id=model_id, max_tokens=max_tokens)

result = await model.count_tokens(messages=messages)

anthropic_client.messages.count_tokens.assert_not_called()
assert isinstance(result, int)
assert result >= 0
19 changes: 15 additions & 4 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -3343,7 +3343,7 @@ def clean_cache(self):
@pytest.fixture
def model_with_client(self, bedrock_client, model_id):
_ = bedrock_client
return BedrockModel(model_id=model_id)
return BedrockModel(model_id=model_id, use_native_token_count=True)

@pytest.fixture
def messages(self):
Expand Down Expand Up @@ -3459,7 +3459,7 @@ async def test_fallback_logs_debug(self, model_with_client, bedrock_client, mess

@pytest.mark.asyncio
async def test_caches_model_id_when_count_tokens_unsupported(self, bedrock_client, messages):
model = BedrockModel(model_id="unsupported-cache-test-model")
model = BedrockModel(model_id="unsupported-cache-test-model", use_native_token_count=True)
bedrock_client.count_tokens.side_effect = ClientError(
{"Error": {"Code": "ValidationException", "Message": "The provided model doesn't support counting tokens"}},
"CountTokens",
Expand All @@ -3475,7 +3475,7 @@ async def test_caches_model_id_when_count_tokens_unsupported(self, bedrock_clien

@pytest.mark.asyncio
async def test_caches_model_id_when_access_denied(self, bedrock_client, messages):
model = BedrockModel(model_id="access-denied-cache-test-model")
model = BedrockModel(model_id="access-denied-cache-test-model", use_native_token_count=True)
bedrock_client.count_tokens.side_effect = ClientError(
{
"Error": {
Expand Down Expand Up @@ -3523,7 +3523,7 @@ async def test_access_denied_logs_warning_with_full_error(

@pytest.mark.asyncio
async def test_does_not_cache_model_id_for_other_errors(self, bedrock_client, messages):
model = BedrockModel(model_id="transient-error-test-model")
model = BedrockModel(model_id="transient-error-test-model", use_native_token_count=True)
bedrock_client.count_tokens.side_effect = RuntimeError("Transient network error")

await model.count_tokens(messages=messages)
Expand All @@ -3543,3 +3543,14 @@ async def test_skip_native_api_when_use_native_token_count_false(self, bedrock_c
bedrock_client.count_tokens.assert_not_called()
assert isinstance(result, int)
assert result >= 0

@pytest.mark.asyncio
async def test_skip_native_api_by_default(self, bedrock_client, model_id, messages):
_ = bedrock_client
model = BedrockModel(model_id=model_id)

result = await model.count_tokens(messages=messages)

bedrock_client.count_tokens.assert_not_called()
assert isinstance(result, int)
assert result >= 0
13 changes: 12 additions & 1 deletion tests/strands/models/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,7 +1144,7 @@ def gemini_client(self):
@pytest.fixture
def model(self, gemini_client):
_ = gemini_client
return GeminiModel(model_id="m1")
return GeminiModel(model_id="m1", use_native_token_count=True)

@pytest.fixture
def messages(self):
Expand Down Expand Up @@ -1239,3 +1239,14 @@ async def test_skip_native_api_when_use_native_token_count_false(self, gemini_cl
gemini_client.aio.models.count_tokens.assert_not_called()
assert isinstance(result, int)
assert result >= 0

@pytest.mark.asyncio
async def test_skip_native_api_by_default(self, gemini_client, messages):
_ = gemini_client
model = GeminiModel(model_id="m1")

result = await model.count_tokens(messages=messages)

gemini_client.aio.models.count_tokens.assert_not_called()
assert isinstance(result, int)
assert result >= 0
13 changes: 12 additions & 1 deletion tests/strands/models/test_llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ class TestCountTokens:

@pytest.fixture
def model(self):
return LlamaCppModel(base_url="http://localhost:8080")
return LlamaCppModel(base_url="http://localhost:8080", use_native_token_count=True)

@pytest.fixture
def messages(self):
Expand Down Expand Up @@ -814,3 +814,14 @@ async def test_skip_native_api_when_use_native_token_count_false(self, messages)
model.client.post.assert_not_called()
assert isinstance(result, int)
assert result >= 0

@pytest.mark.asyncio
async def test_skip_native_api_by_default(self, messages):
model = LlamaCppModel(base_url="http://localhost:8080")
model.client.post = AsyncMock()

result = await model.count_tokens(messages=messages)

model.client.post.assert_not_called()
assert isinstance(result, int)
assert result >= 0
13 changes: 12 additions & 1 deletion tests/strands/models/test_openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,7 +1224,7 @@ def openai_client(self):
@pytest.fixture
def model(self, openai_client):
_ = openai_client
return OpenAIResponsesModel(model_id="gpt-4o")
return OpenAIResponsesModel(model_id="gpt-4o", use_native_token_count=True)

@pytest.fixture
def messages(self):
Expand Down Expand Up @@ -1329,6 +1329,17 @@ async def test_skip_native_api_when_use_native_token_count_false(self, openai_cl
assert isinstance(result, int)
assert result >= 0

@pytest.mark.asyncio
async def test_skip_native_api_by_default(self, openai_client, messages):
_ = openai_client
model = OpenAIResponsesModel(model_id="gpt-4o")

result = await model.count_tokens(messages=messages)

openai_client.responses.input_tokens.count.assert_not_called()
assert isinstance(result, int)
assert result >= 0


# =============================================================================
# Bedrock Mantle (bedrock_mantle_config) integration with OpenAIResponsesModel
Expand Down
Loading