chore: get rid of connection factory (#362)

Drive-by User-Agent sniffing and disabling of image type in Cursor.
This commit is contained in:
Pavel Feldman 2025-05-06 14:27:28 -07:00 committed by GitHub
parent 23a2e5fee7
commit e95b5b1dd6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 181 additions and 154 deletions

View File

@ -163,14 +163,11 @@ The Playwright MCP server can be configured using a JSON configuration file. Her
// List of origins to block the browser to request. Origins matching both `allowedOrigins` and `blockedOrigins` will be blocked.
blockedOrigins?: string[];
};
// Tool-specific configurations
tools?: {
browser_take_screenshot?: {
// Disable base64-encoded image responses
omitBase64?: boolean;
}
}
/**
* Do not send image responses to the client.
*/
noImageResponses?: boolean;
}
```
@ -234,9 +231,9 @@ http.createServer(async (req, res) => {
// ...
// Creates a headless Playwright MCP server with SSE transport
const mcpServer = await createServer({ headless: true });
const connection = await createConnection({ headless: true });
const transport = new SSEServerTransport('/messages', res);
await mcpServer.connect(transport);
await connection.connect(transport);
// ...
});

16
config.d.ts vendored
View File

@ -107,19 +107,7 @@ export type Config = {
};
/**
* Configuration for specific tools.
* Do not send image responses to the client.
*/
tools?: {
/**
* Configuration for the browser_take_screenshot tool.
*/
browser_take_screenshot?: {
/**
* Whether to disable base64-encoded image responses to the clients that
* don't support binary data or prefer to save on tokens.
*/
omitBase64?: boolean;
}
}
noImageResponses?: boolean;
};

10
index.d.ts vendored
View File

@ -16,8 +16,14 @@
*/
import type { Server } from '@modelcontextprotocol/sdk/server/index.js';
import type { Config } from './config';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
export declare function createServer(config?: Config): Promise<Server>;
export type Connection = {
server: Server;
connect(transport: Transport): Promise<void>;
close(): Promise<void>;
};
export declare function createConnection(config?: Config): Promise<Connection>;
export {};

View File

@ -15,5 +15,5 @@
* limitations under the License.
*/
import { createServer } from './lib/index';
export default { createServer };
import { createConnection } from './lib/index';
export default { createConnection };

View File

@ -39,6 +39,7 @@ export type CLIOptions = {
allowedOrigins?: string[];
blockedOrigins?: string[];
outputDir?: string;
noImageResponses?: boolean;
};
const defaultConfig: Config = {

View File

@ -19,20 +19,19 @@ import { CallToolRequestSchema, ListToolsRequestSchema, Tool as McpTool } from '
import { zodToJsonSchema } from 'zod-to-json-schema';
import { Context } from './context.js';
import { snapshotTools, screenshotTools } from './tools.js';
import type { Tool } from './tools/tool.js';
import type { Config } from '../config.js';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
type MCPServerOptions = {
name: string;
version: string;
tools: Tool[];
};
import packageJSON from '../package.json' with { type: 'json' };
export async function createConnection(config: Config): Promise<Connection> {
const allTools = config.vision ? screenshotTools : snapshotTools;
const tools = allTools.filter(tool => !config.capabilities || tool.capability === 'core' || config.capabilities.includes(tool.capability));
export function createServerWithTools(serverOptions: MCPServerOptions, config: Config): Server {
const { name, version, tools } = serverOptions;
const context = new Context(tools, config);
const server = new Server({ name, version }, {
const server = new Server({ name: 'Playwright', version: packageJSON.version }, {
capabilities: {
tools: {},
}
@ -77,38 +76,30 @@ export function createServerWithTools(serverOptions: MCPServerOptions, config: C
}
});
const oldClose = server.close.bind(server);
server.close = async () => {
await oldClose();
await context.close();
};
return server;
const connection = new Connection(server, context);
return connection;
}
export class ServerList {
private _servers: Server[] = [];
private _serverFactory: () => Promise<Server>;
export class Connection {
readonly server: Server;
readonly context: Context;
constructor(serverFactory: () => Promise<Server>) {
this._serverFactory = serverFactory;
constructor(server: Server, context: Context) {
this.server = server;
this.context = context;
}
async create() {
const server = await this._serverFactory();
this._servers.push(server);
return server;
async connect(transport: Transport) {
await this.server.connect(transport);
await new Promise<void>(resolve => {
this.server.oninitialized = () => resolve();
});
if (this.server.getClientVersion()?.name.includes('cursor'))
this.context.config.noImageResponses = true;
}
async close(server: Server) {
const index = this._servers.indexOf(server);
if (index !== -1)
this._servers.splice(index, 1);
await server.close();
}
async closeAll() {
await Promise.all(this._servers.map(server => server.close()));
async close() {
await this.server.close();
await this.context.close();
}
}

View File

@ -14,62 +14,10 @@
* limitations under the License.
*/
import { createServerWithTools } from './server.js';
import common from './tools/common.js';
import console from './tools/console.js';
import dialogs from './tools/dialogs.js';
import files from './tools/files.js';
import install from './tools/install.js';
import keyboard from './tools/keyboard.js';
import navigate from './tools/navigate.js';
import network from './tools/network.js';
import pdf from './tools/pdf.js';
import snapshot from './tools/snapshot.js';
import tabs from './tools/tabs.js';
import screen from './tools/screen.js';
import testing from './tools/testing.js';
import type { Tool } from './tools/tool.js';
import { Connection } from './connection.js';
import type { Config } from '../config.js';
import type { Server } from '@modelcontextprotocol/sdk/server/index.js';
const snapshotTools: Tool<any>[] = [
...common(true),
...console,
...dialogs(true),
...files(true),
...install,
...keyboard(true),
...navigate(true),
...network,
...pdf,
...snapshot,
...tabs(true),
...testing,
];
const screenshotTools: Tool<any>[] = [
...common(false),
...console,
...dialogs(false),
...files(false),
...install,
...keyboard(false),
...navigate(false),
...network,
...pdf,
...screen,
...tabs(false),
...testing,
];
import packageJSON from '../package.json' with { type: 'json' };
export async function createServer(config: Config = {}): Promise<Server> {
const allTools = config.vision ? screenshotTools : snapshotTools;
const tools = allTools.filter(tool => !config.capabilities || tool.capability === 'core' || config.capabilities.includes(tool.capability));
return createServerWithTools({
name: 'Playwright',
version: packageJSON.version,
tools,
}, config);
export async function createConnection(config: Config = {}): Promise<Connection> {
return createConnection(config);
}

View File

@ -16,13 +16,11 @@
import { program } from 'commander';
import { createServer } from './index.js';
import { ServerList } from './server.js';
import { startHttpTransport, startStdioTransport } from './transport.js';
import { resolveConfig } from './config.js';
import type { Connection } from './connection.js';
import packageJSON from '../package.json' with { type: 'json' };
program
@ -40,23 +38,25 @@ program
.option('--allowed-origins <origins>', 'Semicolon-separated list of origins to allow the browser to request. Default is to allow all.', semicolonSeparatedList)
.option('--blocked-origins <origins>', 'Semicolon-separated list of origins to block the browser from requesting. Blocklist is evaluated before allowlist. If used without the allowlist, requests not matching the blocklist are still allowed.', semicolonSeparatedList)
.option('--vision', 'Run server that uses screenshots (Aria snapshots are used by default)')
.option('--no-image-responses', 'Do not send image responses to the client.')
.option('--output-dir <path>', 'Path to the directory for output files.')
.option('--config <path>', 'Path to the configuration file.')
.action(async options => {
const config = await resolveConfig(options);
const serverList = new ServerList(() => createServer(config));
setupExitWatchdog(serverList);
const connectionList: Connection[] = [];
setupExitWatchdog(connectionList);
if (options.port)
startHttpTransport(+options.port, options.host, serverList);
startHttpTransport(config, +options.port, options.host, connectionList);
else
await startStdioTransport(serverList);
await startStdioTransport(config, connectionList);
});
function setupExitWatchdog(serverList: ServerList) {
function setupExitWatchdog(connectionList: Connection[]) {
const handleExit = async () => {
setTimeout(() => process.exit(0), 15000);
await serverList.closeAll();
for (const connection of connectionList)
await connection.close();
process.exit(0);
};

61
src/tools.ts Normal file
View File

@ -0,0 +1,61 @@
/**
* Copyright (c) Microsoft Corporation.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import common from './tools/common.js';
import console from './tools/console.js';
import dialogs from './tools/dialogs.js';
import files from './tools/files.js';
import install from './tools/install.js';
import keyboard from './tools/keyboard.js';
import navigate from './tools/navigate.js';
import network from './tools/network.js';
import pdf from './tools/pdf.js';
import snapshot from './tools/snapshot.js';
import tabs from './tools/tabs.js';
import screen from './tools/screen.js';
import testing from './tools/testing.js';
import type { Tool } from './tools/tool.js';
export const snapshotTools: Tool<any>[] = [
...common(true),
...console,
...dialogs(true),
...files(true),
...install,
...keyboard(true),
...navigate(true),
...network,
...pdf,
...snapshot,
...tabs(true),
...testing,
];
export const screenshotTools: Tool<any>[] = [
...common(false),
...console,
...dialogs(false),
...files(false),
...install,
...keyboard(false),
...navigate(false),
...network,
...pdf,
...screen,
...tabs(false),
...testing,
];

View File

@ -258,7 +258,7 @@ const screenshot = defineTool({
else
code.push(`await page.screenshot(${javascript.formatObject(options)});`);
const includeBase64 = !context.config.tools?.browser_take_screenshot?.omitBase64;
const includeBase64 = !context.config.noImageResponses;
const action = async () => {
const screenshot = locator ? await locator.screenshot(options) : await tab.page.screenshot(options);
return {

View File

@ -18,17 +18,22 @@ import http from 'node:http';
import assert from 'node:assert';
import crypto from 'node:crypto';
import { ServerList } from './server.js';
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js';
import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js';
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js';
export async function startStdioTransport(serverList: ServerList) {
const server = await serverList.create();
await server.connect(new StdioServerTransport());
import { createConnection } from './connection.js';
import type { Config } from '../config.js';
import type { Connection } from './connection.js';
export async function startStdioTransport(config: Config, connectionList: Connection[]) {
const connection = await createConnection(config);
await connection.connect(new StdioServerTransport());
connectionList.push(connection);
}
async function handleSSE(req: http.IncomingMessage, res: http.ServerResponse, url: URL, serverList: ServerList, sessions: Map<string, SSEServerTransport>) {
async function handleSSE(config: Config, req: http.IncomingMessage, res: http.ServerResponse, url: URL, sessions: Map<string, SSEServerTransport>, connectionList: Connection[]) {
if (req.method === 'POST') {
const sessionId = url.searchParams.get('sessionId');
if (!sessionId) {
@ -46,22 +51,24 @@ async function handleSSE(req: http.IncomingMessage, res: http.ServerResponse, ur
} else if (req.method === 'GET') {
const transport = new SSEServerTransport('/sse', res);
sessions.set(transport.sessionId, transport);
const server = await serverList.create();
const connection = await createConnection(config);
await connection.connect(transport);
connectionList.push(connection);
res.on('close', () => {
sessions.delete(transport.sessionId);
serverList.close(server).catch(e => {
connection.close().catch(e => {
// eslint-disable-next-line no-console
console.error(e);
});
});
return await server.connect(transport);
return;
}
res.statusCode = 405;
res.end('Method not allowed');
}
async function handleStreamable(req: http.IncomingMessage, res: http.ServerResponse, serverList: ServerList, sessions: Map<string, StreamableHTTPServerTransport>) {
async function handleStreamable(config: Config, req: http.IncomingMessage, res: http.ServerResponse, sessions: Map<string, StreamableHTTPServerTransport>, connectionList: Connection[]) {
const sessionId = req.headers['mcp-session-id'] as string | undefined;
if (sessionId) {
const transport = sessions.get(sessionId);
@ -84,24 +91,28 @@ async function handleStreamable(req: http.IncomingMessage, res: http.ServerRespo
if (transport.sessionId)
sessions.delete(transport.sessionId);
};
const server = await serverList.create();
await server.connect(transport);
return await transport.handleRequest(req, res);
const connection = await createConnection(config);
connectionList.push(connection);
await Promise.all([
connection.connect(transport),
transport.handleRequest(req, res),
]);
return;
}
res.statusCode = 400;
res.end('Invalid request');
}
export function startHttpTransport(port: number, hostname: string | undefined, serverList: ServerList) {
export function startHttpTransport(config: Config, port: number, hostname: string | undefined, connectionList: Connection[]) {
const sseSessions = new Map<string, SSEServerTransport>();
const streamableSessions = new Map<string, StreamableHTTPServerTransport>();
const httpServer = http.createServer(async (req, res) => {
const url = new URL(`http://localhost${req.url}`);
if (url.pathname.startsWith('/mcp'))
await handleStreamable(req, res, serverList, streamableSessions);
await handleStreamable(config, req, res, streamableSessions, connectionList);
else
await handleSSE(req, res, url, serverList, sseSessions);
await handleSSE(config, req, res, url, sseSessions, connectionList);
});
httpServer.listen(port, hostname, () => {
const address = httpServer.address();

View File

@ -34,7 +34,7 @@ export type TestOptions = {
type TestFixtures = {
client: Client;
visionClient: Client;
startClient: (options?: { args?: string[], config?: Config }) => Promise<Client>;
startClient: (options?: { clientName?: string, args?: string[], config?: Config }) => Promise<Client>;
wsEndpoint: string;
cdpEndpoint: (port?: number) => Promise<string>;
server: TestServer;
@ -79,7 +79,7 @@ export const test = baseTest.extend<TestFixtures & TestOptions, WorkerFixtures>(
command: 'node',
args: [path.join(path.dirname(__filename), '../cli.js'), ...args],
});
client = new Client({ name: 'test', version: '1.0.0' });
client = new Client({ name: options?.clientName ?? 'test', version: '1.0.0' });
await client.connect(transport);
await client.ping();
return client;

View File

@ -116,14 +116,10 @@ test('browser_take_screenshot (outputDir)', async ({ startClient }, testInfo) =>
expect([...fs.readdirSync(outputDir)]).toHaveLength(1);
});
test('browser_take_screenshot (omitBase64)', async ({ startClient }) => {
test('browser_take_screenshot (noImageResponses)', async ({ startClient }) => {
const client = await startClient({
config: {
tools: {
browser_take_screenshot: {
omitBase64: true,
},
},
noImageResponses: true,
},
});
@ -151,3 +147,31 @@ test('browser_take_screenshot (omitBase64)', async ({ startClient }) => {
],
});
});
test('browser_take_screenshot (cursor)', async ({ startClient }) => {
const client = await startClient({ clientName: 'cursor:vscode' });
expect(await client.callTool({
name: 'browser_navigate',
arguments: {
url: 'data:text/html,<html><title>Title</title><body>Hello, world!</body></html>',
},
})).toContainTextContent(`Navigate to data:text/html`);
await client.callTool({
name: 'browser_take_screenshot',
arguments: {},
});
expect(await client.callTool({
name: 'browser_take_screenshot',
arguments: {},
})).toEqual({
content: [
{
text: expect.stringContaining(`Screenshot viewport and save it as`),
type: 'text',
},
],
});
});