Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions agentic_security/llm_providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
)
from agentic_security.llm_providers.openai_provider import OpenAIProvider
from agentic_security.llm_providers.anthropic_provider import AnthropicProvider
from agentic_security.llm_providers.litellm_provider import LiteLLMProvider
from agentic_security.llm_providers.factory import create_provider, get_provider_class

__all__ = [
Expand All @@ -17,6 +18,7 @@
"LLMRateLimitError",
"OpenAIProvider",
"AnthropicProvider",
"LiteLLMProvider",
"create_provider",
"get_provider_class",
]
2 changes: 2 additions & 0 deletions agentic_security/llm_providers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ def _ensure_registered() -> None:
return
from agentic_security.llm_providers.openai_provider import OpenAIProvider
from agentic_security.llm_providers.anthropic_provider import AnthropicProvider
from agentic_security.llm_providers.litellm_provider import LiteLLMProvider

_PROVIDERS["openai"] = OpenAIProvider
_PROVIDERS["anthropic"] = AnthropicProvider
_PROVIDERS["litellm"] = LiteLLMProvider


def register_provider(name: str, provider_class: type[BaseLLMProvider]) -> None:
Expand Down
112 changes: 112 additions & 0 deletions agentic_security/llm_providers/litellm_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""LiteLLM provider — unified access to 100+ LLM backends."""

from typing import Any

import litellm

from agentic_security.llm_providers.base import (
BaseLLMProvider,
LLMMessage,
LLMProviderError,
LLMRateLimitError,
LLMResponse,
)


class LiteLLMProvider(BaseLLMProvider):
"""LLM provider using LiteLLM SDK for 100+ backends.

