Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/ai-providers/server-ai-langchain/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ packages = [{ include = "ldai_langchain", from = "src" }]

[tool.poetry.dependencies]
python = ">=3.9,<4"
launchdarkly-server-sdk-ai = ">=0.11.0"
launchdarkly-server-sdk-ai = ">=0.12.0"
langchain-core = ">=0.2.0"
langchain = ">=0.2.0"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from ldai import LDMessage
from ldai import LDMessage, log
from ldai.models import AIConfigKind
from ldai.providers import AIProvider
from ldai.providers.types import ChatResponse, LDAIMetrics, StructuredResponse
Expand All @@ -18,27 +18,24 @@ class LangChainProvider(AIProvider):
This provider integrates LangChain models with LaunchDarkly's tracking capabilities.
"""

def __init__(self, llm: BaseChatModel, logger: Optional[Any] = None):
def __init__(self, llm: BaseChatModel):
"""
Initialize the LangChain provider.

:param llm: A LangChain BaseChatModel instance
:param logger: Optional logger for logging provider operations
"""
super().__init__(logger)
self._llm = llm

@staticmethod
async def create(ai_config: AIConfigKind, logger: Optional[Any] = None) -> 'LangChainProvider':
async def create(ai_config: AIConfigKind) -> 'LangChainProvider':
"""
Static factory method to create a LangChain AIProvider from an AI configuration.

:param ai_config: The LaunchDarkly AI configuration
:param logger: Optional logger for the provider
:return: Configured LangChainProvider instance
"""
llm = LangChainProvider.create_langchain_model(ai_config)
return LangChainProvider(llm, logger)
return LangChainProvider(llm)

async def invoke_model(self, messages: List[LDMessage]) -> ChatResponse:
"""
Expand All @@ -56,20 +53,18 @@ async def invoke_model(self, messages: List[LDMessage]) -> ChatResponse:
if isinstance(response.content, str):
content = response.content
else:
if self.logger:
self.logger.warn(
f'Multimodal response not supported, expecting a string. '
f'Content type: {type(response.content)}, Content: {response.content}'
)
log.warning(
f'Multimodal response not supported, expecting a string. '
f'Content type: {type(response.content)}, Content: {response.content}'
)
metrics = LDAIMetrics(success=False, usage=metrics.usage)

return ChatResponse(
message=LDMessage(role='assistant', content=content),
metrics=metrics,
)
except Exception as error:
if self.logger:
self.logger.warn(f'LangChain model invocation failed: {error}')
log.warning(f'LangChain model invocation failed: {error}')

return ChatResponse(
message=LDMessage(role='assistant', content=''),
Expand All @@ -94,11 +89,10 @@ async def invoke_structured_model(
response = await structured_llm.ainvoke(langchain_messages)

if not isinstance(response, dict):
if self.logger:
self.logger.warn(
f'Structured output did not return a dict. '
f'Got: {type(response)}'
)
log.warning(
f'Structured output did not return a dict. '
f'Got: {type(response)}'
)
return StructuredResponse(
data={},
raw_response='',
Expand All @@ -117,8 +111,7 @@ async def invoke_structured_model(
),
)
except Exception as error:
if self.logger:
self.logger.warn(f'LangChain structured model invocation failed: {error}')
log.warning(f'LangChain structured model invocation failed: {error}')

return StructuredResponse(
data={},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,53 +145,45 @@ def mock_llm(self):
"""Create a mock LLM."""
return MagicMock()

@pytest.fixture
def mock_logger(self):
"""Create a mock logger."""
return MagicMock()

@pytest.mark.asyncio
async def test_returns_success_true_for_string_content(self, mock_llm, mock_logger):
async def test_returns_success_true_for_string_content(self, mock_llm):
"""Should return success=True for string content."""
mock_response = AIMessage(content='Test response')
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
provider = LangChainProvider(mock_llm, mock_logger)
provider = LangChainProvider(mock_llm)

messages = [LDMessage(role='user', content='Hello')]
result = await provider.invoke_model(messages)

assert result.metrics.success is True
assert result.message.content == 'Test response'
mock_logger.warn.assert_not_called()

@pytest.mark.asyncio
async def test_returns_success_false_for_non_string_content_and_logs_warning(self, mock_llm, mock_logger):
async def test_returns_success_false_for_non_string_content_and_logs_warning(self, mock_llm):
"""Should return success=False for non-string content and log warning."""
mock_response = AIMessage(content=[{'type': 'image', 'data': 'base64data'}])
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
provider = LangChainProvider(mock_llm, mock_logger)
provider = LangChainProvider(mock_llm)

messages = [LDMessage(role='user', content='Hello')]
result = await provider.invoke_model(messages)

assert result.metrics.success is False
assert result.message.content == ''
mock_logger.warn.assert_called_once()

@pytest.mark.asyncio
async def test_returns_success_false_when_model_invocation_throws_error(self, mock_llm, mock_logger):
async def test_returns_success_false_when_model_invocation_throws_error(self, mock_llm):
"""Should return success=False when model invocation throws an error."""
error = Exception('Model invocation failed')
mock_llm.ainvoke = AsyncMock(side_effect=error)
provider = LangChainProvider(mock_llm, mock_logger)
provider = LangChainProvider(mock_llm)

messages = [LDMessage(role='user', content='Hello')]
result = await provider.invoke_model(messages)

assert result.metrics.success is False
assert result.message.content == ''
assert result.message.role == 'assistant'
mock_logger.warn.assert_called()


class TestInvokeStructuredModel:
Expand All @@ -202,36 +194,30 @@ def mock_llm(self):
"""Create a mock LLM."""
return MagicMock()

@pytest.fixture
def mock_logger(self):
"""Create a mock logger."""
return MagicMock()

@pytest.mark.asyncio
async def test_returns_success_true_for_successful_invocation(self, mock_llm, mock_logger):
async def test_returns_success_true_for_successful_invocation(self, mock_llm):
"""Should return success=True for successful invocation."""
mock_response = {'result': 'structured data'}
mock_structured_llm = MagicMock()
mock_structured_llm.ainvoke = AsyncMock(return_value=mock_response)
mock_llm.with_structured_output = MagicMock(return_value=mock_structured_llm)
provider = LangChainProvider(mock_llm, mock_logger)
provider = LangChainProvider(mock_llm)

messages = [LDMessage(role='user', content='Hello')]
response_structure = {'type': 'object', 'properties': {}}
result = await provider.invoke_structured_model(messages, response_structure)

assert result.metrics.success is True
assert result.data == mock_response
mock_logger.warn.assert_not_called()

@pytest.mark.asyncio
async def test_returns_success_false_when_structured_model_invocation_throws_error(self, mock_llm, mock_logger):
async def test_returns_success_false_when_structured_model_invocation_throws_error(self, mock_llm):
"""Should return success=False when structured model invocation throws an error."""
error = Exception('Structured invocation failed')
mock_structured_llm = MagicMock()
mock_structured_llm.ainvoke = AsyncMock(side_effect=error)
mock_llm.with_structured_output = MagicMock(return_value=mock_structured_llm)
provider = LangChainProvider(mock_llm, mock_logger)
provider = LangChainProvider(mock_llm)

messages = [LDMessage(role='user', content='Hello')]
response_structure = {'type': 'object', 'properties': {}}
Expand All @@ -242,7 +228,6 @@ async def test_returns_success_false_when_structured_model_invocation_throws_err
assert result.raw_response == ''
assert result.metrics.usage is not None
assert result.metrics.usage.total == 0
mock_logger.warn.assert_called()


class TestGetChatModel:
Expand Down