From 54ed7c3200e7dc0596c3e70c31aa3cb5c037419d Mon Sep 17 00:00:00 2001 From: Pavel Feldman Date: Wed, 28 May 2025 16:55:47 -0700 Subject: [PATCH] chore: refactor server, prepare for browser reuse (#490) --- src/connection.ts | 28 +++++++++++---------------- src/context.ts | 4 ---- src/package.ts | 22 +++++++++++++++++++++ src/program.ts | 32 +++++++++---------------------- src/server.ts | 49 +++++++++++++++++++++++++++++++++++++++++++++++ src/transport.ts | 39 +++++++++++++------------------------ tests/sse.spec.ts | 49 +---------------------------------------------- 7 files changed, 105 insertions(+), 118 deletions(-) create mode 100644 src/package.ts create mode 100644 src/server.ts diff --git a/src/connection.ts b/src/connection.ts index a315246..8e49506 100644 --- a/src/connection.ts +++ b/src/connection.ts @@ -14,22 +14,22 @@ * limitations under the License. */ -import { Server } from '@modelcontextprotocol/sdk/server/index.js'; +import { Server as McpServer } from '@modelcontextprotocol/sdk/server/index.js'; import { CallToolRequestSchema, ListToolsRequestSchema, Tool as McpTool } from '@modelcontextprotocol/sdk/types.js'; import { zodToJsonSchema } from 'zod-to-json-schema'; -import { Context, packageJSON } from './context.js'; +import { Context } from './context.js'; import { snapshotTools, visionTools } from './tools.js'; +import { packageJSON } from './package.js'; -import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; import { FullConfig } from './config.js'; -export async function createConnection(config: FullConfig): Promise { +export function createConnection(config: FullConfig): Connection { const allTools = config.vision ? visionTools : snapshotTools; const tools = allTools.filter(tool => !config.capabilities || tool.capability === 'core' || config.capabilities.includes(tool.capability)); const context = new Context(tools, config); - const server = new Server({ name: 'Playwright', version: packageJSON.version }, { + const server = new McpServer({ name: 'Playwright', version: packageJSON.version }, { capabilities: { tools: {}, } @@ -74,25 +74,19 @@ export async function createConnection(config: FullConfig): Promise } }); - const connection = new Connection(server, context); - return connection; + return new Connection(server, context); } export class Connection { - readonly server: Server; + readonly server: McpServer; readonly context: Context; - constructor(server: Server, context: Context) { + constructor(server: McpServer, context: Context) { this.server = server; this.context = context; - } - - async connect(transport: Transport) { - await this.server.connect(transport); - await new Promise(resolve => { - this.server.oninitialized = () => resolve(); - }); - this.context.clientVersion = this.server.getClientVersion(); + this.server.oninitialized = () => { + this.context.clientVersion = this.server.getClientVersion(); + }; } async close() { diff --git a/src/context.ts b/src/context.ts index 8da2dbb..01bd537 100644 --- a/src/context.ts +++ b/src/context.ts @@ -15,7 +15,6 @@ */ import fs from 'node:fs'; -import url from 'node:url'; import os from 'node:os'; import path from 'node:path'; @@ -416,6 +415,3 @@ async function createUserDataDir(browserConfig: FullConfig['browser']) { await fs.promises.mkdir(result, { recursive: true }); return result; } - -const __filename = url.fileURLToPath(import.meta.url); -export const packageJSON = JSON.parse(fs.readFileSync(path.join(path.dirname(__filename), '..', 'package.json'), 'utf8')); diff --git a/src/package.ts b/src/package.ts new file mode 100644 index 0000000..a6c7019 --- /dev/null +++ b/src/package.ts @@ -0,0 +1,22 @@ +/** + * Copyright (c) Microsoft Corporation. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import fs from 'node:fs'; +import url from 'node:url'; +import path from 'node:path'; + +const __filename = url.fileURLToPath(import.meta.url); +export const packageJSON = JSON.parse(fs.readFileSync(path.join(path.dirname(__filename), '..', 'package.json'), 'utf8')); diff --git a/src/program.ts b/src/program.ts index 8638342..2c1c7a4 100644 --- a/src/program.ts +++ b/src/program.ts @@ -15,14 +15,13 @@ */ import { program } from 'commander'; - -import { startHttpTransport, startStdioTransport } from './transport.js'; -import { resolveCLIConfig } from './config.js'; // @ts-ignore import { startTraceViewerServer } from 'playwright-core/lib/server'; -import type { Connection } from './connection.js'; -import { packageJSON } from './context.js'; +import { startHttpTransport, startStdioTransport } from './transport.js'; +import { resolveCLIConfig } from './config.js'; +import { Server } from './server.js'; +import { packageJSON } from './package.js'; program .version('Version ' + packageJSON.version) @@ -54,13 +53,13 @@ program .option('--vision', 'Run server that uses screenshots (Aria snapshots are used by default)') .action(async options => { const config = await resolveCLIConfig(options); - const connectionList: Connection[] = []; - setupExitWatchdog(connectionList); + const server = new Server(config); + server.setupExitWatchdog(); if (options.port) - startHttpTransport(config, +options.port, options.host, connectionList); + startHttpTransport(server, +options.port, options.host); else - await startStdioTransport(config, connectionList); + await startStdioTransport(server); if (config.saveTrace) { const server = await startTraceViewerServer(); @@ -71,21 +70,8 @@ program } }); -function setupExitWatchdog(connectionList: Connection[]) { - const handleExit = async () => { - setTimeout(() => process.exit(0), 15000); - for (const connection of connectionList) - await connection.close(); - process.exit(0); - }; - - process.stdin.on('close', handleExit); - process.on('SIGINT', handleExit); - process.on('SIGTERM', handleExit); -} - function semicolonSeparatedList(value: string): string[] { return value.split(';').map(v => v.trim()); } -program.parse(process.argv); +void program.parseAsync(process.argv); diff --git a/src/server.ts b/src/server.ts new file mode 100644 index 0000000..b5cce8b --- /dev/null +++ b/src/server.ts @@ -0,0 +1,49 @@ +/** + * Copyright (c) Microsoft Corporation. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { createConnection } from './connection.js'; + +import type { FullConfig } from './config.js'; +import type { Connection } from './connection.js'; +import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; + +export class Server { + readonly config: FullConfig; + private _connectionList: Connection[] = []; + + constructor(config: FullConfig) { + this.config = config; + } + + async createConnection(transport: Transport): Promise { + const connection = createConnection(this.config); + this._connectionList.push(connection); + await connection.server.connect(transport); + return connection; + } + + setupExitWatchdog() { + const handleExit = async () => { + setTimeout(() => process.exit(0), 15000); + await Promise.all(this._connectionList.map(connection => connection.close())); + process.exit(0); + }; + + process.stdin.on('close', handleExit); + process.on('SIGINT', handleExit); + process.on('SIGTERM', handleExit); + } +} diff --git a/src/transport.ts b/src/transport.ts index d75bbc4..6598e5a 100644 --- a/src/transport.ts +++ b/src/transport.ts @@ -22,18 +22,13 @@ import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js'; import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'; -import { createConnection } from './connection.js'; +import type { Server } from './server.js'; -import type { Connection } from './connection.js'; -import type { FullConfig } from './config.js'; - -export async function startStdioTransport(config: FullConfig, connectionList: Connection[]) { - const connection = await createConnection(config); - await connection.connect(new StdioServerTransport()); - connectionList.push(connection); +export async function startStdioTransport(server: Server) { + await server.createConnection(new StdioServerTransport()); } -async function handleSSE(config: FullConfig, req: http.IncomingMessage, res: http.ServerResponse, url: URL, sessions: Map, connectionList: Connection[]) { +async function handleSSE(server: Server, req: http.IncomingMessage, res: http.ServerResponse, url: URL, sessions: Map) { if (req.method === 'POST') { const sessionId = url.searchParams.get('sessionId'); if (!sessionId) { @@ -51,15 +46,11 @@ async function handleSSE(config: FullConfig, req: http.IncomingMessage, res: htt } else if (req.method === 'GET') { const transport = new SSEServerTransport('/sse', res); sessions.set(transport.sessionId, transport); - const connection = await createConnection(config); - await connection.connect(transport); - connectionList.push(connection); + const connection = await server.createConnection(transport); res.on('close', () => { sessions.delete(transport.sessionId); - connection.close().catch(e => { - // eslint-disable-next-line no-console - console.error(e); - }); + // eslint-disable-next-line no-console + void connection.close().catch(e => console.error(e)); }); return; } @@ -68,7 +59,7 @@ async function handleSSE(config: FullConfig, req: http.IncomingMessage, res: htt res.end('Method not allowed'); } -async function handleStreamable(config: FullConfig, req: http.IncomingMessage, res: http.ServerResponse, sessions: Map, connectionList: Connection[]) { +async function handleStreamable(server: Server, req: http.IncomingMessage, res: http.ServerResponse, sessions: Map) { const sessionId = req.headers['mcp-session-id'] as string | undefined; if (sessionId) { const transport = sessions.get(sessionId); @@ -91,12 +82,8 @@ async function handleStreamable(config: FullConfig, req: http.IncomingMessage, r if (transport.sessionId) sessions.delete(transport.sessionId); }; - const connection = await createConnection(config); - connectionList.push(connection); - await Promise.all([ - connection.connect(transport), - transport.handleRequest(req, res), - ]); + await server.createConnection(transport); + await transport.handleRequest(req, res); return; } @@ -104,15 +91,15 @@ async function handleStreamable(config: FullConfig, req: http.IncomingMessage, r res.end('Invalid request'); } -export function startHttpTransport(config: FullConfig, port: number, hostname: string | undefined, connectionList: Connection[]) { +export function startHttpTransport(server: Server, port: number, hostname: string | undefined) { const sseSessions = new Map(); const streamableSessions = new Map(); const httpServer = http.createServer(async (req, res) => { const url = new URL(`http://localhost${req.url}`); if (url.pathname.startsWith('/mcp')) - await handleStreamable(config, req, res, streamableSessions, connectionList); + await handleStreamable(server, req, res, streamableSessions); else - await handleSSE(config, req, res, url, sseSessions, connectionList); + await handleSSE(server, req, res, url, sseSessions); }); httpServer.listen(port, hostname, () => { const address = httpServer.address(); diff --git a/tests/sse.spec.ts b/tests/sse.spec.ts index 266d456..f1c3aa7 100644 --- a/tests/sse.spec.ts +++ b/tests/sse.spec.ts @@ -15,17 +15,12 @@ */ import url from 'node:url'; -import http from 'node:http'; import { spawn } from 'node:child_process'; import path from 'node:path'; -import type { AddressInfo } from 'node:net'; import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; -import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js'; import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; import { Client } from '@modelcontextprotocol/sdk/client/index.js'; -import { createConnection } from '@playwright/mcp'; - import { test as baseTest, expect } from './fixtures.js'; // NOTE: Can be removed when we drop Node.js 18 support and changed to import.meta.filename. @@ -55,6 +50,7 @@ test('sse transport', async ({ serverEndpoint }) => { const client = new Client({ name: 'test', version: '1.0.0' }); await client.connect(transport); await client.ping(); + await client.close(); }); test('streamable http transport', async ({ serverEndpoint }) => { @@ -64,46 +60,3 @@ test('streamable http transport', async ({ serverEndpoint }) => { await client.ping(); expect(transport.sessionId, 'has session support').toBeDefined(); }); - -test('sse transport via public API', async ({ server }, testInfo) => { - const userDataDir = testInfo.outputPath('user-data-dir'); - const sessions = new Map(); - const mcpServer = http.createServer(async (req, res) => { - if (req.method === 'GET') { - const connection = await createConnection({ - browser: { - userDataDir, - launchOptions: { headless: true } - }, - }); - const transport = new SSEServerTransport('/sse', res); - sessions.set(transport.sessionId, transport); - await connection.connect(transport); - } else if (req.method === 'POST') { - const url = new URL(`http://localhost${req.url}`); - const sessionId = url.searchParams.get('sessionId'); - if (!sessionId) { - res.statusCode = 400; - return res.end('Missing sessionId'); - } - const transport = sessions.get(sessionId); - if (!transport) { - res.statusCode = 404; - return res.end('Session not found'); - } - void transport.handlePostMessage(req, res); - } - }); - await new Promise(resolve => mcpServer.listen(0, () => resolve())); - const serverUrl = `http://localhost:${(mcpServer.address() as AddressInfo).port}/sse`; - const transport = new SSEClientTransport(new URL(serverUrl)); - const client = new Client({ name: 'test', version: '1.0.0' }); - await client.connect(transport); - await client.ping(); - expect(await client.callTool({ - name: 'browser_navigate', - arguments: { url: server.HELLO_WORLD }, - })).toContainTextContent(`- generic [ref=e1]: Hello, world!`); - await client.close(); - mcpServer.close(); -});