Accepts any LiteLLM model string (e.g. ``openai/gpt-4o``,
``anthropic/claude-sonnet-4-6``, ``groq/llama-3.3-70b-versatile``).
"""

DEFAULT_MODEL = "openai/gpt-4o-mini"

def __init__(
self,
model: str = DEFAULT_MODEL,
api_key: str | None = None,
api_base: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(model, **kwargs)
self._api_key = api_key
self._api_base = api_base

def _call_kwargs(self) -> dict[str, Any]:
kwargs: dict[str, Any] = {"model": self.model, "drop_params": True}
if self._api_key:
kwargs["api_key"] = self._api_key
if self._api_base:
kwargs["api_base"] = self._api_base
return kwargs

@classmethod
def get_supported_models(cls) -> list[str]:
return [
"openai/gpt-4o",
"openai/gpt-4o-mini",
"anthropic/claude-sonnet-4-6",
"anthropic/claude-haiku-4-5",
"groq/llama-3.3-70b-versatile",
"together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
]

def _messages_to_dicts(self, messages: list[LLMMessage]) -> list[dict[str, str]]:
return [{"role": m.role, "content": m.content} for m in messages]

def _parse_response(self, response: Any) -> LLMResponse:
choice = response.choices[0]
usage = None
if response.usage:
usage = {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
}
return LLMResponse(
content=choice.message.content or "",
model=getattr(response, "model", self.model),
finish_reason=choice.finish_reason,
usage=usage,
)

def _handle_error(self, e: Exception) -> None:
qualname = f"{type(e).__module__}.{type(e).__name__}"
if qualname == "litellm.exceptions.RateLimitError":
raise LLMRateLimitError(str(e)) from e
raise LLMProviderError(str(e)) from e

async def generate(self, prompt: str, **kwargs: Any) -> LLMResponse:
messages = [LLMMessage(role="user", content=prompt)]
if system_prompt := kwargs.pop("system_prompt", None):
messages.insert(0, LLMMessage(role="system", content=system_prompt))
return await self.chat(messages, **kwargs)

async def chat(self, messages: list[LLMMessage], **kwargs: Any) -> LLMResponse:
try:
response = await litellm.acompletion(
messages=self._messages_to_dicts(messages),
**{**self._call_kwargs(), **kwargs},
)
return self._parse_response(response)
except Exception as e:
self._handle_error(e)
raise

def sync_generate(self, prompt: str, **kwargs: Any) -> LLMResponse:
messages = [LLMMessage(role="user", content=prompt)]
if system_prompt := kwargs.pop("system_prompt", None):
messages.insert(0, LLMMessage(role="system", content=system_prompt))
return self.sync_chat(messages, **kwargs)

def sync_chat(self, messages: list[LLMMessage], **kwargs: Any) -> LLMResponse:
try:
response = litellm.completion(
messages=self._messages_to_dicts(messages),
**{**self._call_kwargs(), **kwargs},
)
return self._parse_response(response)
except Exception as e:
self._handle_error(e)
raise
233 changes: 233 additions & 0 deletions tests/unit/llm_providers/test_litellm_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
"""Tests for LiteLLM provider."""

import pytest
from inline_snapshot import snapshot
from unittest.mock import MagicMock, AsyncMock, patch

from agentic_security.llm_providers.litellm_provider import LiteLLMProvider
from agentic_security.llm_providers.base import (
LLMMessage,
LLMProviderError,
LLMRateLimitError,
)


def _mock_response(content="Hello", model="openai/gpt-4o-mini", finish_reason="stop",
prompt_tokens=10, completion_tokens=5, total_tokens=15):
resp = MagicMock()
resp.choices = [MagicMock()]
resp.choices[0].message.content = content
resp.choices[0].finish_reason = finish_reason
resp.model = model
resp.usage.prompt_tokens = prompt_tokens
resp.usage.completion_tokens = completion_tokens
resp.usage.total_tokens = total_tokens
return resp


class TestLiteLLMProviderInit:
def test_default_model(self):
provider = LiteLLMProvider()
assert provider.model == snapshot("openai/gpt-4o-mini")

def test_custom_model(self):
provider = LiteLLMProvider(model="anthropic/claude-sonnet-4-6")
assert provider.model == snapshot("anthropic/claude-sonnet-4-6")

def test_no_api_key_required(self):
provider = LiteLLMProvider()
assert provider._api_key is None

def test_api_key_stored(self):
provider = LiteLLMProvider(api_key="sk-test")
assert provider._api_key == snapshot("sk-test")

def test_api_base_stored(self):
provider = LiteLLMProvider(api_base="http://localhost:4000")
assert provider._api_base == snapshot("http://localhost:4000")


class TestLiteLLMProviderCallKwargs:
def test_drop_params_always_true(self):
provider = LiteLLMProvider()
kwargs = provider._call_kwargs()
assert kwargs["drop_params"] is True

def test_api_key_forwarded_when_set(self):
provider = LiteLLMProvider(api_key="sk-test")
kwargs = provider._call_kwargs()
assert kwargs["api_key"] == snapshot("sk-test")

def test_api_key_omitted_when_none(self):
provider = LiteLLMProvider()
kwargs = provider._call_kwargs()
assert "api_key" not in kwargs

def test_api_base_forwarded_when_set(self):
provider = LiteLLMProvider(api_base="http://localhost:4000")
kwargs = provider._call_kwargs()
assert kwargs["api_base"] == snapshot("http://localhost:4000")

def test_api_base_omitted_when_none(self):
provider = LiteLLMProvider()
kwargs = provider._call_kwargs()
assert "api_base" not in kwargs

def test_model_in_kwargs(self):
provider = LiteLLMProvider(model="groq/llama-3.3-70b-versatile")
kwargs = provider._call_kwargs()
assert kwargs["model"] == snapshot("groq/llama-3.3-70b-versatile")


class TestLiteLLMProviderMethods:
def test_get_supported_models(self):
models = LiteLLMProvider.get_supported_models()
assert "openai/gpt-4o" in models
assert "anthropic/claude-sonnet-4-6" in models

def test_messages_to_dicts(self):
provider = LiteLLMProvider()
messages = [
LLMMessage(role="system", content="Be helpful"),
LLMMessage(role="user", content="Hello"),
]
result = provider._messages_to_dicts(messages)
assert result == snapshot(
[
{"role": "system", "content": "Be helpful"},
{"role": "user", "content": "Hello"},
]
)

def test_parse_response(self):
provider = LiteLLMProvider()
resp = _mock_response(content="Hi!", model="openai/gpt-4o")
result = provider._parse_response(resp)
assert result.content == snapshot("Hi!")
assert result.model == snapshot("openai/gpt-4o")
assert result.finish_reason == snapshot("stop")
assert result.usage == snapshot(
{"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
)

def test_parse_response_null_content(self):
provider = LiteLLMProvider()
resp = _mock_response(content=None)
result = provider._parse_response(resp)
assert result.content == snapshot("")

def test_parse_response_no_usage(self):
provider = LiteLLMProvider()
resp = _mock_response()
resp.usage = None
result = provider._parse_response(resp)
assert result.usage is None


class TestLiteLLMProviderSync:
@pytest.fixture
def provider(self):
return LiteLLMProvider(model="openai/gpt-4o-mini")

def test_sync_generate(self, provider):
resp = _mock_response(content="Sync response")
with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", return_value=resp) as mock_comp:
result = provider.sync_generate("Hello")
assert result.content == snapshot("Sync response")
call_kwargs = mock_comp.call_args.kwargs
assert call_kwargs["drop_params"] is True

def test_sync_chat(self, provider):
resp = _mock_response(content="Chat response")
messages = [LLMMessage(role="user", content="Hi")]
with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", return_value=resp):
result = provider.sync_chat(messages)
assert result.content == snapshot("Chat response")

def test_sync_generate_with_system_prompt(self, provider):
resp = _mock_response(content="With system")
with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", return_value=resp) as mock_comp:
result = provider.sync_generate("Hello", system_prompt="Be brief")
assert result.content == snapshot("With system")
messages = mock_comp.call_args.kwargs["messages"]
assert messages[0]["role"] == "system"
assert messages[0]["content"] == "Be brief"


class TestLiteLLMProviderAsync:
@pytest.fixture
def provider(self):
return LiteLLMProvider(model="anthropic/claude-sonnet-4-6")

@pytest.mark.asyncio
async def test_generate(self, provider):
resp = _mock_response(content="Async response")
with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock, return_value=resp):
result = await provider.generate("Hello")
assert result.content == snapshot("Async response")

@pytest.mark.asyncio
async def test_chat(self, provider):
resp = _mock_response(content="Async chat")
messages = [LLMMessage(role="user", content="Hi")]
with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock, return_value=resp) as mock_acomp:
result = await provider.chat(messages)
assert result.content == snapshot("Async chat")
call_kwargs = mock_acomp.call_args.kwargs
assert call_kwargs["model"] == "anthropic/claude-sonnet-4-6"
assert call_kwargs["drop_params"] is True


class TestLiteLLMProviderErrors:
@pytest.fixture
def provider(self):
return LiteLLMProvider()

def test_rate_limit_maps_to_llm_rate_limit_error(self, provider):
fake_exc = type("RateLimitError", (Exception,), {})()
fake_exc.__class__.__module__ = "litellm.exceptions"
fake_exc.__class__.__qualname__ = "RateLimitError"
with pytest.raises(LLMRateLimitError):
provider._handle_error(fake_exc)

def test_generic_error_maps_to_llm_provider_error(self, provider):
with pytest.raises(LLMProviderError):
provider._handle_error(Exception("something went wrong"))

def test_sync_chat_auth_error_raises_provider_error(self, provider):
with patch("agentic_security.llm_providers.litellm_provider.litellm.completion", side_effect=Exception("AuthenticationError: Invalid API key")):
with pytest.raises(LLMProviderError, match="Invalid API key"):
provider.sync_generate("test")

@pytest.mark.asyncio
async def test_async_chat_timeout_raises_provider_error(self, provider):
with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock,
side_effect=Exception("Timeout: Request timed out")):
with pytest.raises(LLMProviderError, match="timed out"):
await provider.generate("test")

@pytest.mark.asyncio
async def test_async_chat_model_not_found_raises_provider_error(self, provider):
provider = LiteLLMProvider(model="bad/nonexistent-model")
with patch("agentic_security.llm_providers.litellm_provider.litellm.acompletion", new_callable=AsyncMock,
side_effect=Exception("NotFoundError: Model not found")):
with pytest.raises(LLMProviderError, match="not found"):
await provider.generate("test")


class TestLiteLLMProviderFactory:
def test_factory_creates_litellm_provider(self):
from agentic_security.llm_providers.factory import create_provider
provider = create_provider("litellm")
assert isinstance(provider, LiteLLMProvider)
assert provider.model == snapshot("openai/gpt-4o-mini")

def test_factory_creates_with_custom_model(self):
from agentic_security.llm_providers.factory import create_provider
provider = create_provider("litellm", model="groq/llama-3.3-70b-versatile")
assert provider.model == snapshot("groq/llama-3.3-70b-versatile")

def test_factory_lists_litellm(self):
from agentic_security.llm_providers.factory import list_providers
providers = list_providers()
assert "litellm" in providers
Loading