Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions py/packages/genkit/src/genkit/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,24 @@ async def list_actions(self, allowed_kinds: list[ActionKind] | None = None) -> l
if allowed_kinds and meta.kind not in allowed_kinds:
continue
metas.append(meta)

# Include actions registered directly in the registry
with self._lock:
for kind, kind_map in self._entries.items():
if allowed_kinds and kind not in allowed_kinds:
continue
for action in kind_map.values():
metas.append(
ActionMetadata(
kind=action.kind,
name=action.name,
description=action.description,
input_json_schema=action.input_schema,
output_json_schema=action.output_schema,
metadata=action.metadata,
)
)

return metas

def register_schema(self, name: str, schema: dict[str, Any]) -> None:
Expand Down
4 changes: 3 additions & 1 deletion py/packages/genkit/tests/genkit/core/registry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ async def list_actions(self) -> list[ActionMetadata]:
ai = Genkit(plugins=[MyPlugin()])

metas = await ai.registry.list_actions()
assert metas == [ActionMetadata(kind=ActionKind.MODEL, name='myplugin/foo')]
# Filter for the specific plugin action we expect, ignoring system actions like 'generate'
target_meta = next((m for m in metas if m.name == 'myplugin/foo'), None)
assert target_meta == ActionMetadata(kind=ActionKind.MODEL, name='myplugin/foo')

action = await ai.registry.resolve_action(ActionKind.MODEL, 'myplugin/foo')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ def _create_tool(self, tool: ToolDefinition) -> genai_types.Tool:
params = genai_types.Schema(type=genai_types.Type.OBJECT, properties={})

function = genai_types.FunctionDeclaration(
name=tool.name,
name=tool.name.replace('/', '__'),
description=tool.description,
parameters=params,
response=self._convert_schema_property(tool.output_schema) if tool.output_schema else None,
Expand Down
4 changes: 2 additions & 2 deletions py/plugins/mcp/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries",
]
dependencies = ["genkit", "mcp"]
dependencies = ["genkit", "mcp", "structlog"]
description = "Genkit MCP Plugin"
license = "Apache-2.0"
name = "genkit-plugins-mcp"
Expand All @@ -45,4 +45,4 @@ build-backend = "hatchling.build"
requires = ["hatchling"]

[tool.hatch.build.targets.wheel]
packages = ["src"]
packages = ["src/genkit"]
153 changes: 73 additions & 80 deletions py/plugins/mcp/src/genkit/plugins/mcp/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,20 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any
import asyncio
from contextlib import AsyncExitStack
from typing import Any, cast

import structlog
from pydantic import BaseModel
from pydantic import AnyUrl, BaseModel

from genkit.ai import Genkit, Plugin
from genkit.core.action import Action, ActionMetadata
from genkit.ai import Genkit
from genkit.ai._registry import GenkitRegistry
from genkit.core.action.types import ActionKind
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.types import CallToolResult, Prompt, Resource, Tool
from mcp.types import CallToolResult, Prompt, Resource, TextContent, Tool

logger = structlog.get_logger(__name__)

Expand All @@ -38,57 +40,33 @@ class McpServerConfig(BaseModel):
disabled: bool = False


class McpClient(Plugin):
class McpClient:
"""Client for connecting to a single MCP server."""

def __init__(self, name: str, config: McpServerConfig, server_name: str | None = None):
self.name = name
self.config = config
self.server_name = server_name or name
self.session: ClientSession | None = None
self._exit_stack = None
self._session_context = None
self.ai: Genkit | None = None
self._exit_stack = AsyncExitStack()
self.ai: GenkitRegistry | None = None

def plugin_name(self) -> str:
return self.name

async def init(self) -> list[Action]:
"""Initialize MCP plugin.
def initialize(self, ai: GenkitRegistry) -> None:
self.ai = ai

MCP tools are registered dynamically upon connection, so this returns an empty list.

Returns:
Empty list (tools are registered dynamically).
"""
return []

