mirror of
https://github.com/microsoft/playwright-mcp.git
synced 2025-07-27 09:02:26 +08:00
chore: ping client and disconnect on connection termination (#764)
This commit is contained in:
parent
26a2a6fc83
commit
a9b9fb85da
@ -27,7 +27,7 @@ import type { Server } from '@modelcontextprotocol/sdk/server/index.js';
|
|||||||
export async function createConnection(userConfig: Config = {}, contextGetter?: () => Promise<BrowserContext>): Promise<Server> {
|
export async function createConnection(userConfig: Config = {}, contextGetter?: () => Promise<BrowserContext>): Promise<Server> {
|
||||||
const config = await resolveConfig(userConfig);
|
const config = await resolveConfig(userConfig);
|
||||||
const factory = contextGetter ? new SimpleBrowserContextFactory(contextGetter) : contextFactory(config.browser);
|
const factory = contextGetter ? new SimpleBrowserContextFactory(contextGetter) : contextFactory(config.browser);
|
||||||
return mcpServer.createServer(new BrowserServerBackend(config, factory));
|
return mcpServer.createServer(new BrowserServerBackend(config, factory), false);
|
||||||
}
|
}
|
||||||
|
|
||||||
class SimpleBrowserContextFactory implements BrowserContextFactory {
|
class SimpleBrowserContextFactory implements BrowserContextFactory {
|
||||||
|
@ -46,7 +46,7 @@ export class Context {
|
|||||||
static async create(config: FullConfig) {
|
static async create(config: FullConfig) {
|
||||||
const client = new Client({ name: 'Playwright Proxy', version: '1.0.0' });
|
const client = new Client({ name: 'Playwright Proxy', version: '1.0.0' });
|
||||||
const browserContextFactory = contextFactory(config.browser);
|
const browserContextFactory = contextFactory(config.browser);
|
||||||
const server = mcpServer.createServer(new BrowserServerBackend(config, browserContextFactory));
|
const server = mcpServer.createServer(new BrowserServerBackend(config, browserContextFactory), false);
|
||||||
await client.connect(new InProcessTransport(server));
|
await client.connect(new InProcessTransport(server));
|
||||||
await client.ping();
|
await client.ping();
|
||||||
return new Context(config, client);
|
return new Context(config, client);
|
||||||
|
@ -51,14 +51,14 @@ export interface ServerBackend {
|
|||||||
|
|
||||||
export type ServerBackendFactory = () => ServerBackend;
|
export type ServerBackendFactory = () => ServerBackend;
|
||||||
|
|
||||||
export async function connect(serverBackendFactory: ServerBackendFactory, transport: Transport) {
|
export async function connect(serverBackendFactory: ServerBackendFactory, transport: Transport, runHeartbeat: boolean) {
|
||||||
const backend = serverBackendFactory();
|
const backend = serverBackendFactory();
|
||||||
await backend.initialize?.();
|
await backend.initialize?.();
|
||||||
const server = createServer(backend);
|
const server = createServer(backend, runHeartbeat);
|
||||||
await server.connect(transport);
|
await server.connect(transport);
|
||||||
}
|
}
|
||||||
|
|
||||||
export function createServer(backend: ServerBackend): Server {
|
export function createServer(backend: ServerBackend, runHeartbeat: boolean): Server {
|
||||||
const server = new Server({ name: backend.name, version: backend.version }, {
|
const server = new Server({ name: backend.name, version: backend.version }, {
|
||||||
capabilities: {
|
capabilities: {
|
||||||
tools: {},
|
tools: {},
|
||||||
@ -80,7 +80,13 @@ export function createServer(backend: ServerBackend): Server {
|
|||||||
})) };
|
})) };
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let heartbeatRunning = false;
|
||||||
server.setRequestHandler(CallToolRequestSchema, async request => {
|
server.setRequestHandler(CallToolRequestSchema, async request => {
|
||||||
|
if (runHeartbeat && !heartbeatRunning) {
|
||||||
|
heartbeatRunning = true;
|
||||||
|
startHeartbeat(server);
|
||||||
|
}
|
||||||
|
|
||||||
const errorResult = (...messages: string[]) => ({
|
const errorResult = (...messages: string[]) => ({
|
||||||
content: [{ type: 'text', text: messages.join('\n') }],
|
content: [{ type: 'text', text: messages.join('\n') }],
|
||||||
isError: true,
|
isError: true,
|
||||||
@ -96,10 +102,30 @@ export function createServer(backend: ServerBackend): Server {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
if (backend.serverInitialized)
|
addServerListener(server, 'initialized', () => backend.serverInitialized?.(server.getClientVersion()));
|
||||||
server.oninitialized = () => backend.serverInitialized!(server.getClientVersion());
|
addServerListener(server, 'close', () => backend.serverClosed?.());
|
||||||
if (backend.serverClosed)
|
|
||||||
server.onclose = () => backend.serverClosed!();
|
|
||||||
|
|
||||||
return server;
|
return server;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const startHeartbeat = (server: Server) => {
|
||||||
|
const beat = () => {
|
||||||
|
Promise.race([
|
||||||
|
server.ping(),
|
||||||
|
new Promise((_, reject) => setTimeout(() => reject(new Error('ping timeout')), 5000)),
|
||||||
|
]).then(() => {
|
||||||
|
setTimeout(beat, 3000);
|
||||||
|
}).catch(() => {
|
||||||
|
void server.close();
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
beat();
|
||||||
|
};
|
||||||
|
|
||||||
|
function addServerListener(server: Server, event: 'close' | 'initialized', listener: () => void) {
|
||||||
|
const oldListener = server[`on${event}`];
|
||||||
|
server[`on${event}`] = () => {
|
||||||
|
oldListener?.();
|
||||||
|
listener();
|
||||||
|
};
|
||||||
|
}
|
||||||
|
@ -36,7 +36,7 @@ export async function start(serverBackendFactory: ServerBackendFactory, options:
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function startStdioTransport(serverBackendFactory: ServerBackendFactory) {
|
async function startStdioTransport(serverBackendFactory: ServerBackendFactory) {
|
||||||
await mcpServer.connect(serverBackendFactory, new StdioServerTransport());
|
await mcpServer.connect(serverBackendFactory, new StdioServerTransport(), false);
|
||||||
}
|
}
|
||||||
|
|
||||||
const testDebug = debug('pw:mcp:test');
|
const testDebug = debug('pw:mcp:test');
|
||||||
@ -60,7 +60,7 @@ async function handleSSE(serverBackendFactory: ServerBackendFactory, req: http.I
|
|||||||
const transport = new SSEServerTransport('/sse', res);
|
const transport = new SSEServerTransport('/sse', res);
|
||||||
sessions.set(transport.sessionId, transport);
|
sessions.set(transport.sessionId, transport);
|
||||||
testDebug(`create SSE session: ${transport.sessionId}`);
|
testDebug(`create SSE session: ${transport.sessionId}`);
|
||||||
await mcpServer.connect(serverBackendFactory, transport);
|
await mcpServer.connect(serverBackendFactory, transport, false);
|
||||||
res.on('close', () => {
|
res.on('close', () => {
|
||||||
testDebug(`delete SSE session: ${transport.sessionId}`);
|
testDebug(`delete SSE session: ${transport.sessionId}`);
|
||||||
sessions.delete(transport.sessionId);
|
sessions.delete(transport.sessionId);
|
||||||
@ -89,7 +89,7 @@ async function handleStreamable(serverBackendFactory: ServerBackendFactory, req:
|
|||||||
sessionIdGenerator: () => crypto.randomUUID(),
|
sessionIdGenerator: () => crypto.randomUUID(),
|
||||||
onsessioninitialized: async sessionId => {
|
onsessioninitialized: async sessionId => {
|
||||||
testDebug(`create http session: ${transport.sessionId}`);
|
testDebug(`create http session: ${transport.sessionId}`);
|
||||||
await mcpServer.connect(serverBackendFactory, transport);
|
await mcpServer.connect(serverBackendFactory, transport, true);
|
||||||
sessions.set(sessionId, transport);
|
sessions.set(sessionId, transport);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
Loading…
x
Reference in New Issue
Block a user