diff --git a/src/program.ts b/src/program.ts index ed9cf50..91d1004 100644 --- a/src/program.ts +++ b/src/program.ts @@ -25,8 +25,8 @@ import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js'; import { createServer } from './index'; +import { ServerList } from './server'; -import type { Server } from '@modelcontextprotocol/sdk/server/index.js'; import type { LaunchOptions } from 'playwright'; import assert from 'assert'; @@ -44,86 +44,33 @@ program headless: !!options.headless, channel: 'chrome', }; - const server = createServer({ - userDataDir: options.userDataDir ?? await userDataDir(), + const userDataDir = options.userDataDir ?? await createUserDataDir(); + const serverList = new ServerList(() => createServer({ + userDataDir, launchOptions, vision: !!options.vision, - }); - setupExitWatchdog(server); + })); + setupExitWatchdog(serverList); if (options.port) { - const sessions = new Map(); - const httpServer = http.createServer(async (req, res) => { - if (req.method === 'POST') { - const host = req.headers.host ?? 'http://unknown'; - const sessionId = new URL(host + req.url!).searchParams.get('sessionId'); - if (!sessionId) { - res.statusCode = 400; - res.end('Missing sessionId'); - return; - } - const transport = sessions.get(sessionId); - if (!transport) { - res.statusCode = 404; - res.end('Session not found'); - return; - } - - await transport.handlePostMessage(req, res); - return; - } else if (req.method === 'GET') { - const transport = new SSEServerTransport('/sse', res); - sessions.set(transport.sessionId, transport); - res.on('close', () => { - sessions.delete(transport.sessionId); - }); - await server.connect(transport); - return; - } else { - res.statusCode = 405; - res.end('Method not allowed'); - } - }); - httpServer.listen(+options.port, () => { - const address = httpServer.address(); - assert(address, 'Could not bind server socket'); - let urlPrefixHumanReadable: string; - if (typeof address === 'string') { - urlPrefixHumanReadable = address; - } else { - const port = address.port; - let resolvedHost = address.family === 'IPv4' ? address.address : `[${address.address}]`; - if (resolvedHost === '0.0.0.0' || resolvedHost === '[::]') - resolvedHost = 'localhost'; - urlPrefixHumanReadable = `http://${resolvedHost}:${port}`; - } - console.log(`Listening on ${urlPrefixHumanReadable}`); - console.log('Put this in your client config:'); - console.log(JSON.stringify({ - 'mcpServers': { - 'playwright': { - 'url': `${urlPrefixHumanReadable}/sse` - } - } - }, undefined, 2)); - }); + startSSEServer(+options.port, serverList); } else { - const transport = new StdioServerTransport(); - await server.connect(transport); + const server = await serverList.create(); + await server.connect(new StdioServerTransport()); } }); -function setupExitWatchdog(server: Server) { +function setupExitWatchdog(serverList: ServerList) { process.stdin.on('close', async () => { setTimeout(() => process.exit(0), 15000); - await server.close(); + await serverList.closeAll(); process.exit(0); }); } program.parse(process.argv); -async function userDataDir() { +async function createUserDataDir() { let cacheDirectory: string; if (process.platform === 'linux') cacheDirectory = process.env.XDG_CACHE_HOME || path.join(os.homedir(), '.cache'); @@ -137,3 +84,64 @@ async function userDataDir() { await fs.promises.mkdir(result, { recursive: true }); return result; } + +async function startSSEServer(port: number, serverList: ServerList) { + const sessions = new Map(); + const httpServer = http.createServer(async (req, res) => { + if (req.method === 'POST') { + const host = req.headers.host ?? 'http://unknown'; + const sessionId = new URL(host + req.url!).searchParams.get('sessionId'); + if (!sessionId) { + res.statusCode = 400; + res.end('Missing sessionId'); + return; + } + const transport = sessions.get(sessionId); + if (!transport) { + res.statusCode = 404; + res.end('Session not found'); + return; + } + + await transport.handlePostMessage(req, res); + return; + } else if (req.method === 'GET') { + const transport = new SSEServerTransport('/sse', res); + sessions.set(transport.sessionId, transport); + const server = await serverList.create(); + res.on('close', () => { + sessions.delete(transport.sessionId); + serverList.close(server).catch(e => console.error(e)); + }); + await server.connect(transport); + return; + } else { + res.statusCode = 405; + res.end('Method not allowed'); + } + }); + + httpServer.listen(port, () => { + const address = httpServer.address(); + assert(address, 'Could not bind server socket'); + let url: string; + if (typeof address === 'string') { + url = address; + } else { + const resolvedPort = address.port; + let resolvedHost = address.family === 'IPv4' ? address.address : `[${address.address}]`; + if (resolvedHost === '0.0.0.0' || resolvedHost === '[::]') + resolvedHost = 'localhost'; + url = `http://${resolvedHost}:${resolvedPort}`; + } + console.log(`Listening on ${url}`); + console.log('Put this in your client config:'); + console.log(JSON.stringify({ + 'mcpServers': { + 'playwright': { + 'url': `${url}/sse` + } + } + }, undefined, 2)); + }); +} diff --git a/src/server.ts b/src/server.ts index 0a807a2..fb163c8 100644 --- a/src/server.ts +++ b/src/server.ts @@ -88,3 +88,29 @@ export function createServerWithTools(options: Options): Server { return server; } + +export class ServerList { + private _servers: Server[] = []; + private _serverFactory: () => Server; + + constructor(serverFactory: () => Server) { + this._serverFactory = serverFactory; + } + + async create() { + const server = this._serverFactory(); + this._servers.push(server); + return server; + } + + async close(server: Server) { + const index = this._servers.indexOf(server); + if (index !== -1) + this._servers.splice(index, 1); + await server.close(); + } + + async closeAll() { + await Promise.all(this._servers.map(server => server.close())); + } +}