async def resolve(self, action_type: ActionKind, name: str) -> Action | None:
"""Resolve an action by name.

MCP uses dynamic registration, so this returns None.

Args:
action_type: The kind of action to resolve.
name: The namespaced name of the action to resolve.

Returns:
None (MCP uses dynamic registration).
"""
return None

async def list_actions(self) -> list[ActionMetadata]:
"""List available MCP actions.

MCP tools are discovered at runtime, so this returns an empty list.

Returns:
Empty list (tools are discovered at runtime).
"""
return []
def resolve_action(self, ai: GenkitRegistry, kind: ActionKind, name: str) -> None:
# MCP tools are dynamic and currently registered upon connection/Discovery.
# This hook allows lazy resolution if we implement it.
pass

async def connect(self):
"""Connects to the MCP server."""
if self.session:
return

if self.config.disabled:
logger.info(f'MCP server {self.server_name} is disabled.')
return
Expand All @@ -100,25 +78,24 @@ async def connect(self):
)
# stdio_client returns (read, write) streams
stdio_context = stdio_client(server_params)
read, write = await stdio_context.__aenter__()
self._exit_stack = stdio_context
read, write = await self._exit_stack.enter_async_context(stdio_context)

# Create and initialize session
session_context = ClientSession(read, write)
self.session = await session_context.__aenter__()
self._session_context = session_context
self.session = await self._exit_stack.enter_async_context(session_context)

elif self.config.url:
# TODO: Verify SSE client usage in mcp python SDK
sse_context = sse_client(self.config.url)
read, write = await sse_context.__aenter__()
self._exit_stack = sse_context
read, write = await self._exit_stack.enter_async_context(sse_context)

session_context = ClientSession(read, write)
self.session = await session_context.__aenter__()
self._session_context = session_context
self.session = await self._exit_stack.enter_async_context(session_context)
else:
raise ValueError(f"MCP client {self.name} configuration requires either 'command' or 'url'.")

await self.session.initialize()
if self.session:
await self.session.initialize()
logger.info(f'Connected to MCP server: {self.server_name}')

except Exception as e:
Expand All @@ -130,16 +107,16 @@ async def connect(self):

async def close(self):
"""Closes the connection."""
if hasattr(self, '_session_context') and self._session_context:
try:
await self._session_context.__aexit__(None, None, None)
except Exception as e:
logger.debug(f'Error closing session: {e}')
if self._exit_stack:
try:
await self._exit_stack.__aexit__(None, None, None)
except Exception as e:
logger.debug(f'Error closing transport: {e}')
await self._exit_stack.aclose()
except (Exception, asyncio.CancelledError):
# Ignore errors during cleanup, especially cancellation from anyio
pass

# Reset exit stack for potential reuse (reconnect)
self._exit_stack = AsyncExitStack()
self.session = None

async def list_tools(self) -> list[Tool]:
if not self.session:
Expand All @@ -150,14 +127,21 @@ async def list_tools(self) -> list[Tool]:
async def call_tool(self, tool_name: str, arguments: dict) -> Any:
if not self.session:
raise RuntimeError('MCP client is not connected')
result: CallToolResult = await self.session.call_tool(tool_name, arguments)
# Process result similarly to JS SDK
if result.isError:
raise RuntimeError(f'Tool execution failed: {result.content}')
logger.debug(f'MCP {self.server_name}: calling tool {tool_name}', arguments=arguments)
try:
result: CallToolResult = await self.session.call_tool(tool_name, arguments)
logger.debug(f'MCP {self.server_name}: tool {tool_name} returned')

# Simple text extraction for now
texts = [c.text for c in result.content if c.type == 'text']
return ''.join(texts)
# Process result similarly to JS SDK
if result.isError:
raise RuntimeError(f'Tool execution failed: {result.content}')

# Simple text extraction for now
texts = [c.text for c in result.content if c.type == 'text' and isinstance(c, TextContent)]
return {'content': ''.join(texts)}
except Exception as e:
logger.error(f'MCP {self.server_name}: tool {tool_name} failed', error=str(e))
raise

