Skip to content

Commit fe7c47d

Browse files
Merge pull request #262 from dreadnode/fix/ollama-truncation
fix: Ollama Truncation
2 parents c491fb1 + 27d48bc commit fe7c47d

10 files changed

Lines changed: 131 additions & 40 deletions

File tree

.secrets.baseline

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,12 @@
124124
},
125125
{
126126
"path": "detect_secrets.filters.heuristic.is_templated_secret"
127+
},
128+
{
129+
"path": "detect_secrets.filters.regex.should_exclude_file",
130+
"pattern": [
131+
"examples/*"
132+
]
127133
}
128134
],
129135
"results": {
@@ -133,14 +139,14 @@
133139
"filename": "docs/topics/generators.mdx",
134140
"hashed_secret": "ef5225a03e4f9cc953ab3c4dd41f5c4db7dc2e5b",
135141
"is_verified": false,
136-
"line_number": 342
142+
"line_number": 360
137143
},
138144
{
139145
"type": "Secret Keyword",
140146
"filename": "docs/topics/generators.mdx",
141147
"hashed_secret": "eb6256c862c356b375aafa760fa1851e33aa62a9",
142148
"is_verified": false,
143-
"line_number": 366
149+
"line_number": 384
144150
}
145151
],
146152
"tests/test_http_spec.py": [
@@ -153,5 +159,5 @@
153159
}
154160
]
155161
},
156-
"generated_at": "2025-07-23T00:26:51Z"
162+
"generated_at": "2025-09-03T21:56:13Z"
157163
}

docs/api/error.mdx

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,14 @@ max_rounds = max_rounds
6868

6969
The number of rounds which was exceeded.
7070

71+
GeneratorWarning
72+
----------------
73+
74+
Base class for all generator warnings.
75+
76+
This is used to indicate that something unexpected happened during the generator execution,
77+
but it is not critical enough to stop the execution.
78+
7179
InvalidGeneratorError
7280
---------------------
7381

docs/api/generator.mdx

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1934,7 +1934,9 @@ get\_generator
19341934

19351935
```python
19361936
get_generator(
1937-
identifier: str, *, params: GenerateParams | None = None
1937+
identifier: str,
1938+
*,
1939+
params: GenerateParams | dict[str, Any] | None = None,
19381940
) -> Generator
19391941
```
19401942

