mirror of
https://github.com/microsoft/playwright-mcp.git
synced 2025-07-27 00:52:27 +08:00
chore: refactor server, prepare for browser reuse (#490)
This commit is contained in:
parent
3cd74a824a
commit
54ed7c3200
@ -14,22 +14,22 @@
|
|||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
|
import { Server as McpServer } from '@modelcontextprotocol/sdk/server/index.js';
|
||||||
import { CallToolRequestSchema, ListToolsRequestSchema, Tool as McpTool } from '@modelcontextprotocol/sdk/types.js';
|
import { CallToolRequestSchema, ListToolsRequestSchema, Tool as McpTool } from '@modelcontextprotocol/sdk/types.js';
|
||||||
import { zodToJsonSchema } from 'zod-to-json-schema';
|
import { zodToJsonSchema } from 'zod-to-json-schema';
|
||||||
|
|
||||||
import { Context, packageJSON } from './context.js';
|
import { Context } from './context.js';
|
||||||
import { snapshotTools, visionTools } from './tools.js';
|
import { snapshotTools, visionTools } from './tools.js';
|
||||||
|
import { packageJSON } from './package.js';
|
||||||
|
|
||||||
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
|
||||||
import { FullConfig } from './config.js';
|
import { FullConfig } from './config.js';
|
||||||
|
|
||||||
export async function createConnection(config: FullConfig): Promise<Connection> {
|
export function createConnection(config: FullConfig): Connection {
|
||||||
const allTools = config.vision ? visionTools : snapshotTools;
|
const allTools = config.vision ? visionTools : snapshotTools;
|
||||||
const tools = allTools.filter(tool => !config.capabilities || tool.capability === 'core' || config.capabilities.includes(tool.capability));
|
const tools = allTools.filter(tool => !config.capabilities || tool.capability === 'core' || config.capabilities.includes(tool.capability));
|
||||||
|
|
||||||
const context = new Context(tools, config);
|
const context = new Context(tools, config);
|
||||||
const server = new Server({ name: 'Playwright', version: packageJSON.version }, {
|
const server = new McpServer({ name: 'Playwright', version: packageJSON.version }, {
|
||||||
capabilities: {
|
capabilities: {
|
||||||
tools: {},
|
tools: {},
|
||||||
}
|
}
|
||||||
@ -74,25 +74,19 @@ export async function createConnection(config: FullConfig): Promise<Connection>
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
const connection = new Connection(server, context);
|
return new Connection(server, context);
|
||||||
return connection;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export class Connection {
|
export class Connection {
|
||||||
readonly server: Server;
|
readonly server: McpServer;
|
||||||
readonly context: Context;
|
readonly context: Context;
|
||||||
|
|
||||||
constructor(server: Server, context: Context) {
|
constructor(server: McpServer, context: Context) {
|
||||||
this.server = server;
|
this.server = server;
|
||||||
this.context = context;
|
this.context = context;
|
||||||
}
|
this.server.oninitialized = () => {
|
||||||
|
|
||||||
async connect(transport: Transport) {
|
|
||||||
await this.server.connect(transport);
|
|
||||||
await new Promise<void>(resolve => {
|
|
||||||
this.server.oninitialized = () => resolve();
|
|
||||||
});
|
|
||||||
this.context.clientVersion = this.server.getClientVersion();
|
this.context.clientVersion = this.server.getClientVersion();
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
async close() {
|
async close() {
|
||||||
|
@ -15,7 +15,6 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
import fs from 'node:fs';
|
import fs from 'node:fs';
|
||||||
import url from 'node:url';
|
|
||||||
import os from 'node:os';
|
import os from 'node:os';
|
||||||
import path from 'node:path';
|
import path from 'node:path';
|
||||||
|
|
||||||
@ -416,6 +415,3 @@ async function createUserDataDir(browserConfig: FullConfig['browser']) {
|
|||||||
await fs.promises.mkdir(result, { recursive: true });
|
await fs.promises.mkdir(result, { recursive: true });
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
const __filename = url.fileURLToPath(import.meta.url);
|
|
||||||
export const packageJSON = JSON.parse(fs.readFileSync(path.join(path.dirname(__filename), '..', 'package.json'), 'utf8'));
|
|
||||||
|
22
src/package.ts
Normal file
22
src/package.ts
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
/**
|
||||||
|
* Copyright (c) Microsoft Corporation.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import fs from 'node:fs';
|
||||||
|
import url from 'node:url';
|
||||||
|
import path from 'node:path';
|
||||||
|
|
||||||
|
const __filename = url.fileURLToPath(import.meta.url);
|
||||||
|
export const packageJSON = JSON.parse(fs.readFileSync(path.join(path.dirname(__filename), '..', 'package.json'), 'utf8'));
|
@ -15,14 +15,13 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
import { program } from 'commander';
|
import { program } from 'commander';
|
||||||
|
|
||||||
import { startHttpTransport, startStdioTransport } from './transport.js';
|
|
||||||
import { resolveCLIConfig } from './config.js';
|
|
||||||
// @ts-ignore
|
// @ts-ignore
|
||||||
import { startTraceViewerServer } from 'playwright-core/lib/server';
|
import { startTraceViewerServer } from 'playwright-core/lib/server';
|
||||||
|
|
||||||
import type { Connection } from './connection.js';
|
import { startHttpTransport, startStdioTransport } from './transport.js';
|
||||||
import { packageJSON } from './context.js';
|
import { resolveCLIConfig } from './config.js';
|
||||||
|
import { Server } from './server.js';
|
||||||
|
import { packageJSON } from './package.js';
|
||||||
|
|
||||||
program
|
program
|
||||||
.version('Version ' + packageJSON.version)
|
.version('Version ' + packageJSON.version)
|
||||||
@ -54,13 +53,13 @@ program
|
|||||||
.option('--vision', 'Run server that uses screenshots (Aria snapshots are used by default)')
|
.option('--vision', 'Run server that uses screenshots (Aria snapshots are used by default)')
|
||||||
.action(async options => {
|
.action(async options => {
|
||||||
const config = await resolveCLIConfig(options);
|
const config = await resolveCLIConfig(options);
|
||||||
const connectionList: Connection[] = [];
|
const server = new Server(config);
|
||||||
setupExitWatchdog(connectionList);
|
server.setupExitWatchdog();
|
||||||
|
|
||||||
if (options.port)
|
if (options.port)
|
||||||
startHttpTransport(config, +options.port, options.host, connectionList);
|
startHttpTransport(server, +options.port, options.host);
|
||||||
else
|
else
|
||||||
await startStdioTransport(config, connectionList);
|
await startStdioTransport(server);
|
||||||
|
|
||||||
if (config.saveTrace) {
|
if (config.saveTrace) {
|
||||||
const server = await startTraceViewerServer();
|
const server = await startTraceViewerServer();
|
||||||
@ -71,21 +70,8 @@ program
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
function setupExitWatchdog(connectionList: Connection[]) {
|
|
||||||
const handleExit = async () => {
|
|
||||||
setTimeout(() => process.exit(0), 15000);
|
|
||||||
for (const connection of connectionList)
|
|
||||||
await connection.close();
|
|
||||||
process.exit(0);
|
|
||||||
};
|
|
||||||
|
|
||||||
process.stdin.on('close', handleExit);
|
|
||||||
process.on('SIGINT', handleExit);
|
|
||||||
process.on('SIGTERM', handleExit);
|
|
||||||
}
|
|
||||||
|
|
||||||
function semicolonSeparatedList(value: string): string[] {
|
function semicolonSeparatedList(value: string): string[] {
|
||||||
return value.split(';').map(v => v.trim());
|
return value.split(';').map(v => v.trim());
|
||||||
}
|
}
|
||||||
|
|
||||||
program.parse(process.argv);
|
void program.parseAsync(process.argv);
|
||||||
|
49
src/server.ts
Normal file
49
src/server.ts
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
/**
|
||||||
|
* Copyright (c) Microsoft Corporation.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { createConnection } from './connection.js';
|
||||||
|
|
||||||
|
import type { FullConfig } from './config.js';
|
||||||
|
import type { Connection } from './connection.js';
|
||||||
|
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
||||||
|
|
||||||
|
export class Server {
|
||||||
|
readonly config: FullConfig;
|
||||||
|
private _connectionList: Connection[] = [];
|
||||||
|
|
||||||
|
constructor(config: FullConfig) {
|
||||||
|
this.config = config;
|
||||||
|
}
|
||||||
|
|
||||||
|
async createConnection(transport: Transport): Promise<Connection> {
|
||||||
|
const connection = createConnection(this.config);
|
||||||
|
this._connectionList.push(connection);
|
||||||
|
await connection.server.connect(transport);
|
||||||
|
return connection;
|
||||||
|
}
|
||||||
|
|
||||||
|
setupExitWatchdog() {
|
||||||
|
const handleExit = async () => {
|
||||||
|
setTimeout(() => process.exit(0), 15000);
|
||||||
|
await Promise.all(this._connectionList.map(connection => connection.close()));
|
||||||
|
process.exit(0);
|
||||||
|
};
|
||||||
|
|
||||||
|
process.stdin.on('close', handleExit);
|
||||||
|
process.on('SIGINT', handleExit);
|
||||||
|
process.on('SIGTERM', handleExit);
|
||||||
|
}
|
||||||
|
}
|
@ -22,18 +22,13 @@ import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js';
|
|||||||
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
|
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
|
||||||
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js';
|
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js';
|
||||||
|
|
||||||
import { createConnection } from './connection.js';
|
import type { Server } from './server.js';
|
||||||
|
|
||||||
import type { Connection } from './connection.js';
|
export async function startStdioTransport(server: Server) {
|
||||||
import type { FullConfig } from './config.js';
|
await server.createConnection(new StdioServerTransport());
|
||||||
|
|
||||||
export async function startStdioTransport(config: FullConfig, connectionList: Connection[]) {
|
|
||||||
const connection = await createConnection(config);
|
|
||||||
await connection.connect(new StdioServerTransport());
|
|
||||||
connectionList.push(connection);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async function handleSSE(config: FullConfig, req: http.IncomingMessage, res: http.ServerResponse, url: URL, sessions: Map<string, SSEServerTransport>, connectionList: Connection[]) {
|
async function handleSSE(server: Server, req: http.IncomingMessage, res: http.ServerResponse, url: URL, sessions: Map<string, SSEServerTransport>) {
|
||||||
if (req.method === 'POST') {
|
if (req.method === 'POST') {
|
||||||
const sessionId = url.searchParams.get('sessionId');
|
const sessionId = url.searchParams.get('sessionId');
|
||||||
if (!sessionId) {
|
if (!sessionId) {
|
||||||
@ -51,15 +46,11 @@ async function handleSSE(config: FullConfig, req: http.IncomingMessage, res: htt
|
|||||||
} else if (req.method === 'GET') {
|
} else if (req.method === 'GET') {
|
||||||
const transport = new SSEServerTransport('/sse', res);
|
const transport = new SSEServerTransport('/sse', res);
|
||||||
sessions.set(transport.sessionId, transport);
|
sessions.set(transport.sessionId, transport);
|
||||||
const connection = await createConnection(config);
|
const connection = await server.createConnection(transport);
|
||||||
await connection.connect(transport);
|
|
||||||
connectionList.push(connection);
|
|
||||||
res.on('close', () => {
|
res.on('close', () => {
|
||||||
sessions.delete(transport.sessionId);
|
sessions.delete(transport.sessionId);
|
||||||
connection.close().catch(e => {
|
|
||||||
// eslint-disable-next-line no-console
|
// eslint-disable-next-line no-console
|
||||||
console.error(e);
|
void connection.close().catch(e => console.error(e));
|
||||||
});
|
|
||||||
});
|
});
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -68,7 +59,7 @@ async function handleSSE(config: FullConfig, req: http.IncomingMessage, res: htt
|
|||||||
res.end('Method not allowed');
|
res.end('Method not allowed');
|
||||||
}
|
}
|
||||||
|
|
||||||
async function handleStreamable(config: FullConfig, req: http.IncomingMessage, res: http.ServerResponse, sessions: Map<string, StreamableHTTPServerTransport>, connectionList: Connection[]) {
|
async function handleStreamable(server: Server, req: http.IncomingMessage, res: http.ServerResponse, sessions: Map<string, StreamableHTTPServerTransport>) {
|
||||||
const sessionId = req.headers['mcp-session-id'] as string | undefined;
|
const sessionId = req.headers['mcp-session-id'] as string | undefined;
|
||||||
if (sessionId) {
|
if (sessionId) {
|
||||||
const transport = sessions.get(sessionId);
|
const transport = sessions.get(sessionId);
|
||||||
@ -91,12 +82,8 @@ async function handleStreamable(config: FullConfig, req: http.IncomingMessage, r
|
|||||||
if (transport.sessionId)
|
if (transport.sessionId)
|
||||||
sessions.delete(transport.sessionId);
|
sessions.delete(transport.sessionId);
|
||||||
};
|
};
|
||||||
const connection = await createConnection(config);
|
await server.createConnection(transport);
|
||||||
connectionList.push(connection);
|
await transport.handleRequest(req, res);
|
||||||
await Promise.all([
|
|
||||||
connection.connect(transport),
|
|
||||||
transport.handleRequest(req, res),
|
|
||||||
]);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -104,15 +91,15 @@ async function handleStreamable(config: FullConfig, req: http.IncomingMessage, r
|
|||||||
res.end('Invalid request');
|
res.end('Invalid request');
|
||||||
}
|
}
|
||||||
|
|
||||||
export function startHttpTransport(config: FullConfig, port: number, hostname: string | undefined, connectionList: Connection[]) {
|
export function startHttpTransport(server: Server, port: number, hostname: string | undefined) {
|
||||||
const sseSessions = new Map<string, SSEServerTransport>();
|
const sseSessions = new Map<string, SSEServerTransport>();
|
||||||
const streamableSessions = new Map<string, StreamableHTTPServerTransport>();
|
const streamableSessions = new Map<string, StreamableHTTPServerTransport>();
|
||||||
const httpServer = http.createServer(async (req, res) => {
|
const httpServer = http.createServer(async (req, res) => {
|
||||||
const url = new URL(`http://localhost${req.url}`);
|
const url = new URL(`http://localhost${req.url}`);
|
||||||
if (url.pathname.startsWith('/mcp'))
|
if (url.pathname.startsWith('/mcp'))
|
||||||
await handleStreamable(config, req, res, streamableSessions, connectionList);
|
await handleStreamable(server, req, res, streamableSessions);
|
||||||
else
|
else
|
||||||
await handleSSE(config, req, res, url, sseSessions, connectionList);
|
await handleSSE(server, req, res, url, sseSessions);
|
||||||
});
|
});
|
||||||
httpServer.listen(port, hostname, () => {
|
httpServer.listen(port, hostname, () => {
|
||||||
const address = httpServer.address();
|
const address = httpServer.address();
|
||||||
|
@ -15,17 +15,12 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
import url from 'node:url';
|
import url from 'node:url';
|
||||||
import http from 'node:http';
|
|
||||||
import { spawn } from 'node:child_process';
|
import { spawn } from 'node:child_process';
|
||||||
import path from 'node:path';
|
import path from 'node:path';
|
||||||
import type { AddressInfo } from 'node:net';
|
|
||||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||||
import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js';
|
|
||||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
||||||
|
|
||||||
import { createConnection } from '@playwright/mcp';
|
|
||||||
|
|
||||||
import { test as baseTest, expect } from './fixtures.js';
|
import { test as baseTest, expect } from './fixtures.js';
|
||||||
|
|
||||||
// NOTE: Can be removed when we drop Node.js 18 support and changed to import.meta.filename.
|
// NOTE: Can be removed when we drop Node.js 18 support and changed to import.meta.filename.
|
||||||
@ -55,6 +50,7 @@ test('sse transport', async ({ serverEndpoint }) => {
|
|||||||
const client = new Client({ name: 'test', version: '1.0.0' });
|
const client = new Client({ name: 'test', version: '1.0.0' });
|
||||||
await client.connect(transport);
|
await client.connect(transport);
|
||||||
await client.ping();
|
await client.ping();
|
||||||
|
await client.close();
|
||||||
});
|
});
|
||||||
|
|
||||||
test('streamable http transport', async ({ serverEndpoint }) => {
|
test('streamable http transport', async ({ serverEndpoint }) => {
|
||||||
@ -64,46 +60,3 @@ test('streamable http transport', async ({ serverEndpoint }) => {
|
|||||||
await client.ping();
|
await client.ping();
|
||||||
expect(transport.sessionId, 'has session support').toBeDefined();
|
expect(transport.sessionId, 'has session support').toBeDefined();
|
||||||
});
|
});
|
||||||
|
|
||||||
test('sse transport via public API', async ({ server }, testInfo) => {
|
|
||||||
const userDataDir = testInfo.outputPath('user-data-dir');
|
|
||||||
const sessions = new Map<string, SSEServerTransport>();
|
|
||||||
const mcpServer = http.createServer(async (req, res) => {
|
|
||||||
if (req.method === 'GET') {
|
|
||||||
const connection = await createConnection({
|
|
||||||
browser: {
|
|
||||||
userDataDir,
|
|
||||||
launchOptions: { headless: true }
|
|
||||||
},
|
|
||||||
});
|
|
||||||
const transport = new SSEServerTransport('/sse', res);
|
|
||||||
sessions.set(transport.sessionId, transport);
|
|
||||||
await connection.connect(transport);
|
|
||||||
} else if (req.method === 'POST') {
|
|
||||||
const url = new URL(`http://localhost${req.url}`);
|
|
||||||
const sessionId = url.searchParams.get('sessionId');
|
|
||||||
if (!sessionId) {
|
|
||||||
res.statusCode = 400;
|
|
||||||
return res.end('Missing sessionId');
|
|
||||||
}
|
|
||||||
const transport = sessions.get(sessionId);
|
|
||||||
if (!transport) {
|
|
||||||
res.statusCode = 404;
|
|
||||||
return res.end('Session not found');
|
|
||||||
}
|
|
||||||
void transport.handlePostMessage(req, res);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
await new Promise<void>(resolve => mcpServer.listen(0, () => resolve()));
|
|
||||||
const serverUrl = `http://localhost:${(mcpServer.address() as AddressInfo).port}/sse`;
|
|
||||||
const transport = new SSEClientTransport(new URL(serverUrl));
|
|
||||||
const client = new Client({ name: 'test', version: '1.0.0' });
|
|
||||||
await client.connect(transport);
|
|
||||||
await client.ping();
|
|
||||||
expect(await client.callTool({
|
|
||||||
name: 'browser_navigate',
|
|
||||||
arguments: { url: server.HELLO_WORLD },
|
|
||||||
})).toContainTextContent(`- generic [ref=e1]: Hello, world!`);
|
|
||||||
await client.close();
|
|
||||||
mcpServer.close();
|
|
||||||
});
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user