diff --git a/src/index.ts b/src/index.ts index e5d02b6..88c95f8 100644 --- a/src/index.ts +++ b/src/index.ts @@ -27,7 +27,7 @@ import type { Server } from '@modelcontextprotocol/sdk/server/index.js'; export async function createConnection(userConfig: Config = {}, contextGetter?: () => Promise): Promise { const config = await resolveConfig(userConfig); const factory = contextGetter ? new SimpleBrowserContextFactory(contextGetter) : contextFactory(config.browser); - return mcpServer.createServer(new BrowserServerBackend(config, factory)); + return mcpServer.createServer(new BrowserServerBackend(config, factory), false); } class SimpleBrowserContextFactory implements BrowserContextFactory { diff --git a/src/loopTools/context.ts b/src/loopTools/context.ts index 59d2b64..b1b4709 100644 --- a/src/loopTools/context.ts +++ b/src/loopTools/context.ts @@ -46,7 +46,7 @@ export class Context { static async create(config: FullConfig) { const client = new Client({ name: 'Playwright Proxy', version: '1.0.0' }); const browserContextFactory = contextFactory(config.browser); - const server = mcpServer.createServer(new BrowserServerBackend(config, browserContextFactory)); + const server = mcpServer.createServer(new BrowserServerBackend(config, browserContextFactory), false); await client.connect(new InProcessTransport(server)); await client.ping(); return new Context(config, client); diff --git a/src/mcp/server.ts b/src/mcp/server.ts index 57ba3c9..1d4e2ba 100644 --- a/src/mcp/server.ts +++ b/src/mcp/server.ts @@ -51,14 +51,14 @@ export interface ServerBackend { export type ServerBackendFactory = () => ServerBackend; -export async function connect(serverBackendFactory: ServerBackendFactory, transport: Transport) { +export async function connect(serverBackendFactory: ServerBackendFactory, transport: Transport, runHeartbeat: boolean) { const backend = serverBackendFactory(); await backend.initialize?.(); - const server = createServer(backend); + const server = createServer(backend, runHeartbeat); await server.connect(transport); } -export function createServer(backend: ServerBackend): Server { +export function createServer(backend: ServerBackend, runHeartbeat: boolean): Server { const server = new Server({ name: backend.name, version: backend.version }, { capabilities: { tools: {}, @@ -80,7 +80,13 @@ export function createServer(backend: ServerBackend): Server { })) }; }); + let heartbeatRunning = false; server.setRequestHandler(CallToolRequestSchema, async request => { + if (runHeartbeat && !heartbeatRunning) { + heartbeatRunning = true; + startHeartbeat(server); + } + const errorResult = (...messages: string[]) => ({ content: [{ type: 'text', text: messages.join('\n') }], isError: true, @@ -96,10 +102,30 @@ export function createServer(backend: ServerBackend): Server { } }); - if (backend.serverInitialized) - server.oninitialized = () => backend.serverInitialized!(server.getClientVersion()); - if (backend.serverClosed) - server.onclose = () => backend.serverClosed!(); - + addServerListener(server, 'initialized', () => backend.serverInitialized?.(server.getClientVersion())); + addServerListener(server, 'close', () => backend.serverClosed?.()); return server; } + +const startHeartbeat = (server: Server) => { + const beat = () => { + Promise.race([ + server.ping(), + new Promise((_, reject) => setTimeout(() => reject(new Error('ping timeout')), 5000)), + ]).then(() => { + setTimeout(beat, 3000); + }).catch(() => { + void server.close(); + }); + }; + + beat(); +}; + +function addServerListener(server: Server, event: 'close' | 'initialized', listener: () => void) { + const oldListener = server[`on${event}`]; + server[`on${event}`] = () => { + oldListener?.(); + listener(); + }; +} diff --git a/src/mcp/transport.ts b/src/mcp/transport.ts index 6804294..e85a66f 100644 --- a/src/mcp/transport.ts +++ b/src/mcp/transport.ts @@ -36,7 +36,7 @@ export async function start(serverBackendFactory: ServerBackendFactory, options: } async function startStdioTransport(serverBackendFactory: ServerBackendFactory) { - await mcpServer.connect(serverBackendFactory, new StdioServerTransport()); + await mcpServer.connect(serverBackendFactory, new StdioServerTransport(), false); } const testDebug = debug('pw:mcp:test'); @@ -60,7 +60,7 @@ async function handleSSE(serverBackendFactory: ServerBackendFactory, req: http.I const transport = new SSEServerTransport('/sse', res); sessions.set(transport.sessionId, transport); testDebug(`create SSE session: ${transport.sessionId}`); - await mcpServer.connect(serverBackendFactory, transport); + await mcpServer.connect(serverBackendFactory, transport, false); res.on('close', () => { testDebug(`delete SSE session: ${transport.sessionId}`); sessions.delete(transport.sessionId); @@ -89,7 +89,7 @@ async function handleStreamable(serverBackendFactory: ServerBackendFactory, req: sessionIdGenerator: () => crypto.randomUUID(), onsessioninitialized: async sessionId => { testDebug(`create http session: ${transport.sessionId}`); - await mcpServer.connect(serverBackendFactory, transport); + await mcpServer.connect(serverBackendFactory, transport, true); sessions.set(sessionId, transport); } });