mirror of
https://github.com/microsoft/playwright-mcp.git
synced 2025-07-27 00:52:27 +08:00
chore: do not double close connection (#744)
This commit is contained in:
parent
2c5eac89a8
commit
bc120baa78
7
index.d.ts
vendored
7
index.d.ts
vendored
@ -19,10 +19,5 @@ import type { Server } from '@modelcontextprotocol/sdk/server/index.js';
|
|||||||
import type { Config } from './config.js';
|
import type { Config } from './config.js';
|
||||||
import type { BrowserContext } from 'playwright';
|
import type { BrowserContext } from 'playwright';
|
||||||
|
|
||||||
export type Connection = {
|
export declare function createConnection(config?: Config, contextGetter?: () => Promise<BrowserContext>): Promise<Server>;
|
||||||
server: Server;
|
|
||||||
close(): Promise<void>;
|
|
||||||
};
|
|
||||||
|
|
||||||
export declare function createConnection(config?: Config, contextGetter?: () => Promise<BrowserContext>): Promise<Connection>;
|
|
||||||
export {};
|
export {};
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { Server as McpServer } from '@modelcontextprotocol/sdk/server/index.js';
|
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
|
||||||
import { CallToolRequestSchema, ListToolsRequestSchema, Tool as McpTool } from '@modelcontextprotocol/sdk/types.js';
|
import { CallToolRequestSchema, ListToolsRequestSchema, Tool as McpTool } from '@modelcontextprotocol/sdk/types.js';
|
||||||
import { zodToJsonSchema } from 'zod-to-json-schema';
|
import { zodToJsonSchema } from 'zod-to-json-schema';
|
||||||
import { Context } from './context.js';
|
import { Context } from './context.js';
|
||||||
@ -23,12 +23,13 @@ import { allTools } from './tools.js';
|
|||||||
import { packageJSON } from './package.js';
|
import { packageJSON } from './package.js';
|
||||||
import { FullConfig } from './config.js';
|
import { FullConfig } from './config.js';
|
||||||
import { SessionLog } from './sessionLog.js';
|
import { SessionLog } from './sessionLog.js';
|
||||||
|
import { logUnhandledError } from './log.js';
|
||||||
import type { BrowserContextFactory } from './browserContextFactory.js';
|
import type { BrowserContextFactory } from './browserContextFactory.js';
|
||||||
|
|
||||||
export async function createConnection(config: FullConfig, browserContextFactory: BrowserContextFactory): Promise<Connection> {
|
export async function createMCPServer(config: FullConfig, browserContextFactory: BrowserContextFactory): Promise<Server> {
|
||||||
const tools = allTools.filter(tool => tool.capability.startsWith('core') || config.capabilities?.includes(tool.capability));
|
const tools = allTools.filter(tool => tool.capability.startsWith('core') || config.capabilities?.includes(tool.capability));
|
||||||
const context = new Context(tools, config, browserContextFactory);
|
const context = new Context(tools, config, browserContextFactory);
|
||||||
const server = new McpServer({ name: 'Playwright', version: packageJSON.version }, {
|
const server = new Server({ name: 'Playwright', version: packageJSON.version }, {
|
||||||
capabilities: {
|
capabilities: {
|
||||||
tools: {},
|
tools: {},
|
||||||
}
|
}
|
||||||
@ -72,23 +73,13 @@ export async function createConnection(config: FullConfig, browserContextFactory
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
return new Connection(server, context);
|
server.oninitialized = () => {
|
||||||
}
|
context.clientVersion = server.getClientVersion();
|
||||||
|
};
|
||||||
export class Connection {
|
|
||||||
readonly server: McpServer;
|
server.onclose = () => {
|
||||||
readonly context: Context;
|
void context.dispose().catch(logUnhandledError);
|
||||||
|
};
|
||||||
constructor(server: McpServer, context: Context) {
|
|
||||||
this.server = server;
|
return server;
|
||||||
this.context = context;
|
|
||||||
this.server.oninitialized = () => {
|
|
||||||
this.context.clientVersion = this.server.getClientVersion();
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
async close() {
|
|
||||||
await this.server.close();
|
|
||||||
await this.context.close();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -34,11 +34,19 @@ export class Context {
|
|||||||
private _currentTab: Tab | undefined;
|
private _currentTab: Tab | undefined;
|
||||||
clientVersion: { name: string; version: string; } | undefined;
|
clientVersion: { name: string; version: string; } | undefined;
|
||||||
|
|
||||||
|
private static _allContexts: Set<Context> = new Set();
|
||||||
|
private _closeBrowserContextPromise: Promise<void> | undefined;
|
||||||
|
|
||||||
constructor(tools: Tool[], config: FullConfig, browserContextFactory: BrowserContextFactory) {
|
constructor(tools: Tool[], config: FullConfig, browserContextFactory: BrowserContextFactory) {
|
||||||
this.tools = tools;
|
this.tools = tools;
|
||||||
this.config = config;
|
this.config = config;
|
||||||
this._browserContextFactory = browserContextFactory;
|
this._browserContextFactory = browserContextFactory;
|
||||||
testDebug('create context');
|
testDebug('create context');
|
||||||
|
Context._allContexts.add(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
static async disposeAll() {
|
||||||
|
await Promise.all([...Context._allContexts].map(context => context.dispose()));
|
||||||
}
|
}
|
||||||
|
|
||||||
tabs(): Tab[] {
|
tabs(): Tab[] {
|
||||||
@ -127,10 +135,17 @@ export class Context {
|
|||||||
if (this._currentTab === tab)
|
if (this._currentTab === tab)
|
||||||
this._currentTab = this._tabs[Math.min(index, this._tabs.length - 1)];
|
this._currentTab = this._tabs[Math.min(index, this._tabs.length - 1)];
|
||||||
if (!this._tabs.length)
|
if (!this._tabs.length)
|
||||||
void this.close();
|
void this.closeBrowserContext();
|
||||||
}
|
}
|
||||||
|
|
||||||
async close() {
|
async closeBrowserContext() {
|
||||||
|
if (!this._closeBrowserContextPromise)
|
||||||
|
this._closeBrowserContextPromise = this._closeBrowserContextImpl();
|
||||||
|
await this._closeBrowserContextPromise;
|
||||||
|
this._closeBrowserContextPromise = undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
private async _closeBrowserContextImpl() {
|
||||||
if (!this._browserContextPromise)
|
if (!this._browserContextPromise)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
@ -146,6 +161,11 @@ export class Context {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async dispose() {
|
||||||
|
await this.closeBrowserContext();
|
||||||
|
Context._allContexts.delete(this);
|
||||||
|
}
|
||||||
|
|
||||||
private async _setupRequestInterception(context: playwright.BrowserContext) {
|
private async _setupRequestInterception(context: playwright.BrowserContext) {
|
||||||
if (this.config.network?.allowedOrigins?.length) {
|
if (this.config.network?.allowedOrigins?.length) {
|
||||||
await context.route('**', route => route.abort('blockedbyclient'));
|
await context.route('**', route => route.abort('blockedbyclient'));
|
||||||
@ -171,6 +191,8 @@ export class Context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private async _setupBrowserContext(): Promise<{ browserContext: playwright.BrowserContext, close: () => Promise<void> }> {
|
private async _setupBrowserContext(): Promise<{ browserContext: playwright.BrowserContext, close: () => Promise<void> }> {
|
||||||
|
if (this._closeBrowserContextPromise)
|
||||||
|
throw new Error('Another browser context is being closed.');
|
||||||
// TODO: move to the browser context factory to make it based on isolation mode.
|
// TODO: move to the browser context factory to make it based on isolation mode.
|
||||||
const result = await this._browserContextFactory.createContext(this.clientVersion!);
|
const result = await this._browserContextFactory.createContext(this.clientVersion!);
|
||||||
const { browserContext } = result;
|
const { browserContext } = result;
|
||||||
|
@ -307,7 +307,9 @@ class ExtensionContextFactory implements BrowserContextFactory {
|
|||||||
const browser = await this._browserPromise;
|
const browser = await this._browserPromise;
|
||||||
return {
|
return {
|
||||||
browserContext: browser.contexts()[0],
|
browserContext: browser.contexts()[0],
|
||||||
close: async () => {}
|
close: async () => {
|
||||||
|
debugLogger('close() called for browser context, ignoring');
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -14,18 +14,18 @@
|
|||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { createConnection as createConnectionImpl } from './connection.js';
|
import { createMCPServer } from './connection.js';
|
||||||
import { resolveConfig } from './config.js';
|
import { resolveConfig } from './config.js';
|
||||||
import { contextFactory } from './browserContextFactory.js';
|
import { contextFactory } from './browserContextFactory.js';
|
||||||
import type { Connection } from '../index.js';
|
|
||||||
import type { Config } from '../config.js';
|
import type { Config } from '../config.js';
|
||||||
import type { BrowserContext } from 'playwright';
|
import type { BrowserContext } from 'playwright';
|
||||||
import type { BrowserContextFactory } from './browserContextFactory.js';
|
import type { BrowserContextFactory } from './browserContextFactory.js';
|
||||||
|
import type { Server } from '@modelcontextprotocol/sdk/server/index.js';
|
||||||
|
|
||||||
export async function createConnection(userConfig: Config = {}, contextGetter?: () => Promise<BrowserContext>): Promise<Connection> {
|
export async function createConnection(userConfig: Config = {}, contextGetter?: () => Promise<BrowserContext>): Promise<Server> {
|
||||||
const config = await resolveConfig(userConfig);
|
const config = await resolveConfig(userConfig);
|
||||||
const factory = contextGetter ? new SimpleBrowserContextFactory(contextGetter) : contextFactory(config.browser);
|
const factory = contextGetter ? new SimpleBrowserContextFactory(contextGetter) : contextFactory(config.browser);
|
||||||
return createConnectionImpl(config, factory);
|
return createMCPServer(config, factory);
|
||||||
}
|
}
|
||||||
|
|
||||||
class SimpleBrowserContextFactory implements BrowserContextFactory {
|
class SimpleBrowserContextFactory implements BrowserContextFactory {
|
||||||
|
@ -14,17 +14,16 @@
|
|||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { createConnection } from './connection.js';
|
import { createMCPServer } from './connection.js';
|
||||||
|
import { Context } from './context.js';
|
||||||
import { contextFactory as defaultContextFactory } from './browserContextFactory.js';
|
import { contextFactory as defaultContextFactory } from './browserContextFactory.js';
|
||||||
|
|
||||||
import type { FullConfig } from './config.js';
|
import type { FullConfig } from './config.js';
|
||||||
import type { Connection } from './connection.js';
|
|
||||||
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
||||||
import type { BrowserContextFactory } from './browserContextFactory.js';
|
import type { BrowserContextFactory } from './browserContextFactory.js';
|
||||||
|
|
||||||
export class Server {
|
export class Server {
|
||||||
readonly config: FullConfig;
|
readonly config: FullConfig;
|
||||||
private _connectionList: Connection[] = [];
|
|
||||||
private _browserConfig: FullConfig['browser'];
|
private _browserConfig: FullConfig['browser'];
|
||||||
private _contextFactory: BrowserContextFactory;
|
private _contextFactory: BrowserContextFactory;
|
||||||
|
|
||||||
@ -34,11 +33,9 @@ export class Server {
|
|||||||
this._contextFactory = contextFactory ?? defaultContextFactory(this._browserConfig);
|
this._contextFactory = contextFactory ?? defaultContextFactory(this._browserConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
async createConnection(transport: Transport): Promise<Connection> {
|
async createConnection(transport: Transport): Promise<void> {
|
||||||
const connection = await createConnection(this.config, this._contextFactory);
|
const server = await createMCPServer(this.config, this._contextFactory);
|
||||||
this._connectionList.push(connection);
|
await server.connect(transport);
|
||||||
await connection.server.connect(transport);
|
|
||||||
return connection;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
setupExitWatchdog() {
|
setupExitWatchdog() {
|
||||||
@ -48,7 +45,7 @@ export class Server {
|
|||||||
return;
|
return;
|
||||||
isExiting = true;
|
isExiting = true;
|
||||||
setTimeout(() => process.exit(0), 15000);
|
setTimeout(() => process.exit(0), 15000);
|
||||||
await Promise.all(this._connectionList.map(connection => connection.close()));
|
await Context.disposeAll();
|
||||||
process.exit(0);
|
process.exit(0);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ const close = defineTool({
|
|||||||
},
|
},
|
||||||
|
|
||||||
handle: async (context, params, response) => {
|
handle: async (context, params, response) => {
|
||||||
await context.close();
|
await context.closeBrowserContext();
|
||||||
response.setIncludeTabs();
|
response.setIncludeTabs();
|
||||||
response.addCode(`await page.close()`);
|
response.addCode(`await page.close()`);
|
||||||
},
|
},
|
||||||
|
@ -23,11 +23,8 @@ import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js';
|
|||||||
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
|
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
|
||||||
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js';
|
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js';
|
||||||
|
|
||||||
import { logUnhandledError } from './log.js';
|
|
||||||
|
|
||||||
import type { AddressInfo } from 'node:net';
|
import type { AddressInfo } from 'node:net';
|
||||||
import type { Server } from './server.js';
|
import type { Server } from './server.js';
|
||||||
import type { Connection } from './connection.js';
|
|
||||||
|
|
||||||
export async function startStdioTransport(server: Server) {
|
export async function startStdioTransport(server: Server) {
|
||||||
await server.createConnection(new StdioServerTransport());
|
await server.createConnection(new StdioServerTransport());
|
||||||
@ -54,11 +51,10 @@ async function handleSSE(server: Server, req: http.IncomingMessage, res: http.Se
|
|||||||
const transport = new SSEServerTransport('/sse', res);
|
const transport = new SSEServerTransport('/sse', res);
|
||||||
sessions.set(transport.sessionId, transport);
|
sessions.set(transport.sessionId, transport);
|
||||||
testDebug(`create SSE session: ${transport.sessionId}`);
|
testDebug(`create SSE session: ${transport.sessionId}`);
|
||||||
const connection = await server.createConnection(transport);
|
await server.createConnection(transport);
|
||||||
res.on('close', () => {
|
res.on('close', () => {
|
||||||
testDebug(`delete SSE session: ${transport.sessionId}`);
|
testDebug(`delete SSE session: ${transport.sessionId}`);
|
||||||
sessions.delete(transport.sessionId);
|
sessions.delete(transport.sessionId);
|
||||||
void connection.close().catch(logUnhandledError);
|
|
||||||
});
|
});
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -67,10 +63,10 @@ async function handleSSE(server: Server, req: http.IncomingMessage, res: http.Se
|
|||||||
res.end('Method not allowed');
|
res.end('Method not allowed');
|
||||||
}
|
}
|
||||||
|
|
||||||
async function handleStreamable(server: Server, req: http.IncomingMessage, res: http.ServerResponse, sessions: Map<string, { transport: StreamableHTTPServerTransport, connection: Connection }>) {
|
async function handleStreamable(server: Server, req: http.IncomingMessage, res: http.ServerResponse, sessions: Map<string, StreamableHTTPServerTransport>) {
|
||||||
const sessionId = req.headers['mcp-session-id'] as string | undefined;
|
const sessionId = req.headers['mcp-session-id'] as string | undefined;
|
||||||
if (sessionId) {
|
if (sessionId) {
|
||||||
const { transport } = sessions.get(sessionId) ?? {};
|
const transport = sessions.get(sessionId);
|
||||||
if (!transport) {
|
if (!transport) {
|
||||||
res.statusCode = 404;
|
res.statusCode = 404;
|
||||||
res.end('Session not found');
|
res.end('Session not found');
|
||||||
@ -84,18 +80,16 @@ async function handleStreamable(server: Server, req: http.IncomingMessage, res:
|
|||||||
sessionIdGenerator: () => crypto.randomUUID(),
|
sessionIdGenerator: () => crypto.randomUUID(),
|
||||||
onsessioninitialized: async sessionId => {
|
onsessioninitialized: async sessionId => {
|
||||||
testDebug(`create http session: ${transport.sessionId}`);
|
testDebug(`create http session: ${transport.sessionId}`);
|
||||||
const connection = await server.createConnection(transport);
|
await server.createConnection(transport);
|
||||||
sessions.set(sessionId, { transport, connection });
|
sessions.set(sessionId, transport);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
transport.onclose = () => {
|
transport.onclose = () => {
|
||||||
const result = transport.sessionId ? sessions.get(transport.sessionId) : undefined;
|
if (!transport.sessionId)
|
||||||
if (!result)
|
|
||||||
return;
|
return;
|
||||||
sessions.delete(result.transport.sessionId!);
|
sessions.delete(transport.sessionId);
|
||||||
testDebug(`delete http session: ${transport.sessionId}`);
|
testDebug(`delete http session: ${transport.sessionId}`);
|
||||||
result.connection.close().catch(logUnhandledError);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
await transport.handleRequest(req, res);
|
await transport.handleRequest(req, res);
|
||||||
@ -120,7 +114,7 @@ export async function startHttpServer(config: { host?: string, port?: number }):
|
|||||||
}
|
}
|
||||||
|
|
||||||
export function startHttpTransport(httpServer: http.Server, mcpServer: Server) {
|
export function startHttpTransport(httpServer: http.Server, mcpServer: Server) {
|
||||||
const sseSessions = new Map<string, SSEServerTransport>();
|
const sseSessions = new Map();
|
||||||
const streamableSessions = new Map();
|
const streamableSessions = new Map();
|
||||||
httpServer.on('request', async (req, res) => {
|
httpServer.on('request', async (req, res) => {
|
||||||
const url = new URL(`http://localhost${req.url}`);
|
const url = new URL(`http://localhost${req.url}`);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user