From c63b7823e1eba8d8254ce7b2ba843bba79731349 Mon Sep 17 00:00:00 2001 From: Pavel Feldman Date: Thu, 24 Jul 2025 12:57:01 -0700 Subject: [PATCH] chore: extract pure mcp server helpers (#751) --- src/browserServerBackend.ts | 66 ++++++++++ src/connection.ts | 84 ------------- src/extension/cdpRelay.ts | 3 +- src/extension/main.ts | 16 +-- src/httpServer.ts | 232 ++++-------------------------------- src/index.ts | 7 +- src/loop/onetool.ts | 96 +++++++-------- src/mcp/README.md | 1 + src/mcp/server.ts | 105 ++++++++++++++++ src/{ => mcp}/transport.ts | 63 ++++------ src/program.ts | 37 ++++-- src/server.ts | 59 --------- 12 files changed, 300 insertions(+), 469 deletions(-) create mode 100644 src/browserServerBackend.ts delete mode 100644 src/connection.ts create mode 100644 src/mcp/README.md create mode 100644 src/mcp/server.ts rename src/{ => mcp}/transport.ts (65%) delete mode 100644 src/server.ts diff --git a/src/browserServerBackend.ts b/src/browserServerBackend.ts new file mode 100644 index 0000000..a5ab434 --- /dev/null +++ b/src/browserServerBackend.ts @@ -0,0 +1,66 @@ +/** + * Copyright (c) Microsoft Corporation. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { FullConfig } from './config.js'; +import { Context } from './context.js'; +import { logUnhandledError } from './log.js'; +import { Response } from './response.js'; +import { SessionLog } from './sessionLog.js'; +import { filteredTools } from './tools.js'; +import { packageJSON } from './package.js'; + +import type { BrowserContextFactory } from './browserContextFactory.js'; +import type * as mcpServer from './mcp/server.js'; +import type { ServerBackend } from './mcp/server.js'; +import type { Tool } from './tools/tool.js'; + +export class BrowserServerBackend implements ServerBackend { + name = 'Playwright'; + version = packageJSON.version; + private _tools: Tool[]; + private _context: Context; + private _sessionLog: SessionLog | undefined; + + constructor(config: FullConfig, browserContextFactory: BrowserContextFactory) { + this._tools = filteredTools(config); + this._context = new Context(this._tools, config, browserContextFactory); + } + + async initialize() { + this._sessionLog = this._context.config.saveSession ? await SessionLog.create(this._context.config) : undefined; + } + + tools(): mcpServer.ToolSchema[] { + return this._tools.map(tool => tool.schema); + } + + async callTool(schema: mcpServer.ToolSchema, parsedArguments: any) { + const response = new Response(this._context, schema.name, parsedArguments); + const tool = this._tools.find(tool => tool.schema.name === schema.name)!; + await tool.handle(this._context, parsedArguments, response); + if (this._sessionLog) + await this._sessionLog.log(response); + return await response.serialize(); + } + + serverInitialized(version: mcpServer.ClientVersion | undefined) { + this._context.clientVersion = version; + } + + serverClosed() { + void this._context.dispose().catch(logUnhandledError); + } +} diff --git a/src/connection.ts b/src/connection.ts deleted file mode 100644 index 1f5ade1..0000000 --- a/src/connection.ts +++ /dev/null @@ -1,84 +0,0 @@ -/** - * Copyright (c) Microsoft Corporation. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { Server } from '@modelcontextprotocol/sdk/server/index.js'; -import { CallToolRequestSchema, ListToolsRequestSchema, Tool as McpTool } from '@modelcontextprotocol/sdk/types.js'; -import { zodToJsonSchema } from 'zod-to-json-schema'; -import { Context } from './context.js'; -import { Response } from './response.js'; -import { packageJSON } from './package.js'; -import { FullConfig } from './config.js'; -import { SessionLog } from './sessionLog.js'; -import { logUnhandledError } from './log.js'; -import type { BrowserContextFactory } from './browserContextFactory.js'; -import type { Tool } from './tools/tool.js'; - -export async function createMCPServer(config: FullConfig, tools: Tool[], browserContextFactory: BrowserContextFactory): Promise { - const context = new Context(tools, config, browserContextFactory); - const server = new Server({ name: 'Playwright', version: packageJSON.version }, { - capabilities: { - tools: {}, - } - }); - - const sessionLog = config.saveSession ? await SessionLog.create(config) : undefined; - - server.setRequestHandler(ListToolsRequestSchema, async () => { - return { - tools: tools.map(tool => ({ - name: tool.schema.name, - description: tool.schema.description, - inputSchema: zodToJsonSchema(tool.schema.inputSchema), - annotations: { - title: tool.schema.title, - readOnlyHint: tool.schema.type === 'readOnly', - destructiveHint: tool.schema.type === 'destructive', - openWorldHint: true, - }, - })) as McpTool[], - }; - }); - - server.setRequestHandler(CallToolRequestSchema, async request => { - const errorResult = (...messages: string[]) => ({ - content: [{ type: 'text', text: messages.join('\n') }], - isError: true, - }); - const tool = tools.find(tool => tool.schema.name === request.params.name); - if (!tool) - return errorResult(`Tool "${request.params.name}" not found`); - - try { - const response = new Response(context, request.params.name, request.params.arguments || {}); - await tool.handle(context, tool.schema.inputSchema.parse(request.params.arguments || {}), response); - if (sessionLog) - await sessionLog.log(response); - return await response.serialize(); - } catch (error) { - return errorResult(String(error)); - } - }); - - server.oninitialized = () => { - context.clientVersion = server.getClientVersion(); - }; - - server.onclose = () => { - void context.dispose().catch(logUnhandledError); - }; - - return server; -} diff --git a/src/extension/cdpRelay.ts b/src/extension/cdpRelay.ts index d14fdda..022cba1 100644 --- a/src/extension/cdpRelay.ts +++ b/src/extension/cdpRelay.ts @@ -27,10 +27,9 @@ import { spawn } from 'child_process'; import { WebSocket, WebSocketServer } from 'ws'; import debug from 'debug'; import * as playwright from 'playwright'; -import { httpAddressToString, startHttpServer } from '../transport.js'; // @ts-ignore const { registry } = await import('playwright-core/lib/server/registry/index'); - +import { httpAddressToString, startHttpServer } from '../httpServer.js'; import type { BrowserContextFactory } from '../browserContextFactory.js'; import type websocket from 'ws'; diff --git a/src/extension/main.ts b/src/extension/main.ts index c3fc6f6..cfabbad 100644 --- a/src/extension/main.ts +++ b/src/extension/main.ts @@ -14,22 +14,14 @@ * limitations under the License. */ -import { startHttpServer, startHttpTransport, startStdioTransport } from '../transport.js'; -import { Server } from '../server.js'; import { startCDPRelayServer } from './cdpRelay.js'; -import { filteredTools } from '../tools.js'; +import { BrowserServerBackend } from '../browserServerBackend.js'; +import * as mcpTransport from '../mcp/transport.js'; import type { FullConfig } from '../config.js'; export async function runWithExtension(config: FullConfig) { const contextFactory = await startCDPRelayServer(config.browser.launchOptions.channel || 'chrome'); - const server = new Server(config, filteredTools(config), contextFactory); - server.setupExitWatchdog(); - - if (config.server.port !== undefined) { - const httpServer = await startHttpServer(config.server); - startHttpTransport(httpServer, server); - } else { - await startStdioTransport(server); - } + const serverBackendFactory = () => new BrowserServerBackend(config, contextFactory); + await mcpTransport.start(serverBackendFactory, config.server); } diff --git a/src/httpServer.ts b/src/httpServer.ts index 9e67bef..3102bd5 100644 --- a/src/httpServer.ts +++ b/src/httpServer.ts @@ -14,219 +14,31 @@ * limitations under the License. */ -import fs from 'fs'; -import path from 'path'; +import assert from 'assert'; import http from 'http'; -import net from 'net'; -import mime from 'mime'; +import type * as net from 'net'; -import { ManualPromise } from './manualPromise.js'; - - -export type ServerRouteHandler = (request: http.IncomingMessage, response: http.ServerResponse) => void; - -export type Transport = { - sendEvent?: (method: string, params: any) => void; - close?: () => void; - onconnect: () => void; - dispatch: (method: string, params: any) => Promise; - onclose: () => void; -}; - -export class HttpServer { - private _server: http.Server; - private _urlPrefixPrecise: string = ''; - private _urlPrefixHumanReadable: string = ''; - private _port: number = 0; - private _routes: { prefix?: string, exact?: string, handler: ServerRouteHandler }[] = []; - - constructor() { - this._server = http.createServer(this._onRequest.bind(this)); - decorateServer(this._server); - } - - server() { - return this._server; - } - - routePrefix(prefix: string, handler: ServerRouteHandler) { - this._routes.push({ prefix, handler }); - } - - routePath(path: string, handler: ServerRouteHandler) { - this._routes.push({ exact: path, handler }); - } - - port(): number { - return this._port; - } - - private async _tryStart(port: number | undefined, host: string) { - const errorPromise = new ManualPromise(); - const errorListener = (error: Error) => errorPromise.reject(error); - this._server.on('error', errorListener); - - try { - this._server.listen(port, host); - await Promise.race([ - new Promise(cb => this._server!.once('listening', cb)), - errorPromise, - ]); - } finally { - this._server.removeListener('error', errorListener); - } - } - - async start(options: { port?: number, preferredPort?: number, host?: string } = {}): Promise { - const host = options.host || 'localhost'; - if (options.preferredPort) { - try { - await this._tryStart(options.preferredPort, host); - } catch (e: any) { - if (!e || !e.message || !e.message.includes('EADDRINUSE')) - throw e; - await this._tryStart(undefined, host); - } - } else { - await this._tryStart(options.port, host); - } - - const address = this._server.address(); - if (typeof address === 'string') { - this._urlPrefixPrecise = address; - this._urlPrefixHumanReadable = address; - } else { - this._port = address!.port; - const resolvedHost = address!.family === 'IPv4' ? address!.address : `[${address!.address}]`; - this._urlPrefixPrecise = `http://${resolvedHost}:${address!.port}`; - this._urlPrefixHumanReadable = `http://${host}:${address!.port}`; - } - } - - async stop() { - await new Promise(cb => this._server!.close(cb)); - } - - urlPrefix(purpose: 'human-readable' | 'precise'): string { - return purpose === 'human-readable' ? this._urlPrefixHumanReadable : this._urlPrefixPrecise; - } - - serveFile(request: http.IncomingMessage, response: http.ServerResponse, absoluteFilePath: string, headers?: { [name: string]: string }): boolean { - try { - for (const [name, value] of Object.entries(headers || {})) - response.setHeader(name, value); - if (request.headers.range) - this._serveRangeFile(request, response, absoluteFilePath); - else - this._serveFile(response, absoluteFilePath); - return true; - } catch (e) { - return false; - } - } - - _serveFile(response: http.ServerResponse, absoluteFilePath: string) { - const content = fs.readFileSync(absoluteFilePath); - response.statusCode = 200; - const contentType = mime.getType(path.extname(absoluteFilePath)) || 'application/octet-stream'; - response.setHeader('Content-Type', contentType); - response.setHeader('Content-Length', content.byteLength); - response.end(content); - } - - _serveRangeFile(request: http.IncomingMessage, response: http.ServerResponse, absoluteFilePath: string) { - const range = request.headers.range; - if (!range || !range.startsWith('bytes=') || range.includes(', ') || [...range].filter(char => char === '-').length !== 1) { - response.statusCode = 400; - return response.end('Bad request'); - } - - // Parse the range header: https://datatracker.ietf.org/doc/html/rfc7233#section-2.1 - const [startStr, endStr] = range.replace(/bytes=/, '').split('-'); - - // Both start and end (when passing to fs.createReadStream) and the range header are inclusive and start counting at 0. - let start: number; - let end: number; - const size = fs.statSync(absoluteFilePath).size; - if (startStr !== '' && endStr === '') { - // No end specified: use the whole file - start = +startStr; - end = size - 1; - } else if (startStr === '' && endStr !== '') { - // No start specified: calculate start manually - start = size - +endStr; - end = size - 1; - } else { - start = +startStr; - end = +endStr; - } - - // Handle unavailable range request - if (Number.isNaN(start) || Number.isNaN(end) || start >= size || end >= size || start > end) { - // Return the 416 Range Not Satisfiable: https://datatracker.ietf.org/doc/html/rfc7233#section-4.4 - response.writeHead(416, { - 'Content-Range': `bytes */${size}` - }); - return response.end(); - } - - // Sending Partial Content: https://datatracker.ietf.org/doc/html/rfc7233#section-4.1 - response.writeHead(206, { - 'Content-Range': `bytes ${start}-${end}/${size}`, - 'Accept-Ranges': 'bytes', - 'Content-Length': end - start + 1, - 'Content-Type': mime.getType(path.extname(absoluteFilePath))!, +export async function startHttpServer(config: { host?: string, port?: number }): Promise { + const { host, port } = config; + const httpServer = http.createServer(); + await new Promise((resolve, reject) => { + httpServer.on('error', reject); + httpServer.listen(port, host, () => { + resolve(); + httpServer.removeListener('error', reject); }); - - const readable = fs.createReadStream(absoluteFilePath, { start, end }); - readable.pipe(response); - } - - private _onRequest(request: http.IncomingMessage, response: http.ServerResponse) { - if (request.method === 'OPTIONS') { - response.writeHead(200); - response.end(); - return; - } - - request.on('error', () => response.end()); - try { - if (!request.url) { - response.end(); - return; - } - const url = new URL('http://localhost' + request.url); - for (const route of this._routes) { - if (route.exact && url.pathname === route.exact) { - route.handler(request, response); - return; - } - if (route.prefix && url.pathname.startsWith(route.prefix)) { - route.handler(request, response); - return; - } - } - response.statusCode = 404; - response.end(); - } catch (e) { - response.end(); - } - } -} - -function decorateServer(server: net.Server) { - const sockets = new Set(); - server.on('connection', socket => { - sockets.add(socket); - socket.once('close', () => sockets.delete(socket)); }); - - const close = server.close; - server.close = (callback?: (err?: Error) => void) => { - for (const socket of sockets) - socket.destroy(); - sockets.clear(); - return close.call(server, callback); - }; + return httpServer; +} + +export function httpAddressToString(address: string | net.AddressInfo | null): string { + assert(address, 'Could not bind server socket'); + if (typeof address === 'string') + return address; + const resolvedPort = address.port; + let resolvedHost = address.family === 'IPv4' ? address.address : `[${address.address}]`; + if (resolvedHost === '0.0.0.0' || resolvedHost === '[::]') + resolvedHost = 'localhost'; + return `http://${resolvedHost}:${resolvedPort}`; } diff --git a/src/index.ts b/src/index.ts index 2434437..e5d02b6 100644 --- a/src/index.ts +++ b/src/index.ts @@ -14,10 +14,11 @@ * limitations under the License. */ -import { createMCPServer } from './connection.js'; +import { BrowserServerBackend } from './browserServerBackend.js'; import { resolveConfig } from './config.js'; import { contextFactory } from './browserContextFactory.js'; -import { filteredTools } from './tools.js'; +import * as mcpServer from './mcp/server.js'; + import type { Config } from '../config.js'; import type { BrowserContext } from 'playwright'; import type { BrowserContextFactory } from './browserContextFactory.js'; @@ -26,7 +27,7 @@ import type { Server } from '@modelcontextprotocol/sdk/server/index.js'; export async function createConnection(userConfig: Config = {}, contextGetter?: () => Promise): Promise { const config = await resolveConfig(userConfig); const factory = contextGetter ? new SimpleBrowserContextFactory(contextGetter) : contextFactory(config.browser); - return createMCPServer(config, filteredTools(config), factory); + return mcpServer.createServer(new BrowserServerBackend(config, factory)); } class SimpleBrowserContextFactory implements BrowserContextFactory { diff --git a/src/loop/onetool.ts b/src/loop/onetool.ts index 933b180..9128e39 100644 --- a/src/loop/onetool.ts +++ b/src/loop/onetool.ts @@ -21,64 +21,64 @@ import { z } from 'zod'; import { Client } from '@modelcontextprotocol/sdk/client/index.js'; import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; -import { FullConfig } from '../config.js'; -import { defineTool } from '../tools/tool.js'; -import { Server } from '../server.js'; -import { startHttpServer, startHttpTransport, startStdioTransport } from '../transport.js'; import { OpenAIDelegate } from './loopOpenAI.js'; import { runTask } from './loop.js'; +import { packageJSON } from '../package.js'; +import * as mcpTransport from '../mcp/transport.js'; -dotenv.config(); +import type { FullConfig } from '../config.js'; +import type { ServerBackend } from '../mcp/server.js'; +import type * as mcpServer from '../mcp/server.js'; const __filename = url.fileURLToPath(import.meta.url); -let innerClient: Client | undefined; const delegate = new OpenAIDelegate(); -const oneTool = defineTool({ - capability: 'core', - - schema: { - name: 'browser', - title: 'Perform a task with the browser', - description: 'Perform a task with the browser. It can click, type, export, capture screenshot, drag, hover, select options, etc.', - inputSchema: z.object({ - task: z.string().describe('The task to perform with the browser'), - }), - type: 'readOnly', - }, - - handle: async (context, params, response) => { - const result = await runTask(delegate!, innerClient!, params.task); - response.addResult(result); - }, -}); +const oneToolSchema: mcpServer.ToolSchema = { + name: 'browser', + title: 'Perform a task with the browser', + description: 'Perform a task with the browser. It can click, type, export, capture screenshot, drag, hover, select options, etc.', + inputSchema: z.object({ + task: z.string().describe('The task to perform with the browser'), + }), + type: 'readOnly', +}; export async function runOneTool(config: FullConfig) { - innerClient = await createInnerClient(); - const server = new Server(config, [oneTool]); - server.setupExitWatchdog(); + dotenv.config(); + const serverBackendFactory = () => new OneToolServerBackend(); + await mcpTransport.start(serverBackendFactory, config.server); +} - if (config.server.port !== undefined) { - const httpServer = await startHttpServer(config.server); - startHttpTransport(httpServer, server); - } else { - await startStdioTransport(server); +class OneToolServerBackend implements ServerBackend { + readonly name = 'Playwright'; + readonly version = packageJSON.version; + private _innerClient: Client | undefined; + + async initialize() { + const transport = new StdioClientTransport({ + command: 'node', + args: [ + path.resolve(__filename, '../../../cli.js'), + ], + stderr: 'inherit', + env: process.env as Record, + }); + + const client = new Client({ name: 'Playwright Proxy', version: '1.0.0' }); + await client.connect(transport); + await client.ping(); + this._innerClient = client; + } + + tools(): mcpServer.ToolSchema[] { + return [oneToolSchema]; + } + + async callTool(schema: mcpServer.ToolSchema, parsedArguments: any): Promise { + const result = await runTask(delegate!, this._innerClient!, parsedArguments.task as string); + return { + content: [{ type: 'text', text: result }], + }; } } - -async function createInnerClient(): Promise { - const transport = new StdioClientTransport({ - command: 'node', - args: [ - path.resolve(__filename, '../../../cli.js'), - ], - stderr: 'inherit', - env: process.env as Record, - }); - - const client = new Client({ name: 'Playwright Proxy', version: '1.0.0' }); - await client.connect(transport); - await client.ping(); - return client; -} diff --git a/src/mcp/README.md b/src/mcp/README.md new file mode 100644 index 0000000..64edb62 --- /dev/null +++ b/src/mcp/README.md @@ -0,0 +1 @@ +- Generic MCP utils, no dependencies on Playwright here. diff --git a/src/mcp/server.ts b/src/mcp/server.ts new file mode 100644 index 0000000..57ba3c9 --- /dev/null +++ b/src/mcp/server.ts @@ -0,0 +1,105 @@ +/** + * Copyright (c) Microsoft Corporation. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { z } from 'zod'; +import { Server } from '@modelcontextprotocol/sdk/server/index.js'; +import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js'; +import { zodToJsonSchema } from 'zod-to-json-schema'; + +import type { ImageContent, Implementation, TextContent } from '@modelcontextprotocol/sdk/types.js'; +import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; + +export type ClientVersion = Implementation; + +export type ToolResponse = { + content: (TextContent | ImageContent)[]; + isError?: boolean; +}; + +export type ToolSchema = { + name: string; + title: string; + description: string; + inputSchema: Input; + type: 'readOnly' | 'destructive'; +}; + +export type ToolHandler = (toolName: string, params: any) => Promise; + +export interface ServerBackend { + name: string; + version: string; + initialize?(): Promise; + tools(): ToolSchema[]; + callTool(schema: ToolSchema, parsedArguments: any): Promise; + serverInitialized?(version: ClientVersion | undefined): void; + serverClosed?(): void; +} + +export type ServerBackendFactory = () => ServerBackend; + +export async function connect(serverBackendFactory: ServerBackendFactory, transport: Transport) { + const backend = serverBackendFactory(); + await backend.initialize?.(); + const server = createServer(backend); + await server.connect(transport); +} + +export function createServer(backend: ServerBackend): Server { + const server = new Server({ name: backend.name, version: backend.version }, { + capabilities: { + tools: {}, + } + }); + + const tools = backend.tools(); + server.setRequestHandler(ListToolsRequestSchema, async () => { + return { tools: tools.map(tool => ({ + name: tool.name, + description: tool.description, + inputSchema: zodToJsonSchema(tool.inputSchema), + annotations: { + title: tool.title, + readOnlyHint: tool.type === 'readOnly', + destructiveHint: tool.type === 'destructive', + openWorldHint: true, + }, + })) }; + }); + + server.setRequestHandler(CallToolRequestSchema, async request => { + const errorResult = (...messages: string[]) => ({ + content: [{ type: 'text', text: messages.join('\n') }], + isError: true, + }); + const tool = tools.find(tool => tool.name === request.params.name) as ToolSchema; + if (!tool) + return errorResult(`Tool "${request.params.name}" not found`); + + try { + return await backend.callTool(tool, tool.inputSchema.parse(request.params.arguments || {})); + } catch (error) { + return errorResult(String(error)); + } + }); + + if (backend.serverInitialized) + server.oninitialized = () => backend.serverInitialized!(server.getClientVersion()); + if (backend.serverClosed) + server.onclose = () => backend.serverClosed!(); + + return server; +} diff --git a/src/transport.ts b/src/mcp/transport.ts similarity index 65% rename from src/transport.ts rename to src/mcp/transport.ts index c34ce39..5a0381f 100644 --- a/src/transport.ts +++ b/src/mcp/transport.ts @@ -14,25 +14,34 @@ * limitations under the License. */ -import http from 'node:http'; -import assert from 'node:assert'; -import crypto from 'node:crypto'; - +import http from 'http'; +import crypto from 'crypto'; import debug from 'debug'; + import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js'; import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'; +import { httpAddressToString, startHttpServer } from '../httpServer.js'; +import * as mcpServer from './server.js'; -import type { AddressInfo } from 'node:net'; -import type { Server } from './server.js'; +import type { ServerBackendFactory } from './server.js'; -export async function startStdioTransport(server: Server) { - await server.createConnection(new StdioServerTransport()); +export async function start(serverBackendFactory: ServerBackendFactory, options: { host?: string; port?: number }) { + if (options.port !== undefined) { + const httpServer = await startHttpServer(options); + startHttpTransport(httpServer, serverBackendFactory); + } else { + await startStdioTransport(serverBackendFactory); + } +} + +async function startStdioTransport(serverBackendFactory: ServerBackendFactory) { + await mcpServer.connect(serverBackendFactory, new StdioServerTransport()); } const testDebug = debug('pw:mcp:test'); -async function handleSSE(server: Server, req: http.IncomingMessage, res: http.ServerResponse, url: URL, sessions: Map) { +async function handleSSE(serverBackendFactory: ServerBackendFactory, req: http.IncomingMessage, res: http.ServerResponse, url: URL, sessions: Map) { if (req.method === 'POST') { const sessionId = url.searchParams.get('sessionId'); if (!sessionId) { @@ -51,7 +60,7 @@ async function handleSSE(server: Server, req: http.IncomingMessage, res: http.Se const transport = new SSEServerTransport('/sse', res); sessions.set(transport.sessionId, transport); testDebug(`create SSE session: ${transport.sessionId}`); - await server.createConnection(transport); + await mcpServer.connect(serverBackendFactory, transport); res.on('close', () => { testDebug(`delete SSE session: ${transport.sessionId}`); sessions.delete(transport.sessionId); @@ -63,7 +72,7 @@ async function handleSSE(server: Server, req: http.IncomingMessage, res: http.Se res.end('Method not allowed'); } -async function handleStreamable(server: Server, req: http.IncomingMessage, res: http.ServerResponse, sessions: Map) { +async function handleStreamable(serverBackendFactory: ServerBackendFactory, req: http.IncomingMessage, res: http.ServerResponse, sessions: Map) { const sessionId = req.headers['mcp-session-id'] as string | undefined; if (sessionId) { const transport = sessions.get(sessionId); @@ -80,7 +89,7 @@ async function handleStreamable(server: Server, req: http.IncomingMessage, res: sessionIdGenerator: () => crypto.randomUUID(), onsessioninitialized: async sessionId => { testDebug(`create http session: ${transport.sessionId}`); - await server.createConnection(transport); + await mcpServer.connect(serverBackendFactory, transport); sessions.set(sessionId, transport); } }); @@ -100,28 +109,15 @@ async function handleStreamable(server: Server, req: http.IncomingMessage, res: res.end('Invalid request'); } -export async function startHttpServer(config: { host?: string, port?: number }): Promise { - const { host, port } = config; - const httpServer = http.createServer(); - await new Promise((resolve, reject) => { - httpServer.on('error', reject); - httpServer.listen(port, host, () => { - resolve(); - httpServer.removeListener('error', reject); - }); - }); - return httpServer; -} - -export function startHttpTransport(httpServer: http.Server, mcpServer: Server) { +function startHttpTransport(httpServer: http.Server, serverBackendFactory: ServerBackendFactory) { const sseSessions = new Map(); const streamableSessions = new Map(); httpServer.on('request', async (req, res) => { const url = new URL(`http://localhost${req.url}`); if (url.pathname.startsWith('/sse')) - await handleSSE(mcpServer, req, res, url, sseSessions); + await handleSSE(serverBackendFactory, req, res, url, sseSessions); else - await handleStreamable(mcpServer, req, res, streamableSessions); + await handleStreamable(serverBackendFactory, req, res, streamableSessions); }); const url = httpAddressToString(httpServer.address()); const message = [ @@ -139,14 +135,3 @@ export function startHttpTransport(httpServer: http.Server, mcpServer: Server) { // eslint-disable-next-line no-console console.error(message); } - -export function httpAddressToString(address: string | AddressInfo | null): string { - assert(address, 'Could not bind server socket'); - if (typeof address === 'string') - return address; - const resolvedPort = address.port; - let resolvedHost = address.family === 'IPv4' ? address.address : `[${address.address}]`; - if (resolvedHost === '0.0.0.0' || resolvedHost === '[::]') - resolvedHost = 'localhost'; - return `http://${resolvedHost}:${resolvedPort}`; -} diff --git a/src/program.ts b/src/program.ts index 4573960..a34205c 100644 --- a/src/program.ts +++ b/src/program.ts @@ -18,12 +18,13 @@ import { program, Option } from 'commander'; // @ts-ignore import { startTraceViewerServer } from 'playwright-core/lib/server'; -import { startHttpServer, startHttpTransport, startStdioTransport } from './transport.js'; +import * as mcpTransport from './mcp/transport.js'; import { commaSeparatedList, resolveCLIConfig, semicolonSeparatedList } from './config.js'; -import { Server } from './server.js'; import { packageJSON } from './package.js'; import { runWithExtension } from './extension/main.js'; -import { filteredTools } from './tools.js'; +import { BrowserServerBackend } from './browserServerBackend.js'; +import { Context } from './context.js'; +import { contextFactory } from './browserContextFactory.js'; program .version('Version ' + packageJSON.version) @@ -56,6 +57,8 @@ program .addOption(new Option('--extension', 'Connect to a running browser instance (Edge/Chrome only). Requires the "Playwright MCP Bridge" browser extension to be installed.').hideHelp()) .addOption(new Option('--vision', 'Legacy option, use --caps=vision instead').hideHelp()) .action(async options => { + setupExitWatchdog(); + if (options.vision) { // eslint-disable-next-line no-console console.error('The --vision option is deprecated, use --caps=vision instead'); @@ -68,15 +71,9 @@ program return; } - const server = new Server(config, filteredTools(config)); - server.setupExitWatchdog(); - - if (config.server.port !== undefined) { - const httpServer = await startHttpServer(config.server); - startHttpTransport(httpServer, server); - } else { - await startStdioTransport(server); - } + const browserContextFactory = contextFactory(config.browser); + const serverBackendFactory = () => new BrowserServerBackend(config, browserContextFactory); + await mcpTransport.start(serverBackendFactory, config.server); if (config.saveTrace) { const server = await startTraceViewerServer(); @@ -87,4 +84,20 @@ program } }); +function setupExitWatchdog() { + let isExiting = false; + const handleExit = async () => { + if (isExiting) + return; + isExiting = true; + setTimeout(() => process.exit(0), 15000); + await Context.disposeAll(); + process.exit(0); + }; + + process.stdin.on('close', handleExit); + process.on('SIGINT', handleExit); + process.on('SIGTERM', handleExit); +} + void program.parseAsync(process.argv); diff --git a/src/server.ts b/src/server.ts deleted file mode 100644 index e34928f..0000000 --- a/src/server.ts +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright (c) Microsoft Corporation. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { createMCPServer } from './connection.js'; -import { Context } from './context.js'; -import { contextFactory as defaultContextFactory } from './browserContextFactory.js'; - -import type { FullConfig } from './config.js'; -import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; -import type { BrowserContextFactory } from './browserContextFactory.js'; -import type { Tool } from './tools/tool.js'; - -export class Server { - readonly config: FullConfig; - private _browserConfig: FullConfig['browser']; - private _contextFactory: BrowserContextFactory; - readonly tools: Tool[]; - - constructor(config: FullConfig, tools: Tool[], contextFactory?: BrowserContextFactory) { - this.config = config; - this.tools = tools; - this._browserConfig = config.browser; - this._contextFactory = contextFactory ?? defaultContextFactory(this._browserConfig); - } - - async createConnection(transport: Transport): Promise { - const server = await createMCPServer(this.config, this.tools, this._contextFactory); - await server.connect(transport); - } - - setupExitWatchdog() { - let isExiting = false; - const handleExit = async () => { - if (isExiting) - return; - isExiting = true; - setTimeout(() => process.exit(0), 15000); - await Context.disposeAll(); - process.exit(0); - }; - - process.stdin.on('close', handleExit); - process.on('SIGINT', handleExit); - process.on('SIGTERM', handleExit); - } -}