From bf7dbabca445529841edaebe76861f81baf0eb5c Mon Sep 17 00:00:00 2001 From: Simon Knott Date: Mon, 28 Apr 2025 11:11:31 +0200 Subject: [PATCH] feat: support streamable http transport (#243) Adds support for the new StreamableHttp transport. I'm not aware of any clients that implement it, but somebody's gotta make the start! Once some clients support it, we can also advertise it in the README. --- package-lock.json | 23 ++++----- package.json | 2 +- src/program.ts | 78 ++-------------------------- src/transport.ts | 127 ++++++++++++++++++++++++++++++++++++++++++++++ tests/sse.spec.ts | 60 ++++++++++++++-------- 5 files changed, 181 insertions(+), 109 deletions(-) create mode 100644 src/transport.ts diff --git a/package-lock.json b/package-lock.json index e9c4511..3f1fc41 100644 --- a/package-lock.json +++ b/package-lock.json @@ -9,7 +9,7 @@ "version": "0.0.15", "license": "Apache-2.0", "dependencies": { - "@modelcontextprotocol/sdk": "^1.6.1", + "@modelcontextprotocol/sdk": "^1.10.1", "commander": "^13.1.0", "playwright": "1.53.0-alpha-1745357020000", "yaml": "^2.7.1", @@ -228,17 +228,18 @@ } }, "node_modules/@modelcontextprotocol/sdk": { - "version": "1.7.0", - "resolved": "https://registry.npmjs.org/@modelcontextprotocol/sdk/-/sdk-1.7.0.tgz", - "integrity": "sha512-IYPe/FLpvF3IZrd/f5p5ffmWhMc3aEMuM2wGJASDqC2Ge7qatVCdbfPx3n/5xFeb19xN0j/911M2AaFuircsWA==", + "version": "1.10.1", + "resolved": "https://registry.npmjs.org/@modelcontextprotocol/sdk/-/sdk-1.10.1.tgz", + "integrity": "sha512-xNYdFdkJqEfIaTVP1gPKoEvluACHZsHZegIoICX8DM1o6Qf3G5u2BQJHmgd0n4YgRPqqK/u1ujQvrgAxxSJT9w==", "license": "MIT", "dependencies": { "content-type": "^1.0.5", "cors": "^2.8.5", + "cross-spawn": "^7.0.3", "eventsource": "^3.0.2", "express": "^5.0.1", "express-rate-limit": "^7.5.0", - "pkce-challenge": "^4.1.0", + "pkce-challenge": "^5.0.0", "raw-body": "^3.0.0", "zod": "^3.23.8", "zod-to-json-schema": "^3.24.1" @@ -1091,7 +1092,6 @@ "version": "7.0.6", "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", - "dev": true, "license": "MIT", "dependencies": { "path-key": "^3.1.0", @@ -2786,7 +2786,6 @@ "version": "2.0.0", "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==", - "dev": true, "license": "ISC" }, "node_modules/js-yaml": { @@ -3256,7 +3255,6 @@ "version": "3.1.1", "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", - "dev": true, "license": "MIT", "engines": { "node": ">=8" @@ -3292,9 +3290,9 @@ } }, "node_modules/pkce-challenge": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/pkce-challenge/-/pkce-challenge-4.1.0.tgz", - "integrity": "sha512-ZBmhE1C9LcPoH9XZSdwiPtbPHZROwAnMy+kIFQVrnMCxY4Cudlz3gBOpzilgc0jOgRaiT3sIWfpMomW2ar2orQ==", + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/pkce-challenge/-/pkce-challenge-5.0.0.tgz", + "integrity": "sha512-ueGLflrrnvwB3xuo/uGob5pd5FN7l0MsLf0Z87o/UQmRtwjvfylfc9MurIxRAWywCYTgrvpXBcqjV4OfCYGCIQ==", "license": "MIT", "engines": { "node": ">=16.20.0" @@ -3796,7 +3794,6 @@ "version": "2.0.0", "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", - "dev": true, "license": "MIT", "dependencies": { "shebang-regex": "^3.0.0" @@ -3809,7 +3806,6 @@ "version": "3.0.0", "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", - "dev": true, "license": "MIT", "engines": { "node": ">=8" @@ -4238,7 +4234,6 @@ "version": "2.0.2", "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", - "dev": true, "license": "ISC", "dependencies": { "isexe": "^2.0.0" diff --git a/package.json b/package.json index a27c613..fa8144e 100644 --- a/package.json +++ b/package.json @@ -34,7 +34,7 @@ } }, "dependencies": { - "@modelcontextprotocol/sdk": "^1.6.1", + "@modelcontextprotocol/sdk": "^1.10.1", "commander": "^13.1.0", "playwright": "1.53.0-alpha-1745357020000", "yaml": "^2.7.1", diff --git a/src/program.ts b/src/program.ts index c678c37..4c5f563 100644 --- a/src/program.ts +++ b/src/program.ts @@ -14,18 +14,13 @@ * limitations under the License. */ -import http from 'http'; - import { program } from 'commander'; -import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'; -import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js'; - import { createServer } from './index'; import { ServerList } from './server'; -import assert from 'assert'; import { ToolCapability } from './tools/tool'; +import { startHttpTransport, startStdioTransport } from './transport'; const packageJSON = require('../package.json'); @@ -53,12 +48,10 @@ program })); setupExitWatchdog(serverList); - if (options.port) { - startSSEServer(+options.port, options.host || 'localhost', serverList); - } else { - const server = await serverList.create(); - await server.connect(new StdioServerTransport()); - } + if (options.port) + startHttpTransport(+options.port, options.host, serverList); + else + await startStdioTransport(serverList); }); function setupExitWatchdog(serverList: ServerList) { @@ -74,64 +67,3 @@ function setupExitWatchdog(serverList: ServerList) { } program.parse(process.argv); - -function startSSEServer(port: number, host: string, serverList: ServerList) { - const sessions = new Map(); - const httpServer = http.createServer(async (req, res) => { - if (req.method === 'POST') { - const searchParams = new URL(`http://localhost${req.url}`).searchParams; - const sessionId = searchParams.get('sessionId'); - if (!sessionId) { - res.statusCode = 400; - res.end('Missing sessionId'); - return; - } - const transport = sessions.get(sessionId); - if (!transport) { - res.statusCode = 404; - res.end('Session not found'); - return; - } - - await transport.handlePostMessage(req, res); - return; - } else if (req.method === 'GET') { - const transport = new SSEServerTransport('/sse', res); - sessions.set(transport.sessionId, transport); - const server = await serverList.create(); - res.on('close', () => { - sessions.delete(transport.sessionId); - serverList.close(server).catch(e => console.error(e)); - }); - await server.connect(transport); - return; - } else { - res.statusCode = 405; - res.end('Method not allowed'); - } - }); - - httpServer.listen(port, host, () => { - const address = httpServer.address(); - assert(address, 'Could not bind server socket'); - let url: string; - if (typeof address === 'string') { - url = address; - } else { - const resolvedPort = address.port; - let resolvedHost = address.family === 'IPv4' ? address.address : `[${address.address}]`; - if (resolvedHost === '0.0.0.0' || resolvedHost === '[::]') - resolvedHost = host === 'localhost' ? 'localhost' : resolvedHost; - url = `http://${resolvedHost}:${resolvedPort}`; - } - console.log(`Listening on ${url}`); - console.log('Put this in your client config:'); - console.log(JSON.stringify({ - 'mcpServers': { - 'playwright': { - 'url': `${url}/sse` - } - } - }, undefined, 2)); - }); -} diff --git a/src/transport.ts b/src/transport.ts new file mode 100644 index 0000000..b21db06 --- /dev/null +++ b/src/transport.ts @@ -0,0 +1,127 @@ +/** + * Copyright (c) Microsoft Corporation. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import http from 'node:http'; +import assert from 'node:assert'; +import crypto from 'node:crypto'; + +import { ServerList } from './server'; +import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'; +import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js'; +import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; + +export async function startStdioTransport(serverList: ServerList) { + const server = await serverList.create(); + await server.connect(new StdioServerTransport()); +} + +async function handleSSE(req: http.IncomingMessage, res: http.ServerResponse, url: URL, serverList: ServerList, sessions: Map) { + if (req.method === 'POST') { + const sessionId = url.searchParams.get('sessionId'); + if (!sessionId) { + res.statusCode = 400; + return res.end('Missing sessionId'); + } + + const transport = sessions.get(sessionId); + if (!transport) { + res.statusCode = 404; + return res.end('Session not found'); + } + + return await transport.handlePostMessage(req, res); + } else if (req.method === 'GET') { + const transport = new SSEServerTransport('/sse', res); + sessions.set(transport.sessionId, transport); + const server = await serverList.create(); + res.on('close', () => { + sessions.delete(transport.sessionId); + serverList.close(server).catch(e => console.error(e)); + }); + return await server.connect(transport); + } + + res.statusCode = 405; + res.end('Method not allowed'); +} + +async function handleStreamable(req: http.IncomingMessage, res: http.ServerResponse, serverList: ServerList, sessions: Map) { + const sessionId = req.headers['mcp-session-id'] as string | undefined; + if (sessionId) { + const transport = sessions.get(sessionId); + if (!transport) { + res.statusCode = 404; + res.end('Session not found'); + return; + } + return await transport.handleRequest(req, res); + } + + if (req.method === 'POST') { + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => crypto.randomUUID(), + onsessioninitialized: sessionId => { + sessions.set(sessionId, transport); + } + }); + transport.onclose = () => { + if (transport.sessionId) + sessions.delete(transport.sessionId); + }; + const server = await serverList.create(); + await server.connect(transport); + return await transport.handleRequest(req, res); + } + + res.statusCode = 400; + res.end('Invalid request'); +} + +export function startHttpTransport(port: number, hostname: string | undefined, serverList: ServerList) { + const sseSessions = new Map(); + const streamableSessions = new Map(); + const httpServer = http.createServer(async (req, res) => { + const url = new URL(`http://localhost${req.url}`); + if (url.pathname.startsWith('/mcp')) + await handleStreamable(req, res, serverList, streamableSessions); + else + await handleSSE(req, res, url, serverList, sseSessions); + }); + httpServer.listen(port, hostname, () => { + const address = httpServer.address(); + assert(address, 'Could not bind server socket'); + let url: string; + if (typeof address === 'string') { + url = address; + } else { + const resolvedPort = address.port; + let resolvedHost = address.family === 'IPv4' ? address.address : `[${address.address}]`; + if (resolvedHost === '0.0.0.0' || resolvedHost === '[::]') + resolvedHost = 'localhost'; + url = `http://${resolvedHost}:${resolvedPort}`; + } + console.log(`Listening on ${url}`); + console.log('Put this in your client config:'); + console.log(JSON.stringify({ + 'mcpServers': { + 'playwright': { + 'url': `${url}/sse` + } + } + }, undefined, 2)); + console.log('If your client supports streamable HTTP, you can use the /mcp endpoint instead.'); + }); +} diff --git a/tests/sse.spec.ts b/tests/sse.spec.ts index ad627ef..eaacd15 100644 --- a/tests/sse.spec.ts +++ b/tests/sse.spec.ts @@ -16,27 +16,45 @@ import { spawn } from 'node:child_process'; import path from 'node:path'; -import { test } from './fixtures'; +import { test as baseTest } from './fixtures'; +import { expect } from 'playwright/test'; -test('sse transport', async () => { - const cp = spawn('node', [path.join(__dirname, '../cli.js'), '--port', '0'], { stdio: 'pipe' }); - try { - let stdout = ''; - const url = await new Promise(resolve => cp.stdout?.on('data', data => { - stdout += data.toString(); - const match = stdout.match(/Listening on (http:\/\/.*)/); - if (match) - resolve(match[1]); - })); +const test = baseTest.extend<{ serverEndpoint: string }>({ + serverEndpoint: async ({}, use) => { + const cp = spawn('node', [path.join(__dirname, '../cli.js'), '--port', '0'], { stdio: 'pipe' }); + try { + let stdout = ''; + const url = await new Promise(resolve => cp.stdout?.on('data', data => { + stdout += data.toString(); + const match = stdout.match(/Listening on (http:\/\/.*)/); + if (match) + resolve(match[1]); + })); - // need dynamic import b/c of some ESM nonsense - const { SSEClientTransport } = await import('@modelcontextprotocol/sdk/client/sse.js'); - const { Client } = await import('@modelcontextprotocol/sdk/client/index.js'); - const transport = new SSEClientTransport(new URL(url)); - const client = new Client({ name: 'test', version: '1.0.0' }); - await client.connect(transport); - await client.ping(); - } finally { - cp.kill(); - } + await use(url); + } finally { + cp.kill(); + } + }, +}); + +test('sse transport', async ({ serverEndpoint }) => { + // need dynamic import b/c of some ESM nonsense + const { SSEClientTransport } = await import('@modelcontextprotocol/sdk/client/sse.js'); + const { Client } = await import('@modelcontextprotocol/sdk/client/index.js'); + const transport = new SSEClientTransport(new URL(serverEndpoint)); + const client = new Client({ name: 'test', version: '1.0.0' }); + await client.connect(transport); + await client.ping(); +}); + +test('streamable http transport', async ({ serverEndpoint }) => { + // need dynamic import b/c of some ESM nonsense + const { StreamableHTTPClientTransport } = await import('@modelcontextprotocol/sdk/client/streamableHttp.js'); + const { Client } = await import('@modelcontextprotocol/sdk/client/index.js'); + const transport = new StreamableHTTPClientTransport(new URL('/mcp', serverEndpoint)); + const client = new Client({ name: 'test', version: '1.0.0' }); + await client.connect(transport); + await client.ping(); + expect(transport.sessionId, 'has session support').toBeDefined(); });