Skip to content

Commit 3ee9608

Browse files
author
SentienceDEV
committed
extend llm_provider to support vision; TS needed
1 parent 5482da0 commit 3ee9608

File tree

1 file changed

+216
-0
lines changed

1 file changed

+216
-0
lines changed

sentience/llm_provider.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,48 @@ def model_name(self) -> str:
8181
"""
8282
pass
8383

84+
def supports_vision(self) -> bool:
85+
"""
86+
Whether this provider supports image input for vision tasks.
87+
88+
Override in subclasses that support vision-capable models.
89+
90+
Returns:
91+
True if provider supports vision, False otherwise
92+
"""
93+
return False
94+
95+
def generate_with_image(
96+
self,
97+
system_prompt: str,
98+
user_prompt: str,
99+
image_base64: str,
100+
**kwargs,
101+
) -> LLMResponse:
102+
"""
103+
Generate a response with image input (for vision-capable models).
104+
105+
This method is used for vision fallback in assertions and visual agents.
106+
Override in subclasses that support vision-capable models.
107+
108+
Args:
109+
system_prompt: System instruction/context
110+
user_prompt: User query/request
111+
image_base64: Base64-encoded image (PNG or JPEG)
112+
**kwargs: Provider-specific parameters (temperature, max_tokens, etc.)
113+
114+
Returns:
115+
LLMResponse with content and token usage
116+
117+
Raises:
118+
NotImplementedError: If provider doesn't support vision
119+
"""
120+
raise NotImplementedError(
121+
f"{type(self).__name__} does not support vision. "
122+
"Use a vision-capable provider like OpenAIProvider with GPT-4o "
123+
"or AnthropicProvider with Claude 3."
124+
)
125+
84126

85127
class OpenAIProvider(LLMProvider):
86128
"""
@@ -187,6 +229,92 @@ def supports_json_mode(self) -> bool:
187229
model_lower = self._model_name.lower()
188230
return any(x in model_lower for x in ["gpt-4", "gpt-3.5"])
189231

232+
def supports_vision(self) -> bool:
233+
"""GPT-4o, GPT-4-turbo, and GPT-4-vision support vision."""
234+
model_lower = self._model_name.lower()
235+
return any(x in model_lower for x in ["gpt-4o", "gpt-4-turbo", "gpt-4-vision"])
236+
237+
def generate_with_image(
238+
self,
239+
system_prompt: str,
240+
user_prompt: str,
241+
image_base64: str,
242+
temperature: float = 0.0,
243+
max_tokens: int | None = None,
244+
**kwargs,
245+
) -> LLMResponse:
246+
"""
247+
Generate response with image input using OpenAI Vision API.
248+
249+
Args:
250+
system_prompt: System instruction
251+
user_prompt: User query
252+
image_base64: Base64-encoded image (PNG or JPEG)
253+
temperature: Sampling temperature (0.0 = deterministic)
254+
max_tokens: Maximum tokens to generate
255+
**kwargs: Additional OpenAI API parameters
256+
257+
Returns:
258+
LLMResponse object
259+
260+
Raises:
261+
NotImplementedError: If model doesn't support vision
262+
"""
263+
if not self.supports_vision():
264+
raise NotImplementedError(
265+
f"Model {self._model_name} does not support vision. "
266+
"Use gpt-4o, gpt-4-turbo, or gpt-4-vision-preview."
267+
)
268+
269+
messages = []
270+
if system_prompt:
271+
messages.append({"role": "system", "content": system_prompt})
272+
273+
# Vision message format with image_url
274+
messages.append(
275+
{
276+
"role": "user",
277+
"content": [
278+
{"type": "text", "text": user_prompt},
279+
{
280+
"type": "image_url",
281+
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
282+
},
283+
],
284+
}
285+
)
286+
287+
# Build API parameters
288+
api_params = {
289+
"model": self._model_name,
290+
"messages": messages,
291+
"temperature": temperature,
292+
}
293+
294+
if max_tokens:
295+
api_params["max_tokens"] = max_tokens
296+
297+
# Merge additional parameters
298+
api_params.update(kwargs)
299+
300+
# Call OpenAI API
301+
try:
302+
response = self.client.chat.completions.create(**api_params)
303+
except Exception as e:
304+
handle_provider_error(e, "OpenAI", "generate response with image")
305+
306+
choice = response.choices[0]
307+
usage = response.usage
308+
309+
return LLMResponseBuilder.from_openai_format(
310+
content=choice.message.content,
311+
prompt_tokens=usage.prompt_tokens if usage else None,
312+
completion_tokens=usage.completion_tokens if usage else None,
313+
total_tokens=usage.total_tokens if usage else None,
314+
model_name=response.model,
315+
finish_reason=choice.finish_reason,
316+
)
317+
190318
@property
191319
def model_name(self) -> str:
192320
return self._model_name
@@ -277,6 +405,94 @@ def supports_json_mode(self) -> bool:
277405
"""Anthropic doesn't have native JSON mode (requires prompt engineering)"""
278406
return False
279407

