chore: do not double close connection (#744)

This commit is contained in:
Yury Semikhatsky 2025-07-23 17:41:15 -07:00 committed by GitHub
parent 2c5eac89a8
commit bc120baa78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 60 additions and 59 deletions

7
index.d.ts vendored
View File

@ -19,10 +19,5 @@ import type { Server } from '@modelcontextprotocol/sdk/server/index.js';
import type { Config } from './config.js'; import type { Config } from './config.js';
import type { BrowserContext } from 'playwright'; import type { BrowserContext } from 'playwright';
export type Connection = { export declare function createConnection(config?: Config, contextGetter?: () => Promise<BrowserContext>): Promise<Server>;
server: Server;
close(): Promise<void>;
};
export declare function createConnection(config?: Config, contextGetter?: () => Promise<BrowserContext>): Promise<Connection>;
export {}; export {};

View File

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
import { Server as McpServer } from '@modelcontextprotocol/sdk/server/index.js'; import { Server } from '@modelcontextprotocol/sdk/server/index.js';
import { CallToolRequestSchema, ListToolsRequestSchema, Tool as McpTool } from '@modelcontextprotocol/sdk/types.js'; import { CallToolRequestSchema, ListToolsRequestSchema, Tool as McpTool } from '@modelcontextprotocol/sdk/types.js';
import { zodToJsonSchema } from 'zod-to-json-schema'; import { zodToJsonSchema } from 'zod-to-json-schema';
import { Context } from './context.js'; import { Context } from './context.js';
@ -23,12 +23,13 @@ import { allTools } from './tools.js';
import { packageJSON } from './package.js'; import { packageJSON } from './package.js';
import { FullConfig } from './config.js'; import { FullConfig } from './config.js';
import { SessionLog } from './sessionLog.js'; import { SessionLog } from './sessionLog.js';
import { logUnhandledError } from './log.js';
import type { BrowserContextFactory } from './browserContextFactory.js'; import type { BrowserContextFactory } from './browserContextFactory.js';
export async function createConnection(config: FullConfig, browserContextFactory: BrowserContextFactory): Promise<Connection> { export async function createMCPServer(config: FullConfig, browserContextFactory: BrowserContextFactory): Promise<Server> {
const tools = allTools.filter(tool => tool.capability.startsWith('core') || config.capabilities?.includes(tool.capability)); const tools = allTools.filter(tool => tool.capability.startsWith('core') || config.capabilities?.includes(tool.capability));
const context = new Context(tools, config, browserContextFactory); const context = new Context(tools, config, browserContextFactory);
const server = new McpServer({ name: 'Playwright', version: packageJSON.version }, { const server = new Server({ name: 'Playwright', version: packageJSON.version }, {
capabilities: { capabilities: {
tools: {}, tools: {},
} }
@ -72,23 +73,13 @@ export async function createConnection(config: FullConfig, browserContextFactory
} }
}); });
return new Connection(server, context); server.oninitialized = () => {
} context.clientVersion = server.getClientVersion();
};
export class Connection {
readonly server: McpServer; server.onclose = () => {
readonly context: Context; void context.dispose().catch(logUnhandledError);
};
constructor(server: McpServer, context: Context) {
this.server = server; return server;
this.context = context;
this.server.oninitialized = () => {
this.context.clientVersion = this.server.getClientVersion();
};
}
async close() {
await this.server.close();
await this.context.close();
}
} }

View File

