chore: ping client and disconnect on connection termination (#764)

This commit is contained in:
Pavel Feldman 2025-07-25 12:17:51 -07:00 committed by GitHub
parent 26a2a6fc83
commit a9b9fb85da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 39 additions and 13 deletions

View File

@ -27,7 +27,7 @@ import type { Server } from '@modelcontextprotocol/sdk/server/index.js';
export async function createConnection(userConfig: Config = {}, contextGetter?: () => Promise<BrowserContext>): Promise<Server> { export async function createConnection(userConfig: Config = {}, contextGetter?: () => Promise<BrowserContext>): Promise<Server> {
const config = await resolveConfig(userConfig); const config = await resolveConfig(userConfig);
const factory = contextGetter ? new SimpleBrowserContextFactory(contextGetter) : contextFactory(config.browser); const factory = contextGetter ? new SimpleBrowserContextFactory(contextGetter) : contextFactory(config.browser);
return mcpServer.createServer(new BrowserServerBackend(config, factory)); return mcpServer.createServer(new BrowserServerBackend(config, factory), false);
} }
class SimpleBrowserContextFactory implements BrowserContextFactory { class SimpleBrowserContextFactory implements BrowserContextFactory {

View File

@ -46,7 +46,7 @@ export class Context {
static async create(config: FullConfig) { static async create(config: FullConfig) {
const client = new Client({ name: 'Playwright Proxy', version: '1.0.0' }); const client = new Client({ name: 'Playwright Proxy', version: '1.0.0' });
const browserContextFactory = contextFactory(config.browser); const browserContextFactory = contextFactory(config.browser);
const server = mcpServer.createServer(new BrowserServerBackend(config, browserContextFactory)); const server = mcpServer.createServer(new BrowserServerBackend(config, browserContextFactory), false);
await client.connect(new InProcessTransport(server)); await client.connect(new InProcessTransport(server));
await client.ping(); await client.ping();
return new Context(config, client); return new Context(config, client);

View File

@ -51,14 +51,14 @@ export interface ServerBackend {
export type ServerBackendFactory = () => ServerBackend; export type ServerBackendFactory = () => ServerBackend;
export async function connect(serverBackendFactory: ServerBackendFactory, transport: Transport) { export async function connect(serverBackendFactory: ServerBackendFactory, transport: Transport, runHeartbeat: boolean) {
const backend = serverBackendFactory(); const backend = serverBackendFactory();
await backend.initialize?.(); await backend.initialize?.();
const server = createServer(backend); const server = createServer(backend, runHeartbeat);
await server.connect(transport); await server.connect(transport);
} }
export function createServer(backend: ServerBackend): Server { export function createServer(backend: ServerBackend, runHeartbeat: boolean): Server {
const server = new Server({ name: backend.name, version: backend.version }, { const server = new Server({ name: backend.name, version: backend.version }, {
capabilities: { capabilities: {
tools: {}, tools: {},
@ -80,7 +80,13 @@ export function createServer(backend: ServerBackend): Server {
})) }; })) };
}); });
let heartbeatRunning = false;
server.setRequestHandler(CallToolRequestSchema, async request => { server.setRequestHandler(CallToolRequestSchema, async request => {
if (runHeartbeat && !heartbeatRunning) {
heartbeatRunning = true;
startHeartbeat(server);
}
const errorResult = (...messages: string[]) => ({ const errorResult = (...messages: string[]) => ({
content: [{ type: 'text', text: messages.join('\n') }], content: [{ type: 'text', text: messages.join('\n') }],
isError: true, isError: true,
@ -96,10 +102,30 @@ export function createServer(backend: ServerBackend): Server {
} }
}); });
if (backend.serverInitialized) addServerListener(server, 'initialized', () => backend.serverInitialized?.(server.getClientVersion()));
server.oninitialized = () => backend.serverInitialized!(server.getClientVersion()); addServerListener(server, 'close', () => backend.serverClosed?.());
if (backend.serverClosed)
server.onclose = () => backend.serverClosed!();
return server; return server;
} }
const startHeartbeat = (server: Server) => {
const beat = () => {
Promise.race([
server.ping(),
new Promise((_, reject) => setTimeout(() => reject(new Error('ping timeout')), 5000)),
]).then(() => {
setTimeout(beat, 3000);
}).catch(() => {
void server.close();
});
};
beat();
};
function addServerListener(server: Server, event: 'close' | 'initialized', listener: () => void) {
const oldListener = server[`on${event}`];
server[`on${event}`] = () => {
oldListener?.();
listener();
};
}

View File

@ -36,7 +36,7 @@ export async function start(serverBackendFactory: ServerBackendFactory, options:
} }
async function startStdioTransport(serverBackendFactory: ServerBackendFactory) { async function startStdioTransport(serverBackendFactory: ServerBackendFactory) {
await mcpServer.connect(serverBackendFactory, new StdioServerTransport()); await mcpServer.connect(serverBackendFactory, new StdioServerTransport(), false);
} }
const testDebug = debug('pw:mcp:test'); const testDebug = debug('pw:mcp:test');
@ -60,7 +60,7 @@ async function handleSSE(serverBackendFactory: ServerBackendFactory, req: http.I
const transport = new SSEServerTransport('/sse', res); const transport = new SSEServerTransport('/sse', res);
sessions.set(transport.sessionId, transport); sessions.set(transport.sessionId, transport);
testDebug(`create SSE session: ${transport.sessionId}`); testDebug(`create SSE session: ${transport.sessionId}`);
await mcpServer.connect(serverBackendFactory, transport); await mcpServer.connect(serverBackendFactory, transport, false);
res.on('close', () => { res.on('close', () => {
testDebug(`delete SSE session: ${transport.sessionId}`); testDebug(`delete SSE session: ${transport.sessionId}`);
sessions.delete(transport.sessionId); sessions.delete(transport.sessionId);
@ -89,7 +89,7 @@ async function handleStreamable(serverBackendFactory: ServerBackendFactory, req:
sessionIdGenerator: () => crypto.randomUUID(), sessionIdGenerator: () => crypto.randomUUID(),
onsessioninitialized: async sessionId => { onsessioninitialized: async sessionId => {
testDebug(`create http session: ${transport.sessionId}`); testDebug(`create http session: ${transport.sessionId}`);
await mcpServer.connect(serverBackendFactory, transport); await mcpServer.connect(serverBackendFactory, transport, true);
sessions.set(sessionId, transport); sessions.set(sessionId, transport);
} }
}); });