Skip to content

Commit 800676b

Browse files
committed
feat(mcp): OAuth 2.1 support for outbound MCP servers
Adds full MCP OAuth 2.1 + PKCE + Dynamic Client Registration (RFC 7591) support for outbound MCP servers via the SDK's `authProvider` interface. - `mcp_server_oauth` table holds per-server SDK OAuth artifacts (client info, encrypted tokens, PKCE verifier, state) workspace-scoped and shared across workspace members. - `mcp_servers.{auth_type, oauth_client_id, oauth_client_secret}` columns capture probe result and optional pre-registered credentials for ASes that don't support DCR. - `SimMcpOauthProvider` implements the SDK's `OAuthClientProvider` with a storage-backed redirect-sentinel pattern; the popup flow runs through `/api/mcp/oauth/{start,callback}` and posts back to the opener. - Unauthorized errors during tool execution surface as `reauth_required` so the UI can re-prompt without a stale-server flicker. - Tests, audit script baseline, and turbo bump included.
1 parent 22b5a1e commit 800676b

39 files changed

Lines changed: 19279 additions & 956 deletions

File tree

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import { auth as mcpAuth } from '@modelcontextprotocol/sdk/client/auth.js'
2+
import { db } from '@sim/db'
3+
import { mcpServers } from '@sim/db/schema'
4+
import { createLogger } from '@sim/logger'
5+
import { toError } from '@sim/utils/errors'
6+
import { and, eq, isNull } from 'drizzle-orm'
7+
import type { NextRequest } from 'next/server'
8+
import { NextResponse } from 'next/server'
9+
import { getSession } from '@/lib/auth'
10+
import { withRouteHandler } from '@/lib/core/utils/with-route-handler'
11+
import {
12+
clearState,
13+
clearVerifier,
14+
loadOauthRowByState,
15+
loadPreregisteredClient,
16+
SimMcpOauthProvider,
17+
} from '@/lib/mcp/oauth'
18+
import { mcpService } from '@/lib/mcp/service'
19+
20+
const logger = createLogger('McpOauthCallbackAPI')
21+
22+
export const dynamic = 'force-dynamic'
23+
24+
function escapeHtml(value: string): string {
25+
return value
26+
.replace(/&/g, '&')
27+
.replace(/</g, '&lt;')
28+
.replace(/>/g, '&gt;')
29+
.replace(/"/g, '&quot;')
30+
.replace(/'/g, '&#39;')
31+
}
32+
33+
function htmlClose(message: string, ok: boolean, serverId?: string): NextResponse {
34+
const safeMessage = escapeHtml(message)
35+
const title = ok ? 'Connected' : 'Connection failed'
36+
const serverIdLiteral = serverId
37+
? JSON.stringify(serverId).replace(/</g, '\\u003c').replace(/>/g, '\\u003e')
38+
: 'undefined'
39+
const body = `<!doctype html><html><head><meta charset="utf-8"><title>${title}</title></head><body style="font-family: system-ui; padding: 24px"><p>${safeMessage}</p><script>
40+
try { window.opener && window.opener.postMessage({ type: 'mcp-oauth', ok: ${ok ? 'true' : 'false'}, serverId: ${serverIdLiteral} }, window.location.origin) } catch (e) {}
41+
setTimeout(function () { window.close() }, 800)
42+
</script></body></html>`
43+
return new NextResponse(body, {
44+
headers: { 'Content-Type': 'text/html; charset=utf-8' },
45+
})
46+
}
47+
48+
export const GET = withRouteHandler(async (request: NextRequest) => {
49+
const url = new URL(request.url)
50+
const state = url.searchParams.get('state')
51+
const code = url.searchParams.get('code')
52+
const errorParam = url.searchParams.get('error')
53+
54+
if (errorParam) {
55+
logger.warn(`MCP OAuth callback received error: ${errorParam}`)
56+
return htmlClose(`Authorization failed: ${errorParam}`, false)
57+
}
58+
if (!state || !code) {
59+
return htmlClose('Missing state or code in callback URL.', false)
60+
}
61+
62+
let serverId: string | undefined
63+
try {
64+
const session = await getSession()
65+
if (!session?.user?.id) {
66+
return htmlClose('You must be signed in to complete authorization.', false)
67+
}
68+
69+
const row = await loadOauthRowByState(state)
70+
if (!row) {
71+
return htmlClose('Invalid or expired authorization state.', false)
72+
}
73+
serverId = row.mcpServerId
74+
75+
if (session.user.id !== row.userId) {
76+
return htmlClose(
77+
'You must be signed in as the same user that initiated the flow.',
78+
false,
79+
serverId
80+
)
81+
}
82+
83+
const [server] = await db
84+
.select({ id: mcpServers.id, url: mcpServers.url, workspaceId: mcpServers.workspaceId })
85+
.from(mcpServers)
86+
.where(and(eq(mcpServers.id, row.mcpServerId), isNull(mcpServers.deletedAt)))
87+
.limit(1)
88+
if (!server || !server.url) {
89+
return htmlClose('Server no longer exists.', false, serverId)
90+
}
91+
92+
// Burn state before token exchange so a replayed callback cannot reuse it.
93+
await clearState(row.id)
94+
95+
const preregistered = await loadPreregisteredClient(server.id)
96+
const provider = new SimMcpOauthProvider({ row, preregistered })
97+
let result: Awaited<ReturnType<typeof mcpAuth>>
98+
try {
99+
result = await mcpAuth(provider, {
100+
serverUrl: server.url,
101+
authorizationCode: code,
102+
})
103+
} finally {
104+
await clearVerifier(row.id)
105+
}
106+
107+
if (result !== 'AUTHORIZED') {
108+
return htmlClose('Authorization did not complete.', false, server.id)
109+
}
110+
111+
try {
112+
await mcpService.clearCache(server.workspaceId)
113+
await mcpService.discoverServerTools(session.user.id, server.id, server.workspaceId)
114+
} catch (e) {
115+
logger.warn('Post-auth tools refresh failed', toError(e).message)
116+
}
117+
118+
return htmlClose('Connected. You can close this window.', true, server.id)
119+
} catch (error) {
120+
logger.error('MCP OAuth callback failed', error)
121+
return htmlClose('Authorization failed. Please try again.', false, serverId)
122+
}
123+
})
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
/**
2+
* @vitest-environment node
3+
*/
4+
import {
5+
dbChainMock,
6+
dbChainMockFns,
7+
hybridAuthMock,
8+
hybridAuthMockFns,
9+
permissionsMock,
10+
permissionsMockFns,
11+
resetDbChainMock,
12+
schemaMock,
13+
} from '@sim/testing'
14+
import { NextRequest } from 'next/server'
15+
import { beforeEach, describe, expect, it, vi } from 'vitest'
16+
17+
const {
18+
mockMcpAuth,
19+
mockGetOrCreateOauthRow,
20+
mockLoadPreregisteredClient,
21+
mockSetOauthRowUser,
22+
MockMcpOauthRedirectRequired,
23+
} = vi.hoisted(() => ({
24+
mockMcpAuth: vi.fn(),
25+
mockGetOrCreateOauthRow: vi.fn(),
26+
mockLoadPreregisteredClient: vi.fn(),
27+
mockSetOauthRowUser: vi.fn(),
28+
MockMcpOauthRedirectRequired: class MockMcpOauthRedirectRequired extends Error {
29+
constructor(public readonly authorizationUrl: string) {
30+
super('redirect required')
31+
}
32+
},
33+
}))
34+
35+
vi.mock('@sim/db', () => dbChainMock)
36+
vi.mock('@sim/db/schema', () => schemaMock)
37+
vi.mock('drizzle-orm', () => ({
38+
and: vi.fn(),
39+
eq: vi.fn(),
40+
isNull: vi.fn(),
41+
}))
42+
vi.mock('@modelcontextprotocol/sdk/client/auth.js', () => ({
43+
auth: mockMcpAuth,
44+
}))
45+
vi.mock('@/lib/auth/hybrid', () => hybridAuthMock)
46+
vi.mock('@/lib/workspaces/permissions/utils', () => permissionsMock)
47+
vi.mock('@/lib/mcp/oauth', () => ({
48+
getOrCreateOauthRow: mockGetOrCreateOauthRow,
49+
loadPreregisteredClient: mockLoadPreregisteredClient,
50+
McpOauthRedirectRequired: MockMcpOauthRedirectRequired,
51+
setOauthRowUser: mockSetOauthRowUser,
52+
SimMcpOauthProvider: vi.fn().mockImplementation((value) => value),
53+
}))
54+
55+
import { GET } from './route'
56+
57+
describe('MCP OAuth start route', () => {
58+
beforeEach(() => {
59+
vi.clearAllMocks()
60+
resetDbChainMock()
61+
hybridAuthMockFns.mockCheckSessionOrInternalAuth.mockResolvedValue({
62+
success: true,
63+
userId: 'user-2',
64+
userName: 'User Two',
65+
userEmail: 'user2@example.com',
66+
authType: 'session',
67+
})
68+
permissionsMockFns.mockGetUserEntityPermissions.mockResolvedValue('write')
69+
dbChainMockFns.limit.mockResolvedValue([
70+
{
71+
id: 'server-1',
72+
name: 'Exa',
73+
url: 'https://mcp.exa.ai/mcp',
74+
workspaceId: 'workspace-1',
75+
authType: 'oauth',
76+
deletedAt: null,
77+
},
78+
])
79+
mockGetOrCreateOauthRow.mockResolvedValue({
80+
id: 'oauth-row-1',
81+
mcpServerId: 'server-1',
82+
userId: 'user-1',
83+
workspaceId: 'workspace-1',
84+
clientInformation: null,
85+
tokens: null,
86+
codeVerifier: null,
87+
state: null,
88+
updatedAt: new Date(),
89+
})
90+
mockLoadPreregisteredClient.mockResolvedValue(undefined)
91+
mockMcpAuth.mockRejectedValue(new MockMcpOauthRedirectRequired('https://mcp.exa.ai/authorize'))
92+
})
93+
94+
it('requires workspace write permission via MCP auth middleware', async () => {
95+
const request = new NextRequest(
96+
'http://localhost:3000/api/mcp/oauth/start?workspaceId=workspace-1&serverId=server-1'
97+
)
98+
99+
await GET(request)
100+
101+
expect(permissionsMockFns.mockGetUserEntityPermissions).toHaveBeenCalledWith(
102+
'user-2',
103+
'workspace',
104+
'workspace-1'
105+
)
106+
})
107+
108+
it('uses a workspace-scoped OAuth row and stamps the latest authorizing user', async () => {
109+
const request = new NextRequest(
110+
'http://localhost:3000/api/mcp/oauth/start?workspaceId=workspace-1&serverId=server-1'
111+
)
112+
113+
const response = await GET(request)
114+
const body = await response.json()
115+
116+
expect(response.status).toBe(200)
117+
expect(body).toEqual({
118+
status: 'redirect',
119+
authorizationUrl: 'https://mcp.exa.ai/authorize',
120+
})
121+
expect(mockGetOrCreateOauthRow).toHaveBeenCalledWith({
122+
mcpServerId: 'server-1',
123+
userId: 'user-2',
124+
workspaceId: 'workspace-1',
125+
})
126+
expect(mockSetOauthRowUser).toHaveBeenCalledWith('oauth-row-1', 'user-2')
127+
})
128+
129+
it('rejects a second user starting OAuth while another authorization is active', async () => {
130+
mockGetOrCreateOauthRow.mockResolvedValueOnce({
131+
id: 'oauth-row-1',
132+
mcpServerId: 'server-1',
133+
userId: 'user-1',
134+
workspaceId: 'workspace-1',
135+
clientInformation: null,
136+
tokens: null,
137+
codeVerifier: null,
138+
state: 'hashed-active-state',
139+
updatedAt: new Date(),
140+
})
141+
const request = new NextRequest(
142+
'http://localhost:3000/api/mcp/oauth/start?workspaceId=workspace-1&serverId=server-1'
143+
)
144+
145+
const response = await GET(request)
146+
const body = await response.json()
147+
148+
expect(response.status).toBe(409)
149+
expect(body.error).toBe('OAuth authorization already in progress for this server')
150+
expect(mockMcpAuth).not.toHaveBeenCalled()
151+
})
152+
})
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import { auth as mcpAuth } from '@modelcontextprotocol/sdk/client/auth.js'
2+
import { db } from '@sim/db'
3+
import { mcpServers } from '@sim/db/schema'
4+
import { createLogger } from '@sim/logger'
5+
import { toError } from '@sim/utils/errors'
6+
import { and, eq, isNull } from 'drizzle-orm'
7+
import type { NextRequest } from 'next/server'
8+
import { NextResponse } from 'next/server'
9+
import { startMcpOauthQuerySchema } from '@/lib/api/contracts/mcp'
10+
import { validationErrorResponse } from '@/lib/api/server'
11+
import { withRouteHandler } from '@/lib/core/utils/with-route-handler'
12+
import { withMcpAuth } from '@/lib/mcp/middleware'
13+
import {
14+
getOrCreateOauthRow,
15+
loadPreregisteredClient,
16+
McpOauthRedirectRequired,
17+
SimMcpOauthProvider,
18+
setOauthRowUser,
19+
} from '@/lib/mcp/oauth'
20+
import { createMcpErrorResponse } from '@/lib/mcp/utils'
21+
22+
const logger = createLogger('McpOauthStartAPI')
23+
const OAUTH_START_TTL_MS = 10 * 60 * 1000
24+
25+
export const dynamic = 'force-dynamic'
26+
27+
export const GET = withRouteHandler(
28+
withMcpAuth('write')(async (request: NextRequest, { userId, workspaceId, requestId }) => {
29+
try {
30+
const queryResult = startMcpOauthQuerySchema.safeParse(
31+
Object.fromEntries(new URL(request.url).searchParams)
32+
)
33+
if (!queryResult.success) {
34+
return validationErrorResponse(queryResult.error)
35+
}
36+
const { serverId } = queryResult.data
37+
38+
const [server] = await db
39+
.select()
40+
.from(mcpServers)
41+
.where(
42+
and(
43+
eq(mcpServers.id, serverId),
44+
eq(mcpServers.workspaceId, workspaceId),
45+
isNull(mcpServers.deletedAt)
46+
)
47+
)
48+
.limit(1)
49+
50+
if (!server) {
51+
return createMcpErrorResponse(new Error('Server not found'), 'Server not found', 404)
52+
}
53+
if (server.authType !== 'oauth') {
54+
return createMcpErrorResponse(
55+
new Error(`Server authType is "${server.authType}", not oauth`),
56+
'Server is not configured for OAuth',
57+
400
58+
)
59+
}
60+
if (!server.url) {
61+
return createMcpErrorResponse(new Error('Server has no URL'), 'Missing server URL', 400)
62+
}
63+
64+
const row = await getOrCreateOauthRow({
65+
mcpServerId: server.id,
66+
userId,
67+
workspaceId,
68+
})
69+
const hasActiveFlow = !!row.state && row.updatedAt.getTime() > Date.now() - OAUTH_START_TTL_MS
70+
if (hasActiveFlow && row.userId && row.userId !== userId) {
71+
return createMcpErrorResponse(
72+
new Error('OAuth authorization already in progress'),
73+
'OAuth authorization already in progress for this server',
74+
409
75+
)
76+
}
77+
if (row.userId !== userId) {
78+
await setOauthRowUser(row.id, userId)
79+
row.userId = userId
80+
}
81+
const preregistered = await loadPreregisteredClient(server.id)
82+
const provider = new SimMcpOauthProvider({ row, preregistered })
83+
84+
try {
85+
const result = await mcpAuth(provider, { serverUrl: server.url })
86+
if (result === 'AUTHORIZED') {
87+
return NextResponse.json({ status: 'already_authorized' })
88+
}
89+
return createMcpErrorResponse(
90+
new Error('Provider did not capture redirect URL'),
91+
'Failed to start OAuth flow',
92+
500
93+
)
94+
} catch (e) {
95+
if (e instanceof McpOauthRedirectRequired) {
96+
logger.info(`[${requestId}] OAuth redirect for server ${serverId}`)
97+
return NextResponse.json({
98+
status: 'redirect',
99+
authorizationUrl: e.authorizationUrl,
100+
})
101+
}
102+
throw e
103+
}
104+
} catch (error) {
105+
logger.error(`[${requestId}] Error starting MCP OAuth flow:`, error)
106+
return createMcpErrorResponse(toError(error), 'Failed to start OAuth flow', 500)
107+
}
108+
})
109+
)

0 commit comments

Comments
 (0)