Skip to content
Open
245 changes: 143 additions & 102 deletions app/main.py

Large diffs are not rendered by default.

123 changes: 102 additions & 21 deletions app/models_mapping.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,168 @@
"""
Unified model name mapping.

This allows you to use simple names like "llama-70b" and the conductor
will automatically translate to the correct provider-specific name.
"""

# Default model mappings: unified name -> provider-specific names
DEFAULT_MODEL_MAPPING = {
# Llama 3.3 70B - the flagship model
# Llama 3.3 70B
"llama-70b": {
"cerebras": "llama-3.3-70b",
"nvidia": "meta/llama-3.3-70b-instruct",
"groq": "llama-3.3-70b-versatile",
"openrouter": "meta-llama/llama-3.3-70b-instruct",
"sambanova": "Meta-Llama-3.3-70B-Instruct",
"huggingface": "meta-llama/Llama-3.3-70B-Instruct",
},
Comment thread
FaisalAhmedShariff marked this conversation as resolved.
"llama-3.3-70b": {
"cerebras": "llama-3.3-70b",
"nvidia": "meta/llama-3.3-70b-instruct",
"groq": "llama-3.3-70b-versatile",
"openrouter": "meta-llama/llama-3.3-70b-instruct",
"sambanova": "Meta-Llama-3.3-70B-Instruct",
"huggingface": "meta-llama/Llama-3.3-70B-Instruct",
},

# Llama 3.1 8B - fast and cheap
# Llama 3.1 8B
"llama-8b": {
"cerebras": "llama3.1-8b",
"nvidia": "meta/llama-3.1-8b-instruct",
"groq": "llama-3.1-8b-instant",
"openrouter": "meta-llama/llama-3.1-8b-instruct",
"sambanova": "Meta-Llama-3.1-8B-Instruct",
"huggingface": "meta-llama/Llama-3.1-8B-Instruct",
},
"llama-3.1-8b": {
"cerebras": "llama3.1-8b",
"nvidia": "meta/llama-3.1-8b-instruct",
"groq": "llama-3.1-8b-instant",
"openrouter": "meta-llama/llama-3.1-8b-instruct",
"sambanova": "Meta-Llama-3.1-8B-Instruct",
"huggingface": "meta-llama/Llama-3.1-8B-Instruct",
},

# Llama 3.1 70B
"llama-3.1-70b": {
"cerebras": "llama-3.1-70b",
"nvidia": "meta/llama-3.1-70b-instruct",
"groq": "llama-3.1-70b-versatile",
"openrouter": "meta-llama/llama-3.1-70b-instruct",
"sambanova": "Meta-Llama-3.1-70B-Instruct",
"huggingface": "meta-llama/Llama-3.1-70B-Instruct",
},
# Llama 3.1 405B
"llama-405b": {
"nvidia": "meta/llama-3.1-405b-instruct",
"sambanova": "Meta-Llama-3.1-405B-Instruct",
"openrouter": "meta-llama/llama-3.1-405b-instruct",
},
# Gemini models
"gemini-flash": {
"gemini": "gemini-2.0-flash",
"openrouter": "google/gemini-2.0-flash-exp:free",
},
"gemini-2.0-flash": {
"gemini": "gemini-2.0-flash",
"openrouter": "google/gemini-2.0-flash-exp:free",
},
"gemini-flash-lite": {
"gemini": "gemini-2.0-flash-lite",
},
# Mistral models
"mistral-small": {
"mistral": "mistral-small-latest",
"openrouter": "mistralai/mistral-small",
},
"mistral-7b": {
"mistral": "open-mistral-7b",
"openrouter": "mistralai/mistral-7b-instruct:free",
"huggingface": "mistralai/Mistral-7B-Instruct-v0.3",
},
# DeepSeek models
"deepseek-chat": {
"deepseek": "deepseek-chat",
"openrouter": "deepseek/deepseek-chat:free",
},
"deepseek-coder": {
"deepseek": "deepseek-coder",
"openrouter": "deepseek/deepseek-coder:free",
},
"deepseek-r1": {
"deepseek": "deepseek-reasoner",
"openrouter": "deepseek/deepseek-r1:free",
"sambanova": "DeepSeek-R1",
},
# Cohere models
"command-r": {
"cohere": "command-r-08-2024",
"openrouter": "cohere/command-r",
},
"command-r-plus": {
"cohere": "command-r-plus-08-2024",
"openrouter": "cohere/command-r-plus",
},
# Groq specific
"mixtral-8x7b": {
"groq": "mixtral-8x7b-32768",
"openrouter": "mistralai/mixtral-8x7b-instruct",
},
"gemma-7b": {
"groq": "gemma-7b-it",
"openrouter": "google/gemma-7b-it:free",
"huggingface": "google/gemma-7b-it",
},
"gemma2-9b": {
"groq": "gemma2-9b-it",
"openrouter": "google/gemma-2-9b-it:free",
"huggingface": "google/gemma-2-9b-it",
},
# Qwen models
"qwen-72b": {
"sambanova": "Qwen2.5-72B-Instruct",
"openrouter": "qwen/qwen-2.5-72b-instruct:free",
"huggingface": "Qwen/Qwen2.5-72B-Instruct",
},
}

# Default model to use when none specified
DEFAULT_MODEL = "llama-70b"


class ModelMapper:
"""Maps unified model names to provider-specific names."""

def __init__(self, custom_mappings: dict = None):
"""
Initialize with optional custom mappings from config.

Args:
custom_mappings: Dict of {unified_name: {provider: provider_name}}
"""
self.mappings = DEFAULT_MODEL_MAPPING.copy()
if custom_mappings:
self.mappings.update(custom_mappings)

