chore: extract pure mcp server helpers (#751)

This commit is contained in:
Pavel Feldman 2025-07-24 12:57:01 -07:00 committed by GitHub
parent bd34e9d7e9
commit c63b7823e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 300 additions and 469 deletions

View File

@ -0,0 +1,66 @@
/**
* Copyright (c) Microsoft Corporation.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { FullConfig } from './config.js';
import { Context } from './context.js';
import { logUnhandledError } from './log.js';
import { Response } from './response.js';
import { SessionLog } from './sessionLog.js';
import { filteredTools } from './tools.js';
import { packageJSON } from './package.js';
import type { BrowserContextFactory } from './browserContextFactory.js';
import type * as mcpServer from './mcp/server.js';
import type { ServerBackend } from './mcp/server.js';
import type { Tool } from './tools/tool.js';
export class BrowserServerBackend implements ServerBackend {
name = 'Playwright';
version = packageJSON.version;
private _tools: Tool[];
private _context: Context;
private _sessionLog: SessionLog | undefined;
constructor(config: FullConfig, browserContextFactory: BrowserContextFactory) {
this._tools = filteredTools(config);
this._context = new Context(this._tools, config, browserContextFactory);
}
async initialize() {
this._sessionLog = this._context.config.saveSession ? await SessionLog.create(this._context.config) : undefined;
}
tools(): mcpServer.ToolSchema<any>[] {
return this._tools.map(tool => tool.schema);
}
async callTool(schema: mcpServer.ToolSchema<any>, parsedArguments: any) {
const response = new Response(this._context, schema.name, parsedArguments);
const tool = this._tools.find(tool => tool.schema.name === schema.name)!;
await tool.handle(this._context, parsedArguments, response);
if (this._sessionLog)
await this._sessionLog.log(response);
return await response.serialize();
}
serverInitialized(version: mcpServer.ClientVersion | undefined) {
this._context.clientVersion = version;
}
serverClosed() {
void this._context.dispose().catch(logUnhandledError);
}
}

View File

@ -1,84 +0,0 @@
/**
* Copyright (c) Microsoft Corporation.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
import { CallToolRequestSchema, ListToolsRequestSchema, Tool as McpTool } from '@modelcontextprotocol/sdk/types.js';
import { zodToJsonSchema } from 'zod-to-json-schema';
import { Context } from './context.js';
import { Response } from './response.js';
import { packageJSON } from './package.js';
import { FullConfig } from './config.js';
import { SessionLog } from './sessionLog.js';
import { logUnhandledError } from './log.js';
import type { BrowserContextFactory } from './browserContextFactory.js';
import type { Tool } from './tools/tool.js';
export async function createMCPServer(config: FullConfig, tools: Tool<any>[], browserContextFactory: BrowserContextFactory): Promise<Server> {
const context = new Context(tools, config, browserContextFactory);
const server = new Server({ name: 'Playwright', version: packageJSON.version }, {
capabilities: {
tools: {},
}
});
const sessionLog = config.saveSession ? await SessionLog.create(config) : undefined;
server.setRequestHandler(ListToolsRequestSchema, async () => {
return {
tools: tools.map(tool => ({
name: tool.schema.name,
description: tool.schema.description,
inputSchema: zodToJsonSchema(tool.schema.inputSchema),
annotations: {
title: tool.schema.title,
readOnlyHint: tool.schema.type === 'readOnly',
destructiveHint: tool.schema.type === 'destructive',
openWorldHint: true,
},
})) as McpTool[],
};
});
server.setRequestHandler(CallToolRequestSchema, async request => {
const errorResult = (...messages: string[]) => ({
content: [{ type: 'text', text: messages.join('\n') }],
isError: true,
});
const tool = tools.find(tool => tool.schema.name === request.params.name);
if (!tool)
return errorResult(`Tool "${request.params.name}" not found`);
try {
const response = new Response(context, request.params.name, request.params.arguments || {});
await tool.handle(context, tool.schema.inputSchema.parse(request.params.arguments || {}), response);
if (sessionLog)
await sessionLog.log(response);
return await response.serialize();
} catch (error) {
return errorResult(String(error));
}
});
server.oninitialized = () => {
context.clientVersion = server.getClientVersion();
};
server.onclose = () => {
void context.dispose().catch(logUnhandledError);
};
return server;
}

View File

@ -27,10 +27,9 @@ import { spawn } from 'child_process';
import { WebSocket, WebSocketServer } from 'ws'; import { WebSocket, WebSocketServer } from 'ws';
import debug from 'debug'; import debug from 'debug';
import * as playwright from 'playwright'; import * as playwright from 'playwright';
import { httpAddressToString, startHttpServer } from '../transport.js';
// @ts-ignore // @ts-ignore
const { registry } = await import('playwright-core/lib/server/registry/index'); const { registry } = await import('playwright-core/lib/server/registry/index');
import { httpAddressToString, startHttpServer } from '../httpServer.js';
import type { BrowserContextFactory } from '../browserContextFactory.js'; import type { BrowserContextFactory } from '../browserContextFactory.js';
import type websocket from 'ws'; import type websocket from 'ws';

View File

@ -14,22 +14,14 @@
* limitations under the License. * limitations under the License.
*/ */
import { startHttpServer, startHttpTransport, startStdioTransport } from '../transport.js';
import { Server } from '../server.js';
import { startCDPRelayServer } from './cdpRelay.js'; import { startCDPRelayServer } from './cdpRelay.js';
import { filteredTools } from '../tools.js'; import { BrowserServerBackend } from '../browserServerBackend.js';
import * as mcpTransport from '../mcp/transport.js';
import type { FullConfig } from '../config.js'; import type { FullConfig } from '../config.js';
export async function runWithExtension(config: FullConfig) { export async function runWithExtension(config: FullConfig) {
const contextFactory = await startCDPRelayServer(config.browser.launchOptions.channel || 'chrome'); const contextFactory = await startCDPRelayServer(config.browser.launchOptions.channel || 'chrome');
const server = new Server(config, filteredTools(config), contextFactory); const serverBackendFactory = () => new BrowserServerBackend(config, contextFactory);
server.setupExitWatchdog(); await mcpTransport.start(serverBackendFactory, config.server);
if (config.server.port !== undefined) {
const httpServer = await startHttpServer(config.server);
startHttpTransport(httpServer, server);
} else {
await startStdioTransport(server);
}
} }

View File

@ -14,219 +14,31 @@
* limitations under the License. * limitations under the License.
*/ */
import fs from 'fs'; import assert from 'assert';
import path from 'path';
import http from 'http'; import http from 'http';
import net from 'net';
import mime from 'mime'; import type * as net from 'net';
import { ManualPromise } from './manualPromise.js'; export async function startHttpServer(config: { host?: string, port?: number }): Promise<http.Server> {
const { host, port } = config;
const httpServer = http.createServer();
export type ServerRouteHandler = (request: http.IncomingMessage, response: http.ServerResponse) => void; await new Promise<void>((resolve, reject) => {
httpServer.on('error', reject);
export type Transport = { httpServer.listen(port, host, () => {
sendEvent?: (method: string, params: any) => void; resolve();
close?: () => void; httpServer.removeListener('error', reject);
onconnect: () => void;
dispatch: (method: string, params: any) => Promise<any>;
onclose: () => void;
};
export class HttpServer {
private _server: http.Server;
private _urlPrefixPrecise: string = '';
private _urlPrefixHumanReadable: string = '';
private _port: number = 0;
private _routes: { prefix?: string, exact?: string, handler: ServerRouteHandler }[] = [];
constructor() {
this._server = http.createServer(this._onRequest.bind(this));
decorateServer(this._server);
}
server() {
return this._server;
}
routePrefix(prefix: string, handler: ServerRouteHandler) {
this._routes.push({ prefix, handler });
}
routePath(path: string, handler: ServerRouteHandler) {
this._routes.push({ exact: path, handler });
}
port(): number {
return this._port;
}
private async _tryStart(port: number | undefined, host: string) {
const errorPromise = new ManualPromise();
const errorListener = (error: Error) => errorPromise.reject(error);
this._server.on('error', errorListener);
try {
this._server.listen(port, host);
await Promise.race([
new Promise(cb => this._server!.once('listening', cb)),
errorPromise,
]);
} finally {
this._server.removeListener('error', errorListener);
}
}
async start(options: { port?: number, preferredPort?: number, host?: string } = {}): Promise<void> {
const host = options.host || 'localhost';
if (options.preferredPort) {
try {
await this._tryStart(options.preferredPort, host);
} catch (e: any) {
if (!e || !e.message || !e.message.includes('EADDRINUSE'))
throw e;
await this._tryStart(undefined, host);
}
} else {
await this._tryStart(options.port, host);
}
const address = this._server.address();
if (typeof address === 'string') {
this._urlPrefixPrecise = address;
this._urlPrefixHumanReadable = address;
} else {
this._port = address!.port;
const resolvedHost = address!.family === 'IPv4' ? address!.address : `[${address!.address}]`;
this._urlPrefixPrecise = `http://${resolvedHost}:${address!.port}`;
this._urlPrefixHumanReadable = `http://${host}:${address!.port}`;
}
}
async stop() {
await new Promise(cb => this._server!.close(cb));
}
urlPrefix(purpose: 'human-readable' | 'precise'): string {
return purpose === 'human-readable' ? this._urlPrefixHumanReadable : this._urlPrefixPrecise;
}
serveFile(request: http.IncomingMessage, response: http.ServerResponse, absoluteFilePath: string, headers?: { [name: string]: string }): boolean {
try {
for (const [name, value] of Object.entries(headers || {}))
response.setHeader(name, value);
if (request.headers.range)
this._serveRangeFile(request, response, absoluteFilePath);
else
this._serveFile(response, absoluteFilePath);
return true;
} catch (e) {
return false;
}
}
_serveFile(response: http.ServerResponse, absoluteFilePath: string) {
const content = fs.readFileSync(absoluteFilePath);
response.statusCode = 200;
const contentType = mime.getType(path.extname(absoluteFilePath)) || 'application/octet-stream';
response.setHeader('Content-Type', contentType);
response.setHeader('Content-Length', content.byteLength);
response.end(content);
}
_serveRangeFile(request: http.IncomingMessage, response: http.ServerResponse, absoluteFilePath: string) {
const range = request.headers.range;
if (!range || !range.startsWith('bytes=') || range.includes(', ') || [...range].filter(char => char === '-').length !== 1) {
response.statusCode = 400;
return response.end('Bad request');
}
// Parse the range header: https://datatracker.ietf.org/doc/html/rfc7233#section-2.1
const [startStr, endStr] = range.replace(/bytes=/, '').split('-');
// Both start and end (when passing to fs.createReadStream) and the range header are inclusive and start counting at 0.
let start: number;
let end: number;
const size = fs.statSync(absoluteFilePath).size;
if (startStr !== '' && endStr === '') {
// No end specified: use the whole file
start = +startStr;
end = size - 1;
} else if (startStr === '' && endStr !== '') {
// No start specified: calculate start manually
start = size - +endStr;
end = size - 1;
} else {
start = +startStr;
end = +endStr;
}
// Handle unavailable range request
if (Number.isNaN(start) || Number.isNaN(end) || start >= size || end >= size || start > end) {
// Return the 416 Range Not Satisfiable: https://datatracker.ietf.org/doc/html/rfc7233#section-4.4
response.writeHead(416, {
'Content-Range': `bytes */${size}`
});
return response.end();
}
// Sending Partial Content: https://datatracker.ietf.org/doc/html/rfc7233#section-4.1
response.writeHead(206, {
'Content-Range': `bytes ${start}-${end}/${size}`,
'Accept-Ranges': 'bytes',
'Content-Length': end - start + 1,
'Content-Type': mime.getType(path.extname(absoluteFilePath))!,
}); });
const readable = fs.createReadStream(absoluteFilePath, { start, end });
readable.pipe(response);
}
private _onRequest(request: http.IncomingMessage, response: http.ServerResponse) {
if (request.method === 'OPTIONS') {
response.writeHead(200);
response.end();
return;
}
request.on('error', () => response.end());
try {
if (!request.url) {
response.end();
return;
}
const url = new URL('http://localhost' + request.url);
for (const route of this._routes) {
if (route.exact && url.pathname === route.exact) {
route.handler(request, response);
return;
}
if (route.prefix && url.pathname.startsWith(route.prefix)) {
route.handler(request, response);
return;
}
}
response.statusCode = 404;
response.end();
} catch (e) {
response.end();
}
}
}
function decorateServer(server: net.Server) {
const sockets = new Set<net.Socket>();
server.on('connection', socket => {
sockets.add(socket);
socket.once('close', () => sockets.delete(socket));
}); });
return httpServer;
const close = server.close; }
server.close = (callback?: (err?: Error) => void) => {
for (const socket of sockets) export function httpAddressToString(address: string | net.AddressInfo | null): string {
socket.destroy(); assert(address, 'Could not bind server socket');
sockets.clear(); if (typeof address === 'string')
return close.call(server, callback); return address;
}; const resolvedPort = address.port;
let resolvedHost = address.family === 'IPv4' ? address.address : `[${address.address}]`;
if (resolvedHost === '0.0.0.0' || resolvedHost === '[::]')
resolvedHost = 'localhost';
return `http://${resolvedHost}:${resolvedPort}`;
} }

View File

@ -14,10 +14,11 @@
* limitations under the License. * limitations under the License.
*/ */
import { createMCPServer } from './connection.js'; import { BrowserServerBackend } from './browserServerBackend.js';
import { resolveConfig } from './config.js'; import { resolveConfig } from './config.js';
import { contextFactory } from './browserContextFactory.js'; import { contextFactory } from './browserContextFactory.js';
import { filteredTools } from './tools.js'; import * as mcpServer from './mcp/server.js';
import type { Config } from '../config.js'; import type { Config } from '../config.js';
import type { BrowserContext } from 'playwright'; import type { BrowserContext } from 'playwright';
import type { BrowserContextFactory } from './browserContextFactory.js'; import type { BrowserContextFactory } from './browserContextFactory.js';
@ -26,7 +27,7 @@ import type { Server } from '@modelcontextprotocol/sdk/server/index.js';
export async function createConnection(userConfig: Config = {}, contextGetter?: () => Promise<BrowserContext>): Promise<Server> { export async function createConnection(userConfig: Config = {}, contextGetter?: () => Promise<BrowserContext>): Promise<Server> {
const config = await resolveConfig(userConfig); const config = await resolveConfig(userConfig);
const factory = contextGetter ? new SimpleBrowserContextFactory(contextGetter) : contextFactory(config.browser); const factory = contextGetter ? new SimpleBrowserContextFactory(contextGetter) : contextFactory(config.browser);
return createMCPServer(config, filteredTools(config), factory); return mcpServer.createServer(new BrowserServerBackend(config, factory));
} }
class SimpleBrowserContextFactory implements BrowserContextFactory { class SimpleBrowserContextFactory implements BrowserContextFactory {

View File

@ -21,64 +21,64 @@ import { z } from 'zod';
import { Client } from '@modelcontextprotocol/sdk/client/index.js'; import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
import { FullConfig } from '../config.js';
import { defineTool } from '../tools/tool.js';
import { Server } from '../server.js';
import { startHttpServer, startHttpTransport, startStdioTransport } from '../transport.js';
import { OpenAIDelegate } from './loopOpenAI.js'; import { OpenAIDelegate } from './loopOpenAI.js';
import { runTask } from './loop.js'; import { runTask } from './loop.js';
import { packageJSON } from '../package.js';
import * as mcpTransport from '../mcp/transport.js';
dotenv.config(); import type { FullConfig } from '../config.js';
import type { ServerBackend } from '../mcp/server.js';
import type * as mcpServer from '../mcp/server.js';
const __filename = url.fileURLToPath(import.meta.url); const __filename = url.fileURLToPath(import.meta.url);
let innerClient: Client | undefined;
const delegate = new OpenAIDelegate(); const delegate = new OpenAIDelegate();
const oneTool = defineTool({ const oneToolSchema: mcpServer.ToolSchema<any> = {
capability: 'core', name: 'browser',
title: 'Perform a task with the browser',
schema: { description: 'Perform a task with the browser. It can click, type, export, capture screenshot, drag, hover, select options, etc.',
name: 'browser', inputSchema: z.object({
title: 'Perform a task with the browser', task: z.string().describe('The task to perform with the browser'),
description: 'Perform a task with the browser. It can click, type, export, capture screenshot, drag, hover, select options, etc.', }),
inputSchema: z.object({ type: 'readOnly',
task: z.string().describe('The task to perform with the browser'), };
}),
type: 'readOnly',
},
handle: async (context, params, response) => {
const result = await runTask(delegate!, innerClient!, params.task);
response.addResult(result);
},
});
export async function runOneTool(config: FullConfig) { export async function runOneTool(config: FullConfig) {
innerClient = await createInnerClient(); dotenv.config();
const server = new Server(config, [oneTool]); const serverBackendFactory = () => new OneToolServerBackend();
server.setupExitWatchdog(); await mcpTransport.start(serverBackendFactory, config.server);
}
if (config.server.port !== undefined) { class OneToolServerBackend implements ServerBackend {
const httpServer = await startHttpServer(config.server); readonly name = 'Playwright';
startHttpTransport(httpServer, server); readonly version = packageJSON.version;
} else { private _innerClient: Client | undefined;
await startStdioTransport(server);
async initialize() {
const transport = new StdioClientTransport({
command: 'node',
args: [
path.resolve(__filename, '../../../cli.js'),
],
stderr: 'inherit',
env: process.env as Record<string, string>,
});
const client = new Client({ name: 'Playwright Proxy', version: '1.0.0' });
await client.connect(transport);
await client.ping();
this._innerClient = client;
}
tools(): mcpServer.ToolSchema<any>[] {
return [oneToolSchema];
}
async callTool(schema: mcpServer.ToolSchema<any>, parsedArguments: any): Promise<mcpServer.ToolResponse> {
const result = await runTask(delegate!, this._innerClient!, parsedArguments.task as string);
return {
content: [{ type: 'text', text: result }],
};
} }
} }
async function createInnerClient(): Promise<Client> {
const transport = new StdioClientTransport({
command: 'node',
args: [
path.resolve(__filename, '../../../cli.js'),
],
stderr: 'inherit',
env: process.env as Record<string, string>,
});
const client = new Client({ name: 'Playwright Proxy', version: '1.0.0' });
await client.connect(transport);
await client.ping();
return client;
}

1
src/mcp/README.md Normal file
View File

@ -0,0 +1 @@
- Generic MCP utils, no dependencies on Playwright here.

105
src/mcp/server.ts Normal file
View File

@ -0,0 +1,105 @@
/**
* Copyright (c) Microsoft Corporation.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { z } from 'zod';
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js';
import { zodToJsonSchema } from 'zod-to-json-schema';
import type { ImageContent, Implementation, TextContent } from '@modelcontextprotocol/sdk/types.js';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
export type ClientVersion = Implementation;
export type ToolResponse = {
content: (TextContent | ImageContent)[];
isError?: boolean;
};
export type ToolSchema<Input extends z.Schema> = {
name: string;
title: string;
description: string;
inputSchema: Input;
type: 'readOnly' | 'destructive';
};
export type ToolHandler = (toolName: string, params: any) => Promise<ToolResponse>;
export interface ServerBackend {
name: string;
version: string;
initialize?(): Promise<void>;
tools(): ToolSchema<any>[];
callTool(schema: ToolSchema<any>, parsedArguments: any): Promise<ToolResponse>;
serverInitialized?(version: ClientVersion | undefined): void;
serverClosed?(): void;
}
export type ServerBackendFactory = () => ServerBackend;
export async function connect(serverBackendFactory: ServerBackendFactory, transport: Transport) {
const backend = serverBackendFactory();
await backend.initialize?.();
const server = createServer(backend);
await server.connect(transport);
}
export function createServer(backend: ServerBackend): Server {
const server = new Server({ name: backend.name, version: backend.version }, {
capabilities: {
tools: {},
}
});
const tools = backend.tools();
server.setRequestHandler(ListToolsRequestSchema, async () => {
return { tools: tools.map(tool => ({
name: tool.name,
description: tool.description,
inputSchema: zodToJsonSchema(tool.inputSchema),
annotations: {
title: tool.title,
readOnlyHint: tool.type === 'readOnly',
destructiveHint: tool.type === 'destructive',
openWorldHint: true,
},
})) };
});
server.setRequestHandler(CallToolRequestSchema, async request => {
const errorResult = (...messages: string[]) => ({
content: [{ type: 'text', text: messages.join('\n') }],
isError: true,
});
const tool = tools.find(tool => tool.name === request.params.name) as ToolSchema<any>;
if (!tool)
return errorResult(`Tool "${request.params.name}" not found`);
try {
return await backend.callTool(tool, tool.inputSchema.parse(request.params.arguments || {}));
} catch (error) {
return errorResult(String(error));
}
});
if (backend.serverInitialized)
server.oninitialized = () => backend.serverInitialized!(server.getClientVersion());
if (backend.serverClosed)
server.onclose = () => backend.serverClosed!();
return server;
}

View File

@ -14,25 +14,34 @@
* limitations under the License. * limitations under the License.
*/ */
import http from 'node:http'; import http from 'http';
import assert from 'node:assert'; import crypto from 'crypto';
import crypto from 'node:crypto';
import debug from 'debug'; import debug from 'debug';
import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js'; import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js';
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'; import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js';
import { httpAddressToString, startHttpServer } from '../httpServer.js';
import * as mcpServer from './server.js';
import type { AddressInfo } from 'node:net'; import type { ServerBackendFactory } from './server.js';
import type { Server } from './server.js';
export async function startStdioTransport(server: Server) { export async function start(serverBackendFactory: ServerBackendFactory, options: { host?: string; port?: number }) {
await server.createConnection(new StdioServerTransport()); if (options.port !== undefined) {
const httpServer = await startHttpServer(options);
startHttpTransport(httpServer, serverBackendFactory);
} else {
await startStdioTransport(serverBackendFactory);
}
}
async function startStdioTransport(serverBackendFactory: ServerBackendFactory) {
await mcpServer.connect(serverBackendFactory, new StdioServerTransport());
} }
const testDebug = debug('pw:mcp:test'); const testDebug = debug('pw:mcp:test');
async function handleSSE(server: Server, req: http.IncomingMessage, res: http.ServerResponse, url: URL, sessions: Map<string, SSEServerTransport>) { async function handleSSE(serverBackendFactory: ServerBackendFactory, req: http.IncomingMessage, res: http.ServerResponse, url: URL, sessions: Map<string, SSEServerTransport>) {
if (req.method === 'POST') { if (req.method === 'POST') {
const sessionId = url.searchParams.get('sessionId'); const sessionId = url.searchParams.get('sessionId');
if (!sessionId) { if (!sessionId) {
@ -51,7 +60,7 @@ async function handleSSE(server: Server, req: http.IncomingMessage, res: http.Se
const transport = new SSEServerTransport('/sse', res); const transport = new SSEServerTransport('/sse', res);
sessions.set(transport.sessionId, transport); sessions.set(transport.sessionId, transport);
testDebug(`create SSE session: ${transport.sessionId}`); testDebug(`create SSE session: ${transport.sessionId}`);
await server.createConnection(transport); await mcpServer.connect(serverBackendFactory, transport);
res.on('close', () => { res.on('close', () => {
testDebug(`delete SSE session: ${transport.sessionId}`); testDebug(`delete SSE session: ${transport.sessionId}`);
sessions.delete(transport.sessionId); sessions.delete(transport.sessionId);
@ -63,7 +72,7 @@ async function handleSSE(server: Server, req: http.IncomingMessage, res: http.Se
res.end('Method not allowed'); res.end('Method not allowed');
} }
async function handleStreamable(server: Server, req: http.IncomingMessage, res: http.ServerResponse, sessions: Map<string, StreamableHTTPServerTransport>) { async function handleStreamable(serverBackendFactory: ServerBackendFactory, req: http.IncomingMessage, res: http.ServerResponse, sessions: Map<string, StreamableHTTPServerTransport>) {
const sessionId = req.headers['mcp-session-id'] as string | undefined; const sessionId = req.headers['mcp-session-id'] as string | undefined;
if (sessionId) { if (sessionId) {
const transport = sessions.get(sessionId); const transport = sessions.get(sessionId);
@ -80,7 +89,7 @@ async function handleStreamable(server: Server, req: http.IncomingMessage, res:
sessionIdGenerator: () => crypto.randomUUID(), sessionIdGenerator: () => crypto.randomUUID(),
onsessioninitialized: async sessionId => { onsessioninitialized: async sessionId => {
testDebug(`create http session: ${transport.sessionId}`); testDebug(`create http session: ${transport.sessionId}`);
await server.createConnection(transport); await mcpServer.connect(serverBackendFactory, transport);
sessions.set(sessionId, transport); sessions.set(sessionId, transport);
} }
}); });
@ -100,28 +109,15 @@ async function handleStreamable(server: Server, req: http.IncomingMessage, res:
res.end('Invalid request'); res.end('Invalid request');
} }
export async function startHttpServer(config: { host?: string, port?: number }): Promise<http.Server> { function startHttpTransport(httpServer: http.Server, serverBackendFactory: ServerBackendFactory) {
const { host, port } = config;
const httpServer = http.createServer();
await new Promise<void>((resolve, reject) => {
httpServer.on('error', reject);
httpServer.listen(port, host, () => {
resolve();
httpServer.removeListener('error', reject);
});
});
return httpServer;
}
export function startHttpTransport(httpServer: http.Server, mcpServer: Server) {
const sseSessions = new Map(); const sseSessions = new Map();
const streamableSessions = new Map(); const streamableSessions = new Map();
httpServer.on('request', async (req, res) => { httpServer.on('request', async (req, res) => {
const url = new URL(`http://localhost${req.url}`); const url = new URL(`http://localhost${req.url}`);
if (url.pathname.startsWith('/sse')) if (url.pathname.startsWith('/sse'))
await handleSSE(mcpServer, req, res, url, sseSessions); await handleSSE(serverBackendFactory, req, res, url, sseSessions);
else else
await handleStreamable(mcpServer, req, res, streamableSessions); await handleStreamable(serverBackendFactory, req, res, streamableSessions);
}); });
const url = httpAddressToString(httpServer.address()); const url = httpAddressToString(httpServer.address());
const message = [ const message = [
@ -139,14 +135,3 @@ export function startHttpTransport(httpServer: http.Server, mcpServer: Server) {
// eslint-disable-next-line no-console // eslint-disable-next-line no-console
console.error(message); console.error(message);
} }
export function httpAddressToString(address: string | AddressInfo | null): string {
assert(address, 'Could not bind server socket');
if (typeof address === 'string')
return address;
const resolvedPort = address.port;
let resolvedHost = address.family === 'IPv4' ? address.address : `[${address.address}]`;
if (resolvedHost === '0.0.0.0' || resolvedHost === '[::]')
resolvedHost = 'localhost';
return `http://${resolvedHost}:${resolvedPort}`;
}

View File

@ -18,12 +18,13 @@ import { program, Option } from 'commander';
// @ts-ignore // @ts-ignore
import { startTraceViewerServer } from 'playwright-core/lib/server'; import { startTraceViewerServer } from 'playwright-core/lib/server';
import { startHttpServer, startHttpTransport, startStdioTransport } from './transport.js'; import * as mcpTransport from './mcp/transport.js';
import { commaSeparatedList, resolveCLIConfig, semicolonSeparatedList } from './config.js'; import { commaSeparatedList, resolveCLIConfig, semicolonSeparatedList } from './config.js';
import { Server } from './server.js';
import { packageJSON } from './package.js'; import { packageJSON } from './package.js';
import { runWithExtension } from './extension/main.js'; import { runWithExtension } from './extension/main.js';
import { filteredTools } from './tools.js'; import { BrowserServerBackend } from './browserServerBackend.js';
import { Context } from './context.js';
import { contextFactory } from './browserContextFactory.js';
program program
.version('Version ' + packageJSON.version) .version('Version ' + packageJSON.version)
@ -56,6 +57,8 @@ program
.addOption(new Option('--extension', 'Connect to a running browser instance (Edge/Chrome only). Requires the "Playwright MCP Bridge" browser extension to be installed.').hideHelp()) .addOption(new Option('--extension', 'Connect to a running browser instance (Edge/Chrome only). Requires the "Playwright MCP Bridge" browser extension to be installed.').hideHelp())
.addOption(new Option('--vision', 'Legacy option, use --caps=vision instead').hideHelp()) .addOption(new Option('--vision', 'Legacy option, use --caps=vision instead').hideHelp())
.action(async options => { .action(async options => {
setupExitWatchdog();
if (options.vision) { if (options.vision) {
// eslint-disable-next-line no-console // eslint-disable-next-line no-console
console.error('The --vision option is deprecated, use --caps=vision instead'); console.error('The --vision option is deprecated, use --caps=vision instead');
@ -68,15 +71,9 @@ program
return; return;
} }
const server = new Server(config, filteredTools(config)); const browserContextFactory = contextFactory(config.browser);
server.setupExitWatchdog(); const serverBackendFactory = () => new BrowserServerBackend(config, browserContextFactory);
await mcpTransport.start(serverBackendFactory, config.server);
if (config.server.port !== undefined) {
const httpServer = await startHttpServer(config.server);
startHttpTransport(httpServer, server);
} else {
await startStdioTransport(server);
}
if (config.saveTrace) { if (config.saveTrace) {
const server = await startTraceViewerServer(); const server = await startTraceViewerServer();
@ -87,4 +84,20 @@ program
} }
}); });
function setupExitWatchdog() {
let isExiting = false;
const handleExit = async () => {
if (isExiting)
return;
isExiting = true;
setTimeout(() => process.exit(0), 15000);
await Context.disposeAll();
process.exit(0);
};
process.stdin.on('close', handleExit);
process.on('SIGINT', handleExit);
process.on('SIGTERM', handleExit);
}
void program.parseAsync(process.argv); void program.parseAsync(process.argv);

View File

@ -1,59 +0,0 @@
/**
* Copyright (c) Microsoft Corporation.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { createMCPServer } from './connection.js';
import { Context } from './context.js';
import { contextFactory as defaultContextFactory } from './browserContextFactory.js';
import type { FullConfig } from './config.js';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
import type { BrowserContextFactory } from './browserContextFactory.js';
import type { Tool } from './tools/tool.js';
export class Server {
readonly config: FullConfig;
private _browserConfig: FullConfig['browser'];
private _contextFactory: BrowserContextFactory;
readonly tools: Tool<any>[];
constructor(config: FullConfig, tools: Tool<any>[], contextFactory?: BrowserContextFactory) {
this.config = config;
this.tools = tools;
this._browserConfig = config.browser;
this._contextFactory = contextFactory ?? defaultContextFactory(this._browserConfig);
}
async createConnection(transport: Transport): Promise<void> {
const server = await createMCPServer(this.config, this.tools, this._contextFactory);
await server.connect(transport);
}
setupExitWatchdog() {
let isExiting = false;
const handleExit = async () => {
if (isExiting)
return;
isExiting = true;
setTimeout(() => process.exit(0), 15000);
await Context.disposeAll();
process.exit(0);
};
process.stdin.on('close', handleExit);
process.on('SIGINT', handleExit);
process.on('SIGTERM', handleExit);
}
}