diff --git a/src/browserContextFactory.ts b/src/browserContextFactory.ts index e91f2be..c548bba 100644 --- a/src/browserContextFactory.ts +++ b/src/browserContextFactory.ts @@ -36,7 +36,7 @@ export function contextFactory(browserConfig: FullConfig['browser']): BrowserCon } export interface BrowserContextFactory { - createContext(): Promise<{ browserContext: playwright.BrowserContext, close: () => Promise }>; + createContext(clientInfo: { name: string, version: string }): Promise<{ browserContext: playwright.BrowserContext, close: () => Promise }>; } class BaseContextFactory implements BrowserContextFactory { diff --git a/src/context.ts b/src/context.ts index 18fcc71..3a19872 100644 --- a/src/context.ts +++ b/src/context.ts @@ -336,7 +336,7 @@ ${code.join('\n')} private async _setupBrowserContext(): Promise<{ browserContext: playwright.BrowserContext, close: () => Promise }> { // TODO: move to the browser context factory to make it based on isolation mode. - const result = await this._browserContextFactory.createContext(); + const result = await this._browserContextFactory.createContext(this.clientVersion!); const { browserContext } = result; await this._setupRequestInterception(browserContext); for (const page of browserContext.pages()) diff --git a/src/extension/cdpRelay.ts b/src/extension/cdpRelay.ts index e775909..5fabd25 100644 --- a/src/extension/cdpRelay.ts +++ b/src/extension/cdpRelay.ts @@ -29,6 +29,8 @@ import debug from 'debug'; import { promisify } from 'node:util'; import { exec } from 'node:child_process'; import { httpAddressToString, startHttpServer } from '../transport.js'; +import { BrowserContextFactory } from '../browserContextFactory.js'; +import { Browser, chromium, type BrowserContext } from 'playwright'; const debugLogger = debug('pw:mcp:relay'); @@ -50,7 +52,6 @@ type CDPResponse = { export class CDPRelayServer { private _wsHost: string; - private _getClientInfo: () => { name: string, version: string }; private _cdpPath: string; private _extensionPath: string; private _wss: WebSocketServer; @@ -64,8 +65,7 @@ export class CDPRelayServer { private _extensionConnectionPromise: Promise; private _extensionConnectionResolve: (() => void) | null = null; - constructor(server: http.Server, getClientInfo: () => { name: string, version: string }) { - this._getClientInfo = getClientInfo; + constructor(server: http.Server) { this._wsHost = httpAddressToString(server.address()).replace(/^http/, 'ws'); const uuid = crypto.randomUUID(); @@ -75,7 +75,7 @@ export class CDPRelayServer { this._extensionConnectionPromise = new Promise(resolve => { this._extensionConnectionResolve = resolve; }); - this._wss = new WebSocketServer({ server, verifyClient: this._verifyClient.bind(this) }); + this._wss = new WebSocketServer({ server }); this._wss.on('connection', this._onConnection.bind(this)); } @@ -87,26 +87,19 @@ export class CDPRelayServer { return `${this._wsHost}${this._extensionPath}`; } - private async _verifyClient(info: { origin: string, req: http.IncomingMessage }, callback: (result: boolean, code?: number, message?: string) => void) { - if (info.req.url?.startsWith(this._cdpPath)) { - if (this._playwrightConnection) { - callback(false, 500, 'Another Playwright connection already established'); - return; - } - await this._connectBrowser(); - await this._extensionConnectionPromise; - callback(!!this._extensionConnection); + async ensureExtensionConnectionForMCPContext(clientInfo: { name: string, version: string }) { + if (this._extensionConnection) return; - } - callback(true); + await this._connectBrowser(clientInfo); + await this._extensionConnectionPromise; } - private async _connectBrowser() { + private async _connectBrowser(clientInfo: { name: string, version: string }) { const mcpRelayEndpoint = `${this._wsHost}${this._extensionPath}`; // Need to specify "key" in the manifest.json to make the id stable when loading from file. const url = new URL('chrome-extension://jakfalbnbhgkpmoaakfflhflbfpkailf/connect.html'); url.searchParams.set('mcpRelayUrl', mcpRelayEndpoint); - url.searchParams.set('client', JSON.stringify(this._getClientInfo())); + url.searchParams.set('client', JSON.stringify(clientInfo)); const href = url.toString(); const command = `'/Applications/Google Chrome.app/Contents/MacOS/Google Chrome' '${href}'`; try { @@ -289,18 +282,37 @@ export class CDPRelayServer { } } -export async function startCDPRelayServer({ - getClientInfo, - port, -}: { - getClientInfo: () => { name: string, version: string }; - port: number; -}) { +class ExtensionContextFactory implements BrowserContextFactory { + private _relay: CDPRelayServer; + private _browserPromise: Promise | undefined; + + constructor(relay: CDPRelayServer) { + this._relay = relay; + } + + async createContext(clientInfo: { name: string, version: string }): Promise<{ browserContext: BrowserContext, close: () => Promise }> { + // First call will establish the connection to the extension. + if (!this._browserPromise) + this._browserPromise = this._obtainBrowser(clientInfo); + const browser = await this._browserPromise; + return { + browserContext: browser.contexts()[0], + close: async () => {} + }; + } + + private async _obtainBrowser(clientInfo: { name: string, version: string }): Promise { + await this._relay.ensureExtensionConnectionForMCPContext(clientInfo); + return await chromium.connectOverCDP(this._relay.cdpEndpoint()); + } +} + +export async function startCDPRelayServer(port: number) { const httpServer = await startHttpServer({ port }); - const cdpRelayServer = new CDPRelayServer(httpServer, getClientInfo); + const cdpRelayServer = new CDPRelayServer(httpServer); process.on('exit', () => cdpRelayServer.stop()); debugLogger(`CDP relay server started, extension endpoint: ${cdpRelayServer.extensionEndpoint()}.`); - return cdpRelayServer.cdpEndpoint(); + return new ExtensionContextFactory(cdpRelayServer); } class ExtensionConnection { diff --git a/src/extension/main.ts b/src/extension/main.ts index f6c8651..93cdc62 100644 --- a/src/extension/main.ts +++ b/src/extension/main.ts @@ -15,24 +15,21 @@ */ import { resolveCLIConfig } from '../config.js'; -import { Connection } from '../connection.js'; -import { startStdioTransport } from '../transport.js'; +import { startHttpServer, startHttpTransport, startStdioTransport } from '../transport.js'; import { Server } from '../server.js'; import { startCDPRelayServer } from './cdpRelay.js'; export async function runWithExtension(options: any) { const config = await resolveCLIConfig({ }); + const contextFactory = await startCDPRelayServer(9225); - let connection: Connection | null = null; - const cdpEndpoint = await startCDPRelayServer({ - getClientInfo: () => connection!.server.getClientVersion()!, - port: 9225, - }); - // Point CDP endpoint to the relay server. - config.browser.cdpEndpoint = cdpEndpoint; - - const server = new Server(config); + const server = new Server(config, contextFactory); server.setupExitWatchdog(); - connection = await startStdioTransport(server); + if (options.port !== undefined) { + const httpServer = await startHttpServer({ port: options.port }); + startHttpTransport(httpServer, server); + } else { + await startStdioTransport(server); + } } diff --git a/src/server.ts b/src/server.ts index 8c143e1..b54367e 100644 --- a/src/server.ts +++ b/src/server.ts @@ -15,7 +15,7 @@ */ import { createConnection } from './connection.js'; -import { contextFactory } from './browserContextFactory.js'; +import { contextFactory as defaultContextFactory } from './browserContextFactory.js'; import type { FullConfig } from './config.js'; import type { Connection } from './connection.js'; @@ -28,10 +28,10 @@ export class Server { private _browserConfig: FullConfig['browser']; private _contextFactory: BrowserContextFactory; - constructor(config: FullConfig) { + constructor(config: FullConfig, contextFactory?: BrowserContextFactory) { this.config = config; this._browserConfig = config.browser; - this._contextFactory = contextFactory(this._browserConfig); + this._contextFactory = contextFactory ?? defaultContextFactory(this._browserConfig); } async createConnection(transport: Transport): Promise { diff --git a/src/transport.ts b/src/transport.ts index 14858e9..48bcd9a 100644 --- a/src/transport.ts +++ b/src/transport.ts @@ -30,7 +30,7 @@ import type { Server } from './server.js'; import type { Connection } from './connection.js'; export async function startStdioTransport(server: Server) { - return await server.createConnection(new StdioServerTransport()); + await server.createConnection(new StdioServerTransport()); } const testDebug = debug('pw:mcp:test');