def get_provider_model(self, unified_name: str, provider: str) -> str:
"""
Get the provider-specific model name.

Args:
unified_name: The unified model name (e.g., "llama-70b")
provider: The provider name (e.g., "cerebras", "nvidia")

Returns:
Provider-specific model name
"""
if not unified_name:
unified_name = DEFAULT_MODEL

# Normalize the name
name_lower = unified_name.lower().strip()

# Check if it's in our mappings
if name_lower in self.mappings:
provider_models = self.mappings[name_lower]
if provider in provider_models:
return provider_models[provider]

# If not found, return as-is (maybe it's already provider-specific)
return unified_name
Comment thread
FaisalAhmedShariff marked this conversation as resolved.
Comment thread
FaisalAhmedShariff marked this conversation as resolved.

def get_available_models(self) -> list[str]:
"""Get list of available unified model names."""
return list(self.mappings.keys())

def add_mapping(self, unified_name: str, provider_models: dict):
"""Add a custom model mapping."""
self.mappings[unified_name.lower()] = provider_models
self.mappings[unified_name.lower()] = provider_models
20 changes: 17 additions & 3 deletions app/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
"""Provider implementations for TrainForgeConductor."""

from .base import BaseProvider, ProviderKey
from .cerebras import CerebrasProvider
from .nvidia import NvidiaProvider
from .groq import GroqProvider
from .gemini import GeminiProvider
from .mistral import MistralProvider
from .openrouter import OpenRouterProvider
from .deepseek import DeepSeekProvider
from .huggingface import HuggingFaceProvider
from .cohere import CohereProvider
from .sambanova import SambaNovaProvider

Comment thread
FaisalAhmedShariff marked this conversation as resolved.
__all__ = [
"BaseProvider",
"ProviderKey",
"CerebrasProvider",
"NvidiaProvider",
]

"GroqProvider",
"GeminiProvider",
"MistralProvider",
"OpenRouterProvider",
"DeepSeekProvider",
"HuggingFaceProvider",
"CohereProvider",
"SambaNovaProvider",
]
162 changes: 162 additions & 0 deletions app/providers/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
"""Cohere provider implementation."""

import time
import httpx
import structlog

from app.models import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionChoice,
Message,
Usage,
)
from app.providers.base import BaseProvider, ProviderKey
from app.models_mapping import ModelMapper
from app.exceptions import (
RateLimitError,
CapabilityError,
ProviderUnavailableError,
)

logger = structlog.get_logger()


class CohereProvider(BaseProvider):
"""Cohere provider — free tier available."""

name = "cohere"

def __init__(self, base_url: str = "https://api.cohere.com/compatibility/v1", model_mapper: ModelMapper = None):
super().__init__(base_url, model_mapper)

async def chat_completion(
self,
key: ProviderKey,
request: ChatCompletionRequest,
) -> ChatCompletionResponse:
"""Execute chat completion via Cohere API."""

model = self.get_model_name(request.model)

payload = {
"model": model,
"messages": [{"role": m.role, "content": m.content} for m in request.messages],
"temperature": request.temperature,
"max_tokens": request.max_tokens,
"top_p": request.top_p,
}

if request.stop:
payload["stop"] = request.stop

headers = {
"Authorization": f"Bearer {key.api_key}",
"Content-Type": "application/json",
}

client = await self.get_client()

await logger.ainfo(
"Sending request to Cohere",
model=model,
key_name=key.key_name,
messages_count=len(request.messages),
)

try:
response = await client.post(
f"{self.base_url}/chat/completions",
json=payload,
headers=headers,
)
response.raise_for_status()
data = response.json()

choice = data["choices"][0]
usage = data.get("usage", {})

return ChatCompletionResponse(
id=data.get("id", f"cohere-{int(time.time())}"),
created=data.get("created", int(time.time())),
model=model,
choices=[
ChatCompletionChoice(
index=0,
message=Message(
role="assistant",
content=choice["message"]["content"],
),
finish_reason=choice.get("finish_reason", "stop"),
)
],
usage=Usage(
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
total_tokens=usage.get("total_tokens", 0),
),
provider=self.name,
provider_key_name=key.key_name,
)

except httpx.HTTPStatusError as e:
status_code = e.response.status_code
response_text = e.response.text

await logger.aerror(
"Cohere API error",
status_code=status_code,
response=response_text,
key_name=key.key_name,
)

if status_code == 429:
retry_after = e.response.headers.get("retry-after")
retry_seconds = float(retry_after) if retry_after else None
raise RateLimitError(
provider=self.name,
retry_after=retry_seconds,
message=f"Rate limit exceeded: {response_text[:200]}",
)

if 500 <= status_code < 600:
raise ProviderUnavailableError(
provider=self.name,
status_code=status_code,
message=f"Cohere server error: {response_text[:200]}",
)

if status_code == 400:
response_lower = response_text.lower()
if "image" in response_lower or "vision" in response_lower:
raise CapabilityError(
provider=self.name,
capability="vision",
message=f"Vision not supported: {response_text[:200]}",
)
if "tool" in response_lower or "function" in response_lower:
raise CapabilityError(
provider=self.name,
capability="tool_calls",
message=f"Tool calling error: {response_text[:200]}",
)
raise CapabilityError(
provider=self.name,
capability="unknown",
message=f"Request error: {response_text[:200]}",
)

raise

except httpx.TimeoutException as e:
await logger.aerror("Cohere request timeout", error=str(e))
raise ProviderUnavailableError(
provider=self.name,
status_code=504,
message=f"Request timeout: {str(e)}",
)
except (RateLimitError, CapabilityError, ProviderUnavailableError):
raise
except Exception as e:
await logger.aerror("Cohere request failed", error=str(e))
raise
Loading