From bc120baa78dc42bbabb132d916ce80d77a0d7f6e Mon Sep 17 00:00:00 2001 From: Yury Semikhatsky Date: Wed, 23 Jul 2025 17:41:15 -0700 Subject: [PATCH] chore: do not double close connection (#744) --- index.d.ts | 7 +------ src/connection.ts | 35 +++++++++++++---------------------- src/context.ts | 26 ++++++++++++++++++++++++-- src/extension/cdpRelay.ts | 4 +++- src/index.ts | 8 ++++---- src/server.ts | 15 ++++++--------- src/tools/common.ts | 2 +- src/transport.ts | 22 ++++++++-------------- 8 files changed, 60 insertions(+), 59 deletions(-) diff --git a/index.d.ts b/index.d.ts index 9ea8bcc..f7a1dde 100644 --- a/index.d.ts +++ b/index.d.ts @@ -19,10 +19,5 @@ import type { Server } from '@modelcontextprotocol/sdk/server/index.js'; import type { Config } from './config.js'; import type { BrowserContext } from 'playwright'; -export type Connection = { - server: Server; - close(): Promise; -}; - -export declare function createConnection(config?: Config, contextGetter?: () => Promise): Promise; +export declare function createConnection(config?: Config, contextGetter?: () => Promise): Promise; export {}; diff --git a/src/connection.ts b/src/connection.ts index 69fd434..d318bba 100644 --- a/src/connection.ts +++ b/src/connection.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { Server as McpServer } from '@modelcontextprotocol/sdk/server/index.js'; +import { Server } from '@modelcontextprotocol/sdk/server/index.js'; import { CallToolRequestSchema, ListToolsRequestSchema, Tool as McpTool } from '@modelcontextprotocol/sdk/types.js'; import { zodToJsonSchema } from 'zod-to-json-schema'; import { Context } from './context.js'; @@ -23,12 +23,13 @@ import { allTools } from './tools.js'; import { packageJSON } from './package.js'; import { FullConfig } from './config.js'; import { SessionLog } from './sessionLog.js'; +import { logUnhandledError } from './log.js'; import type { BrowserContextFactory } from './browserContextFactory.js'; -export async function createConnection(config: FullConfig, browserContextFactory: BrowserContextFactory): Promise { +export async function createMCPServer(config: FullConfig, browserContextFactory: BrowserContextFactory): Promise { const tools = allTools.filter(tool => tool.capability.startsWith('core') || config.capabilities?.includes(tool.capability)); const context = new Context(tools, config, browserContextFactory); - const server = new McpServer({ name: 'Playwright', version: packageJSON.version }, { + const server = new Server({ name: 'Playwright', version: packageJSON.version }, { capabilities: { tools: {}, } @@ -72,23 +73,13 @@ export async function createConnection(config: FullConfig, browserContextFactory } }); - return new Connection(server, context); -} - -export class Connection { - readonly server: McpServer; - readonly context: Context; - - constructor(server: McpServer, context: Context) { - this.server = server; - this.context = context; - this.server.oninitialized = () => { - this.context.clientVersion = this.server.getClientVersion(); - }; - } - - async close() { - await this.server.close(); - await this.context.close(); - } + server.oninitialized = () => { + context.clientVersion = server.getClientVersion(); + }; + + server.onclose = () => { + void context.dispose().catch(logUnhandledError); + }; + + return server; } diff --git a/src/context.ts b/src/context.ts index d18faa0..c8377e2 100644 --- a/src/context.ts +++ b/src/context.ts @@ -34,11 +34,19 @@ export class Context { private _currentTab: Tab | undefined; clientVersion: { name: string; version: string; } | undefined; + private static _allContexts: Set = new Set(); + private _closeBrowserContextPromise: Promise | undefined; + constructor(tools: Tool[], config: FullConfig, browserContextFactory: BrowserContextFactory) { this.tools = tools; this.config = config; this._browserContextFactory = browserContextFactory; testDebug('create context'); + Context._allContexts.add(this); + } + + static async disposeAll() { + await Promise.all([...Context._allContexts].map(context => context.dispose())); } tabs(): Tab[] { @@ -127,10 +135,17 @@ export class Context { if (this._currentTab === tab) this._currentTab = this._tabs[Math.min(index, this._tabs.length - 1)]; if (!this._tabs.length) - void this.close(); + void this.closeBrowserContext(); } - async close() { + async closeBrowserContext() { + if (!this._closeBrowserContextPromise) + this._closeBrowserContextPromise = this._closeBrowserContextImpl(); + await this._closeBrowserContextPromise; + this._closeBrowserContextPromise = undefined; + } + + private async _closeBrowserContextImpl() { if (!this._browserContextPromise) return; @@ -146,6 +161,11 @@ export class Context { }); } + async dispose() { + await this.closeBrowserContext(); + Context._allContexts.delete(this); + } + private async _setupRequestInterception(context: playwright.BrowserContext) { if (this.config.network?.allowedOrigins?.length) { await context.route('**', route => route.abort('blockedbyclient')); @@ -171,6 +191,8 @@ export class Context { } private async _setupBrowserContext(): Promise<{ browserContext: playwright.BrowserContext, close: () => Promise }> { + if (this._closeBrowserContextPromise) + throw new Error('Another browser context is being closed.'); // TODO: move to the browser context factory to make it based on isolation mode. const result = await this._browserContextFactory.createContext(this.clientVersion!); const { browserContext } = result; diff --git a/src/extension/cdpRelay.ts b/src/extension/cdpRelay.ts index c6b9bbb..5366167 100644 --- a/src/extension/cdpRelay.ts +++ b/src/extension/cdpRelay.ts @@ -307,7 +307,9 @@ class ExtensionContextFactory implements BrowserContextFactory { const browser = await this._browserPromise; return { browserContext: browser.contexts()[0], - close: async () => {} + close: async () => { + debugLogger('close() called for browser context, ignoring'); + } }; } diff --git a/src/index.ts b/src/index.ts index 417492d..1751657 100644 --- a/src/index.ts +++ b/src/index.ts @@ -14,18 +14,18 @@ * limitations under the License. */ -import { createConnection as createConnectionImpl } from './connection.js'; +import { createMCPServer } from './connection.js'; import { resolveConfig } from './config.js'; import { contextFactory } from './browserContextFactory.js'; -import type { Connection } from '../index.js'; import type { Config } from '../config.js'; import type { BrowserContext } from 'playwright'; import type { BrowserContextFactory } from './browserContextFactory.js'; +import type { Server } from '@modelcontextprotocol/sdk/server/index.js'; -export async function createConnection(userConfig: Config = {}, contextGetter?: () => Promise): Promise { +export async function createConnection(userConfig: Config = {}, contextGetter?: () => Promise): Promise { const config = await resolveConfig(userConfig); const factory = contextGetter ? new SimpleBrowserContextFactory(contextGetter) : contextFactory(config.browser); - return createConnectionImpl(config, factory); + return createMCPServer(config, factory); } class SimpleBrowserContextFactory implements BrowserContextFactory { diff --git a/src/server.ts b/src/server.ts index 4c20154..e1927a4 100644 --- a/src/server.ts +++ b/src/server.ts @@ -14,17 +14,16 @@ * limitations under the License. */ -import { createConnection } from './connection.js'; +import { createMCPServer } from './connection.js'; +import { Context } from './context.js'; import { contextFactory as defaultContextFactory } from './browserContextFactory.js'; import type { FullConfig } from './config.js'; -import type { Connection } from './connection.js'; import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; import type { BrowserContextFactory } from './browserContextFactory.js'; export class Server { readonly config: FullConfig; - private _connectionList: Connection[] = []; private _browserConfig: FullConfig['browser']; private _contextFactory: BrowserContextFactory; @@ -34,11 +33,9 @@ export class Server { this._contextFactory = contextFactory ?? defaultContextFactory(this._browserConfig); } - async createConnection(transport: Transport): Promise { - const connection = await createConnection(this.config, this._contextFactory); - this._connectionList.push(connection); - await connection.server.connect(transport); - return connection; + async createConnection(transport: Transport): Promise { + const server = await createMCPServer(this.config, this._contextFactory); + await server.connect(transport); } setupExitWatchdog() { @@ -48,7 +45,7 @@ export class Server { return; isExiting = true; setTimeout(() => process.exit(0), 15000); - await Promise.all(this._connectionList.map(connection => connection.close())); + await Context.disposeAll(); process.exit(0); }; diff --git a/src/tools/common.ts b/src/tools/common.ts index ba9847a..337f4ba 100644 --- a/src/tools/common.ts +++ b/src/tools/common.ts @@ -29,7 +29,7 @@ const close = defineTool({ }, handle: async (context, params, response) => { - await context.close(); + await context.closeBrowserContext(); response.setIncludeTabs(); response.addCode(`await page.close()`); }, diff --git a/src/transport.ts b/src/transport.ts index 48bcd9a..c34ce39 100644 --- a/src/transport.ts +++ b/src/transport.ts @@ -23,11 +23,8 @@ import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js'; import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'; -import { logUnhandledError } from './log.js'; - import type { AddressInfo } from 'node:net'; import type { Server } from './server.js'; -import type { Connection } from './connection.js'; export async function startStdioTransport(server: Server) { await server.createConnection(new StdioServerTransport()); @@ -54,11 +51,10 @@ async function handleSSE(server: Server, req: http.IncomingMessage, res: http.Se const transport = new SSEServerTransport('/sse', res); sessions.set(transport.sessionId, transport); testDebug(`create SSE session: ${transport.sessionId}`); - const connection = await server.createConnection(transport); + await server.createConnection(transport); res.on('close', () => { testDebug(`delete SSE session: ${transport.sessionId}`); sessions.delete(transport.sessionId); - void connection.close().catch(logUnhandledError); }); return; } @@ -67,10 +63,10 @@ async function handleSSE(server: Server, req: http.IncomingMessage, res: http.Se res.end('Method not allowed'); } -async function handleStreamable(server: Server, req: http.IncomingMessage, res: http.ServerResponse, sessions: Map) { +async function handleStreamable(server: Server, req: http.IncomingMessage, res: http.ServerResponse, sessions: Map) { const sessionId = req.headers['mcp-session-id'] as string | undefined; if (sessionId) { - const { transport } = sessions.get(sessionId) ?? {}; + const transport = sessions.get(sessionId); if (!transport) { res.statusCode = 404; res.end('Session not found'); @@ -84,18 +80,16 @@ async function handleStreamable(server: Server, req: http.IncomingMessage, res: sessionIdGenerator: () => crypto.randomUUID(), onsessioninitialized: async sessionId => { testDebug(`create http session: ${transport.sessionId}`); - const connection = await server.createConnection(transport); - sessions.set(sessionId, { transport, connection }); + await server.createConnection(transport); + sessions.set(sessionId, transport); } }); transport.onclose = () => { - const result = transport.sessionId ? sessions.get(transport.sessionId) : undefined; - if (!result) + if (!transport.sessionId) return; - sessions.delete(result.transport.sessionId!); + sessions.delete(transport.sessionId); testDebug(`delete http session: ${transport.sessionId}`); - result.connection.close().catch(logUnhandledError); }; await transport.handleRequest(req, res); @@ -120,7 +114,7 @@ export async function startHttpServer(config: { host?: string, port?: number }): } export function startHttpTransport(httpServer: http.Server, mcpServer: Server) { - const sseSessions = new Map(); + const sseSessions = new Map(); const streamableSessions = new Map(); httpServer.on('request', async (req, res) => { const url = new URL(`http://localhost${req.url}`);