chore: get rid of connection factory (#362)

Drive-by User-Agent sniffing and disabling of image type in Cursor.
This commit is contained in:
Pavel Feldman 2025-05-06 14:27:28 -07:00 committed by GitHub
parent 23a2e5fee7
commit e95b5b1dd6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 181 additions and 154 deletions

View File

@ -163,14 +163,11 @@ The Playwright MCP server can be configured using a JSON configuration file. Her
// List of origins to block the browser to request. Origins matching both `allowedOrigins` and `blockedOrigins` will be blocked. // List of origins to block the browser to request. Origins matching both `allowedOrigins` and `blockedOrigins` will be blocked.
blockedOrigins?: string[]; blockedOrigins?: string[];
}; };
// Tool-specific configurations /**
tools?: { * Do not send image responses to the client.
browser_take_screenshot?: { */
// Disable base64-encoded image responses noImageResponses?: boolean;
omitBase64?: boolean;
}
}
} }
``` ```
@ -234,9 +231,9 @@ http.createServer(async (req, res) => {
// ... // ...
// Creates a headless Playwright MCP server with SSE transport // Creates a headless Playwright MCP server with SSE transport
const mcpServer = await createServer({ headless: true }); const connection = await createConnection({ headless: true });
const transport = new SSEServerTransport('/messages', res); const transport = new SSEServerTransport('/messages', res);
await mcpServer.connect(transport); await connection.connect(transport);
// ... // ...
}); });

16
config.d.ts vendored
View File

@ -107,19 +107,7 @@ export type Config = {
}; };
/** /**
* Configuration for specific tools. * Do not send image responses to the client.
*/ */
tools?: { noImageResponses?: boolean;
/**
* Configuration for the browser_take_screenshot tool.
*/
browser_take_screenshot?: {
/**
* Whether to disable base64-encoded image responses to the clients that
* don't support binary data or prefer to save on tokens.
*/
omitBase64?: boolean;
}
}
}; };

10
index.d.ts vendored
View File

@ -16,8 +16,14 @@
*/ */
import type { Server } from '@modelcontextprotocol/sdk/server/index.js'; import type { Server } from '@modelcontextprotocol/sdk/server/index.js';
import type { Config } from './config'; import type { Config } from './config';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
export declare function createServer(config?: Config): Promise<Server>; export type Connection = {
server: Server;
connect(transport: Transport): Promise<void>;
close(): Promise<void>;
};
export declare function createConnection(config?: Config): Promise<Connection>;
export {}; export {};

View File

@ -15,5 +15,5 @@
* limitations under the License. * limitations under the License.
*/ */
import { createServer } from './lib/index'; import { createConnection } from './lib/index';
export default { createServer }; export default { createConnection };

View File

