From 89627fd23a90e1868f3e55c78324f0f1ced3623b Mon Sep 17 00:00:00 2001 From: Pavel Feldman Date: Wed, 2 Apr 2025 11:42:39 -0700 Subject: [PATCH] chore: extract page snapshot, prep for multipage (#120) --- src/context.ts | 121 ++++++++++++++++++++++++++++++++++------ src/index.ts | 15 ++--- src/tools/common.ts | 58 +++++++++++-------- src/tools/screenshot.ts | 14 +++-- src/tools/snapshot.ts | 44 ++++++++++----- src/tools/utils.ts | 40 +------------ tests/basic.spec.ts | 16 +++--- 7 files changed, 194 insertions(+), 114 deletions(-) diff --git a/src/context.ts b/src/context.ts index 2467955..54c7994 100644 --- a/src/context.ts +++ b/src/context.ts @@ -20,6 +20,9 @@ import path from 'path'; import * as playwright from 'playwright'; import yaml from 'yaml'; +import { waitForCompletion } from './tools/utils'; +import { ToolResult } from './tools/tool'; + export type ContextOptions = { browserName?: 'chromium' | 'firefox' | 'webkit'; userDataDir: string; @@ -28,6 +31,15 @@ export type ContextOptions = { remoteEndpoint?: string; }; +type PageOrFrameLocator = playwright.Page | playwright.FrameLocator; + +type RunOptions = { + captureSnapshot?: boolean; + waitForCompletion?: boolean; + status?: string; + noClearFileChooser?: boolean; +}; + export class Context { private _options: ContextOptions; private _browser: playwright.Browser | undefined; @@ -35,7 +47,7 @@ export class Context { private _console: playwright.ConsoleMessage[] = []; private _createPagePromise: Promise | undefined; private _fileChooser: playwright.FileChooser | undefined; - private _lastSnapshotFrames: (playwright.Page | playwright.FrameLocator)[] = []; + private _snapshot: PageSnapshot | undefined; constructor(options: ContextOptions) { this._options = options; @@ -99,6 +111,48 @@ export class Context { return this._page; } + async run(callback: (page: playwright.Page) => Promise, options?: RunOptions): Promise { + const page = this.existingPage(); + try { + if (!options?.noClearFileChooser) + this._fileChooser = undefined; + if (options?.waitForCompletion) + await waitForCompletion(page, () => callback(page)); + else + await callback(page); + } finally { + if (options?.captureSnapshot) + this._snapshot = await PageSnapshot.create(page); + } + return { + content: [{ + type: 'text', + text: this._snapshot?.text({ status: options?.status, hasFileChooser: !!this._fileChooser }) ?? options?.status ?? '', + }], + }; + } + + async runAndWait(callback: (page: playwright.Page) => Promise, options?: RunOptions): Promise { + return await this.run(callback, { + waitForCompletion: true, + ...options, + }); + } + + async runAndWaitWithSnapshot(callback: (page: playwright.Page) => Promise, options?: RunOptions): Promise { + return await this.run(callback, { + captureSnapshot: true, + waitForCompletion: true, + ...options, + }); + } + + lastSnapshot(): PageSnapshot { + if (!this._snapshot) + throw new Error('No snapshot available'); + return this._snapshot; + } + async console(): Promise { return this._console; } @@ -116,14 +170,6 @@ export class Context { this._fileChooser = undefined; } - hasFileChooser() { - return !!this._fileChooser; - } - - clearFileChooser() { - this._fileChooser = undefined; - } - private async _createPage(): Promise<{ browser?: playwright.Browser, page: playwright.Page }> { if (this._options.remoteEndpoint) { const url = new URL(this._options.remoteEndpoint); @@ -160,15 +206,54 @@ export class Context { throw error; } } +} - async allFramesSnapshot() { - this._lastSnapshotFrames = []; - const yaml = await this._allFramesSnapshot(this.existingPage()); - return yaml.toString().trim(); +class PageSnapshot { + private _frameLocators: PageOrFrameLocator[] = []; + private _text!: string; + + constructor() { } - private async _allFramesSnapshot(frame: playwright.Page | playwright.FrameLocator): Promise { - const frameIndex = this._lastSnapshotFrames.push(frame) - 1; + static async create(page: playwright.Page): Promise { + const snapshot = new PageSnapshot(); + await snapshot._build(page); + return snapshot; + } + + text(options?: { status?: string, hasFileChooser?: boolean }): string { + const results: string[] = []; + if (options?.status) { + results.push(options.status); + results.push(''); + } + if (options?.hasFileChooser) { + results.push('- There is a file chooser visible that requires browser_choose_file to be called'); + results.push(''); + } + results.push(this._text); + return results.join('\n'); + } + + private async _build(page: playwright.Page) { + const yamlDocument = await this._snapshotFrame(page); + const lines = []; + lines.push( + `- Page URL: ${page.url()}`, + `- Page Title: ${await page.title()}` + ); + lines.push( + `- Page Snapshot`, + '```yaml', + yamlDocument.toString().trim(), + '```', + '' + ); + this._text = lines.join('\n'); + } + + private async _snapshotFrame(frame: playwright.Page | playwright.FrameLocator) { + const frameIndex = this._frameLocators.push(frame) - 1; const snapshotString = await frame.locator('body').ariaSnapshot({ ref: true }); const snapshot = yaml.parseDocument(snapshotString); @@ -189,7 +274,7 @@ export class Context { const ref = value.match(/\[ref=(.*)\]/)?.[1]; if (ref) { try { - const childSnapshot = await this._allFramesSnapshot(frame.frameLocator(`aria-ref=${ref}`)); + const childSnapshot = await this._snapshotFrame(frame.frameLocator(`aria-ref=${ref}`)); return snapshot.createPair(node.value, childSnapshot); } catch (error) { return snapshot.createPair(node.value, ''); @@ -206,11 +291,11 @@ export class Context { } refLocator(ref: string): playwright.Locator { - let frame = this._lastSnapshotFrames[0]; + let frame = this._frameLocators[0]; const match = ref.match(/^f(\d+)(.*)/); if (match) { const frameIndex = parseInt(match[1], 10); - frame = this._lastSnapshotFrames[frameIndex]; + frame = this._frameLocators[frameIndex]; ref = match[2]; } diff --git a/src/index.ts b/src/index.ts index a54a674..c5969d3 100644 --- a/src/index.ts +++ b/src/index.ts @@ -26,7 +26,6 @@ import type { Server } from '@modelcontextprotocol/sdk/server/index.js'; import type { LaunchOptions } from 'playwright'; const commonTools: Tool[] = [ - common.pressKey, common.wait, common.pdf, common.close, @@ -35,28 +34,30 @@ const commonTools: Tool[] = [ const snapshotTools: Tool[] = [ common.navigate(true), - common.goBack(true), - common.goForward(true), - common.chooseFile(true), snapshot.snapshot, snapshot.click, snapshot.hover, snapshot.type, snapshot.selectOption, snapshot.screenshot, + common.goBack(true), + common.goForward(true), + common.chooseFile(true), + common.pressKey(true), ...commonTools, ]; const screenshotTools: Tool[] = [ common.navigate(false), - common.goBack(false), - common.goForward(false), - common.chooseFile(false), screenshot.screenshot, screenshot.moveMouse, screenshot.click, screenshot.drag, screenshot.type, + common.goBack(false), + common.goForward(false), + common.chooseFile(false), + common.pressKey(false), ...commonTools, ]; diff --git a/src/tools/common.ts b/src/tools/common.ts index 6ef52e2..c91d111 100644 --- a/src/tools/common.ts +++ b/src/tools/common.ts @@ -20,7 +20,7 @@ import path from 'path'; import { z } from 'zod'; import { zodToJsonSchema } from 'zod-to-json-schema'; -import { captureAriaSnapshot, runAndWait, sanitizeForFilePath } from './utils'; +import { sanitizeForFilePath } from './utils'; import type { ToolFactory, Tool } from './tool'; @@ -28,7 +28,7 @@ const navigateSchema = z.object({ url: z.string().describe('The URL to navigate to'), }); -export const navigate: ToolFactory = snapshot => ({ +export const navigate: ToolFactory = captureSnapshot => ({ schema: { name: 'browser_navigate', description: 'Navigate to a URL', @@ -36,18 +36,15 @@ export const navigate: ToolFactory = snapshot => ({ }, handle: async (context, params) => { const validatedParams = navigateSchema.parse(params); - const page = await context.createPage(); - await page.goto(validatedParams.url, { waitUntil: 'domcontentloaded' }); - // Cap load event to 5 seconds, the page is operational at this point. - await page.waitForLoadState('load', { timeout: 5000 }).catch(() => {}); - if (snapshot) - return captureAriaSnapshot(context); - return { - content: [{ - type: 'text', - text: `Navigated to ${validatedParams.url}`, - }], - }; + await context.createPage(); + return await context.run(async page => { + await page.goto(validatedParams.url, { waitUntil: 'domcontentloaded' }); + // Cap load event to 5 seconds, the page is operational at this point. + await page.waitForLoadState('load', { timeout: 5000 }).catch(() => {}); + }, { + status: `Navigated to ${validatedParams.url}`, + captureSnapshot, + }); }, }); @@ -60,7 +57,12 @@ export const goBack: ToolFactory = snapshot => ({ inputSchema: zodToJsonSchema(goBackSchema), }, handle: async context => { - return await runAndWait(context, 'Navigated back', async page => page.goBack(), snapshot); + return await context.runAndWait(async page => { + await page.goBack(); + }, { + status: 'Navigated back', + captureSnapshot: snapshot, + }); }, }); @@ -73,7 +75,12 @@ export const goForward: ToolFactory = snapshot => ({ inputSchema: zodToJsonSchema(goForwardSchema), }, handle: async context => { - return await runAndWait(context, 'Navigated forward', async page => page.goForward(), snapshot); + return await context.runAndWait(async page => { + await page.goForward(); + }, { + status: 'Navigated forward', + captureSnapshot: snapshot, + }); }, }); @@ -103,7 +110,7 @@ const pressKeySchema = z.object({ key: z.string().describe('Name of the key to press or a character to generate, such as `ArrowLeft` or `a`'), }); -export const pressKey: Tool = { +export const pressKey: (captureSnapshot: boolean) => Tool = captureSnapshot => ({ schema: { name: 'browser_press_key', description: 'Press a key on the keyboard', @@ -111,11 +118,14 @@ export const pressKey: Tool = { }, handle: async (context, params) => { const validatedParams = pressKeySchema.parse(params); - return await runAndWait(context, `Pressed key ${validatedParams.key}`, async page => { + return await context.runAndWait(async page => { await page.keyboard.press(validatedParams.key); + }, { + status: `Pressed key ${validatedParams.key}`, + captureSnapshot, }); }, -}; +}); const pdfSchema = z.object({}); @@ -161,7 +171,7 @@ const chooseFileSchema = z.object({ paths: z.array(z.string()).describe('The absolute paths to the files to upload. Can be a single file or multiple files.'), }); -export const chooseFile: ToolFactory = snapshot => ({ +export const chooseFile: ToolFactory = captureSnapshot => ({ schema: { name: 'browser_choose_file', description: 'Choose one or multiple files to upload', @@ -169,9 +179,13 @@ export const chooseFile: ToolFactory = snapshot => ({ }, handle: async (context, params) => { const validatedParams = chooseFileSchema.parse(params); - return await runAndWait(context, `Chose files ${validatedParams.paths.join(', ')}`, async () => { + return await context.runAndWait(async () => { await context.submitFileChooser(validatedParams.paths); - }, snapshot); + }, { + status: `Chose files ${validatedParams.paths.join(', ')}`, + captureSnapshot, + noClearFileChooser: true, + }); }, }); diff --git a/src/tools/screenshot.ts b/src/tools/screenshot.ts index cef2ba1..03e75f9 100644 --- a/src/tools/screenshot.ts +++ b/src/tools/screenshot.ts @@ -17,8 +17,6 @@ import { z } from 'zod'; import { zodToJsonSchema } from 'zod-to-json-schema'; -import { runAndWait } from './utils'; - import type { Tool } from './tool'; export const screenshot: Tool = { @@ -76,11 +74,13 @@ export const click: Tool = { }, handle: async (context, params) => { - return await runAndWait(context, 'Clicked mouse', async page => { + return await context.runAndWait(async page => { const validatedParams = clickSchema.parse(params); await page.mouse.move(validatedParams.x, validatedParams.y); await page.mouse.down(); await page.mouse.up(); + }, { + status: 'Clicked mouse', }); }, }; @@ -101,11 +101,13 @@ export const drag: Tool = { handle: async (context, params) => { const validatedParams = dragSchema.parse(params); - return await runAndWait(context, `Dragged mouse from (${validatedParams.startX}, ${validatedParams.startY}) to (${validatedParams.endX}, ${validatedParams.endY})`, async page => { + return await context.runAndWait(async page => { await page.mouse.move(validatedParams.startX, validatedParams.startY); await page.mouse.down(); await page.mouse.move(validatedParams.endX, validatedParams.endY); await page.mouse.up(); + }, { + status: `Dragged mouse from (${validatedParams.startX}, ${validatedParams.startY}) to (${validatedParams.endX}, ${validatedParams.endY})`, }); }, }; @@ -124,10 +126,12 @@ export const type: Tool = { handle: async (context, params) => { const validatedParams = typeSchema.parse(params); - return await runAndWait(context, `Typed text "${validatedParams.text}"`, async page => { + return await context.runAndWait(async page => { await page.keyboard.type(validatedParams.text); if (validatedParams.submit) await page.keyboard.press('Enter'); + }, { + status: `Typed text "${validatedParams.text}"`, }); }, }; diff --git a/src/tools/snapshot.ts b/src/tools/snapshot.ts index 4e75805..b0b9e12 100644 --- a/src/tools/snapshot.ts +++ b/src/tools/snapshot.ts @@ -17,8 +17,6 @@ import { z } from 'zod'; import zodToJsonSchema from 'zod-to-json-schema'; -import { captureAriaSnapshot, runAndWait } from './utils'; - import type * as playwright from 'playwright'; import type { Tool } from './tool'; @@ -30,7 +28,7 @@ export const snapshot: Tool = { }, handle: async context => { - return await captureAriaSnapshot(context); + return await context.run(async () => {}, { captureSnapshot: true }); }, }; @@ -48,7 +46,12 @@ export const click: Tool = { handle: async (context, params) => { const validatedParams = elementSchema.parse(params); - return runAndWait(context, `"${validatedParams.element}" clicked`, () => context.refLocator(validatedParams.ref).click(), true); + return await context.runAndWaitWithSnapshot(async () => { + const locator = context.lastSnapshot().refLocator(validatedParams.ref); + await locator.click(); + }, { + status: `Clicked "${validatedParams.element}"`, + }); }, }; @@ -68,11 +71,13 @@ export const drag: Tool = { handle: async (context, params) => { const validatedParams = dragSchema.parse(params); - return runAndWait(context, `Dragged "${validatedParams.startElement}" to "${validatedParams.endElement}"`, async () => { - const startLocator = context.refLocator(validatedParams.startRef); - const endLocator = context.refLocator(validatedParams.endRef); + return await context.runAndWaitWithSnapshot(async () => { + const startLocator = context.lastSnapshot().refLocator(validatedParams.startRef); + const endLocator = context.lastSnapshot().refLocator(validatedParams.endRef); await startLocator.dragTo(endLocator); - }, true); + }, { + status: `Dragged "${validatedParams.startElement}" to "${validatedParams.endElement}"`, + }); }, }; @@ -85,7 +90,12 @@ export const hover: Tool = { handle: async (context, params) => { const validatedParams = elementSchema.parse(params); - return runAndWait(context, `Hovered over "${validatedParams.element}"`, () => context.refLocator(validatedParams.ref).hover(), true); + return context.runAndWaitWithSnapshot(async () => { + const locator = context.lastSnapshot().refLocator(validatedParams.ref); + await locator.hover(); + }, { + status: `Hovered over "${validatedParams.element}"`, + }); }, }; @@ -103,12 +113,14 @@ export const type: Tool = { handle: async (context, params) => { const validatedParams = typeSchema.parse(params); - return await runAndWait(context, `Typed "${validatedParams.text}" into "${validatedParams.element}"`, async () => { - const locator = context.refLocator(validatedParams.ref); + return await context.runAndWaitWithSnapshot(async () => { + const locator = context.lastSnapshot().refLocator(validatedParams.ref); await locator.fill(validatedParams.text); if (validatedParams.submit) await locator.press('Enter'); - }, true); + }, { + status: `Typed "${validatedParams.text}" into "${validatedParams.element}"`, + }); }, }; @@ -125,10 +137,12 @@ export const selectOption: Tool = { handle: async (context, params) => { const validatedParams = selectOptionSchema.parse(params); - return await runAndWait(context, `Selected option in "${validatedParams.element}"`, async () => { - const locator = context.refLocator(validatedParams.ref); + return await context.runAndWaitWithSnapshot(async () => { + const locator = context.lastSnapshot().refLocator(validatedParams.ref); await locator.selectOption(validatedParams.values); - }, true); + }, { + status: `Selected option in "${validatedParams.element}"`, + }); }, }; diff --git a/src/tools/utils.ts b/src/tools/utils.ts index e9c984e..dd9790d 100644 --- a/src/tools/utils.ts +++ b/src/tools/utils.ts @@ -15,10 +15,8 @@ */ import type * as playwright from 'playwright'; -import type { ToolResult } from './tool'; -import type { Context } from '../context'; -async function waitForCompletion(page: playwright.Page, callback: () => Promise): Promise { +export async function waitForCompletion(page: playwright.Page, callback: () => Promise): Promise { const requests = new Set(); let frameNavigated = false; let waitCallback: () => void = () => {}; @@ -71,42 +69,6 @@ async function waitForCompletion(page: playwright.Page, callback: () => Promi } } -export async function runAndWait(context: Context, status: string, callback: (page: playwright.Page) => Promise, snapshot: boolean = false): Promise { - const page = context.existingPage(); - const dismissFileChooser = context.hasFileChooser(); - await waitForCompletion(page, () => callback(page)); - if (dismissFileChooser) - context.clearFileChooser(); - const result: ToolResult = snapshot ? await captureAriaSnapshot(context, status) : { - content: [{ type: 'text', text: status }], - }; - return result; -} - -export async function captureAriaSnapshot(context: Context, status: string = ''): Promise { - const page = context.existingPage(); - const lines = []; - if (status) - lines.push(`${status}`); - lines.push( - '', - `- Page URL: ${page.url()}`, - `- Page Title: ${await page.title()}` - ); - if (context.hasFileChooser()) - lines.push(`- There is a file chooser visible that requires browser_choose_file to be called`); - lines.push( - `- Page Snapshot`, - '```yaml', - await context.allFramesSnapshot(), - '```', - '' - ); - return { - content: [{ type: 'text', text: lines.join('\n') }], - }; -} - export function sanitizeForFilePath(s: string) { return s.replace(/[\x00-\x2C\x2E-\x2F\x3A-\x40\x5B-\x60\x7B-\x7F]+/g, '-'); } diff --git a/tests/basic.spec.ts b/tests/basic.spec.ts index 6d4640c..663a583 100644 --- a/tests/basic.spec.ts +++ b/tests/basic.spec.ts @@ -23,15 +23,15 @@ test('test tool list', async ({ client, visionClient }) => { const { tools } = await client.listTools(); expect(tools.map(t => t.name)).toEqual([ 'browser_navigate', - 'browser_go_back', - 'browser_go_forward', - 'browser_choose_file', 'browser_snapshot', 'browser_click', 'browser_hover', 'browser_type', 'browser_select_option', 'browser_take_screenshot', + 'browser_go_back', + 'browser_go_forward', + 'browser_choose_file', 'browser_press_key', 'browser_wait', 'browser_save_as_pdf', @@ -42,14 +42,14 @@ test('test tool list', async ({ client, visionClient }) => { const { tools: visionTools } = await visionClient.listTools(); expect(visionTools.map(t => t.name)).toEqual([ 'browser_navigate', - 'browser_go_back', - 'browser_go_forward', - 'browser_choose_file', 'browser_screenshot', 'browser_move_mouse', 'browser_click', 'browser_drag', 'browser_type', + 'browser_go_back', + 'browser_go_forward', + 'browser_choose_file', 'browser_press_key', 'browser_wait', 'browser_save_as_pdf', @@ -99,7 +99,7 @@ test('test browser_click', async ({ client }) => { element: 'Submit button', ref: 's1e3', }, - })).toHaveTextContent(`"Submit button" clicked + })).toHaveTextContent(`Clicked "Submit button" - Page URL: data:text/html,Title - Page Title: Title @@ -235,7 +235,7 @@ test('stitched aria frames', async ({ client }) => { element: 'World', ref: 'f1s1e3', }, - })).toContainTextContent('"World" clicked'); + })).toContainTextContent('Clicked "World"'); }); test('browser_choose_file', async ({ client }) => {