@@ -1964,7 +1966,7 @@ You can also specify arguments to the generator by comma-separating them:
19641966
(`str`)
19651967
–The identifier string to use to get a generator.
19661968
* **`params`**
1967-
(`GenerateParams | None`, default:
1969+
(`GenerateParams | dict[str, Any] | None`, default:
19681970
`None`
19691971
)
19701972
–The generation parameters to use for the generator.
@@ -1982,8 +1984,9 @@ You can also specify arguments to the generator by comma-separating them:
19821984

19831985
<Accordion title="Source code in rigging/generator/base.py" icon="code">
19841986
```python
1985-
@lru_cache(maxsize=128)
1986-
def get_generator(identifier: str, *, params: GenerateParams | None = None) -> Generator:
1987+
def get_generator(
1988+
identifier: str, *, params: GenerateParams | dict[str, t.Any] | None = None
1989+
) -> Generator:
19871990
"""
19881991
Get a generator by an identifier string. Uses LiteLLM by default.
19891992
@@ -2080,6 +2083,8 @@ def get_generator(identifier: str, *, params: GenerateParams | None = None) -> G
20802083
if isinstance(v, str) and v.lower() in ["true", "false"]:
20812084
init_kwargs[k] = v.lower() == "true"
20822085

2086+
params = GenerateParams(**params) if isinstance(params, dict) else params
2087+
20832088
try:
20842089
merged_params = GenerateParams(**kwargs).merge_with(params)
20852090
except Exception as e:

docs/api/model.mdx

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -237,15 +237,15 @@ def from_text(
237237
# Walk through any fields which are strings, and dedent them
238238

239239
for field_name, field_info in cls.model_fields.items():
240-
if isinstance(field_info, XmlEntityInfo) and field_info.annotation == str: # noqa: E721
240+
if isinstance(field_info, XmlEntityInfo) and field_info.annotation is str:
241241
model.__dict__[field_name] = textwrap.dedent(
242242
model.__dict__[field_name]
243243
).strip()
244244

245245
extracted.append((model, slice_))
246246
except Exception as e: # noqa: BLE001
247247
extracted.append((e, slice_))
248-
continue
248+
continue
249249

250250
# sort back to original order
251251
extracted.sort(key=lambda x: x[1].start)
@@ -471,11 +471,7 @@ def preprocess_with_cdata(cls, content: str) -> str:
471471
}
472472
else:
473473
field_map = {
474-
(
475-
field_info.path
476-
if isinstance(field_info, XmlEntityInfo) and field_info.path
477-
else field_name
478-
): field_info
474+
cls._get_field_xml_name(field_name, field_info): field_info
479475
for field_name, field_info in cls.model_fields.items()
480476
if isinstance(field_info, XmlEntityInfo)
481477
and field_info.location == EntityLocation.ELEMENT
@@ -722,7 +718,7 @@ def xml_example(cls) -> str:
722718
isinstance(field_info, XmlEntityInfo)
723719
and field_info.location == EntityLocation.ATTRIBUTE
724720
):
725-
path = field_info.path or field_name
721+
path = cls._get_field_xml_name(field_name, field_info)
726722
example = str(next(iter(field_info.examples or []), "")).replace('"', "&quot;")
727723
attribute_parts.append(f'{path}="{example}"')
728724
else:
@@ -732,7 +728,7 @@ def xml_example(cls) -> str:
732728
lines.append(f"<{cls.__xml_tag__}{attr_string}>")
733729

734730
for field_name, field_info in element_fields.items():
735-
path = (isinstance(field_info, XmlEntityInfo) and field_info.path) or field_name
731+
path = cls._get_field_xml_name(field_name, field_info)
736732
description = field_info.description
737733
example = str(next(iter(field_info.examples or []), ""))
738734

@@ -912,12 +908,16 @@ def make_from_schema(
912908
for field_name, field_schema in properties.items():
913909
field_type, field_info = _process_field(field_name, field_schema)
914910

911+
# Use the field name as alias if it differs from python naming conventions
912+
alias = field_name if field_name != field_name.replace("-", "_") else None
913+
915914
fields[field_name] = (
916915
field_type,
917916
field_cls(
918917
default=... if field_name in required else None,
919918
description=field_schema.get("description", ""),
920919
title=field_schema.get("title", ""),
920+
alias=alias,
921921
**field_info,
922922
)
923923
if isinstance(field_info, dict)

docs/topics/generators.mdx

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ openai/o3-mini
2222
gemini/gemini-2.5-pro
2323
claude-4-sonnet-latest
2424
vllm_hosted/meta-llama/Llama-3.1-8B-Instruct
25-
ollama/qwen3
25+
ollama_chat/qwen3
2626
2727
openai/gpt-4,api_key=sk-1234
2828
anthropic/claude-3-7-haiku-latest,stop=output:;---,seed=1337
@@ -184,12 +184,30 @@ NAME ID SIZE PROCESSOR UNTIL
184184
qwen3:0.6b 7df6b6e09427 2.3 GB 100% GPU 4 minutes from now
185185
```
186186

187-
Using this model in Rigging is as simple as using the `ollama/` or `ollama_chat/` prefixes:
187+
<Warning>
188+
Ollama is configured with a maximum context length on the server, by default 4096 tokens. This does not change depending on model and requires configuration to update.
189+
190+
If the input messages to the API would exceed this length, Ollama will silently truncate them to fit in the context window. This behavior can cause unexpected generation results due to missing context and is very difficult to detect in Rigging.
191+
192+
We make a best effort by monitoring model responses and checking if the reported input tokens is far less than the input messages we just sent. If observed, the following warning will be emitted.
193+
194+
```
195+
GeneratorWarning: Input messages may have been truncated ...
196+
```
197+
198+
When in doubt, monitor the Ollama server logs for the following:
199+
200+
```bash
201+
... msg="truncating input prompt" limit=4096 prompt=6767 keep=4 new=409
202+
```
203+
</Warning>
204+
205+
Using this model in Rigging is as simple as using the `ollama_chat/` (recommended) or `ollama/` prefixes:
188206

189207
```python
190208
import rigging as rg
191209

192-
qwen = rg.get_generator("ollama/qwen3:0.6b")
210+
qwen = rg.get_generator("ollama_chat/qwen3:0.6b")
193211

194212
chat = await qwen.chat("Hello!").run()
195213
print(chat.conversation)
@@ -211,7 +229,7 @@ If you are running the Ollama server somewhere besides localhost, just pass the
211229

212230
```python
213231
qwen = rg.get_generator(
214-
"ollama/qwen3:0.6b,api_base=http://remote-server:11434"
232+
"ollama_chat/qwen3:0.6b,api_base=http://remote-server:11434"
215233
)
216234
```
217235
</Note>

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "rigging"
3-
version = "3.3.2"
3+
version = "3.3.3"
44
description = "LLM Interaction Framework"
55
authors = ["Nick Landers <monoxgas@gmail.com>"]
66
license = "MIT"

rigging/error.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,15 @@ class TokenizerWarning(Warning):
8282
"""
8383

8484

85+
class GeneratorWarning(Warning):
86+
"""
87+
Base class for all generator warnings.
88+
89+
This is used to indicate that something unexpected happened during the generator execution,
90+
but it is not critical enough to stop the execution.
91+
"""
92+
93+
8594
# System Exceptions
8695

8796

rigging/generator/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import functools
55
import inspect
66
import typing as t
7-
from functools import lru_cache
87

98
from loguru import logger
109
from pydantic import (
@@ -738,8 +737,9 @@ def encode_value(val: t.Any) -> t.Any:
738737
return identifier
739738

740739

741-
@lru_cache(maxsize=128)
742-
def get_generator(identifier: str, *, params: GenerateParams | None = None) -> Generator:
740+
def get_generator(
741+
identifier: str, *, params: GenerateParams | dict[str, t.Any] | None = None
742+
) -> Generator:
743743
"""
744744
Get a generator by an identifier string. Uses LiteLLM by default.
745745
@@ -836,6 +836,8 @@ def decode_value(value: str) -> t.Any:
836836
if isinstance(v, str) and v.lower() in ["true", "false"]:
837837
init_kwargs[k] = v.lower() == "true"
838838

839+
params = GenerateParams(**params) if isinstance(params, dict) else params
840+
839841
try:
840842
merged_params = GenerateParams(**kwargs).merge_with(params)
841843
except Exception as e:

rigging/generator/litellm_.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import datetime
44
import re
55
import typing as t
6+
import warnings
67

78
from loguru import logger
89

10+
from rigging.error import GeneratorWarning
911
from rigging.generator.base import (
1012
Fixup,
1113
GeneratedMessage,
@@ -206,7 +208,7 @@ async def supports_function_calling(self) -> bool | None:
206208
):
207209
self._supports_function_calling = True
208210
except Exception as e: # noqa: BLE001
209-
logger.warning(f"Failed to check for function calling support: {e}")
211+
logger.warning(f"Error checking for function calling support: {e}")
210212
span.set_attribute("error", str(e))
211213

212214
span.set_attribute("supports_function_calling", self._supports_function_calling)
@@ -236,6 +238,37 @@ async def _ensure_delay_between_requests(self) -> None:
236238
# This seems like a brittle feature at the moment, so we'll
237239
# leave it out for now.
238240

241+
def _warn_on_input_truncation(
242+
self, messages: list[Message], response: "GeneratedMessage"
243+
) -> None:
244+
# Ollama has a known behavior where it performs silent truncation
245+
# of input messages rather than return an error or any API indication.
246+
#
247+
# This code attempts to detect such truncation by comparing the expected
248+
# input length with the reported usage - but it's not foolproof.
249+
#
250+
# See:
251+
# - https://github.com/ollama/ollama/issues/7043
252+
# - https://github.com/ollama/ollama/issues/7987
253+
# - https://github.com/ollama/ollama/issues/4967
254+
255+
# We can't check with usage info
256+
if not response.usage:
257+
return
258+
259+
# Get a general view of how long we might expect the input prompt to
260+
# We'll use a gracious 4 char per token estimate
261+
input_tokens_estimate = int(sum(len(message.content) for message in messages) / 4)
262+
263+
# Check if the response reports that accepted input tokens are less than this
264+
if response.usage.input_tokens < input_tokens_estimate:
265+
warnings.warn(
266+
f"Input messages may have been truncated - see https://github.com/ollama/ollama/issues/7043 "
267+
f"(input tokens: {response.usage.input_tokens} < estimate: {input_tokens_estimate})",
268+
GeneratorWarning,
269+
stacklevel=2,
270+
)
271+
239272
def _parse_model_response(
240273
self,
241274
response: "ModelResponse",
@@ -359,7 +392,9 @@ async def _generate_message(
359392
)
360393

361394
self._last_request_time = datetime.datetime.now(tz=datetime.timezone.utc)
362-
return self._parse_model_response(response)
395+
generated = self._parse_model_response(response)
396+
self._warn_on_input_truncation(list(messages), generated)
397+
return generated
363398

364399
async def _generate_text(self, text: str, params: GenerateParams) -> GeneratedText:
365400
import litellm

0 commit comments

Comments
 (0)