@ -39,6 +39,7 @@ export type CLIOptions = {
allowedOrigins?: string[]; allowedOrigins?: string[];
blockedOrigins?: string[]; blockedOrigins?: string[];
outputDir?: string; outputDir?: string;
noImageResponses?: boolean;
}; };
const defaultConfig: Config = { const defaultConfig: Config = {

View File

@ -19,20 +19,19 @@ import { CallToolRequestSchema, ListToolsRequestSchema, Tool as McpTool } from '
import { zodToJsonSchema } from 'zod-to-json-schema'; import { zodToJsonSchema } from 'zod-to-json-schema';
import { Context } from './context.js'; import { Context } from './context.js';
import { snapshotTools, screenshotTools } from './tools.js';
import type { Tool } from './tools/tool.js';
import type { Config } from '../config.js'; import type { Config } from '../config.js';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
type MCPServerOptions = { import packageJSON from '../package.json' with { type: 'json' };
name: string;
version: string; export async function createConnection(config: Config): Promise<Connection> {
tools: Tool[]; const allTools = config.vision ? screenshotTools : snapshotTools;
}; const tools = allTools.filter(tool => !config.capabilities || tool.capability === 'core' || config.capabilities.includes(tool.capability));
export function createServerWithTools(serverOptions: MCPServerOptions, config: Config): Server {
const { name, version, tools } = serverOptions;
const context = new Context(tools, config); const context = new Context(tools, config);
const server = new Server({ name, version }, { const server = new Server({ name: 'Playwright', version: packageJSON.version }, {
capabilities: { capabilities: {
tools: {}, tools: {},
} }
@ -77,38 +76,30 @@ export function createServerWithTools(serverOptions: MCPServerOptions, config: C
} }
}); });
const oldClose = server.close.bind(server); const connection = new Connection(server, context);
return connection;
server.close = async () => {
await oldClose();
await context.close();
};
return server;
} }
export class ServerList { export class Connection {
private _servers: Server[] = []; readonly server: Server;
private _serverFactory: () => Promise<Server>; readonly context: Context;
constructor(serverFactory: () => Promise<Server>) { constructor(server: Server, context: Context) {
this._serverFactory = serverFactory; this.server = server;
this.context = context;
} }
async create() { async connect(transport: Transport) {
const server = await this._serverFactory(); await this.server.connect(transport);
this._servers.push(server); await new Promise<void>(resolve => {
return server; this.server.oninitialized = () => resolve();
});
if (this.server.getClientVersion()?.name.includes('cursor'))
this.context.config.noImageResponses = true;
} }
async close(server: Server) { async close() {
const index = this._servers.indexOf(server); await this.server.close();
if (index !== -1) await this.context.close();
this._servers.splice(index, 1);
await server.close();
}
async closeAll() {
await Promise.all(this._servers.map(server => server.close()));
} }
} }

View File

@ -14,62 +14,10 @@
* limitations under the License. * limitations under the License.
*/ */
import { createServerWithTools } from './server.js'; import { Connection } from './connection.js';
import common from './tools/common.js';
import console from './tools/console.js';
import dialogs from './tools/dialogs.js';
import files from './tools/files.js';
import install from './tools/install.js';
import keyboard from './tools/keyboard.js';
import navigate from './tools/navigate.js';
import network from './tools/network.js';
import pdf from './tools/pdf.js';
import snapshot from './tools/snapshot.js';
import tabs from './tools/tabs.js';
import screen from './tools/screen.js';
import testing from './tools/testing.js';
import type { Tool } from './tools/tool.js';
import type { Config } from '../config.js'; import type { Config } from '../config.js';
import type { Server } from '@modelcontextprotocol/sdk/server/index.js';
const snapshotTools: Tool<any>[] = [ export async function createConnection(config: Config = {}): Promise<Connection> {
...common(true), return createConnection(config);
...console,
...dialogs(true),
...files(true),
...install,
...keyboard(true),
...navigate(true),
...network,
...pdf,
...snapshot,
...tabs(true),
...testing,
];
const screenshotTools: Tool<any>[] = [
...common(false),
...console,
...dialogs(false),
...files(false),
...install,
...keyboard(false),
...navigate(false),
...network,
...pdf,
...screen,
...tabs(false),
...testing,
];
import packageJSON from '../package.json' with { type: 'json' };
export async function createServer(config: Config = {}): Promise<Server> {
const allTools = config.vision ? screenshotTools : snapshotTools;
const tools = allTools.filter(tool => !config.capabilities || tool.capability === 'core' || config.capabilities.includes(tool.capability));
return createServerWithTools({
name: 'Playwright',
version: packageJSON.version,
tools,
}, config);
} }

View File

@ -16,13 +16,11 @@
import { program } from 'commander'; import { program } from 'commander';
import { createServer } from './index.js';
import { ServerList } from './server.js';
import { startHttpTransport, startStdioTransport } from './transport.js'; import { startHttpTransport, startStdioTransport } from './transport.js';
import { resolveConfig } from './config.js'; import { resolveConfig } from './config.js';
import type { Connection } from './connection.js';
import packageJSON from '../package.json' with { type: 'json' }; import packageJSON from '../package.json' with { type: 'json' };
program program
@ -40,23 +38,25 @@ program
.option('--allowed-origins <origins>', 'Semicolon-separated list of origins to allow the browser to request. Default is to allow all.', semicolonSeparatedList) .option('--allowed-origins <origins>', 'Semicolon-separated list of origins to allow the browser to request. Default is to allow all.', semicolonSeparatedList)
.option('--blocked-origins <origins>', 'Semicolon-separated list of origins to block the browser from requesting. Blocklist is evaluated before allowlist. If used without the allowlist, requests not matching the blocklist are still allowed.', semicolonSeparatedList) .option('--blocked-origins <origins>', 'Semicolon-separated list of origins to block the browser from requesting. Blocklist is evaluated before allowlist. If used without the allowlist, requests not matching the blocklist are still allowed.', semicolonSeparatedList)
.option('--vision', 'Run server that uses screenshots (Aria snapshots are used by default)') .option('--vision', 'Run server that uses screenshots (Aria snapshots are used by default)')
.option('--no-image-responses', 'Do not send image responses to the client.')
.option('--output-dir <path>', 'Path to the directory for output files.') .option('--output-dir <path>', 'Path to the directory for output files.')
.option('--config <path>', 'Path to the configuration file.') .option('--config <path>', 'Path to the configuration file.')
.action(async options => { .action(async options => {
const config = await resolveConfig(options); const config = await resolveConfig(options);
const serverList = new ServerList(() => createServer(config)); const connectionList: Connection[] = [];
setupExitWatchdog(serverList); setupExitWatchdog(connectionList);
if (options.port) if (options.port)
startHttpTransport(+options.port, options.host, serverList); startHttpTransport(config, +options.port, options.host, connectionList);
else else
await startStdioTransport(serverList); await startStdioTransport(config, connectionList);
}); });
function setupExitWatchdog(serverList: ServerList) { function setupExitWatchdog(connectionList: Connection[]) {
const handleExit = async () => { const handleExit = async () => {
setTimeout(() => process.exit(0), 15000); setTimeout(() => process.exit(0), 15000);
await serverList.closeAll(); for (const connection of connectionList)
await connection.close();
process.exit(0); process.exit(0);
}; };

61
src/tools.ts Normal file
View File

@ -0,0 +1,61 @@
/**
* Copyright (c) Microsoft Corporation.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import common from './tools/common.js';
import console from './tools/console.js';
import dialogs from './tools/dialogs.js';
import files from './tools/files.js';
import install from './tools/install.js';
import keyboard from './tools/keyboard.js';
import navigate from './tools/navigate.js';
import network from './tools/network.js';
import pdf from './tools/pdf.js';
import snapshot from './tools/snapshot.js';
import tabs from './tools/tabs.js';
import screen from './tools/screen.js';
import testing from './tools/testing.js';
import type { Tool } from './tools/tool.js';
export const snapshotTools: Tool<any>[] = [
...common(true),
...console,
...dialogs(true),
...files(true),
...install,
...keyboard(true),
...navigate(true),
...network,
...pdf,
...snapshot,
...tabs(true),
...testing,
];
export const screenshotTools: Tool<any>[] = [
...common(false),
...console,
...dialogs(false),
...files(false),
...install,
...keyboard(false),
...navigate(false),
...network,
...pdf,
...screen,
...tabs(false),
...testing,
];

View File

@ -258,7 +258,7 @@ const screenshot = defineTool({
else else
code.push(`await page.screenshot(${javascript.formatObject(options)});`); code.push(`await page.screenshot(${javascript.formatObject(options)});`);
const includeBase64 = !context.config.tools?.browser_take_screenshot?.omitBase64; const includeBase64 = !context.config.noImageResponses;
const action = async () => { const action = async () => {
const screenshot = locator ? await locator.screenshot(options) : await tab.page.screenshot(options); const screenshot = locator ? await locator.screenshot(options) : await tab.page.screenshot(options);
return { return {

View File

@ -18,17 +18,22 @@ import http from 'node:http';
import assert from 'node:assert'; import assert from 'node:assert';
import crypto from 'node:crypto'; import crypto from 'node:crypto';
import { ServerList } from './server.js';
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js';
import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js'; import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js';
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js';
export async function startStdioTransport(serverList: ServerList) { import { createConnection } from './connection.js';
const server = await serverList.create();
await server.connect(new StdioServerTransport()); import type { Config } from '../config.js';
import type { Connection } from './connection.js';
export async function startStdioTransport(config: Config, connectionList: Connection[]) {
const connection = await createConnection(config);
await connection.connect(new StdioServerTransport());
connectionList.push(connection);
} }
async function handleSSE(req: http.IncomingMessage, res: http.ServerResponse, url: URL, serverList: ServerList, sessions: Map<string, SSEServerTransport>) { async function handleSSE(config: Config, req: http.IncomingMessage, res: http.ServerResponse, url: URL, sessions: Map<string, SSEServerTransport>, connectionList: Connection[]) {
if (req.method === 'POST') { if (req.method === 'POST') {
const sessionId = url.searchParams.get('sessionId'); const sessionId = url.searchParams.get('sessionId');
if (!sessionId) { if (!sessionId) {
@ -46,22 +51,24 @@ async function handleSSE(req: http.IncomingMessage, res: http.ServerResponse, ur
} else if (req.method === 'GET') { } else if (req.method === 'GET') {
const transport = new SSEServerTransport('/sse', res); const transport = new SSEServerTransport('/sse', res);
sessions.set(transport.sessionId, transport); sessions.set(transport.sessionId, transport);
const server = await serverList.create(); const connection = await createConnection(config);
await connection.connect(transport);
connectionList.push(connection);
res.on('close', () => { res.on('close', () => {
sessions.delete(transport.sessionId); sessions.delete(transport.sessionId);
serverList.close(server).catch(e => { connection.close().catch(e => {
// eslint-disable-next-line no-console // eslint-disable-next-line no-console
console.error(e); console.error(e);
}); });
}); });
return await server.connect(transport); return;
} }
res.statusCode = 405; res.statusCode = 405;
res.end('Method not allowed'); res.end('Method not allowed');
} }
async function handleStreamable(req: http.IncomingMessage, res: http.ServerResponse, serverList: ServerList, sessions: Map<string, StreamableHTTPServerTransport>) { async function handleStreamable(config: Config, req: http.IncomingMessage, res: http.ServerResponse, sessions: Map<string, StreamableHTTPServerTransport>, connectionList: Connection[]) {
const sessionId = req.headers['mcp-session-id'] as string | undefined; const sessionId = req.headers['mcp-session-id'] as string | undefined;
if (sessionId) { if (sessionId) {
const transport = sessions.get(sessionId); const transport = sessions.get(sessionId);
@ -84,24 +91,28 @@ async function handleStreamable(req: http.IncomingMessage, res: http.ServerRespo
if (transport.sessionId) if (transport.sessionId)
sessions.delete(transport.sessionId); sessions.delete(transport.sessionId);
}; };
const server = await serverList.create(); const connection = await createConnection(config);
await server.connect(transport); connectionList.push(connection);
return await transport.handleRequest(req, res); await Promise.all([
connection.connect(transport),
transport.handleRequest(req, res),
]);
return;
} }
res.statusCode = 400; res.statusCode = 400;
res.end('Invalid request'); res.end('Invalid request');
} }
export function startHttpTransport(port: number, hostname: string | undefined, serverList: ServerList) { export function startHttpTransport(config: Config, port: number, hostname: string | undefined, connectionList: Connection[]) {
const sseSessions = new Map<string, SSEServerTransport>(); const sseSessions = new Map<string, SSEServerTransport>();
const streamableSessions = new Map<string, StreamableHTTPServerTransport>(); const streamableSessions = new Map<string, StreamableHTTPServerTransport>();
const httpServer = http.createServer(async (req, res) => { const httpServer = http.createServer(async (req, res) => {
const url = new URL(`http://localhost${req.url}`); const url = new URL(`http://localhost${req.url}`);
if (url.pathname.startsWith('/mcp')) if (url.pathname.startsWith('/mcp'))
await handleStreamable(req, res, serverList, streamableSessions); await handleStreamable(config, req, res, streamableSessions, connectionList);
else else
await handleSSE(req, res, url, serverList, sseSessions); await handleSSE(config, req, res, url, sseSessions, connectionList);
}); });
httpServer.listen(port, hostname, () => { httpServer.listen(port, hostname, () => {
const address = httpServer.address(); const address = httpServer.address();

View File

@ -34,7 +34,7 @@ export type TestOptions = {
type TestFixtures = { type TestFixtures = {
client: Client; client: Client;
visionClient: Client; visionClient: Client;
startClient: (options?: { args?: string[], config?: Config }) => Promise<Client>; startClient: (options?: { clientName?: string, args?: string[], config?: Config }) => Promise<Client>;
wsEndpoint: string; wsEndpoint: string;
cdpEndpoint: (port?: number) => Promise<string>; cdpEndpoint: (port?: number) => Promise<string>;
server: TestServer; server: TestServer;
@ -79,7 +79,7 @@ export const test = baseTest.extend<TestFixtures & TestOptions, WorkerFixtures>(
command: 'node', command: 'node',
args: [path.join(path.dirname(__filename), '../cli.js'), ...args], args: [path.join(path.dirname(__filename), '../cli.js'), ...args],
}); });
client = new Client({ name: 'test', version: '1.0.0' }); client = new Client({ name: options?.clientName ?? 'test', version: '1.0.0' });
await client.connect(transport); await client.connect(transport);
await client.ping(); await client.ping();
return client; return client;

View File

@ -116,14 +116,10 @@ test('browser_take_screenshot (outputDir)', async ({ startClient }, testInfo) =>
expect([...fs.readdirSync(outputDir)]).toHaveLength(1); expect([...fs.readdirSync(outputDir)]).toHaveLength(1);
}); });
test('browser_take_screenshot (omitBase64)', async ({ startClient }) => { test('browser_take_screenshot (noImageResponses)', async ({ startClient }) => {
const client = await startClient({ const client = await startClient({
config: { config: {
tools: { noImageResponses: true,
browser_take_screenshot: {
omitBase64: true,
},
},
}, },
}); });
@ -151,3 +147,31 @@ test('browser_take_screenshot (omitBase64)', async ({ startClient }) => {
], ],
}); });
}); });
test('browser_take_screenshot (cursor)', async ({ startClient }) => {
const client = await startClient({ clientName: 'cursor:vscode' });
expect(await client.callTool({
name: 'browser_navigate',
arguments: {
url: 'data:text/html,<html><title>Title</title><body>Hello, world!</body></html>',
},
})).toContainTextContent(`Navigate to data:text/html`);
await client.callTool({
name: 'browser_take_screenshot',
arguments: {},
});
expect(await client.callTool({
name: 'browser_take_screenshot',
arguments: {},
})).toEqual({
content: [
{
text: expect.stringContaining(`Screenshot viewport and save it as`),
type: 'text',
},
],
});
});