From 53e3e379912d83a681ad4c1699b170b8565fed04 Mon Sep 17 00:00:00 2001 From: Yury Semikhatsky Date: Tue, 22 Jul 2025 22:23:00 -0700 Subject: [PATCH] chore(extension): terminate all connections when tab closes (#741) --- extension/src/background.ts | 26 ++++++++++------ extension/src/relayConnection.ts | 18 ++++------- src/extension/cdpRelay.ts | 53 ++++++++++++++++++++++---------- 3 files changed, 58 insertions(+), 39 deletions(-) diff --git a/extension/src/background.ts b/extension/src/background.ts index a1fe2ac..205d847 100644 --- a/extension/src/background.ts +++ b/extension/src/background.ts @@ -40,7 +40,7 @@ class TabShareExtension { sendResponse({ success: false, error: 'No tab id' }); return true; } - this._connectTab(tabId, message.mcpRelayUrl!).then( + this._connectTab(sender.tab!, message.mcpRelayUrl!).then( () => sendResponse({ success: true }), (error: any) => sendResponse({ success: false, error: error.message })); return true; // Return true to indicate that the response will be sent asynchronously @@ -48,9 +48,9 @@ class TabShareExtension { return false; } - private async _connectTab(tabId: number, mcpRelayUrl: string): Promise { + private async _connectTab(tab: chrome.tabs.Tab, mcpRelayUrl: string): Promise { try { - debugLog(`Connecting tab ${tabId} to bridge at ${mcpRelayUrl}`); + debugLog(`Connecting tab ${tab.id} to bridge at ${mcpRelayUrl}`); const socket = new WebSocket(mcpRelayUrl); await new Promise((resolve, reject) => { socket.onopen = () => resolve(); @@ -58,7 +58,7 @@ class TabShareExtension { setTimeout(() => reject(new Error('Connection timeout')), 5000); }); - const connection = new RelayConnection(socket); + const connection = new RelayConnection(socket, tab.id!); const connectionClosed = (m: string) => { debugLog(m); if (this._activeConnection === connection) { @@ -70,12 +70,15 @@ class TabShareExtension { socket.onerror = error => connectionClosed(`WebSocket error: ${error}`); this._activeConnection = connection; - connection.setConnectedTabId(tabId); - await this._setConnectedTabId(tabId); - debugLog(`Tab ${tabId} connected successfully`); + await Promise.all([ + this._setConnectedTabId(tab.id!), + chrome.tabs.update(tab.id!, { active: true }), + chrome.windows.update(tab.windowId, { focused: true }), + ]); + debugLog(`Connected to MCP bridge`); } catch (error: any) { - debugLog(`Failed to connect tab ${tabId}:`, error.message); await this._setConnectedTabId(null); + debugLog(`Failed to connect tab ${tab.id}:`, error.message); throw error; } } @@ -96,8 +99,11 @@ class TabShareExtension { } private async _onTabRemoved(tabId: number): Promise { - if (this._connectedTabId === tabId) - this._activeConnection!.setConnectedTabId(null); + if (this._connectedTabId !== tabId) + return; + this._activeConnection?.close('Browser tab closed'); + this._activeConnection = undefined; + this._connectedTabId = null; } private async _onTabUpdated(tabId: number, changeInfo: chrome.tabs.TabChangeInfo, tab: chrome.tabs.Tab): Promise { diff --git a/extension/src/relayConnection.ts b/extension/src/relayConnection.ts index 0a9b964..2913571 100644 --- a/extension/src/relayConnection.ts +++ b/extension/src/relayConnection.ts @@ -37,12 +37,13 @@ type ProtocolResponse = { }; export class RelayConnection { - private _debuggee: chrome.debugger.Debuggee = {}; + private _debuggee: chrome.debugger.Debuggee; private _ws: WebSocket; private _eventListener: (source: chrome.debugger.DebuggerSession, method: string, params: any) => void; private _detachListener: (source: chrome.debugger.Debuggee, reason: string) => void; - constructor(ws: WebSocket) { + constructor(ws: WebSocket, tabId: number) { + this._debuggee = { tabId }; this._ws = ws; this._ws.onmessage = this._onMessage.bind(this); // Store listeners for cleanup @@ -52,18 +53,10 @@ export class RelayConnection { chrome.debugger.onDetach.addListener(this._detachListener); } - setConnectedTabId(tabId: number | null): void { - if (!tabId) { - this._debuggee = { }; - return; - } - this._debuggee = { tabId }; - } - - close(message?: string): void { + close(message: string): void { chrome.debugger.onEvent.removeListener(this._eventListener); chrome.debugger.onDetach.removeListener(this._detachListener); - this._ws.close(1000, message || 'Connection closed'); + this._ws.close(1000, message); } private async _detachDebugger(): Promise { @@ -95,6 +88,7 @@ export class RelayConnection { reason, }, }); + this._debuggee = { }; } private _onMessage(event: MessageEvent): void { diff --git a/src/extension/cdpRelay.ts b/src/extension/cdpRelay.ts index 1cc8ec3..c6b9bbb 100644 --- a/src/extension/cdpRelay.ts +++ b/src/extension/cdpRelay.ts @@ -125,8 +125,8 @@ export class CDPRelayServer { } stop(): void { - this._playwrightConnection?.close(); - this._extensionConnection?.close(); + this._closePlaywrightConnection('Server stopped'); + this._closeExtensionConnection('Server stopped'); } private _onConnection(ws: WebSocket, request: http.IncomingMessage): void { @@ -153,11 +153,11 @@ export class CDPRelayServer { } }); ws.on('close', () => { - if (this._playwrightConnection === ws) { - this._playwrightConnection = null; - this._closeExtensionConnection(); - debugLogger('Playwright MCP disconnected'); - } + if (this._playwrightConnection !== ws) + return; + this._playwrightConnection = null; + this._closeExtensionConnection('Playwright client disconnected'); + debugLogger('Playwright WebSocket closed'); }); ws.on('error', error => { debugLogger('Playwright WebSocket error:', error); @@ -165,24 +165,37 @@ export class CDPRelayServer { debugLogger('Playwright MCP connected'); } - private _closeExtensionConnection() { + private _closeExtensionConnection(reason: string) { + this._extensionConnection?.close(reason); + this._resetExtensionConnection(); + } + + private _resetExtensionConnection() { this._connectedTabInfo = undefined; - this._extensionConnection?.close(); this._extensionConnection = null; this._extensionConnectionPromise = new Promise(resolve => { this._extensionConnectionResolve = resolve; }); } + private _closePlaywrightConnection(reason: string) { + if (this._playwrightConnection?.readyState === WebSocket.OPEN) + this._playwrightConnection.close(1000, reason); + this._playwrightConnection = null; + } + private _handleExtensionConnection(ws: WebSocket): void { if (this._extensionConnection) { ws.close(1000, 'Another extension connection already established'); return; } this._extensionConnection = new ExtensionConnection(ws); - this._extensionConnection.onclose = c => { - if (this._extensionConnection === c) - this._extensionConnection = null; + this._extensionConnection.onclose = (c, reason) => { + debugLogger('Extension WebSocket closed:', reason, c === this._extensionConnection); + if (this._extensionConnection !== c) + return; + this._resetExtensionConnection(); + this._closePlaywrightConnection(`Extension disconnected: ${reason}`); }; this._extensionConnection.onmessage = this._handleExtensionMessage.bind(this); this._extensionConnectionResolve?.(); @@ -300,7 +313,12 @@ class ExtensionContextFactory implements BrowserContextFactory { private async _obtainBrowser(clientInfo: { name: string, version: string }): Promise { await this._relay.ensureExtensionConnectionForMCPContext(clientInfo); - return await playwright.chromium.connectOverCDP(this._relay.cdpEndpoint()); + const browser = await playwright.chromium.connectOverCDP(this._relay.cdpEndpoint()); + browser.on('disconnected', () => { + this._browserPromise = undefined; + debugLogger('Browser disconnected'); + }); + return browser; } } @@ -326,7 +344,7 @@ class ExtensionConnection { private _lastId = 0; onmessage?: (method: string, params: any) => void; - onclose?: (self: ExtensionConnection) => void; + onclose?: (self: ExtensionConnection, reason: string) => void; constructor(ws: WebSocket) { this._ws = ws; @@ -346,10 +364,10 @@ class ExtensionConnection { }); } - close(message?: string) { + close(message: string) { debugLogger('closing extension connection:', message); - this._ws.close(1000, message ?? 'Connection closed'); - this.onclose?.(this); + if (this._ws.readyState === WebSocket.OPEN) + this._ws.close(1000, message); } private _onMessage(event: websocket.RawData) { @@ -391,6 +409,7 @@ class ExtensionConnection { private _onClose(event: websocket.CloseEvent) { debugLogger(` code=${event.code} reason=${event.reason}`); this._dispose(); + this.onclose?.(this, event.reason); } private _onError(event: websocket.ErrorEvent) {