mirror of
https://github.com/microsoft/playwright-mcp.git
synced 2025-07-26 08:32:26 +08:00
chore: do not double close connection (#744)
This commit is contained in:
parent
2c5eac89a8
commit
bc120baa78
7
index.d.ts
vendored
7
index.d.ts
vendored
@ -19,10 +19,5 @@ import type { Server } from '@modelcontextprotocol/sdk/server/index.js';
|
||||
import type { Config } from './config.js';
|
||||
import type { BrowserContext } from 'playwright';
|
||||
|
||||
export type Connection = {
|
||||
server: Server;
|
||||
close(): Promise<void>;
|
||||
};
|
||||
|
||||
export declare function createConnection(config?: Config, contextGetter?: () => Promise<BrowserContext>): Promise<Connection>;
|
||||
export declare function createConnection(config?: Config, contextGetter?: () => Promise<BrowserContext>): Promise<Server>;
|
||||
export {};
|
||||
|
@ -14,7 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import { Server as McpServer } from '@modelcontextprotocol/sdk/server/index.js';
|
||||
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
|
||||
import { CallToolRequestSchema, ListToolsRequestSchema, Tool as McpTool } from '@modelcontextprotocol/sdk/types.js';
|
||||
import { zodToJsonSchema } from 'zod-to-json-schema';
|
||||
import { Context } from './context.js';
|
||||
@ -23,12 +23,13 @@ import { allTools } from './tools.js';
|
||||
import { packageJSON } from './package.js';
|
||||
import { FullConfig } from './config.js';
|
||||
import { SessionLog } from './sessionLog.js';
|
||||
import { logUnhandledError } from './log.js';
|
||||
import type { BrowserContextFactory } from './browserContextFactory.js';
|
||||
|
||||
export async function createConnection(config: FullConfig, browserContextFactory: BrowserContextFactory): Promise<Connection> {
|
||||
export async function createMCPServer(config: FullConfig, browserContextFactory: BrowserContextFactory): Promise<Server> {
|
||||
const tools = allTools.filter(tool => tool.capability.startsWith('core') || config.capabilities?.includes(tool.capability));
|
||||
const context = new Context(tools, config, browserContextFactory);
|
||||
const server = new McpServer({ name: 'Playwright', version: packageJSON.version }, {
|
||||
const server = new Server({ name: 'Playwright', version: packageJSON.version }, {
|
||||
capabilities: {
|
||||
tools: {},
|
||||
}
|
||||
@ -72,23 +73,13 @@ export async function createConnection(config: FullConfig, browserContextFactory
|
||||
}
|
||||
});
|
||||
|
||||
return new Connection(server, context);
|
||||
}
|
||||
|
||||
export class Connection {
|
||||
readonly server: McpServer;
|
||||
readonly context: Context;
|
||||
|
||||
constructor(server: McpServer, context: Context) {
|
||||
this.server = server;
|
||||
this.context = context;
|
||||
this.server.oninitialized = () => {
|
||||
this.context.clientVersion = this.server.getClientVersion();
|
||||
};
|
||||
}
|
||||
|
||||
async close() {
|
||||
await this.server.close();
|
||||
await this.context.close();
|
||||
}
|
||||
server.oninitialized = () => {
|
||||
context.clientVersion = server.getClientVersion();
|
||||
};
|
||||
|
||||
server.onclose = () => {
|
||||
void context.dispose().catch(logUnhandledError);
|
||||
};
|
||||
|
||||
return server;
|
||||
}
|
||||
|
@ -34,11 +34,19 @@ export class Context {
|
||||
private _currentTab: Tab | undefined;
|
||||
clientVersion: { name: string; version: string; } | undefined;
|
||||
|
||||
private static _allContexts: Set<Context> = new Set();
|
||||
private _closeBrowserContextPromise: Promise<void> | undefined;
|
||||
|
||||
constructor(tools: Tool[], config: FullConfig, browserContextFactory: BrowserContextFactory) {
|
||||
this.tools = tools;
|
||||
this.config = config;
|
||||
this._browserContextFactory = browserContextFactory;
|
||||
testDebug('create context');
|
||||
Context._allContexts.add(this);
|
||||
}
|
||||
|
||||
static async disposeAll() {
|
||||
await Promise.all([...Context._allContexts].map(context => context.dispose()));
|
||||
}
|
||||
|
||||
tabs(): Tab[] {
|
||||
@ -127,10 +135,17 @@ export class Context {
|
||||
if (this._currentTab === tab)
|
||||
this._currentTab = this._tabs[Math.min(index, this._tabs.length - 1)];
|
||||
if (!this._tabs.length)
|
||||
void this.close();
|
||||
void this.closeBrowserContext();
|
||||
}
|
||||
|
||||
async close() {
|
||||
async closeBrowserContext() {
|
||||
if (!this._closeBrowserContextPromise)
|
||||
this._closeBrowserContextPromise = this._closeBrowserContextImpl();
|
||||
await this._closeBrowserContextPromise;
|
||||
this._closeBrowserContextPromise = undefined;
|
||||
}
|
||||
|
||||
private async _closeBrowserContextImpl() {
|
||||
if (!this._browserContextPromise)
|
||||
return;
|
||||
|
||||
@ -146,6 +161,11 @@ export class Context {
|
||||
});
|
||||
}
|
||||
|
||||
async dispose() {
|
||||
await this.closeBrowserContext();
|
||||
Context._allContexts.delete(this);
|
||||
}
|
||||
|
||||
private async _setupRequestInterception(context: playwright.BrowserContext) {
|
||||
if (this.config.network?.allowedOrigins?.length) {
|
||||
await context.route('**', route => route.abort('blockedbyclient'));
|
||||
@ -171,6 +191,8 @@ export class Context {
|
||||
}
|
||||
|
||||
private async _setupBrowserContext(): Promise<{ browserContext: playwright.BrowserContext, close: () => Promise<void> }> {
|
||||
if (this._closeBrowserContextPromise)
|
||||
throw new Error('Another browser context is being closed.');
|
||||
// TODO: move to the browser context factory to make it based on isolation mode.
|
||||
const result = await this._browserContextFactory.createContext(this.clientVersion!);
|
||||
const { browserContext } = result;
|
||||
|
@ -307,7 +307,9 @@ class ExtensionContextFactory implements BrowserContextFactory {
|
||||
const browser = await this._browserPromise;
|
||||
return {
|
||||
browserContext: browser.contexts()[0],
|
||||
close: async () => {}
|
||||
close: async () => {
|
||||
debugLogger('close() called for browser context, ignoring');
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -14,18 +14,18 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import { createConnection as createConnectionImpl } from './connection.js';
|
||||
import { createMCPServer } from './connection.js';
|
||||
import { resolveConfig } from './config.js';
|
||||
import { contextFactory } from './browserContextFactory.js';
|
||||
import type { Connection } from '../index.js';
|
||||
import type { Config } from '../config.js';
|
||||
import type { BrowserContext } from 'playwright';
|
||||
import type { BrowserContextFactory } from './browserContextFactory.js';
|
||||
import type { Server } from '@modelcontextprotocol/sdk/server/index.js';
|
||||
|
||||
export async function createConnection(userConfig: Config = {}, contextGetter?: () => Promise<BrowserContext>): Promise<Connection> {
|
||||
export async function createConnection(userConfig: Config = {}, contextGetter?: () => Promise<BrowserContext>): Promise<Server> {
|
||||
const config = await resolveConfig(userConfig);
|
||||
const factory = contextGetter ? new SimpleBrowserContextFactory(contextGetter) : contextFactory(config.browser);
|
||||
return createConnectionImpl(config, factory);
|
||||
return createMCPServer(config, factory);
|
||||
}
|
||||
|
||||
class SimpleBrowserContextFactory implements BrowserContextFactory {
|
||||
|
@ -14,17 +14,16 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import { createConnection } from './connection.js';
|
||||
import { createMCPServer } from './connection.js';
|
||||
import { Context } from './context.js';
|
||||
import { contextFactory as defaultContextFactory } from './browserContextFactory.js';
|
||||
|
||||
import type { FullConfig } from './config.js';
|
||||
import type { Connection } from './connection.js';
|
||||
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
||||
import type { BrowserContextFactory } from './browserContextFactory.js';
|
||||
|
||||
export class Server {
|
||||
readonly config: FullConfig;
|
||||
private _connectionList: Connection[] = [];
|
||||
private _browserConfig: FullConfig['browser'];
|
||||
private _contextFactory: BrowserContextFactory;
|
||||
|
||||
@ -34,11 +33,9 @@ export class Server {
|
||||
this._contextFactory = contextFactory ?? defaultContextFactory(this._browserConfig);
|
||||
}
|
||||
|
||||
async createConnection(transport: Transport): Promise<Connection> {
|
||||
const connection = await createConnection(this.config, this._contextFactory);
|
||||
this._connectionList.push(connection);
|
||||
await connection.server.connect(transport);
|
||||
return connection;
|
||||
async createConnection(transport: Transport): Promise<void> {
|
||||
const server = await createMCPServer(this.config, this._contextFactory);
|
||||
await server.connect(transport);
|
||||
}
|
||||
|
||||
setupExitWatchdog() {
|
||||
@ -48,7 +45,7 @@ export class Server {
|
||||
return;
|
||||
isExiting = true;
|
||||
setTimeout(() => process.exit(0), 15000);
|
||||
await Promise.all(this._connectionList.map(connection => connection.close()));
|
||||
await Context.disposeAll();
|
||||
process.exit(0);
|
||||
};
|
||||
|
||||
|
@ -29,7 +29,7 @@ const close = defineTool({
|
||||
},
|
||||
|
||||
handle: async (context, params, response) => {
|
||||
await context.close();
|
||||
await context.closeBrowserContext();
|
||||
response.setIncludeTabs();
|
||||
response.addCode(`await page.close()`);
|
||||
},
|
||||
|
@ -23,11 +23,8 @@ import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js';
|
||||
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
|
||||
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js';
|
||||
|
||||
import { logUnhandledError } from './log.js';
|
||||
|
||||
import type { AddressInfo } from 'node:net';
|
||||
import type { Server } from './server.js';
|
||||
import type { Connection } from './connection.js';
|
||||
|
||||
export async function startStdioTransport(server: Server) {
|
||||
await server.createConnection(new StdioServerTransport());
|
||||
@ -54,11 +51,10 @@ async function handleSSE(server: Server, req: http.IncomingMessage, res: http.Se
|
||||
const transport = new SSEServerTransport('/sse', res);
|
||||
sessions.set(transport.sessionId, transport);
|
||||
testDebug(`create SSE session: ${transport.sessionId}`);
|
||||
const connection = await server.createConnection(transport);
|
||||
await server.createConnection(transport);
|
||||
res.on('close', () => {
|
||||
testDebug(`delete SSE session: ${transport.sessionId}`);
|
||||
sessions.delete(transport.sessionId);
|
||||
void connection.close().catch(logUnhandledError);
|
||||
});
|
||||
return;
|
||||
}
|
||||
@ -67,10 +63,10 @@ async function handleSSE(server: Server, req: http.IncomingMessage, res: http.Se
|
||||
res.end('Method not allowed');
|
||||
}
|
||||
|
||||
async function handleStreamable(server: Server, req: http.IncomingMessage, res: http.ServerResponse, sessions: Map<string, { transport: StreamableHTTPServerTransport, connection: Connection }>) {
|
||||
async function handleStreamable(server: Server, req: http.IncomingMessage, res: http.ServerResponse, sessions: Map<string, StreamableHTTPServerTransport>) {
|
||||
const sessionId = req.headers['mcp-session-id'] as string | undefined;
|
||||
if (sessionId) {
|
||||
const { transport } = sessions.get(sessionId) ?? {};
|
||||
const transport = sessions.get(sessionId);
|
||||
if (!transport) {
|
||||
res.statusCode = 404;
|
||||
res.end('Session not found');
|
||||
@ -84,18 +80,16 @@ async function handleStreamable(server: Server, req: http.IncomingMessage, res:
|
||||
sessionIdGenerator: () => crypto.randomUUID(),
|
||||
onsessioninitialized: async sessionId => {
|
||||
testDebug(`create http session: ${transport.sessionId}`);
|
||||
const connection = await server.createConnection(transport);
|
||||
sessions.set(sessionId, { transport, connection });
|
||||
await server.createConnection(transport);
|
||||
sessions.set(sessionId, transport);
|
||||
}
|
||||
});
|
||||
|
||||
transport.onclose = () => {
|
||||
const result = transport.sessionId ? sessions.get(transport.sessionId) : undefined;
|
||||
if (!result)
|
||||
if (!transport.sessionId)
|
||||
return;
|
||||
sessions.delete(result.transport.sessionId!);
|
||||
sessions.delete(transport.sessionId);
|
||||
testDebug(`delete http session: ${transport.sessionId}`);
|
||||
result.connection.close().catch(logUnhandledError);
|
||||
};
|
||||
|
||||
await transport.handleRequest(req, res);
|
||||
@ -120,7 +114,7 @@ export async function startHttpServer(config: { host?: string, port?: number }):
|
||||
}
|
||||
|
||||
export function startHttpTransport(httpServer: http.Server, mcpServer: Server) {
|
||||
const sseSessions = new Map<string, SSEServerTransport>();
|
||||
const sseSessions = new Map();
|
||||
const streamableSessions = new Map();
|
||||
httpServer.on('request', async (req, res) => {
|
||||
const url = new URL(`http://localhost${req.url}`);
|
||||
|
Loading…
x
Reference in New Issue
Block a user