chore: do not double close connection (#744)

This commit is contained in:
Yury Semikhatsky 2025-07-23 17:41:15 -07:00 committed by GitHub
parent 2c5eac89a8
commit bc120baa78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 60 additions and 59 deletions

7
index.d.ts vendored
View File

@ -19,10 +19,5 @@ import type { Server } from '@modelcontextprotocol/sdk/server/index.js';
import type { Config } from './config.js';
import type { BrowserContext } from 'playwright';
export type Connection = {
server: Server;
close(): Promise<void>;
};
export declare function createConnection(config?: Config, contextGetter?: () => Promise<BrowserContext>): Promise<Connection>;
export declare function createConnection(config?: Config, contextGetter?: () => Promise<BrowserContext>): Promise<Server>;
export {};

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
import { Server as McpServer } from '@modelcontextprotocol/sdk/server/index.js';
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
import { CallToolRequestSchema, ListToolsRequestSchema, Tool as McpTool } from '@modelcontextprotocol/sdk/types.js';
import { zodToJsonSchema } from 'zod-to-json-schema';
import { Context } from './context.js';
@ -23,12 +23,13 @@ import { allTools } from './tools.js';
import { packageJSON } from './package.js';
import { FullConfig } from './config.js';
import { SessionLog } from './sessionLog.js';
import { logUnhandledError } from './log.js';
import type { BrowserContextFactory } from './browserContextFactory.js';
export async function createConnection(config: FullConfig, browserContextFactory: BrowserContextFactory): Promise<Connection> {
export async function createMCPServer(config: FullConfig, browserContextFactory: BrowserContextFactory): Promise<Server> {
const tools = allTools.filter(tool => tool.capability.startsWith('core') || config.capabilities?.includes(tool.capability));
const context = new Context(tools, config, browserContextFactory);
const server = new McpServer({ name: 'Playwright', version: packageJSON.version }, {
const server = new Server({ name: 'Playwright', version: packageJSON.version }, {
capabilities: {
tools: {},
}
@ -72,23 +73,13 @@ export async function createConnection(config: FullConfig, browserContextFactory
}
});
return new Connection(server, context);
}
export class Connection {
readonly server: McpServer;
readonly context: Context;
constructor(server: McpServer, context: Context) {
this.server = server;
this.context = context;
this.server.oninitialized = () => {
this.context.clientVersion = this.server.getClientVersion();
};
}
async close() {
await this.server.close();
await this.context.close();
}
server.oninitialized = () => {
context.clientVersion = server.getClientVersion();
};
server.onclose = () => {
void context.dispose().catch(logUnhandledError);
};
return server;
}

View File

@ -34,11 +34,19 @@ export class Context {
private _currentTab: Tab | undefined;
clientVersion: { name: string; version: string; } | undefined;
private static _allContexts: Set<Context> = new Set();
private _closeBrowserContextPromise: Promise<void> | undefined;
constructor(tools: Tool[], config: FullConfig, browserContextFactory: BrowserContextFactory) {
this.tools = tools;
this.config = config;
this._browserContextFactory = browserContextFactory;
testDebug('create context');
Context._allContexts.add(this);
}
static async disposeAll() {
await Promise.all([...Context._allContexts].map(context => context.dispose()));
}
tabs(): Tab[] {
@ -127,10 +135,17 @@ export class Context {
if (this._currentTab === tab)
this._currentTab = this._tabs[Math.min(index, this._tabs.length - 1)];
if (!this._tabs.length)
void this.close();
void this.closeBrowserContext();
}
async close() {
async closeBrowserContext() {
if (!this._closeBrowserContextPromise)
this._closeBrowserContextPromise = this._closeBrowserContextImpl();
await this._closeBrowserContextPromise;
this._closeBrowserContextPromise = undefined;
}
private async _closeBrowserContextImpl() {
if (!this._browserContextPromise)
return;
@ -146,6 +161,11 @@ export class Context {
});
}
async dispose() {
await this.closeBrowserContext();
Context._allContexts.delete(this);
}
private async _setupRequestInterception(context: playwright.BrowserContext) {
if (this.config.network?.allowedOrigins?.length) {
await context.route('**', route => route.abort('blockedbyclient'));
@ -171,6 +191,8 @@ export class Context {
}
private async _setupBrowserContext(): Promise<{ browserContext: playwright.BrowserContext, close: () => Promise<void> }> {
if (this._closeBrowserContextPromise)
throw new Error('Another browser context is being closed.');
// TODO: move to the browser context factory to make it based on isolation mode.
const result = await this._browserContextFactory.createContext(this.clientVersion!);
const { browserContext } = result;

View File

@ -307,7 +307,9 @@ class ExtensionContextFactory implements BrowserContextFactory {
const browser = await this._browserPromise;
return {
browserContext: browser.contexts()[0],
close: async () => {}
close: async () => {
debugLogger('close() called for browser context, ignoring');
}
};
}

View File

@ -14,18 +14,18 @@
* limitations under the License.
*/
import { createConnection as createConnectionImpl } from './connection.js';
import { createMCPServer } from './connection.js';
import { resolveConfig } from './config.js';
import { contextFactory } from './browserContextFactory.js';
import type { Connection } from '../index.js';
import type { Config } from '../config.js';
import type { BrowserContext } from 'playwright';
import type { BrowserContextFactory } from './browserContextFactory.js';
import type { Server } from '@modelcontextprotocol/sdk/server/index.js';
export async function createConnection(userConfig: Config = {}, contextGetter?: () => Promise<BrowserContext>): Promise<Connection> {
export async function createConnection(userConfig: Config = {}, contextGetter?: () => Promise<BrowserContext>): Promise<Server> {
const config = await resolveConfig(userConfig);
const factory = contextGetter ? new SimpleBrowserContextFactory(contextGetter) : contextFactory(config.browser);
return createConnectionImpl(config, factory);
return createMCPServer(config, factory);
}
class SimpleBrowserContextFactory implements BrowserContextFactory {

View File

@ -14,17 +14,16 @@
* limitations under the License.
*/
import { createConnection } from './connection.js';
import { createMCPServer } from './connection.js';
import { Context } from './context.js';
import { contextFactory as defaultContextFactory } from './browserContextFactory.js';
import type { FullConfig } from './config.js';
import type { Connection } from './connection.js';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
import type { BrowserContextFactory } from './browserContextFactory.js';
export class Server {
readonly config: FullConfig;
private _connectionList: Connection[] = [];
private _browserConfig: FullConfig['browser'];
private _contextFactory: BrowserContextFactory;
@ -34,11 +33,9 @@ export class Server {
this._contextFactory = contextFactory ?? defaultContextFactory(this._browserConfig);
}
async createConnection(transport: Transport): Promise<Connection> {
const connection = await createConnection(this.config, this._contextFactory);
this._connectionList.push(connection);
await connection.server.connect(transport);
return connection;
async createConnection(transport: Transport): Promise<void> {
const server = await createMCPServer(this.config, this._contextFactory);
await server.connect(transport);
}
setupExitWatchdog() {
@ -48,7 +45,7 @@ export class Server {
return;
isExiting = true;
setTimeout(() => process.exit(0), 15000);
await Promise.all(this._connectionList.map(connection => connection.close()));
await Context.disposeAll();
process.exit(0);
};

View File

@ -29,7 +29,7 @@ const close = defineTool({
},
handle: async (context, params, response) => {
await context.close();
await context.closeBrowserContext();
response.setIncludeTabs();
response.addCode(`await page.close()`);
},

View File

@ -23,11 +23,8 @@ import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js';
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js';
import { logUnhandledError } from './log.js';
import type { AddressInfo } from 'node:net';
import type { Server } from './server.js';
import type { Connection } from './connection.js';
export async function startStdioTransport(server: Server) {
await server.createConnection(new StdioServerTransport());
@ -54,11 +51,10 @@ async function handleSSE(server: Server, req: http.IncomingMessage, res: http.Se
const transport = new SSEServerTransport('/sse', res);
sessions.set(transport.sessionId, transport);
testDebug(`create SSE session: ${transport.sessionId}`);
const connection = await server.createConnection(transport);
await server.createConnection(transport);
res.on('close', () => {
testDebug(`delete SSE session: ${transport.sessionId}`);
sessions.delete(transport.sessionId);
void connection.close().catch(logUnhandledError);
});
return;
}
@ -67,10 +63,10 @@ async function handleSSE(server: Server, req: http.IncomingMessage, res: http.Se
res.end('Method not allowed');
}
async function handleStreamable(server: Server, req: http.IncomingMessage, res: http.ServerResponse, sessions: Map<string, { transport: StreamableHTTPServerTransport, connection: Connection }>) {
async function handleStreamable(server: Server, req: http.IncomingMessage, res: http.ServerResponse, sessions: Map<string, StreamableHTTPServerTransport>) {
const sessionId = req.headers['mcp-session-id'] as string | undefined;
if (sessionId) {
const { transport } = sessions.get(sessionId) ?? {};
const transport = sessions.get(sessionId);
if (!transport) {
res.statusCode = 404;
res.end('Session not found');
@ -84,18 +80,16 @@ async function handleStreamable(server: Server, req: http.IncomingMessage, res:
sessionIdGenerator: () => crypto.randomUUID(),
onsessioninitialized: async sessionId => {
testDebug(`create http session: ${transport.sessionId}`);
const connection = await server.createConnection(transport);
sessions.set(sessionId, { transport, connection });
await server.createConnection(transport);
sessions.set(sessionId, transport);
}
});
transport.onclose = () => {
const result = transport.sessionId ? sessions.get(transport.sessionId) : undefined;
if (!result)
if (!transport.sessionId)
return;
sessions.delete(result.transport.sessionId!);
sessions.delete(transport.sessionId);
testDebug(`delete http session: ${transport.sessionId}`);
result.connection.close().catch(logUnhandledError);
};
await transport.handleRequest(req, res);
@ -120,7 +114,7 @@ export async function startHttpServer(config: { host?: string, port?: number }):
}
export function startHttpTransport(httpServer: http.Server, mcpServer: Server) {
const sseSessions = new Map<string, SSEServerTransport>();
const sseSessions = new Map();
const streamableSessions = new Map();
httpServer.on('request', async (req, res) => {
const url = new URL(`http://localhost${req.url}`);