chore: refactor server, prepare for browser reuse (#490)

This commit is contained in:
Pavel Feldman 2025-05-28 16:55:47 -07:00 committed by GitHub
parent 3cd74a824a
commit 54ed7c3200
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 105 additions and 118 deletions

View File

@ -14,22 +14,22 @@
* limitations under the License.
*/
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
import { Server as McpServer } from '@modelcontextprotocol/sdk/server/index.js';
import { CallToolRequestSchema, ListToolsRequestSchema, Tool as McpTool } from '@modelcontextprotocol/sdk/types.js';
import { zodToJsonSchema } from 'zod-to-json-schema';
import { Context, packageJSON } from './context.js';
import { Context } from './context.js';
import { snapshotTools, visionTools } from './tools.js';
import { packageJSON } from './package.js';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
import { FullConfig } from './config.js';
export async function createConnection(config: FullConfig): Promise<Connection> {
export function createConnection(config: FullConfig): Connection {
const allTools = config.vision ? visionTools : snapshotTools;
const tools = allTools.filter(tool => !config.capabilities || tool.capability === 'core' || config.capabilities.includes(tool.capability));
const context = new Context(tools, config);
const server = new Server({ name: 'Playwright', version: packageJSON.version }, {
const server = new McpServer({ name: 'Playwright', version: packageJSON.version }, {
capabilities: {
tools: {},
}
@ -74,25 +74,19 @@ export async function createConnection(config: FullConfig): Promise<Connection>
}
});
const connection = new Connection(server, context);
return connection;
return new Connection(server, context);
}
export class Connection {
readonly server: Server;
readonly server: McpServer;
readonly context: Context;
constructor(server: Server, context: Context) {
constructor(server: McpServer, context: Context) {
this.server = server;
this.context = context;
}
async connect(transport: Transport) {
await this.server.connect(transport);
await new Promise<void>(resolve => {
this.server.oninitialized = () => resolve();
});
this.context.clientVersion = this.server.getClientVersion();
this.server.oninitialized = () => {
this.context.clientVersion = this.server.getClientVersion();
};
}
async close() {

View File

@ -15,7 +15,6 @@
*/
import fs from 'node:fs';
import url from 'node:url';
import os from 'node:os';
import path from 'node:path';
@ -416,6 +415,3 @@ async function createUserDataDir(browserConfig: FullConfig['browser']) {
await fs.promises.mkdir(result, { recursive: true });
return result;
}
const __filename = url.fileURLToPath(import.meta.url);
export const packageJSON = JSON.parse(fs.readFileSync(path.join(path.dirname(__filename), '..', 'package.json'), 'utf8'));

22
src/package.ts Normal file
View File

@ -0,0 +1,22 @@
/**
* Copyright (c) Microsoft Corporation.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import fs from 'node:fs';
import url from 'node:url';
import path from 'node:path';
const __filename = url.fileURLToPath(import.meta.url);
export const packageJSON = JSON.parse(fs.readFileSync(path.join(path.dirname(__filename), '..', 'package.json'), 'utf8'));

View File

@ -15,14 +15,13 @@
*/
import { program } from 'commander';
import { startHttpTransport, startStdioTransport } from './transport.js';
import { resolveCLIConfig } from './config.js';
// @ts-ignore
import { startTraceViewerServer } from 'playwright-core/lib/server';
import type { Connection } from './connection.js';
import { packageJSON } from './context.js';
import { startHttpTransport, startStdioTransport } from './transport.js';
import { resolveCLIConfig } from './config.js';
import { Server } from './server.js';
import { packageJSON } from './package.js';
program
.version('Version ' + packageJSON.version)
@ -54,13 +53,13 @@ program
.option('--vision', 'Run server that uses screenshots (Aria snapshots are used by default)')
.action(async options => {
const config = await resolveCLIConfig(options);
const connectionList: Connection[] = [];
setupExitWatchdog(connectionList);
const server = new Server(config);
server.setupExitWatchdog();
if (options.port)
startHttpTransport(config, +options.port, options.host, connectionList);
startHttpTransport(server, +options.port, options.host);
else
await startStdioTransport(config, connectionList);
await startStdioTransport(server);
if (config.saveTrace) {
const server = await startTraceViewerServer();
@ -71,21 +70,8 @@ program
}
});
function setupExitWatchdog(connectionList: Connection[]) {
const handleExit = async () => {
setTimeout(() => process.exit(0), 15000);
for (const connection of connectionList)
await connection.close();
process.exit(0);
};
process.stdin.on('close', handleExit);
process.on('SIGINT', handleExit);
process.on('SIGTERM', handleExit);
}
function semicolonSeparatedList(value: string): string[] {
return value.split(';').map(v => v.trim());
}
program.parse(process.argv);
void program.parseAsync(process.argv);

49
src/server.ts Normal file
View File

@ -0,0 +1,49 @@
/**
* Copyright (c) Microsoft Corporation.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { createConnection } from './connection.js';
import type { FullConfig } from './config.js';
import type { Connection } from './connection.js';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
export class Server {
readonly config: FullConfig;
private _connectionList: Connection[] = [];
constructor(config: FullConfig) {
this.config = config;
}
async createConnection(transport: Transport): Promise<Connection> {
const connection = createConnection(this.config);
this._connectionList.push(connection);
await connection.server.connect(transport);
return connection;
}
setupExitWatchdog() {
const handleExit = async () => {
setTimeout(() => process.exit(0), 15000);
await Promise.all(this._connectionList.map(connection => connection.close()));
process.exit(0);
};
process.stdin.on('close', handleExit);
process.on('SIGINT', handleExit);
process.on('SIGTERM', handleExit);
}
}

View File

@ -22,18 +22,13 @@ import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js';
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js';
import { createConnection } from './connection.js';
import type { Server } from './server.js';
import type { Connection } from './connection.js';
import type { FullConfig } from './config.js';
export async function startStdioTransport(config: FullConfig, connectionList: Connection[]) {
const connection = await createConnection(config);
await connection.connect(new StdioServerTransport());
connectionList.push(connection);
export async function startStdioTransport(server: Server) {
await server.createConnection(new StdioServerTransport());
}
async function handleSSE(config: FullConfig, req: http.IncomingMessage, res: http.ServerResponse, url: URL, sessions: Map<string, SSEServerTransport>, connectionList: Connection[]) {
async function handleSSE(server: Server, req: http.IncomingMessage, res: http.ServerResponse, url: URL, sessions: Map<string, SSEServerTransport>) {
if (req.method === 'POST') {
const sessionId = url.searchParams.get('sessionId');
if (!sessionId) {
@ -51,15 +46,11 @@ async function handleSSE(config: FullConfig, req: http.IncomingMessage, res: htt
} else if (req.method === 'GET') {
const transport = new SSEServerTransport('/sse', res);
sessions.set(transport.sessionId, transport);
const connection = await createConnection(config);
await connection.connect(transport);
connectionList.push(connection);
const connection = await server.createConnection(transport);
res.on('close', () => {
sessions.delete(transport.sessionId);
connection.close().catch(e => {
// eslint-disable-next-line no-console
console.error(e);
});
// eslint-disable-next-line no-console
void connection.close().catch(e => console.error(e));
});
return;
}
@ -68,7 +59,7 @@ async function handleSSE(config: FullConfig, req: http.IncomingMessage, res: htt
res.end('Method not allowed');
}
async function handleStreamable(config: FullConfig, req: http.IncomingMessage, res: http.ServerResponse, sessions: Map<string, StreamableHTTPServerTransport>, connectionList: Connection[]) {
async function handleStreamable(server: Server, req: http.IncomingMessage, res: http.ServerResponse, sessions: Map<string, StreamableHTTPServerTransport>) {
const sessionId = req.headers['mcp-session-id'] as string | undefined;
if (sessionId) {
const transport = sessions.get(sessionId);
@ -91,12 +82,8 @@ async function handleStreamable(config: FullConfig, req: http.IncomingMessage, r
if (transport.sessionId)
sessions.delete(transport.sessionId);
};
const connection = await createConnection(config);
connectionList.push(connection);
await Promise.all([
connection.connect(transport),
transport.handleRequest(req, res),
]);
await server.createConnection(transport);
await transport.handleRequest(req, res);
return;
}
@ -104,15 +91,15 @@ async function handleStreamable(config: FullConfig, req: http.IncomingMessage, r
res.end('Invalid request');
}
export function startHttpTransport(config: FullConfig, port: number, hostname: string | undefined, connectionList: Connection[]) {
export function startHttpTransport(server: Server, port: number, hostname: string | undefined) {
const sseSessions = new Map<string, SSEServerTransport>();
const streamableSessions = new Map<string, StreamableHTTPServerTransport>();
const httpServer = http.createServer(async (req, res) => {
const url = new URL(`http://localhost${req.url}`);
if (url.pathname.startsWith('/mcp'))
await handleStreamable(config, req, res, streamableSessions, connectionList);
await handleStreamable(server, req, res, streamableSessions);
else
await handleSSE(config, req, res, url, sseSessions, connectionList);
await handleSSE(server, req, res, url, sseSessions);
});
httpServer.listen(port, hostname, () => {
const address = httpServer.address();

View File

@ -15,17 +15,12 @@
*/
import url from 'node:url';
import http from 'node:http';
import { spawn } from 'node:child_process';
import path from 'node:path';
import type { AddressInfo } from 'node:net';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js';
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { createConnection } from '@playwright/mcp';
import { test as baseTest, expect } from './fixtures.js';
// NOTE: Can be removed when we drop Node.js 18 support and changed to import.meta.filename.
@ -55,6 +50,7 @@ test('sse transport', async ({ serverEndpoint }) => {
const client = new Client({ name: 'test', version: '1.0.0' });
await client.connect(transport);
await client.ping();
await client.close();
});
test('streamable http transport', async ({ serverEndpoint }) => {
@ -64,46 +60,3 @@ test('streamable http transport', async ({ serverEndpoint }) => {
await client.ping();
expect(transport.sessionId, 'has session support').toBeDefined();
});
test('sse transport via public API', async ({ server }, testInfo) => {
const userDataDir = testInfo.outputPath('user-data-dir');
const sessions = new Map<string, SSEServerTransport>();
const mcpServer = http.createServer(async (req, res) => {
if (req.method === 'GET') {
const connection = await createConnection({
browser: {
userDataDir,
launchOptions: { headless: true }
},
});
const transport = new SSEServerTransport('/sse', res);
sessions.set(transport.sessionId, transport);
await connection.connect(transport);
} else if (req.method === 'POST') {
const url = new URL(`http://localhost${req.url}`);
const sessionId = url.searchParams.get('sessionId');
if (!sessionId) {
res.statusCode = 400;
return res.end('Missing sessionId');
}
const transport = sessions.get(sessionId);
if (!transport) {
res.statusCode = 404;
return res.end('Session not found');
}
void transport.handlePostMessage(req, res);
}
});
await new Promise<void>(resolve => mcpServer.listen(0, () => resolve()));
const serverUrl = `http://localhost:${(mcpServer.address() as AddressInfo).port}/sse`;
const transport = new SSEClientTransport(new URL(serverUrl));
const client = new Client({ name: 'test', version: '1.0.0' });
await client.connect(transport);
await client.ping();
expect(await client.callTool({
name: 'browser_navigate',
arguments: { url: server.HELLO_WORLD },
})).toContainTextContent(`- generic [ref=e1]: Hello, world!`);
await client.close();
mcpServer.close();
});