Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions packages/playwright/src/mcp/sdk/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ export type ProgressCallback = (params: ProgressParams) => void;
export interface ServerBackend {
initialize?(clientInfo: ClientInfo): Promise<void>;
listTools(): Promise<Tool[]>;
callTool(name: string, args: CallToolRequest['params']['arguments'], progress: ProgressCallback): Promise<CallToolResult>;
callTool(name: string, args: CallToolRequest['params']['arguments'], progress: ProgressCallback, signal?: AbortSignal): Promise<CallToolResult>;
serverClosed?(server: Server): void;
}

Expand Down Expand Up @@ -111,7 +111,7 @@ export function createServer(name: string, version: string, backend: ServerBacke
if (!initializePromise)
initializePromise = initializeServer(server, backend, runHeartbeat);
await initializePromise;
const toolResult = await backend.callTool(request.params.name, request.params.arguments || {}, progress);
const toolResult = await backend.callTool(request.params.name, request.params.arguments || {}, progress, extra.signal);
const mergedResult = mergeTextParts(toolResult);
serverDebugResponse('callResult', mergedResult);
return mergedResult;
Expand Down
4 changes: 2 additions & 2 deletions packages/playwright/src/mcp/test/generatorTools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ export const setupPage = defineTestTool({
type: 'readOnly',
},

handle: async (context, params) => {
handle: async (context, params, signal) => {
const seed = await context.getOrCreateSeedFile(params.seedFile, params.project);
context.generatorJournal = new GeneratorJournal(context.rootPath, params.plan, seed);
const { output, status } = await context.runSeedTest(seed.file, seed.projectName);
const { output, status } = await context.runSeedTest(seed.file, seed.projectName, signal);
return { content: [{ type: 'text', text: output }], isError: status !== 'paused' };
},
});
Expand Down
4 changes: 2 additions & 2 deletions packages/playwright/src/mcp/test/plannerTools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ export const setupPage = defineTestTool({
type: 'readOnly',
},

handle: async (context, params) => {
handle: async (context, params, signal) => {
const seed = await context.getOrCreateSeedFile(params.seedFile, params.project);
const { output, status } = await context.runSeedTest(seed.file, seed.projectName);
const { output, status } = await context.runSeedTest(seed.file, seed.projectName, signal);
return { content: [{ type: 'text', text: output }], isError: status !== 'paused' };
},
});
Expand Down
4 changes: 2 additions & 2 deletions packages/playwright/src/mcp/test/testBackend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ export class TestServerBackend implements mcp.ServerBackend {
return this._tools.map(tool => mcp.toMcpTool(tool.schema));
}

async callTool(name: string, args: mcp.CallToolRequest['params']['arguments']): Promise<mcp.CallToolResult> {
async callTool(name: string, args: mcp.CallToolRequest['params']['arguments'], _progress: mcp.ProgressCallback, signal: AbortSignal): Promise<mcp.CallToolResult> {
const tool = this._tools.find(tool => tool.schema.name === name);
if (!tool)
throw new Error(`Tool not found: ${name}. Available tools: ${this._tools.map(tool => tool.schema.name).join(', ')}`);
try {
return await tool.handle(this._context!, tool.schema.inputSchema.parse(args || {}));
return await tool.handle(this._context!, tool.schema.inputSchema.parse(args || {}), signal);
} catch (e) {
return { content: [{ type: 'text', text: String(e) }], isError: true };
}
Expand Down
11 changes: 8 additions & 3 deletions packages/playwright/src/mcp/test/testContext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ export class TestContext {
};
}

async runSeedTest(seedFile: string, projectName: string): Promise<{ output: string, status: FullResultStatus | 'paused' }> {
async runSeedTest(seedFile: string, projectName: string, signal?: AbortSignal): Promise<{ output: string, status: FullResultStatus | 'paused' }> {
const result = await this.runTestsWithGlobalSetupAndPossiblePause({
headed: this.computedHeaded,
locations: ['/' + escapeRegExp(seedFile) + '/'],
Expand All @@ -190,19 +190,23 @@ export class TestContext {
pauseAtEnd: true,
disableConfigReporters: true,
failOnLoadErrors: true,
});
}, signal);
if (result.status === 'passed')
result.output += '\nError: seed test not found.';
else if (result.status !== 'paused')
result.output += '\nError while running the seed test.';
return result;
}

async runTestsWithGlobalSetupAndPossiblePause(params: RunTestsParams): Promise<{ output: string, status: FullResultStatus | 'paused' }> {
async runTestsWithGlobalSetupAndPossiblePause(params: RunTestsParams, signal?: AbortSignal): Promise<{ output: string, status: FullResultStatus | 'paused' }> {
const configDir = this._configLocation.configDir;
const testRunnerAndScreen = await this.createTestRunner();
const { testRunner, screen, claimStdio, releaseStdio } = testRunnerAndScreen;

// Stop tests when the signal is aborted
const abortHandler = () => testRunner.stopTests();
signal?.addEventListener('abort', abortHandler);

claimStdio();
try {
const setupReporter = new MCPListReporter({ configDir, screen, includeTestId: true });
Expand All @@ -216,6 +220,7 @@ export class TestContext {
let status: FullResultStatus | 'paused' = 'passed';

const cleanup = async () => {
signal?.removeEventListener('abort', abortHandler);
claimStdio();
try {
const result = await testRunner.runGlobalTeardown();
Expand Down
2 changes: 1 addition & 1 deletion packages/playwright/src/mcp/test/testTool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import type { ToolSchema } from '../sdk/tool.js';

export type TestTool<Input extends z.Schema = z.Schema> = {
schema: ToolSchema<Input>;
handle: (context: TestContext, params: z.output<Input>) => Promise<CallToolResult>;
handle: (context: TestContext, params: z.output<Input>, signal: AbortSignal) => Promise<CallToolResult>;
};

export function defineTestTool<Input extends z.Schema>(tool: TestTool<Input>): TestTool<Input> {
Expand Down
8 changes: 4 additions & 4 deletions packages/playwright/src/mcp/test/testTools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ export const runTests = defineTestTool({
type: 'readOnly',
},

handle: async (context, params) => {
handle: async (context, params, signal) => {
const { output } = await context.runTestsWithGlobalSetupAndPossiblePause({
locations: params.locations ?? [],
projects: params.projects,
disableConfigReporters: true,
});
}, signal);
return { content: [{ type: 'text', text: output }] };
},
});
Expand All @@ -71,7 +71,7 @@ export const debugTest = defineTestTool({
type: 'readOnly',
},

handle: async (context, params) => {
handle: async (context, params, signal) => {
const { output, status } = await context.runTestsWithGlobalSetupAndPossiblePause({
headed: context.computedHeaded,
locations: [], // we can make this faster by passing the test's location, so we don't need to scan all tests to find the ID
Expand All @@ -82,7 +82,7 @@ export const debugTest = defineTestTool({
pauseOnError: true,
disableConfigReporters: true,
actionTimeout: 5000,
});
}, signal);
return { content: [{ type: 'text', text: output }], isError: status !== 'paused' && status !== 'passed' };
},
});
32 changes: 32 additions & 0 deletions tests/mcp/test-run.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,35 @@ Running 3 tests using 1 worker
ok 3 [id=<ID>] [project=chromium] › example.test.ts:6:11 › example2 (XXms)
3 passed (XXms)`);
});

test('test_run should stop when aborted', async ({ startClient }) => {
await writeFiles({
'slow.test.ts': `
import { test, expect } from '@playwright/test';
test('slow test', async () => {
await new Promise(resolve => setTimeout(resolve, 30000));
});
`,
});

const { client } = await startClient();

const abortController = new AbortController();

// Start the test run
const startTime = Date.now();
const testRunPromise = client.callTool({
name: 'test_run',
}, undefined, { signal: abortController.signal });

// Wait a bit for the test to start, then abort
await new Promise(resolve => setTimeout(resolve, 500));
abortController.abort();

// The call should reject with an abort error
await expect(testRunPromise).rejects.toThrow(/abort/i);

// Verify the abort happened quickly (not after 30 seconds)
const elapsed = Date.now() - startTime;
expect(elapsed).toBeLessThan(5000);
});