-
Notifications
You must be signed in to change notification settings - Fork 378
Feat/expose kagent agents in mcp #1201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
8141aef
96b50cd
9d18dfd
204ee08
430ae63
88ae8c2
f61b306
4dcdb6a
4f0fc35
b30446b
0131074
756f32d
1b8f30d
7e4e773
2400dff
b03a12c
d3d2f33
70b11cb
652147f
d3863f2
8b5ce96
88007b9
719bef3
cbb1751
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,209 @@ | ||
| package mcp | ||
|
|
||
| import ( | ||
| "context" | ||
| "fmt" | ||
| "net/http" | ||
| "os" | ||
| "strings" | ||
| "sync" | ||
| "sync/atomic" | ||
| "time" | ||
|
|
||
| "github.com/kagent-dev/kagent/go/cli/internal/config" | ||
| "github.com/kagent-dev/kagent/go/internal/a2a" | ||
| "github.com/kagent-dev/kagent/go/internal/version" | ||
| "github.com/mark3labs/mcp-go/mcp" | ||
| mcpserver "github.com/mark3labs/mcp-go/server" | ||
| "github.com/spf13/cobra" | ||
| a2aclient "trpc.group/trpc-go/trpc-a2a-go/client" | ||
| "trpc.group/trpc-go/trpc-a2a-go/protocol" | ||
| ) | ||
|
|
||
| var ( | ||
| serveAgentsTransport string | ||
| serveAgentsHost string | ||
| serveAgentsPort int | ||
| ) | ||
|
|
||
| var a2aContextBySessionAndAgent sync.Map | ||
|
|
||
| var fallbackInvocationCounter uint64 | ||
|
|
||
| var ServeAgentsCmd = &cobra.Command{ | ||
| Use: "serve-mcp", | ||
| Short: "Serve kagent agents via MCP", | ||
| RunE: func(cmd *cobra.Command, args []string) error { | ||
| cfg, err := config.Get() | ||
| if err != nil { | ||
| return fmt.Errorf("config: %w", err) | ||
| } | ||
| hooks := &mcpserver.Hooks{} | ||
| hooks.AddOnUnregisterSession(func(ctx context.Context, session mcpserver.ClientSession) { | ||
| sessionID := session.SessionID() | ||
| a2aContextBySessionAndAgent.Range(func(key, _ any) bool { | ||
| keyStr, ok := key.(string) | ||
| if !ok { | ||
| return true | ||
| } | ||
| if strings.HasPrefix(keyStr, sessionID+"|") { | ||
| a2aContextBySessionAndAgent.Delete(key) | ||
| } | ||
| return true | ||
| }) | ||
| }) | ||
| s := mcpserver.NewMCPServer( | ||
| "kagent-agents", | ||
| version.Version, | ||
| mcpserver.WithToolCapabilities(false), | ||
| mcpserver.WithHooks(hooks), | ||
| ) | ||
|
|
||
| s.AddTool(mcp.NewTool("list_agents", | ||
| mcp.WithDescription("List invokable kagent agents (accepted + deploymentReady)"), | ||
| ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { | ||
| resp, err := cfg.Client().Agent.ListAgents(ctx) | ||
| if err != nil { | ||
| return mcp.NewToolResultErrorFromErr("list agents", err), nil | ||
| } | ||
| type agentSummary struct { | ||
| Ref string `json:"ref"` | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Id doesn't seem to be useful, maybe just ref (name) and description? |
||
| Description string `json:"description,omitempty"` | ||
| } | ||
| agents := make([]agentSummary, 0) | ||
| for _, agent := range resp.Data { | ||
| if !agent.Accepted || !agent.DeploymentReady || agent.Agent == nil { | ||
| continue | ||
| } | ||
| ref := agent.Agent.Namespace + "/" + agent.Agent.Name | ||
| agents = append(agents, agentSummary{Ref: ref, Description: agent.Agent.Spec.Description}) | ||
| } | ||
| if len(agents) == 0 { | ||
| return mcp.NewToolResultStructured(agents, "No invokable agents found."), nil | ||
| } | ||
|
|
||
| var fallbackText strings.Builder | ||
| for i, agent := range agents { | ||
| if i > 0 { | ||
| fallbackText.WriteByte('\n') | ||
| } | ||
| fallbackText.WriteString(agent.Ref) | ||
| if agent.Description != "" { | ||
| fallbackText.WriteString(" - ") | ||
| fallbackText.WriteString(agent.Description) | ||
| } | ||
| } | ||
|
|
||
| return mcp.NewToolResultStructured(agents, fallbackText.String()), nil | ||
| }) | ||
|
|
||
| s.AddTool(mcp.NewTool("invoke_agent", | ||
| mcp.WithDescription("Invoke a kagent agent via A2A"), | ||
| mcp.WithString("agent", mcp.Description("Agent name (or namespace/name)"), mcp.Required()), | ||
| mcp.WithString("task", mcp.Description("Task to run"), mcp.Required()), | ||
| ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { | ||
| agentRef, err := request.RequireString("agent") | ||
| if err != nil { | ||
| return mcp.NewToolResultError(err.Error()), nil | ||
| } | ||
| task, err := request.RequireString("task") | ||
| if err != nil { | ||
| return mcp.NewToolResultError(err.Error()), nil | ||
| } | ||
| agentNS, agentName, ok := strings.Cut(agentRef, "/") | ||
| if !ok { | ||
| agentNS, agentName = cfg.Namespace, agentRef | ||
| } | ||
| agentRef = agentNS + "/" + agentName | ||
|
|
||
| sessionID := "" | ||
| if session := mcpserver.ClientSessionFromContext(ctx); session != nil { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of falling back to unknown, use a unique invocation ID per session if none is available to keep the context separate. When callers with proper session support use this they will get unknown as session and it will cause unexpected behaviour with multiple concurrent users like potentially wrong context history. |
||
| sessionID = session.SessionID() | ||
| } else if headerSessionID := request.Header.Get(mcpserver.HeaderKeySessionID); headerSessionID != "" { | ||
| sessionID = headerSessionID | ||
| } | ||
| if sessionID == "" { | ||
| sessionID = fmt.Sprintf("invocation-%d", atomic.AddUint64(&fallbackInvocationCounter, 1)) | ||
| } | ||
| contextKey := sessionID + "|" + agentRef | ||
| var contextIDPtr *string | ||
| if prior, ok := a2aContextBySessionAndAgent.Load(contextKey); ok { | ||
| if priorStr, ok := prior.(string); ok && priorStr != "" { | ||
| contextIDPtr = &priorStr | ||
| } | ||
| } | ||
|
|
||
| a2aURL := fmt.Sprintf("%s/api/a2a/%s/%s", cfg.KAgentURL, agentNS, agentName) | ||
| client, err := a2aclient.NewA2AClient(a2aURL, a2aclient.WithTimeout(cfg.Timeout)) | ||
| if err != nil { | ||
| return mcp.NewToolResultErrorFromErr("a2a client", err), nil | ||
| } | ||
| result, err := client.SendMessage(ctx, protocol.SendMessageParams{Message: protocol.Message{ | ||
| Kind: protocol.KindMessage, Role: protocol.MessageRoleUser, ContextID: contextIDPtr, Parts: []protocol.Part{protocol.NewTextPart(task)}, | ||
| }}) | ||
| if err != nil { | ||
| return mcp.NewToolResultErrorFromErr("a2a send", err), nil | ||
| } | ||
|
|
||
| var responseText, newContextID string | ||
| switch a2aResult := result.Result.(type) { | ||
| case *protocol.Message: | ||
| responseText = a2a.ExtractText(*a2aResult) | ||
| if a2aResult.ContextID != nil { | ||
| newContextID = *a2aResult.ContextID | ||
| } | ||
| case *protocol.Task: | ||
| newContextID = a2aResult.ContextID | ||
| if a2aResult.Status.Message != nil { | ||
| responseText = a2a.ExtractText(*a2aResult.Status.Message) | ||
| } | ||
| for _, artifact := range a2aResult.Artifacts { | ||
| responseText += a2a.ExtractText(protocol.Message{Parts: artifact.Parts}) | ||
| } | ||
| } | ||
| if responseText == "" { | ||
| raw, err := result.MarshalJSON() | ||
| if err != nil { | ||
| return mcp.NewToolResultErrorFromErr("marshal result", err), nil | ||
| } | ||
| responseText = string(raw) | ||
| } | ||
| if newContextID != "" { | ||
| a2aContextBySessionAndAgent.Store(contextKey, newContextID) | ||
| } | ||
| return mcp.NewToolResultStructured(map[string]any{ | ||
| "agent": agentRef, | ||
| "context_id": newContextID, | ||
| "text": responseText, | ||
| }, responseText), nil | ||
| }) | ||
|
|
||
| switch strings.ToLower(serveAgentsTransport) { | ||
| case "stdio": | ||
| stdioServer := mcpserver.NewStdioServer(s) | ||
| return stdioServer.Listen(cmd.Context(), os.Stdin, os.Stdout) | ||
| case "http": | ||
| addr := fmt.Sprintf("%s:%d", serveAgentsHost, serveAgentsPort) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps some logging to indicate the server is running successfully like "MCP server listening on xxx" |
||
| cmd.PrintErrf("MCP server listening on http://%s/mcp\n", addr) | ||
| httpServer := mcpserver.NewStreamableHTTPServer(s) | ||
| go func() { | ||
| <-cmd.Context().Done() | ||
| shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||
| defer cancel() | ||
| _ = httpServer.Shutdown(shutdownCtx) | ||
| }() | ||
| if err := httpServer.Start(addr); err != nil && err != http.ErrServerClosed { | ||
| return err | ||
| } | ||
| return nil | ||
| default: | ||
| return fmt.Errorf("invalid transport %q (expected stdio or http)", serveAgentsTransport) | ||
| } | ||
| }, | ||
| } | ||
|
|
||
| func init() { | ||
| ServeAgentsCmd.Flags().StringVar(&serveAgentsTransport, "transport", "stdio", "Transport mode (stdio or http)") | ||
| ServeAgentsCmd.Flags().StringVar(&serveAgentsHost, "host", "127.0.0.1", "HTTP host to bind (when --transport http)") | ||
| ServeAgentsCmd.Flags().IntVar(&serveAgentsPort, "port", 3000, "HTTP port to bind (when --transport http)") | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,144 @@ | ||
| package e2e_test | ||
|
|
||
| import ( | ||
| "bufio" | ||
| "bytes" | ||
| "context" | ||
| "encoding/json" | ||
| "fmt" | ||
| "os" | ||
| "os/exec" | ||
| "path/filepath" | ||
| "runtime" | ||
| "testing" | ||
| "time" | ||
|
|
||
| "github.com/stretchr/testify/require" | ||
| ) | ||
|
|
||
| func TestE2EInvokeAgentThroughMCPServeAgents(t *testing.T) { | ||
| // Setup mock server (so agent responses are deterministic and don't hit real LLMs) | ||
| baseURL, stopServer := setupMockServer(t, "mocks/invoke_mcp_serve_agents.json") | ||
| defer stopServer() | ||
|
|
||
| // Setup Kubernetes resources for a known-good agent | ||
| cli := setupK8sClient(t, false) | ||
| modelCfg := setupModelConfig(t, cli, baseURL) | ||
| agent := setupAgentWithOptions(t, cli, modelCfg.Name, nil, AgentOptions{ | ||
| Name: "kebab-agent", | ||
| }) | ||
|
|
||
| kagentURL := os.Getenv("KAGENT_URL") | ||
| if kagentURL == "" { | ||
| kagentURL = "http://localhost:8083" | ||
| } | ||
|
|
||
| _, testFile, _, ok := runtime.Caller(0) | ||
| require.True(t, ok) | ||
| goModuleRoot := filepath.Clean(filepath.Join(filepath.Dir(testFile), "../..")) | ||
|
|
||
| kagentBin := filepath.Join(t.TempDir(), "kagent") | ||
| build := exec.Command("go", "build", "-o", kagentBin, "./cli/cmd/kagent") | ||
| build.Dir = goModuleRoot | ||
| buildOutput, err := build.CombinedOutput() | ||
| require.NoError(t, err, string(buildOutput)) | ||
|
|
||
| homeDir := t.TempDir() | ||
| cfgDir := filepath.Join(homeDir, ".kagent") | ||
| require.NoError(t, os.MkdirAll(cfgDir, 0755)) | ||
| cfgPath := filepath.Join(cfgDir, "config.yaml") | ||
| require.NoError(t, os.WriteFile(cfgPath, []byte(fmt.Sprintf("kagent_url: %s\nnamespace: kagent\ntimeout: 300s\n", kagentURL)), 0644)) | ||
|
|
||
| ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) | ||
| defer cancel() | ||
|
|
||
| cmd := exec.CommandContext(ctx, kagentBin, "mcp", "serve-agents") | ||
| cmd.Env = append(os.Environ(), "HOME="+homeDir) | ||
| stdout, err := cmd.StdoutPipe() | ||
| require.NoError(t, err) | ||
| stdin, err := cmd.StdinPipe() | ||
| require.NoError(t, err) | ||
| var stderr bytes.Buffer | ||
| cmd.Stderr = &stderr | ||
| require.NoError(t, cmd.Start()) | ||
| t.Cleanup(func() { | ||
| _ = stdin.Close() | ||
| _ = cmd.Process.Kill() | ||
| _ = cmd.Wait() | ||
| }) | ||
|
|
||
| lines := make(chan string, 32) | ||
| go func() { | ||
| scanner := bufio.NewScanner(stdout) | ||
| for scanner.Scan() { | ||
| lines <- scanner.Text() | ||
| } | ||
| close(lines) | ||
| }() | ||
|
|
||
| writeLine := func(line string) { | ||
| _, _ = fmt.Fprintln(stdin, line) | ||
| } | ||
|
|
||
| readResponse := func(wantID int) json.RawMessage { | ||
| deadline := time.NewTimer(15 * time.Second) | ||
| defer deadline.Stop() | ||
| for { | ||
| select { | ||
| case line, ok := <-lines: | ||
| require.True(t, ok, stderr.String()) | ||
| var msg struct { | ||
| ID int `json:"id"` | ||
| Result json.RawMessage `json:"result,omitempty"` | ||
| Error json.RawMessage `json:"error,omitempty"` | ||
| } | ||
| require.NoError(t, json.Unmarshal([]byte(line), &msg), line) | ||
| if msg.ID != wantID { | ||
| continue | ||
| } | ||
| require.Nil(t, msg.Error, line) | ||
| return msg.Result | ||
| case <-deadline.C: | ||
| t.Fatalf("timed out waiting for id=%d; stderr=%s", wantID, stderr.String()) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| writeLine(`{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"e2e","version":"0.0.0"}}}`) | ||
| _ = readResponse(1) | ||
| writeLine(`{"jsonrpc":"2.0","method":"notifications/initialized","params":{}}`) | ||
|
|
||
| writeLine(`{"jsonrpc":"2.0","id":2,"method":"tools/list","params":{}}`) | ||
| toolsList := readResponse(2) | ||
| var listResult struct { | ||
| Tools []struct { | ||
| Name string `json:"name"` | ||
| } `json:"tools"` | ||
| } | ||
| require.NoError(t, json.Unmarshal(toolsList, &listResult), string(toolsList)) | ||
| require.GreaterOrEqual(t, len(listResult.Tools), 2) | ||
| toolNames := make([]string, 0, len(listResult.Tools)) | ||
| for _, tool := range listResult.Tools { | ||
| toolNames = append(toolNames, tool.Name) | ||
| } | ||
| require.Contains(t, toolNames, "list_agents") | ||
| require.Contains(t, toolNames, "invoke_agent") | ||
|
|
||
| writeLine(`{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"list_agents"}}`) | ||
| agentsResult := readResponse(3) | ||
| var callResult struct { | ||
| Content []struct { | ||
| Type string `json:"type"` | ||
| Text string `json:"text"` | ||
| } `json:"content"` | ||
| } | ||
| require.NoError(t, json.Unmarshal(agentsResult, &callResult), string(agentsResult)) | ||
| require.NotEmpty(t, callResult.Content) | ||
| require.Contains(t, callResult.Content[0].Text, agent.Namespace+"/"+agent.Name) | ||
|
|
||
| writeLine(fmt.Sprintf(`{"jsonrpc":"2.0","id":4,"method":"tools/call","params":{"name":"invoke_agent","arguments":{"agent":%q,"task":"What can you do?"}}}`, agent.Name)) | ||
| invokeResult := readResponse(4) | ||
| require.NoError(t, json.Unmarshal(invokeResult, &callResult), string(invokeResult)) | ||
| require.NotEmpty(t, callResult.Content) | ||
| require.Contains(t, callResult.Content[0].Text, "kebab") | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This map never cleans up old session contexts. This might be an issue for HTTP server
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I quickly looked at the docs for mcp-go and seems like this hook will help:
hope it helps