@ -34,11 +34,19 @@ export class Context {
private _currentTab: Tab | undefined; private _currentTab: Tab | undefined;
clientVersion: { name: string; version: string; } | undefined; clientVersion: { name: string; version: string; } | undefined;
private static _allContexts: Set<Context> = new Set();
private _closeBrowserContextPromise: Promise<void> | undefined;
constructor(tools: Tool[], config: FullConfig, browserContextFactory: BrowserContextFactory) { constructor(tools: Tool[], config: FullConfig, browserContextFactory: BrowserContextFactory) {
this.tools = tools; this.tools = tools;
this.config = config; this.config = config;
this._browserContextFactory = browserContextFactory; this._browserContextFactory = browserContextFactory;
testDebug('create context'); testDebug('create context');
Context._allContexts.add(this);
}
static async disposeAll() {
await Promise.all([...Context._allContexts].map(context => context.dispose()));
} }
tabs(): Tab[] { tabs(): Tab[] {
@ -127,10 +135,17 @@ export class Context {
if (this._currentTab === tab) if (this._currentTab === tab)
this._currentTab = this._tabs[Math.min(index, this._tabs.length - 1)]; this._currentTab = this._tabs[Math.min(index, this._tabs.length - 1)];
if (!this._tabs.length) if (!this._tabs.length)
void this.close(); void this.closeBrowserContext();
} }
async close() { async closeBrowserContext() {
if (!this._closeBrowserContextPromise)
this._closeBrowserContextPromise = this._closeBrowserContextImpl();
await this._closeBrowserContextPromise;
this._closeBrowserContextPromise = undefined;
}
private async _closeBrowserContextImpl() {
if (!this._browserContextPromise) if (!this._browserContextPromise)
return; return;
@ -146,6 +161,11 @@ export class Context {
}); });
} }
async dispose() {
await this.closeBrowserContext();
Context._allContexts.delete(this);
}
private async _setupRequestInterception(context: playwright.BrowserContext) { private async _setupRequestInterception(context: playwright.BrowserContext) {
if (this.config.network?.allowedOrigins?.length) { if (this.config.network?.allowedOrigins?.length) {
await context.route('**', route => route.abort('blockedbyclient')); await context.route('**', route => route.abort('blockedbyclient'));
@ -171,6 +191,8 @@ export class Context {
} }
private async _setupBrowserContext(): Promise<{ browserContext: playwright.BrowserContext, close: () => Promise<void> }> { private async _setupBrowserContext(): Promise<{ browserContext: playwright.BrowserContext, close: () => Promise<void> }> {
if (this._closeBrowserContextPromise)
throw new Error('Another browser context is being closed.');
// TODO: move to the browser context factory to make it based on isolation mode. // TODO: move to the browser context factory to make it based on isolation mode.
const result = await this._browserContextFactory.createContext(this.clientVersion!); const result = await this._browserContextFactory.createContext(this.clientVersion!);
const { browserContext } = result; const { browserContext } = result;

View File

@ -307,7 +307,9 @@ class ExtensionContextFactory implements BrowserContextFactory {
const browser = await this._browserPromise; const browser = await this._browserPromise;
return { return {
browserContext: browser.contexts()[0], browserContext: browser.contexts()[0],
close: async () => {} close: async () => {
debugLogger('close() called for browser context, ignoring');
}
}; };
} }

View File

@ -14,18 +14,18 @@
* limitations under the License. * limitations under the License.
*/ */
import { createConnection as createConnectionImpl } from './connection.js'; import { createMCPServer } from './connection.js';
import { resolveConfig } from './config.js'; import { resolveConfig } from './config.js';
import { contextFactory } from './browserContextFactory.js'; import { contextFactory } from './browserContextFactory.js';
import type { Connection } from '../index.js';
import type { Config } from '../config.js'; import type { Config } from '../config.js';
import type { BrowserContext } from 'playwright'; import type { BrowserContext } from 'playwright';
import type { BrowserContextFactory } from './browserContextFactory.js'; import type { BrowserContextFactory } from './browserContextFactory.js';
import type { Server } from '@modelcontextprotocol/sdk/server/index.js';
export async function createConnection(userConfig: Config = {}, contextGetter?: () => Promise<BrowserContext>): Promise<Connection> { export async function createConnection(userConfig: Config = {}, contextGetter?: () => Promise<BrowserContext>): Promise<Server> {
const config = await resolveConfig(userConfig); const config = await resolveConfig(userConfig);
const factory = contextGetter ? new SimpleBrowserContextFactory(contextGetter) : contextFactory(config.browser); const factory = contextGetter ? new SimpleBrowserContextFactory(contextGetter) : contextFactory(config.browser);
return createConnectionImpl(config, factory); return createMCPServer(config, factory);
} }
class SimpleBrowserContextFactory implements BrowserContextFactory { class SimpleBrowserContextFactory implements BrowserContextFactory {

View File

@ -14,17 +14,16 @@
* limitations under the License. * limitations under the License.
*/ */
import { createConnection } from './connection.js'; import { createMCPServer } from './connection.js';
import { Context } from './context.js';
import { contextFactory as defaultContextFactory } from './browserContextFactory.js'; import { contextFactory as defaultContextFactory } from './browserContextFactory.js';
import type { FullConfig } from './config.js'; import type { FullConfig } from './config.js';
import type { Connection } from './connection.js';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
import type { BrowserContextFactory } from './browserContextFactory.js'; import type { BrowserContextFactory } from './browserContextFactory.js';
export class Server { export class Server {
readonly config: FullConfig; readonly config: FullConfig;
private _connectionList: Connection[] = [];
private _browserConfig: FullConfig['browser']; private _browserConfig: FullConfig['browser'];
private _contextFactory: BrowserContextFactory; private _contextFactory: BrowserContextFactory;
@ -34,11 +33,9 @@ export class Server {
this._contextFactory = contextFactory ?? defaultContextFactory(this._browserConfig); this._contextFactory = contextFactory ?? defaultContextFactory(this._browserConfig);
} }
async createConnection(transport: Transport): Promise<Connection> { async createConnection(transport: Transport): Promise<void> {
const connection = await createConnection(this.config, this._contextFactory); const server = await createMCPServer(this.config, this._contextFactory);
this._connectionList.push(connection); await server.connect(transport);
await connection.server.connect(transport);
return connection;
} }
setupExitWatchdog() { setupExitWatchdog() {
@ -48,7 +45,7 @@ export class Server {
return; return;
isExiting = true; isExiting = true;
setTimeout(() => process.exit(0), 15000); setTimeout(() => process.exit(0), 15000);
await Promise.all(this._connectionList.map(connection => connection.close())); await Context.disposeAll();
process.exit(0); process.exit(0);
}; };

View File

@ -29,7 +29,7 @@ const close = defineTool({
}, },
handle: async (context, params, response) => { handle: async (context, params, response) => {
await context.close(); await context.closeBrowserContext();
response.setIncludeTabs(); response.setIncludeTabs();
response.addCode(`await page.close()`); response.addCode(`await page.close()`);
}, },

View File

@ -23,11 +23,8 @@ import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js';
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'; import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js';
import { logUnhandledError } from './log.js';
import type { AddressInfo } from 'node:net'; import type { AddressInfo } from 'node:net';
import type { Server } from './server.js'; import type { Server } from './server.js';
import type { Connection } from './connection.js';
export async function startStdioTransport(server: Server) { export async function startStdioTransport(server: Server) {
await server.createConnection(new StdioServerTransport()); await server.createConnection(new StdioServerTransport());
@ -54,11 +51,10 @@ async function handleSSE(server: Server, req: http.IncomingMessage, res: http.Se
const transport = new SSEServerTransport('/sse', res); const transport = new SSEServerTransport('/sse', res);
sessions.set(transport.sessionId, transport); sessions.set(transport.sessionId, transport);
testDebug(`create SSE session: ${transport.sessionId}`); testDebug(`create SSE session: ${transport.sessionId}`);
const connection = await server.createConnection(transport); await server.createConnection(transport);
res.on('close', () => { res.on('close', () => {
testDebug(`delete SSE session: ${transport.sessionId}`); testDebug(`delete SSE session: ${transport.sessionId}`);
sessions.delete(transport.sessionId); sessions.delete(transport.sessionId);
void connection.close().catch(logUnhandledError);
}); });
return; return;
} }
@ -67,10 +63,10 @@ async function handleSSE(server: Server, req: http.IncomingMessage, res: http.Se
res.end('Method not allowed'); res.end('Method not allowed');
} }
async function handleStreamable(server: Server, req: http.IncomingMessage, res: http.ServerResponse, sessions: Map<string, { transport: StreamableHTTPServerTransport, connection: Connection }>) { async function handleStreamable(server: Server, req: http.IncomingMessage, res: http.ServerResponse, sessions: Map<string, StreamableHTTPServerTransport>) {
const sessionId = req.headers['mcp-session-id'] as string | undefined; const sessionId = req.headers['mcp-session-id'] as string | undefined;
if (sessionId) { if (sessionId) {
const { transport } = sessions.get(sessionId) ?? {}; const transport = sessions.get(sessionId);
if (!transport) { if (!transport) {
res.statusCode = 404; res.statusCode = 404;
res.end('Session not found'); res.end('Session not found');
@ -84,18 +80,16 @@ async function handleStreamable(server: Server, req: http.IncomingMessage, res:
sessionIdGenerator: () => crypto.randomUUID(), sessionIdGenerator: () => crypto.randomUUID(),
onsessioninitialized: async sessionId => { onsessioninitialized: async sessionId => {
testDebug(`create http session: ${transport.sessionId}`); testDebug(`create http session: ${transport.sessionId}`);
const connection = await server.createConnection(transport); await server.createConnection(transport);
sessions.set(sessionId, { transport, connection }); sessions.set(sessionId, transport);
} }
}); });
transport.onclose = () => { transport.onclose = () => {
const result = transport.sessionId ? sessions.get(transport.sessionId) : undefined; if (!transport.sessionId)
if (!result)
return; return;
sessions.delete(result.transport.sessionId!); sessions.delete(transport.sessionId);
testDebug(`delete http session: ${transport.sessionId}`); testDebug(`delete http session: ${transport.sessionId}`);
result.connection.close().catch(logUnhandledError);
}; };
await transport.handleRequest(req, res); await transport.handleRequest(req, res);
@ -120,7 +114,7 @@ export async function startHttpServer(config: { host?: string, port?: number }):
} }
export function startHttpTransport(httpServer: http.Server, mcpServer: Server) { export function startHttpTransport(httpServer: http.Server, mcpServer: Server) {
const sseSessions = new Map<string, SSEServerTransport>(); const sseSessions = new Map();
const streamableSessions = new Map(); const streamableSessions = new Map();
httpServer.on('request', async (req, res) => { httpServer.on('request', async (req, res) => {
const url = new URL(`http://localhost${req.url}`); const url = new URL(`http://localhost${req.url}`);