408+
def supports_vision(self) -> bool:
409+
"""Claude 3 models (Opus, Sonnet, Haiku) all support vision."""
410+
model_lower = self._model_name.lower()
411+
return any(x in model_lower for x in ["claude-3", "claude-3.5"])
412+
413+
def generate_with_image(
414+
self,
415+
system_prompt: str,
416+
user_prompt: str,
417+
image_base64: str,
418+
temperature: float = 0.0,
419+
max_tokens: int = 1024,
420+
**kwargs,
421+
) -> LLMResponse:
422+
"""
423+
Generate response with image input using Anthropic Vision API.
424+
425+
Args:
426+
system_prompt: System instruction
427+
user_prompt: User query
428+
image_base64: Base64-encoded image (PNG or JPEG)
429+
temperature: Sampling temperature
430+
max_tokens: Maximum tokens to generate (required by Anthropic)
431+
**kwargs: Additional Anthropic API parameters
432+
433+
Returns:
434+
LLMResponse object
435+
436+
Raises:
437+
NotImplementedError: If model doesn't support vision
438+
"""
439+
if not self.supports_vision():
440+
raise NotImplementedError(
441+
f"Model {self._model_name} does not support vision. "
442+
"Use Claude 3 models (claude-3-opus, claude-3-sonnet, claude-3-haiku)."
443+
)
444+
445+
# Anthropic vision message format
446+
messages = [
447+
{
448+
"role": "user",
449+
"content": [
450+
{
451+
"type": "image",
452+
"source": {
453+
"type": "base64",
454+
"media_type": "image/png",
455+
"data": image_base64,
456+
},
457+
},
458+
{
459+
"type": "text",
460+
"text": user_prompt,
461+
},
462+
],
463+
}
464+
]
465+
466+
# Build API parameters
467+
api_params = {
468+
"model": self._model_name,
469+
"max_tokens": max_tokens,
470+
"temperature": temperature,
471+
"messages": messages,
472+
}
473+
474+
if system_prompt:
475+
api_params["system"] = system_prompt
476+
477+
# Merge additional parameters
478+
api_params.update(kwargs)
479+
480+
# Call Anthropic API
481+
try:
482+
response = self.client.messages.create(**api_params)
483+
except Exception as e:
484+
handle_provider_error(e, "Anthropic", "generate response with image")
485+
486+
content = response.content[0].text if response.content else ""
487+
488+
return LLMResponseBuilder.from_anthropic_format(
489+
content=content,
490+
input_tokens=response.usage.input_tokens if hasattr(response, "usage") else None,
491+
output_tokens=response.usage.output_tokens if hasattr(response, "usage") else None,
492+
model_name=response.model,
493+
stop_reason=response.stop_reason,
494+
)
495+
280496
@property
281497
def model_name(self) -> str:
282498
return self._model_name

0 commit comments

Comments
 (0)