async def list_prompts(self) -> list[Prompt]:
if not self.session:
Expand All @@ -179,7 +163,7 @@ async def list_resources(self) -> list[Resource]:
async def read_resource(self, uri: str) -> Any:
if not self.session:
raise RuntimeError('MCP client is not connected')
return await self.session.read_resource(uri)
return await self.session.read_resource(cast(AnyUrl, uri))

async def register_tools(self, ai: Genkit | None = None):
"""Registers all tools from connected client to Genkit."""
Expand All @@ -194,29 +178,38 @@ async def register_tools(self, ai: Genkit | None = None):
try:
tools = await self.list_tools()
for tool in tools:
# Create a wrapper function for the tool
# We need to capture tool and client in closure
async def tool_wrapper(args: Any = None, _tool_name=tool.name):
# args might be Pydantic model or dict. Genkit passes dict usually?
# TODO: Validate args against schema if needed
arguments = args
if hasattr(args, 'model_dump'):
arguments = args.model_dump()
return await self.call_tool(_tool_name, arguments or {})
# Create a wrapper function for the tool using a factory to capture tool name
def create_wrapper(tool_name: str):
async def tool_wrapper(args: Any = None):
# args might be Pydantic model or dict. Genkit passes dict usually?
# TODO: Validate args against schema if needed
arguments = args
if hasattr(args, 'model_dump'):
arguments = args.model_dump()
return await self.call_tool(tool_name, arguments or {})

return tool_wrapper

tool_wrapper = create_wrapper(tool.name)

# Use metadata to store MCP specific info
metadata = {'mcp': {'_meta': tool._meta}} if hasattr(tool, '_meta') else {}

# Define the tool in Genkit registry
registry.register_action(
kind=ActionKind.TOOL,
name=f'{self.server_name}/{tool.name}',
action = registry.register_action(
kind=cast(ActionKind, ActionKind.TOOL),
name=f'{self.server_name}_{tool.name}',
fn=tool_wrapper,
description=tool.description,
metadata=metadata,
# TODO: json_schema conversion from tool.inputSchema
)
logger.debug(f'Registered MCP tool: {self.server_name}/{tool.name}')

# Patch input schema from MCP tool definition
if tool.inputSchema:
action._input_schema = tool.inputSchema
action._metadata['inputSchema'] = tool.inputSchema

logger.debug(f'Registered MCP tool: {self.server_name}_{tool.name}')
except Exception as e:
logger.error(f'Error registering tools for {self.server_name}: {e}')

Expand Down
38 changes: 38 additions & 0 deletions py/plugins/mcp/src/genkit/plugins/mcp/client/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Optional

from genkit.ai import Genkit

Expand Down Expand Up @@ -58,6 +59,43 @@ async def disable(self, name: str):
client.config.disabled = True
await client.close()

async def reconnect(self, name: str):
"""Reconnects a specific MCP client."""
if name in self.clients:
client_to_reconnect = self.clients[name]
await client_to_reconnect.close()
await client_to_reconnect.connect()

async def get_active_tools(self, ai: Genkit) -> list[str]:
"""Returns a list of all active tool names from all clients."""
active_tools = []
for client in self.clients.values():
if client.session:
try:
tools = await client.get_active_tools()
# Determine tool names as registered: server_tool
for tool in tools:
active_tools.append(f'{client.server_name}_{tool.name}')
except Exception as e:
# Log error but continue with other clients
# Use print or logger if available. Ideally structlog.
pass
return active_tools

async def get_active_resources(self, ai: Genkit) -> list[str]:
"""Returns a list of all active resource URIs from all clients."""
active_resources = []
for client in self.clients.values():
if client.session:
try:
resources = await client.list_resources()
for resource in resources:
active_resources.append(resource.uri)
except Exception:
# Log error but continue with other clients
pass
return active_resources


def create_mcp_host(configs: dict[str, McpServerConfig]) -> McpHost:
return McpHost(configs)
Loading
Loading