Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/keen-rare-slim.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"agents": patch
---

Fix: MCP OAuth callback errors are now returned as structured results instead of throwing unhandled exceptions. Errors with an active connection properly transition to "failed" state and are surfaced to clients via WebSocket broadcast.
11 changes: 11 additions & 0 deletions examples/mcp-client/src/client.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,17 @@ function App() {
/>
{server.state} (id: {id})
</div>
{server.state === "failed" && server.error && (
Comment thread
Muhammad-Bin-Ali marked this conversation as resolved.
<div
style={{
color: "#c00",
fontSize: "0.85em",
marginTop: "4px"
}}
>
Error: {server.error}
</div>
)}
</div>
<div style={{ display: "flex", gap: "8px" }}>
{server.state === "authenticating" && server.auth_url && (
Expand Down
16 changes: 12 additions & 4 deletions examples/mcp-client/src/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,18 @@ export class MyAgent extends Agent {
onStart() {
// Optionally configure OAuth callback. Here we use popup-closing behavior since we're opening a window on the client
this.mcp.configureOAuthCallback({
customHandler: () => {
return new Response("<script>window.close();</script>", {
headers: { "content-type": "text/html" },
status: 200
customHandler: (result) => {
if (result.authSuccess) {
return new Response("<script>window.close();</script>", {
headers: { "content-type": "text/html" },
status: 200
});
}
// Show error briefly, then close the popup
const error = result.authError || "Unknown error";
return new Response(`Authentication Failed: ${error}`, {
headers: { "content-type": "text/plain" },
status: 400
});
}
});
Expand Down
9 changes: 0 additions & 9 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions packages/agents/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
"@cfworker/json-schema": "^4.1.1",
"@modelcontextprotocol/sdk": "1.26.0",
"cron-schedule": "^6.0.0",
"escape-html": "^1.0.3",
"json-schema": "^0.4.0",
"json-schema-to-typescript": "^15.0.4",
"mimetext": "^3.0.28",
Expand All @@ -40,7 +39,6 @@
"@ai-sdk/openai": "^3.0.29",
"@ai-sdk/react": "^3.0.90",
"@cloudflare/workers-oauth-provider": "^0.2.3",
"@types/escape-html": "^1.0.4",
"@types/react": "^19.2.14",
"@types/yargs": "^17.0.35",
"@x402/core": "^2.3.1",
Expand Down
2 changes: 1 addition & 1 deletion packages/agents/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ export type MCPServer = {
// Scope outside of that can't be relied upon because when the DO sleeps, there's no way
// to communicate a change to a non-ready state.
state: MCPConnectionState;
/** May contain untrusted content from external OAuth providers. Escape appropriately for your output context. */
error: string | null;
instructions: string | null;
capabilities: ServerCapabilities | null;
Expand Down Expand Up @@ -4053,7 +4054,6 @@ export class Agent<
}
}

// Default: redirect to base URL
return Response.redirect(baseOrigin);
}
}
Expand Down
103 changes: 75 additions & 28 deletions packages/agents/src/mcp/client.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import type { Client } from "@modelcontextprotocol/sdk/client/index.js";
import escapeHtml from "escape-html";
import type { RequestOptions } from "@modelcontextprotocol/sdk/shared/protocol.js";
import type {
CallToolRequest,
Expand Down Expand Up @@ -53,7 +52,7 @@ export type MCPServerOptions = {
*/
export type MCPOAuthCallbackResult =
| { serverId: string; authSuccess: true; authError?: undefined }
| { serverId: string; authSuccess: false; authError: string };
| { serverId?: string; authSuccess: false; authError: string };

/**
* Options for registering an MCP server
Expand Down Expand Up @@ -106,11 +105,14 @@ export type MCPClientOAuthCallbackConfig = {
customHandler?: (result: MCPClientOAuthResult) => Response;
};

export type MCPClientOAuthResult = {
serverId: string;
authSuccess: boolean;
authError?: string;
};
export type MCPClientOAuthResult =
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

love this

| { serverId: string; authSuccess: true; authError?: undefined }
| {
serverId?: string;
authSuccess: false;
/** May contain untrusted content from external OAuth providers. Escape appropriately for your output context. */
authError: string;
};

export type MCPClientManagerOptions = {
storage: DurableObjectStorage;
Expand Down Expand Up @@ -723,39 +725,93 @@ export class MCPClientManager {
});
}

async handleCallbackRequest(req: Request): Promise<MCPOAuthCallbackResult> {
private validateCallbackRequest(
req: Request
):
| { valid: true; serverId: string; code: string; state: string }
| { valid: false; serverId?: string; error: string } {
const url = new URL(req.url);
const code = url.searchParams.get("code");
const state = url.searchParams.get("state");
const error = url.searchParams.get("error");
const errorDescription = url.searchParams.get("error_description");

// Early validation - these throw because we can't identify the connection
// Early validation - return errors because we can't identify the connection
if (!state) {
throw new Error("Unauthorized: no state provided");
return {
valid: false,
error: "Unauthorized: no state provided"
};
}

const serverId = this.extractServerIdFromState(state);
if (!serverId) {
throw new Error(
"No serverId found in state parameter. Expected format: {nonce}.{serverId}"
);
return {
valid: false,
error:
"No serverId found in state parameter. Expected format: {nonce}.{serverId}"
};
}

if (error) {
return {
serverId: serverId,
valid: false,
error: errorDescription || error
};
}

if (!code) {
return {
serverId: serverId,
valid: false,
error: "Unauthorized: no code provided"
};
}

const servers = this.getServersFromStorage();
const serverExists = servers.some((server) => server.id === serverId);
if (!serverExists) {
throw new Error(
`No server found with id "${serverId}". Was the request matched with \`isCallbackRequest()\`?`
);
return {
serverId: serverId,
valid: false,
error: `No server found with id "${serverId}". Was the request matched with \`isCallbackRequest()\`?`
};
}

if (this.mcpConnections[serverId] === undefined) {
throw new Error(`Could not find serverId: ${serverId}`);
return {
serverId: serverId,
valid: false,
error: `No connection found for serverId "${serverId}".`
};
}

// We have a valid connection - all errors from here should fail the connection
const conn = this.mcpConnections[serverId];
return {
valid: true,
serverId,
code: code,
state: state
};
}

async handleCallbackRequest(req: Request): Promise<MCPOAuthCallbackResult> {
const validation = this.validateCallbackRequest(req);

if (!validation.valid) {
if (validation.serverId && this.mcpConnections[validation.serverId]) {
return this.failConnection(validation.serverId, validation.error);
}

return {
serverId: validation.serverId,
authSuccess: false,
authError: validation.error
};
}

const { serverId, code, state } = validation;
const conn = this.mcpConnections[serverId]; // We have a valid connection - all errors from here should fail the connection

try {
if (!conn.options.transport.authProvider) {
Expand All @@ -774,15 +830,6 @@ export class MCPClientManager {
throw new Error(stateValidation.error || "Invalid state");
}

if (error) {
// Escape external OAuth error params to prevent XSS
throw new Error(escapeHtml(errorDescription || error));
}

if (!code) {
throw new Error("Unauthorized: no code provided");
}

// Already authenticated - just return success
if (
conn.connectionState === MCPConnectionState.READY ||
Expand Down
7 changes: 2 additions & 5 deletions packages/agents/src/tests/agents/oauth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { Agent } from "../../index.ts";
import { DurableObjectOAuthClientProvider } from "../../mcp/do-oauth-client-provider";
import type { AgentMcpOAuthProvider } from "../../mcp/do-oauth-client-provider";
import type { MCPClientConnection } from "../../mcp/client-connection";
import type { MCPClientOAuthResult } from "../../mcp/client.ts";

// Test Agent for OAuth client side flows
export class TestOAuthAgent extends Agent<Record<string, unknown>> {
Expand All @@ -19,11 +20,7 @@ export class TestOAuthAgent extends Agent<Record<string, unknown>> {
}): void {
if (config.useJsonHandler) {
this.mcp.configureOAuthCallback({
customHandler: (result: {
serverId: string;
authSuccess: boolean;
authError?: string;
}) => {
customHandler: (result: MCPClientOAuthResult) => {
return new Response(
JSON.stringify({
custom: true,
Expand Down
48 changes: 18 additions & 30 deletions packages/agents/src/tests/mcp/client-manager.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -349,14 +349,14 @@ describe("MCPClientManager OAuth Integration", () => {
);
});

it("should throw error for callback without matching URL", async () => {
it("should return auth error for callback without matching URL", async () => {
const callbackRequest = new Request(
"http://localhost:3000/unknown?code=test&state=invalid.format"
);

await expect(
manager.handleCallbackRequest(callbackRequest)
).rejects.toThrow("No server found with id");
const result = await manager.handleCallbackRequest(callbackRequest);
expect(result.authSuccess).toBe(false);
expect(result.authError).toContain("No server found with id");
});

it("should handle OAuth error response from provider", async () => {
Expand Down Expand Up @@ -434,25 +434,25 @@ describe("MCPClientManager OAuth Integration", () => {
expect(result.authError).toBe("Unauthorized: no code provided");
});

it("should throw error for callback without state", async () => {
it("should return auth error for callback without state", async () => {
const callbackUrl = "http://localhost:3000/callback";
const callbackRequest = new Request(`${callbackUrl}?code=test`);

await expect(
manager.handleCallbackRequest(callbackRequest)
).rejects.toThrow("Unauthorized: no state provided");
const result = await manager.handleCallbackRequest(callbackRequest);
expect(result.authSuccess).toBe(false);
expect(result.authError).toBe("Unauthorized: no state provided");
});

it("should throw error for callback with non-existent server", async () => {
it("should return auth error for callback with non-existent server", async () => {
const stateStorage = createMockStateStorage();
const state = stateStorage.createState("non-existent");
const callbackRequest = new Request(
`http://localhost:3000/callback?code=test&state=${state}`
);

await expect(
manager.handleCallbackRequest(callbackRequest)
).rejects.toThrow("No server found with id");
const result = await manager.handleCallbackRequest(callbackRequest);
expect(result.authSuccess).toBe(false);
expect(result.authError).toContain("No server found with id");
});

it("should handle duplicate callback when already in ready state", async () => {
Expand Down Expand Up @@ -935,7 +935,7 @@ describe("MCPClientManager OAuth Integration", () => {
expect(result.authError).toBe("server_error");
});

it("should escape XSS payloads in error_description", async () => {
it("should pass through raw error_description without escaping", async () => {
const serverId = "test-server";
const callbackUrl = "http://localhost:3000/callback";
const stateStorage = createMockStateStorage();
Expand Down Expand Up @@ -970,19 +970,11 @@ describe("MCPClientManager OAuth Integration", () => {
const result = await manager.handleCallbackRequest(callbackRequest);

expect(result.authSuccess).toBe(false);
// Verify XSS payload is escaped
expect(result.authError).toBe(
"&lt;/script&gt;&lt;img src=x onerror=alert(1)&gt;"
);
expect(connection.connectionError).toBe(
"&lt;/script&gt;&lt;img src=x onerror=alert(1)&gt;"
);
// Should not contain raw script tag
expect(result.authError).not.toContain("<script>");
expect(result.authError).not.toContain("</script>");
expect(result.authError).toBe(xssPayload);
expect(connection.connectionError).toBe(xssPayload);
});

it("should escape XSS payloads in error parameter when description is absent", async () => {
it("should pass through raw error parameter without escaping when description is absent", async () => {
const serverId = "test-server";
const callbackUrl = "http://localhost:3000/callback";
const stateStorage = createMockStateStorage();
Expand Down Expand Up @@ -1017,12 +1009,8 @@ describe("MCPClientManager OAuth Integration", () => {
const result = await manager.handleCallbackRequest(callbackRequest);

expect(result.authSuccess).toBe(false);
expect(result.authError).toBe(
"&lt;script&gt;alert(&#39;xss&#39;)&lt;/script&gt;"
);
expect(connection.connectionError).toBe(
"&lt;script&gt;alert(&#39;xss&#39;)&lt;/script&gt;"
);
expect(result.authError).toBe(xssPayload);
expect(connection.connectionError).toBe(xssPayload);
});

it("should handle token exchange failure", async () => {
Expand Down
Loading
Loading