diff --git a/package.json b/package.json index ed29800..a69c189 100644 --- a/package.json +++ b/package.json @@ -20,6 +20,8 @@ "watch": "tsc --watch", "test": "playwright test", "ctest": "playwright test --project=chrome", + "ftest": "playwright test --project=firefox", + "wtest": "playwright test --project=webkit", "clean": "rm -rf lib", "npm-publish": "npm run clean && npm run build && npm run test && npm publish" }, diff --git a/src/context.ts b/src/context.ts index 1feb1d0..26a04d4 100644 --- a/src/context.ts +++ b/src/context.ts @@ -18,9 +18,10 @@ import * as playwright from 'playwright'; import yaml from 'yaml'; import { waitForCompletion } from './tools/utils'; +import { ManualPromise } from './manualPromise'; import type { ImageContent, TextContent } from '@modelcontextprotocol/sdk/types'; -import type { ModalState, Tool } from './tools/tool'; +import type { ModalState, Tool, ToolActionResult } from './tools/tool'; export type ContextOptions = { browserName?: 'chromium' | 'firefox' | 'webkit'; @@ -32,6 +33,10 @@ export type ContextOptions = { type PageOrFrameLocator = playwright.Page | playwright.FrameLocator; +type PendingAction = { + dialogShown: ManualPromise; +}; + export class Context { readonly tools: Tool[]; readonly options: ContextOptions; @@ -40,6 +45,7 @@ export class Context { private _tabs: Tab[] = []; private _currentTab: Tab | undefined; private _modalStates: (ModalState & { tab: Tab })[] = []; + private _pendingAction: PendingAction | undefined; constructor(tools: Tool[], options: ContextOptions) { this.tools = tools; @@ -120,6 +126,7 @@ export class Context { // Tab management is done outside of the action() call. const toolResult = await tool.handle(this, params); const { code, action, waitForNetwork, captureSnapshot, resultOverride } = toolResult; + const racingAction = action ? () => this._raceAgainstModalDialogs(action) : undefined; if (resultOverride) return resultOverride; @@ -138,11 +145,11 @@ export class Context { let actionResult: { content?: (ImageContent | TextContent)[] } | undefined; try { if (waitForNetwork) - actionResult = await waitForCompletion(tab.page, async () => action?.()) ?? undefined; + actionResult = await waitForCompletion(this, tab.page, async () => racingAction?.()) ?? undefined; else - actionResult = await action?.() ?? undefined; + actionResult = await racingAction?.() ?? undefined; } finally { - if (captureSnapshot) + if (captureSnapshot && !this._javaScriptBlocked()) await tab.captureSnapshot(); } @@ -190,6 +197,43 @@ ${code.join('\n')} }; } + async waitForTimeout(time: number) { + if (this._currentTab && !this._javaScriptBlocked()) + await this._currentTab.page.waitForTimeout(time); + else + await new Promise(f => setTimeout(f, time)); + } + + private async _raceAgainstModalDialogs(action: () => Promise): Promise { + this._pendingAction = { + dialogShown: new ManualPromise(), + }; + + let result: ToolActionResult | undefined; + try { + await Promise.race([ + action().then(r => result = r), + this._pendingAction.dialogShown, + ]); + } finally { + this._pendingAction = undefined; + } + return result; + } + + private _javaScriptBlocked(): boolean { + return this._modalStates.some(state => state.type === 'dialog'); + } + + dialogShown(tab: Tab, dialog: playwright.Dialog) { + this.setModalState({ + type: 'dialog', + description: `"${dialog.type()}" dialog with message "${dialog.message()}"`, + dialog, + }, tab); + this._pendingAction?.dialogShown.resolve(); + } + private _onPageCreated(page: playwright.Page) { const tab = new Tab(this, page, tab => this._onPageClosed(tab)); this._tabs.push(tab); @@ -293,6 +337,7 @@ export class Tab { fileChooser: chooser, }, this); }); + page.on('dialog', dialog => this.context.dialogShown(this, dialog)); page.setDefaultNavigationTimeout(60000); page.setDefaultTimeout(5000); } diff --git a/src/index.ts b/src/index.ts index 33cd7f4..f845544 100644 --- a/src/index.ts +++ b/src/index.ts @@ -21,6 +21,7 @@ import fs from 'fs'; import { createServerWithTools } from './server'; import common from './tools/common'; import console from './tools/console'; +import dialogs from './tools/dialogs'; import files from './tools/files'; import install from './tools/install'; import keyboard from './tools/keyboard'; @@ -37,6 +38,7 @@ import type { LaunchOptions } from 'playwright'; const snapshotTools: Tool[] = [ ...common(true), ...console, + ...dialogs(true), ...files(true), ...install, ...keyboard(true), @@ -49,6 +51,7 @@ const snapshotTools: Tool[] = [ const screenshotTools: Tool[] = [ ...common(false), ...console, + ...dialogs(false), ...files(false), ...install, ...keyboard(false), diff --git a/src/manualPromise.ts b/src/manualPromise.ts new file mode 100644 index 0000000..a5034e0 --- /dev/null +++ b/src/manualPromise.ts @@ -0,0 +1,127 @@ +/** + * Copyright (c) Microsoft Corporation. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export class ManualPromise extends Promise { + private _resolve!: (t: T) => void; + private _reject!: (e: Error) => void; + private _isDone: boolean; + + constructor() { + let resolve: (t: T) => void; + let reject: (e: Error) => void; + super((f, r) => { + resolve = f; + reject = r; + }); + this._isDone = false; + this._resolve = resolve!; + this._reject = reject!; + } + + isDone() { + return this._isDone; + } + + resolve(t: T) { + this._isDone = true; + this._resolve(t); + } + + reject(e: Error) { + this._isDone = true; + this._reject(e); + } + + static override get [Symbol.species]() { + return Promise; + } + + override get [Symbol.toStringTag]() { + return 'ManualPromise'; + } +} + +export class LongStandingScope { + private _terminateError: Error | undefined; + private _closeError: Error | undefined; + private _terminatePromises = new Map, string[]>(); + private _isClosed = false; + + reject(error: Error) { + this._isClosed = true; + this._terminateError = error; + for (const p of this._terminatePromises.keys()) + p.resolve(error); + } + + close(error: Error) { + this._isClosed = true; + this._closeError = error; + for (const [p, frames] of this._terminatePromises) + p.resolve(cloneError(error, frames)); + } + + isClosed() { + return this._isClosed; + } + + static async raceMultiple(scopes: LongStandingScope[], promise: Promise): Promise { + return Promise.race(scopes.map(s => s.race(promise))); + } + + async race(promise: Promise | Promise[]): Promise { + return this._race(Array.isArray(promise) ? promise : [promise], false) as Promise; + } + + async safeRace(promise: Promise, defaultValue?: T): Promise { + return this._race([promise], true, defaultValue); + } + + private async _race(promises: Promise[], safe: boolean, defaultValue?: any): Promise { + const terminatePromise = new ManualPromise(); + const frames = captureRawStack(); + if (this._terminateError) + terminatePromise.resolve(this._terminateError); + if (this._closeError) + terminatePromise.resolve(cloneError(this._closeError, frames)); + this._terminatePromises.set(terminatePromise, frames); + try { + return await Promise.race([ + terminatePromise.then(e => safe ? defaultValue : Promise.reject(e)), + ...promises + ]); + } finally { + this._terminatePromises.delete(terminatePromise); + } + } +} + +function cloneError(error: Error, frames: string[]) { + const clone = new Error(); + clone.name = error.name; + clone.message = error.message; + clone.stack = [error.name + ':' + error.message, ...frames].join('\n'); + return clone; +} + +function captureRawStack(): string[] { + const stackTraceLimit = Error.stackTraceLimit; + Error.stackTraceLimit = 50; + const error = new Error(); + const stack = error.stack || ''; + Error.stackTraceLimit = stackTraceLimit; + return stack.split('\n'); +} diff --git a/src/tools/dialogs.ts b/src/tools/dialogs.ts new file mode 100644 index 0000000..4c08bad --- /dev/null +++ b/src/tools/dialogs.ts @@ -0,0 +1,65 @@ +/** + * Copyright (c) Microsoft Corporation. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { z } from 'zod'; +import { zodToJsonSchema } from 'zod-to-json-schema'; + +import type { ToolFactory } from './tool'; + +const handleDialogSchema = z.object({ + accept: z.boolean().describe('Whether to accept the dialog.'), + promptText: z.string().optional().describe('The text of the prompt in case of a prompt dialog.'), +}); + +const handleDialog: ToolFactory = captureSnapshot => ({ + capability: 'core', + + schema: { + name: 'browser_handle_dialog', + description: 'Handle a dialog', + inputSchema: zodToJsonSchema(handleDialogSchema), + }, + + handle: async (context, params) => { + const validatedParams = handleDialogSchema.parse(params); + const dialogState = context.modalStates().find(state => state.type === 'dialog'); + if (!dialogState) + throw new Error('No dialog visible'); + + if (validatedParams.accept) + await dialogState.dialog.accept(validatedParams.promptText); + else + await dialogState.dialog.dismiss(); + + context.clearModalState(dialogState); + + const code = [ + `// `, + ]; + + return { + code, + captureSnapshot, + waitForNetwork: false, + }; + }, + + clearsModalState: 'dialog', +}); + +export default (captureSnapshot: boolean) => [ + handleDialog(captureSnapshot), +]; diff --git a/src/tools/tool.ts b/src/tools/tool.ts index d80fddf..d161129 100644 --- a/src/tools/tool.ts +++ b/src/tools/tool.ts @@ -32,14 +32,22 @@ export type FileUploadModalState = { fileChooser: playwright.FileChooser; }; -export type ModalState = FileUploadModalState; +export type DialogModalState = { + type: 'dialog'; + description: string; + dialog: playwright.Dialog; +}; + +export type ModalState = FileUploadModalState | DialogModalState; + +export type ToolActionResult = { content?: (ImageContent | TextContent)[] } | undefined | void; export type ToolResult = { code: string[]; - action?: () => Promise<{ content?: (ImageContent | TextContent)[] } | undefined | void>; + action?: () => Promise; captureSnapshot: boolean; waitForNetwork: boolean; - resultOverride?: { content?: (ImageContent | TextContent)[] }; + resultOverride?: ToolActionResult; }; export type Tool = { diff --git a/src/tools/utils.ts b/src/tools/utils.ts index dd9790d..88150c3 100644 --- a/src/tools/utils.ts +++ b/src/tools/utils.ts @@ -15,8 +15,9 @@ */ import type * as playwright from 'playwright'; +import type { Context } from '../context'; -export async function waitForCompletion(page: playwright.Page, callback: () => Promise): Promise { +export async function waitForCompletion(context: Context, page: playwright.Page, callback: () => Promise): Promise { const requests = new Set(); let frameNavigated = false; let waitCallback: () => void = () => {}; @@ -62,7 +63,7 @@ export async function waitForCompletion(page: playwright.Page, callback: () = if (!requests.size && !frameNavigated) waitCallback(); await waitBarrier; - await page.evaluate(() => new Promise(f => setTimeout(f, 1000))); + await context.waitForTimeout(1000); return result; } finally { dispose(); diff --git a/tests/capabilities.spec.ts b/tests/capabilities.spec.ts index ea05526..fd8070b 100644 --- a/tests/capabilities.spec.ts +++ b/tests/capabilities.spec.ts @@ -23,6 +23,7 @@ test('test snapshot tool list', async ({ client }) => { 'browser_console_messages', 'browser_drag', 'browser_file_upload', + 'browser_handle_dialog', 'browser_hover', 'browser_select_option', 'browser_type', @@ -50,6 +51,7 @@ test('test vision tool list', async ({ visionClient }) => { 'browser_close', 'browser_console_messages', 'browser_file_upload', + 'browser_handle_dialog', 'browser_install', 'browser_navigate_back', 'browser_navigate_forward', diff --git a/tests/dialogs.spec.ts b/tests/dialogs.spec.ts new file mode 100644 index 0000000..ac9fefb --- /dev/null +++ b/tests/dialogs.spec.ts @@ -0,0 +1,192 @@ +/** + * Copyright (c) Microsoft Corporation. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { test, expect } from './fixtures'; + +// https://github.com/microsoft/playwright/issues/35663 +test.skip(({ mcpBrowser, mcpHeadless }) => mcpBrowser === 'webkit' && mcpHeadless); + +test('alert dialog', async ({ client }) => { + expect(await client.callTool({ + name: 'browser_navigate', + arguments: { + url: 'data:text/html,Title', + }, + })).toContainTextContent('- button "Button" [ref=s1e3]'); + + expect(await client.callTool({ + name: 'browser_click', + arguments: { + element: 'Button', + ref: 's1e3', + }, + })).toHaveTextContent(`- Ran Playwright code: +\`\`\`js +// Click Button +await page.getByRole('button', { name: 'Button' }).click(); +\`\`\` + +### Modal state +- ["alert" dialog with message "Alert"]: can be handled by the "browser_handle_dialog" tool`); + + const result = await client.callTool({ + name: 'browser_handle_dialog', + arguments: { + accept: true, + }, + }); + + expect(result).not.toContainTextContent('### Modal state'); + expect(result).toHaveTextContent(`- Ran Playwright code: +\`\`\`js +// +\`\`\` + +- Page URL: data:text/html,Title +- Page Title: Title +- Page Snapshot +\`\`\`yaml +- button "Button" [ref=s2e3] +\`\`\` +`); +}); + +test('two alert dialogs', async ({ client }) => { + test.fixme(true, 'Race between the dialog and ariaSnapshot'); + expect(await client.callTool({ + name: 'browser_navigate', + arguments: { + url: 'data:text/html,Title', + }, + })).toContainTextContent('- button "Button" [ref=s1e3]'); + + expect(await client.callTool({ + name: 'browser_click', + arguments: { + element: 'Button', + ref: 's1e3', + }, + })).toHaveTextContent(`- Ran Playwright code: +\`\`\`js +// Click Button +await page.getByRole('button', { name: 'Button' }).click(); +\`\`\` + +### Modal state +- ["alert" dialog with message "Alert 1"]: can be handled by the "browser_handle_dialog" tool`); + + const result = await client.callTool({ + name: 'browser_handle_dialog', + arguments: { + accept: true, + }, + }); + + expect(result).not.toContainTextContent('### Modal state'); +}); + +test('confirm dialog (true)', async ({ client }) => { + expect(await client.callTool({ + name: 'browser_navigate', + arguments: { + url: 'data:text/html,Title', + }, + })).toContainTextContent('- button "Button" [ref=s1e3]'); + + expect(await client.callTool({ + name: 'browser_click', + arguments: { + element: 'Button', + ref: 's1e3', + }, + })).toContainTextContent(`### Modal state +- ["confirm" dialog with message "Confirm"]: can be handled by the "browser_handle_dialog" tool`); + + const result = await client.callTool({ + name: 'browser_handle_dialog', + arguments: { + accept: true, + }, + }); + + expect(result).not.toContainTextContent('### Modal state'); + expect(result).toContainTextContent('// '); + expect(result).toContainTextContent(`- Page Snapshot +\`\`\`yaml +- text: "true" +\`\`\``); +}); + +test('confirm dialog (false)', async ({ client }) => { + expect(await client.callTool({ + name: 'browser_navigate', + arguments: { + url: 'data:text/html,Title', + }, + })).toContainTextContent('- button "Button" [ref=s1e3]'); + + expect(await client.callTool({ + name: 'browser_click', + arguments: { + element: 'Button', + ref: 's1e3', + }, + })).toContainTextContent(`### Modal state +- ["confirm" dialog with message "Confirm"]: can be handled by the "browser_handle_dialog" tool`); + + const result = await client.callTool({ + name: 'browser_handle_dialog', + arguments: { + accept: false, + }, + }); + + expect(result).toContainTextContent(`- Page Snapshot +\`\`\`yaml +- text: "false" +\`\`\``); +}); + +test('prompt dialog', async ({ client }) => { + expect(await client.callTool({ + name: 'browser_navigate', + arguments: { + url: 'data:text/html,Title', + }, + })).toContainTextContent('- button "Button" [ref=s1e3]'); + + expect(await client.callTool({ + name: 'browser_click', + arguments: { + element: 'Button', + ref: 's1e3', + }, + })).toContainTextContent(`### Modal state +- ["prompt" dialog with message "Prompt"]: can be handled by the "browser_handle_dialog" tool`); + + const result = await client.callTool({ + name: 'browser_handle_dialog', + arguments: { + accept: true, + promptText: 'Answer', + }, + }); + + expect(result).toContainTextContent(`- Page Snapshot +\`\`\`yaml +- text: Answer +\`\`\``); +}); diff --git a/tests/fixtures.ts b/tests/fixtures.ts index 4601271..cd7b4fd 100644 --- a/tests/fixtures.ts +++ b/tests/fixtures.ts @@ -22,19 +22,20 @@ import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js' import { Client } from '@modelcontextprotocol/sdk/client/index.js'; import { spawn } from 'child_process'; -type Fixtures = { +type TestFixtures = { client: Client; visionClient: Client; startClient: (options?: { args?: string[] }) => Promise; wsEndpoint: string; cdpEndpoint: string; +}; - // Cli options. +type WorkerFixtures = { mcpHeadless: boolean; mcpBrowser: string | undefined; }; -export const test = baseTest.extend({ +export const test = baseTest.extend({ client: async ({ startClient }, use) => { await use(await startClient()); @@ -98,11 +99,11 @@ export const test = baseTest.extend({ browserProcess.kill(); }, - mcpHeadless: async ({ headless }, use) => { + mcpHeadless: [async ({ headless }, use) => { await use(headless); - }, + }, { scope: 'worker' }], - mcpBrowser: ['chromium', { option: true }], + mcpBrowser: ['chromium', { option: true, scope: 'worker' }], }); type Response = Awaited>;