From e95b5b1dd6ef940f4233acfa38d3229bdd613e74 Mon Sep 17 00:00:00 2001 From: Pavel Feldman Date: Tue, 6 May 2025 14:27:28 -0700 Subject: [PATCH] chore: get rid of connection factory (#362) Drive-by User-Agent sniffing and disabling of image type in Cursor. --- README.md | 17 ++++----- config.d.ts | 16 ++------- index.d.ts | 10 ++++-- index.js | 4 +-- src/config.ts | 1 + src/{server.ts => connection.ts} | 61 ++++++++++++++------------------ src/index.ts | 60 +++---------------------------- src/program.ts | 20 +++++------ src/tools.ts | 61 ++++++++++++++++++++++++++++++++ src/tools/snapshot.ts | 2 +- src/transport.ts | 43 +++++++++++++--------- tests/fixtures.ts | 4 +-- tests/screenshot.spec.ts | 36 +++++++++++++++---- 13 files changed, 181 insertions(+), 154 deletions(-) rename src/{server.ts => connection.ts} (67%) create mode 100644 src/tools.ts diff --git a/README.md b/README.md index 3f281f8..cf57dd7 100644 --- a/README.md +++ b/README.md @@ -163,14 +163,11 @@ The Playwright MCP server can be configured using a JSON configuration file. Her // List of origins to block the browser to request. Origins matching both `allowedOrigins` and `blockedOrigins` will be blocked. blockedOrigins?: string[]; }; - - // Tool-specific configurations - tools?: { - browser_take_screenshot?: { - // Disable base64-encoded image responses - omitBase64?: boolean; - } - } + + /** + * Do not send image responses to the client. + */ + noImageResponses?: boolean; } ``` @@ -234,9 +231,9 @@ http.createServer(async (req, res) => { // ... // Creates a headless Playwright MCP server with SSE transport - const mcpServer = await createServer({ headless: true }); + const connection = await createConnection({ headless: true }); const transport = new SSEServerTransport('/messages', res); - await mcpServer.connect(transport); + await connection.connect(transport); // ... }); diff --git a/config.d.ts b/config.d.ts index 053c969..6b01c22 100644 --- a/config.d.ts +++ b/config.d.ts @@ -107,19 +107,7 @@ export type Config = { }; /** - * Configuration for specific tools. + * Do not send image responses to the client. */ - tools?: { - /** - * Configuration for the browser_take_screenshot tool. - */ - browser_take_screenshot?: { - - /** - * Whether to disable base64-encoded image responses to the clients that - * don't support binary data or prefer to save on tokens. - */ - omitBase64?: boolean; - } - } + noImageResponses?: boolean; }; diff --git a/index.d.ts b/index.d.ts index 900c478..f66ca71 100644 --- a/index.d.ts +++ b/index.d.ts @@ -16,8 +16,14 @@ */ import type { Server } from '@modelcontextprotocol/sdk/server/index.js'; - import type { Config } from './config'; +import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; -export declare function createServer(config?: Config): Promise; +export type Connection = { + server: Server; + connect(transport: Transport): Promise; + close(): Promise; +}; + +export declare function createConnection(config?: Config): Promise; export {}; diff --git a/index.js b/index.js index 61f7c07..d4b94fb 100755 --- a/index.js +++ b/index.js @@ -15,5 +15,5 @@ * limitations under the License. */ -import { createServer } from './lib/index'; -export default { createServer }; +import { createConnection } from './lib/index'; +export default { createConnection }; diff --git a/src/config.ts b/src/config.ts index 99820d6..5fc7643 100644 --- a/src/config.ts +++ b/src/config.ts @@ -39,6 +39,7 @@ export type CLIOptions = { allowedOrigins?: string[]; blockedOrigins?: string[]; outputDir?: string; + noImageResponses?: boolean; }; const defaultConfig: Config = { diff --git a/src/server.ts b/src/connection.ts similarity index 67% rename from src/server.ts rename to src/connection.ts index 405af78..95cf039 100644 --- a/src/server.ts +++ b/src/connection.ts @@ -19,20 +19,19 @@ import { CallToolRequestSchema, ListToolsRequestSchema, Tool as McpTool } from ' import { zodToJsonSchema } from 'zod-to-json-schema'; import { Context } from './context.js'; +import { snapshotTools, screenshotTools } from './tools.js'; -import type { Tool } from './tools/tool.js'; import type { Config } from '../config.js'; +import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; -type MCPServerOptions = { - name: string; - version: string; - tools: Tool[]; -}; +import packageJSON from '../package.json' with { type: 'json' }; + +export async function createConnection(config: Config): Promise { + const allTools = config.vision ? screenshotTools : snapshotTools; + const tools = allTools.filter(tool => !config.capabilities || tool.capability === 'core' || config.capabilities.includes(tool.capability)); -export function createServerWithTools(serverOptions: MCPServerOptions, config: Config): Server { - const { name, version, tools } = serverOptions; const context = new Context(tools, config); - const server = new Server({ name, version }, { + const server = new Server({ name: 'Playwright', version: packageJSON.version }, { capabilities: { tools: {}, } @@ -77,38 +76,30 @@ export function createServerWithTools(serverOptions: MCPServerOptions, config: C } }); - const oldClose = server.close.bind(server); - - server.close = async () => { - await oldClose(); - await context.close(); - }; - - return server; + const connection = new Connection(server, context); + return connection; } -export class ServerList { - private _servers: Server[] = []; - private _serverFactory: () => Promise; +export class Connection { + readonly server: Server; + readonly context: Context; - constructor(serverFactory: () => Promise) { - this._serverFactory = serverFactory; + constructor(server: Server, context: Context) { + this.server = server; + this.context = context; } - async create() { - const server = await this._serverFactory(); - this._servers.push(server); - return server; + async connect(transport: Transport) { + await this.server.connect(transport); + await new Promise(resolve => { + this.server.oninitialized = () => resolve(); + }); + if (this.server.getClientVersion()?.name.includes('cursor')) + this.context.config.noImageResponses = true; } - async close(server: Server) { - const index = this._servers.indexOf(server); - if (index !== -1) - this._servers.splice(index, 1); - await server.close(); - } - - async closeAll() { - await Promise.all(this._servers.map(server => server.close())); + async close() { + await this.server.close(); + await this.context.close(); } } diff --git a/src/index.ts b/src/index.ts index 25ff2d4..9fc7633 100644 --- a/src/index.ts +++ b/src/index.ts @@ -14,62 +14,10 @@ * limitations under the License. */ -import { createServerWithTools } from './server.js'; -import common from './tools/common.js'; -import console from './tools/console.js'; -import dialogs from './tools/dialogs.js'; -import files from './tools/files.js'; -import install from './tools/install.js'; -import keyboard from './tools/keyboard.js'; -import navigate from './tools/navigate.js'; -import network from './tools/network.js'; -import pdf from './tools/pdf.js'; -import snapshot from './tools/snapshot.js'; -import tabs from './tools/tabs.js'; -import screen from './tools/screen.js'; -import testing from './tools/testing.js'; -import type { Tool } from './tools/tool.js'; +import { Connection } from './connection.js'; + import type { Config } from '../config.js'; -import type { Server } from '@modelcontextprotocol/sdk/server/index.js'; -const snapshotTools: Tool[] = [ - ...common(true), - ...console, - ...dialogs(true), - ...files(true), - ...install, - ...keyboard(true), - ...navigate(true), - ...network, - ...pdf, - ...snapshot, - ...tabs(true), - ...testing, -]; - -const screenshotTools: Tool[] = [ - ...common(false), - ...console, - ...dialogs(false), - ...files(false), - ...install, - ...keyboard(false), - ...navigate(false), - ...network, - ...pdf, - ...screen, - ...tabs(false), - ...testing, -]; - -import packageJSON from '../package.json' with { type: 'json' }; - -export async function createServer(config: Config = {}): Promise { - const allTools = config.vision ? screenshotTools : snapshotTools; - const tools = allTools.filter(tool => !config.capabilities || tool.capability === 'core' || config.capabilities.includes(tool.capability)); - return createServerWithTools({ - name: 'Playwright', - version: packageJSON.version, - tools, - }, config); +export async function createConnection(config: Config = {}): Promise { + return createConnection(config); } diff --git a/src/program.ts b/src/program.ts index 1753490..5d4403b 100644 --- a/src/program.ts +++ b/src/program.ts @@ -16,13 +16,11 @@ import { program } from 'commander'; -import { createServer } from './index.js'; -import { ServerList } from './server.js'; - import { startHttpTransport, startStdioTransport } from './transport.js'; - import { resolveConfig } from './config.js'; +import type { Connection } from './connection.js'; + import packageJSON from '../package.json' with { type: 'json' }; program @@ -40,23 +38,25 @@ program .option('--allowed-origins ', 'Semicolon-separated list of origins to allow the browser to request. Default is to allow all.', semicolonSeparatedList) .option('--blocked-origins ', 'Semicolon-separated list of origins to block the browser from requesting. Blocklist is evaluated before allowlist. If used without the allowlist, requests not matching the blocklist are still allowed.', semicolonSeparatedList) .option('--vision', 'Run server that uses screenshots (Aria snapshots are used by default)') + .option('--no-image-responses', 'Do not send image responses to the client.') .option('--output-dir ', 'Path to the directory for output files.') .option('--config ', 'Path to the configuration file.') .action(async options => { const config = await resolveConfig(options); - const serverList = new ServerList(() => createServer(config)); - setupExitWatchdog(serverList); + const connectionList: Connection[] = []; + setupExitWatchdog(connectionList); if (options.port) - startHttpTransport(+options.port, options.host, serverList); + startHttpTransport(config, +options.port, options.host, connectionList); else - await startStdioTransport(serverList); + await startStdioTransport(config, connectionList); }); -function setupExitWatchdog(serverList: ServerList) { +function setupExitWatchdog(connectionList: Connection[]) { const handleExit = async () => { setTimeout(() => process.exit(0), 15000); - await serverList.closeAll(); + for (const connection of connectionList) + await connection.close(); process.exit(0); }; diff --git a/src/tools.ts b/src/tools.ts new file mode 100644 index 0000000..8613d92 --- /dev/null +++ b/src/tools.ts @@ -0,0 +1,61 @@ +/** + * Copyright (c) Microsoft Corporation. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import common from './tools/common.js'; +import console from './tools/console.js'; +import dialogs from './tools/dialogs.js'; +import files from './tools/files.js'; +import install from './tools/install.js'; +import keyboard from './tools/keyboard.js'; +import navigate from './tools/navigate.js'; +import network from './tools/network.js'; +import pdf from './tools/pdf.js'; +import snapshot from './tools/snapshot.js'; +import tabs from './tools/tabs.js'; +import screen from './tools/screen.js'; +import testing from './tools/testing.js'; + +import type { Tool } from './tools/tool.js'; + +export const snapshotTools: Tool[] = [ + ...common(true), + ...console, + ...dialogs(true), + ...files(true), + ...install, + ...keyboard(true), + ...navigate(true), + ...network, + ...pdf, + ...snapshot, + ...tabs(true), + ...testing, +]; + +export const screenshotTools: Tool[] = [ + ...common(false), + ...console, + ...dialogs(false), + ...files(false), + ...install, + ...keyboard(false), + ...navigate(false), + ...network, + ...pdf, + ...screen, + ...tabs(false), + ...testing, +]; diff --git a/src/tools/snapshot.ts b/src/tools/snapshot.ts index b969b88..410d2f9 100644 --- a/src/tools/snapshot.ts +++ b/src/tools/snapshot.ts @@ -258,7 +258,7 @@ const screenshot = defineTool({ else code.push(`await page.screenshot(${javascript.formatObject(options)});`); - const includeBase64 = !context.config.tools?.browser_take_screenshot?.omitBase64; + const includeBase64 = !context.config.noImageResponses; const action = async () => { const screenshot = locator ? await locator.screenshot(options) : await tab.page.screenshot(options); return { diff --git a/src/transport.ts b/src/transport.ts index 9038972..8d38fc3 100644 --- a/src/transport.ts +++ b/src/transport.ts @@ -18,17 +18,22 @@ import http from 'node:http'; import assert from 'node:assert'; import crypto from 'node:crypto'; -import { ServerList } from './server.js'; -import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'; import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js'; import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; +import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'; -export async function startStdioTransport(serverList: ServerList) { - const server = await serverList.create(); - await server.connect(new StdioServerTransport()); +import { createConnection } from './connection.js'; + +import type { Config } from '../config.js'; +import type { Connection } from './connection.js'; + +export async function startStdioTransport(config: Config, connectionList: Connection[]) { + const connection = await createConnection(config); + await connection.connect(new StdioServerTransport()); + connectionList.push(connection); } -async function handleSSE(req: http.IncomingMessage, res: http.ServerResponse, url: URL, serverList: ServerList, sessions: Map) { +async function handleSSE(config: Config, req: http.IncomingMessage, res: http.ServerResponse, url: URL, sessions: Map, connectionList: Connection[]) { if (req.method === 'POST') { const sessionId = url.searchParams.get('sessionId'); if (!sessionId) { @@ -46,22 +51,24 @@ async function handleSSE(req: http.IncomingMessage, res: http.ServerResponse, ur } else if (req.method === 'GET') { const transport = new SSEServerTransport('/sse', res); sessions.set(transport.sessionId, transport); - const server = await serverList.create(); + const connection = await createConnection(config); + await connection.connect(transport); + connectionList.push(connection); res.on('close', () => { sessions.delete(transport.sessionId); - serverList.close(server).catch(e => { + connection.close().catch(e => { // eslint-disable-next-line no-console console.error(e); }); }); - return await server.connect(transport); + return; } res.statusCode = 405; res.end('Method not allowed'); } -async function handleStreamable(req: http.IncomingMessage, res: http.ServerResponse, serverList: ServerList, sessions: Map) { +async function handleStreamable(config: Config, req: http.IncomingMessage, res: http.ServerResponse, sessions: Map, connectionList: Connection[]) { const sessionId = req.headers['mcp-session-id'] as string | undefined; if (sessionId) { const transport = sessions.get(sessionId); @@ -84,24 +91,28 @@ async function handleStreamable(req: http.IncomingMessage, res: http.ServerRespo if (transport.sessionId) sessions.delete(transport.sessionId); }; - const server = await serverList.create(); - await server.connect(transport); - return await transport.handleRequest(req, res); + const connection = await createConnection(config); + connectionList.push(connection); + await Promise.all([ + connection.connect(transport), + transport.handleRequest(req, res), + ]); + return; } res.statusCode = 400; res.end('Invalid request'); } -export function startHttpTransport(port: number, hostname: string | undefined, serverList: ServerList) { +export function startHttpTransport(config: Config, port: number, hostname: string | undefined, connectionList: Connection[]) { const sseSessions = new Map(); const streamableSessions = new Map(); const httpServer = http.createServer(async (req, res) => { const url = new URL(`http://localhost${req.url}`); if (url.pathname.startsWith('/mcp')) - await handleStreamable(req, res, serverList, streamableSessions); + await handleStreamable(config, req, res, streamableSessions, connectionList); else - await handleSSE(req, res, url, serverList, sseSessions); + await handleSSE(config, req, res, url, sseSessions, connectionList); }); httpServer.listen(port, hostname, () => { const address = httpServer.address(); diff --git a/tests/fixtures.ts b/tests/fixtures.ts index 7b7dfee..1697147 100644 --- a/tests/fixtures.ts +++ b/tests/fixtures.ts @@ -34,7 +34,7 @@ export type TestOptions = { type TestFixtures = { client: Client; visionClient: Client; - startClient: (options?: { args?: string[], config?: Config }) => Promise; + startClient: (options?: { clientName?: string, args?: string[], config?: Config }) => Promise; wsEndpoint: string; cdpEndpoint: (port?: number) => Promise; server: TestServer; @@ -79,7 +79,7 @@ export const test = baseTest.extend( command: 'node', args: [path.join(path.dirname(__filename), '../cli.js'), ...args], }); - client = new Client({ name: 'test', version: '1.0.0' }); + client = new Client({ name: options?.clientName ?? 'test', version: '1.0.0' }); await client.connect(transport); await client.ping(); return client; diff --git a/tests/screenshot.spec.ts b/tests/screenshot.spec.ts index b769a02..5f99b79 100644 --- a/tests/screenshot.spec.ts +++ b/tests/screenshot.spec.ts @@ -116,14 +116,10 @@ test('browser_take_screenshot (outputDir)', async ({ startClient }, testInfo) => expect([...fs.readdirSync(outputDir)]).toHaveLength(1); }); -test('browser_take_screenshot (omitBase64)', async ({ startClient }) => { +test('browser_take_screenshot (noImageResponses)', async ({ startClient }) => { const client = await startClient({ config: { - tools: { - browser_take_screenshot: { - omitBase64: true, - }, - }, + noImageResponses: true, }, }); @@ -151,3 +147,31 @@ test('browser_take_screenshot (omitBase64)', async ({ startClient }) => { ], }); }); + +test('browser_take_screenshot (cursor)', async ({ startClient }) => { + const client = await startClient({ clientName: 'cursor:vscode' }); + + expect(await client.callTool({ + name: 'browser_navigate', + arguments: { + url: 'data:text/html,TitleHello, world!', + }, + })).toContainTextContent(`Navigate to data:text/html`); + + await client.callTool({ + name: 'browser_take_screenshot', + arguments: {}, + }); + + expect(await client.callTool({ + name: 'browser_take_screenshot', + arguments: {}, + })).toEqual({ + content: [ + { + text: expect.stringContaining(`Screenshot viewport and save it as`), + type: 'text', + }, + ], + }); +});