Skip to content

Commit e0da839

Browse files
committed
more llm providers: GLM & Gemini
1 parent 559c7e7 commit e0da839

File tree

1 file changed

+206
-0
lines changed

1 file changed

+206
-0
lines changed

sentience/llm_provider.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,212 @@ def model_name(self) -> str:
263263
return self._model_name
264264

265265

266+
class GLMProvider(LLMProvider):
267+
"""
268+
Zhipu AI GLM provider implementation (GLM-4, GLM-4-Plus, etc.)
269+
270+
Requirements:
271+
pip install zhipuai
272+
273+
Example:
274+
>>> from sentience.llm_provider import GLMProvider
275+
>>> llm = GLMProvider(api_key="your-api-key", model="glm-4-plus")
276+
>>> response = llm.generate("You are a helpful assistant", "Hello!")
277+
>>> print(response.content)
278+
"""
279+
280+
def __init__(self, api_key: str | None = None, model: str = "glm-4-plus"):
281+
"""
282+
Initialize GLM provider
283+
284+
Args:
285+
api_key: Zhipu AI API key (or set GLM_API_KEY env var)
286+
model: Model name (glm-4-plus, glm-4, glm-4-air, glm-4-flash, etc.)
287+
"""
288+
try:
289+
from zhipuai import ZhipuAI
290+
except ImportError:
291+
raise ImportError("ZhipuAI package not installed. Install with: pip install zhipuai")
292+
293+
self.client = ZhipuAI(api_key=api_key)
294+
self._model_name = model
295+
296+
def generate(
297+
self,
298+
system_prompt: str,
299+
user_prompt: str,
300+
temperature: float = 0.0,
301+
max_tokens: int | None = None,
302+
**kwargs,
303+
) -> LLMResponse:
304+
"""
305+
Generate response using GLM API
306+
307+
Args:
308+
system_prompt: System instruction
309+
user_prompt: User query
310+
temperature: Sampling temperature (0.0 = deterministic, 1.0 = creative)
311+
max_tokens: Maximum tokens to generate
312+
**kwargs: Additional GLM API parameters
313+
314+
Returns:
315+
LLMResponse object
316+
"""
317+
messages = []
318+
if system_prompt:
319+
messages.append({"role": "system", "content": system_prompt})
320+
messages.append({"role": "user", "content": user_prompt})
321+
322+
# Build API parameters
323+
api_params = {
324+
"model": self._model_name,
325+
"messages": messages,
326+
"temperature": temperature,
327+
}
328+
329+
if max_tokens:
330+
api_params["max_tokens"] = max_tokens
331+
332+
# Merge additional parameters
333+
api_params.update(kwargs)
334+
335+
# Call GLM API
336+
response = self.client.chat.completions.create(**api_params)
337+
338+
choice = response.choices[0]
339+
usage = response.usage
340+
341+
return LLMResponse(
342+
content=choice.message.content,
343+
prompt_tokens=usage.prompt_tokens if usage else None,
344+
completion_tokens=usage.completion_tokens if usage else None,
345+
total_tokens=usage.total_tokens if usage else None,
346+
model_name=response.model,
347+
finish_reason=choice.finish_reason,
348+
)
349+
350+
def supports_json_mode(self) -> bool:
351+
"""GLM-4 models support JSON mode"""
352+
return "glm-4" in self._model_name.lower()
353+
354+
@property
355+
def model_name(self) -> str:
356+
return self._model_name
357+
358+
359+
class GeminiProvider(LLMProvider):
360+
"""
361+
Google Gemini provider implementation (Gemini 2.0, Gemini 1.5 Pro, etc.)
362+
363+
Requirements:
364+
pip install google-generativeai
365+
366+
Example:
367+
>>> from sentience.llm_provider import GeminiProvider
368+
>>> llm = GeminiProvider(api_key="your-api-key", model="gemini-2.0-flash-exp")
369+
>>> response = llm.generate("You are a helpful assistant", "Hello!")
370+
>>> print(response.content)
371+
"""
372+
373+
def __init__(self, api_key: str | None = None, model: str = "gemini-2.0-flash-exp"):
374+
"""
375+
Initialize Gemini provider
376+
377+
Args:
378+
api_key: Google API key (or set GEMINI_API_KEY or GOOGLE_API_KEY env var)
379+
model: Model name (gemini-2.0-flash-exp, gemini-1.5-pro, gemini-1.5-flash, etc.)
380+
"""
381+
try:
382+
import google.generativeai as genai
383+
except ImportError:
384+
raise ImportError(
385+
"Google Generative AI package not installed. Install with: pip install google-generativeai"
386+
)
387+
388+
# Configure API key
389+
if api_key:
390+
genai.configure(api_key=api_key)
391+
else:
392+
import os
393+
394+
api_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
395+
if api_key:
396+
genai.configure(api_key=api_key)
397+
398+
self.genai = genai
399+
self._model_name = model
400+
self.model = genai.GenerativeModel(model)
401+
402+
def generate(
403+
self,
404+
system_prompt: str,
405+
user_prompt: str,
406+
temperature: float = 0.0,
407+
max_tokens: int | None = None,
408+
**kwargs,
409+
) -> LLMResponse:
410+
"""
411+
Generate response using Gemini API
412+
413+
Args:
414+
system_prompt: System instruction
415+
user_prompt: User query
416+
temperature: Sampling temperature (0.0 = deterministic, 2.0 = very creative)
417+
max_tokens: Maximum tokens to generate
418+
**kwargs: Additional Gemini API parameters
419+
420+
Returns:
421+
LLMResponse object
422+
"""
423+
# Combine system and user prompts (Gemini doesn't have separate system role in all versions)
424+
full_prompt = f"{system_prompt}\n\n{user_prompt}" if system_prompt else user_prompt
425+
426+
# Build generation config
427+
generation_config = {
428+
"temperature": temperature,
429+
}
430+
431+
if max_tokens:
432+
generation_config["max_output_tokens"] = max_tokens
433+
434+
# Merge additional parameters
435+
generation_config.update(kwargs)
436+
437+
# Call Gemini API
438+
response = self.model.generate_content(full_prompt, generation_config=generation_config)
439+
440+
# Extract content
441+
content = response.text if response.text else ""
442+
443+
# Token usage (if available)
444+
prompt_tokens = None
445+
completion_tokens = None
446+
total_tokens = None
447+
448+
if hasattr(response, "usage_metadata") and response.usage_metadata:
449+
prompt_tokens = response.usage_metadata.prompt_token_count
450+
completion_tokens = response.usage_metadata.candidates_token_count
451+
total_tokens = response.usage_metadata.total_token_count
452+
453+
return LLMResponse(
454+
content=content,
455+
prompt_tokens=prompt_tokens,
456+
completion_tokens=completion_tokens,
457+
total_tokens=total_tokens,
458+
model_name=self._model_name,
459+
finish_reason=None, # Gemini uses different finish reason format
460+
)
461+
462+
def supports_json_mode(self) -> bool:
463+
"""Gemini 1.5+ models support JSON mode via response_mime_type"""
464+
model_lower = self._model_name.lower()
465+
return any(x in model_lower for x in ["gemini-1.5", "gemini-2.0"])
466+
467+
@property
468+
def model_name(self) -> str:
469+
return self._model_name
470+
471+
266472
class LocalLLMProvider(LLMProvider):
267473
"""
268474
Local LLM provider using HuggingFace Transformers

0 commit comments

Comments
 (0)