diff --git a/src/eval/loop.ts b/src/eval/loop.ts new file mode 100644 index 0000000..48425e4 --- /dev/null +++ b/src/eval/loop.ts @@ -0,0 +1,107 @@ +/** + * Copyright (c) Microsoft Corporation. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import debug from 'debug'; +import type { Tool, ImageContent, TextContent } from '@modelcontextprotocol/sdk/types.js'; +import type { Client } from '@modelcontextprotocol/sdk/client/index.js'; + +export type LLMToolCall = { + name: string; + arguments: any; + id: string; +}; + +export type LLMTool = { + name: string; + description: string; + inputSchema: any; +}; + +export type LLMMessage = + | { role: 'user'; content: string } + | { role: 'assistant'; content: string; toolCalls?: LLMToolCall[] } + | { role: 'tool'; toolCallId: string; content: string; isError?: boolean }; + +export type LLMConversation = { + messages: LLMMessage[]; + tools: LLMTool[]; +}; + +export interface LLMDelegate { + createConversation(task: string, tools: Tool[]): LLMConversation; + makeApiCall(conversation: LLMConversation): Promise; + addToolResults(conversation: LLMConversation, results: Array<{ toolCallId: string; content: string; isError?: boolean }>): void; + checkDoneToolCall(toolCall: LLMToolCall): string | null; +} + +export async function runTask(delegate: LLMDelegate, client: Client, task: string): Promise { + const { tools } = await client.listTools(); + const conversation = delegate.createConversation(task, tools); + + for (let iteration = 0; iteration < 5; ++iteration) { + debug('history')('Making API call for iteration', iteration); + const toolCalls = await delegate.makeApiCall(conversation); + if (toolCalls.length === 0) + throw new Error('Call the "done" tool when the task is complete.'); + + const toolResults: Array<{ toolCallId: string; content: string; isError?: boolean }> = []; + for (const toolCall of toolCalls) { + // Check if this is the "done" tool + const doneResult = delegate.checkDoneToolCall(toolCall); + if (doneResult !== null) + return doneResult; + + const { name, arguments: args, id } = toolCall; + try { + debug('tool')(name, args); + const response = await client.callTool({ + name, + arguments: args, + }); + const responseContent = (response.content || []) as (TextContent | ImageContent)[]; + debug('tool')(responseContent); + const text = responseContent.filter(part => part.type === 'text').map(part => part.text).join('\n'); + + toolResults.push({ + toolCallId: id, + content: text, + }); + } catch (error) { + debug('tool')(error); + toolResults.push({ + toolCallId: id, + content: `Error while executing tool "${name}": ${error instanceof Error ? error.message : String(error)}\n\nPlease try to recover and complete the task.`, + isError: true, + }); + + // Skip remaining tool calls for this iteration + for (const remainingToolCall of toolCalls.slice(toolCalls.indexOf(toolCall) + 1)) { + toolResults.push({ + toolCallId: remainingToolCall.id, + content: `This tool call is skipped due to previous error.`, + isError: true, + }); + } + break; + } + } + + // Add tool results to conversation + delegate.addToolResults(conversation, toolResults); + } + + throw new Error('Failed to perform step, max attempts reached'); +} diff --git a/src/eval/loopClaude.ts b/src/eval/loopClaude.ts index 77bd233..c05e972 100644 --- a/src/eval/loopClaude.ts +++ b/src/eval/loopClaude.ts @@ -15,105 +15,155 @@ */ import Anthropic from '@anthropic-ai/sdk'; -import debug from 'debug'; - -import type { Tool, ImageContent, TextContent } from '@modelcontextprotocol/sdk/types.js'; -import type { Client } from '@modelcontextprotocol/sdk/client/index.js'; +import type { LLMDelegate, LLMConversation, LLMToolCall, LLMTool } from './loop.js'; +import type { Tool } from '@modelcontextprotocol/sdk/types.js'; const model = 'claude-sonnet-4-20250514'; -export async function runTask(client: Client, task: string): Promise { - const anthropic = new Anthropic(); - const messages: Anthropic.Messages.MessageParam[] = []; +export class ClaudeDelegate implements LLMDelegate { + private anthropic = new Anthropic(); - const { tools } = await client.listTools(); - const claudeTools = tools.map(tool => asClaudeDeclaration(tool)); + createConversation(task: string, tools: Tool[]): LLMConversation { + const llmTools: LLMTool[] = tools.map(tool => ({ + name: tool.name, + description: tool.description || '', + inputSchema: tool.inputSchema, + })); - // Add initial user message - messages.push({ - role: 'user', - content: `Perform following task: ${task}.` - }); - - for (let iteration = 0; iteration < 5; ++iteration) { - debug('history')(messages); - - const response = await anthropic.messages.create({ - model, - max_tokens: 10000, - messages, - tools: claudeTools, + // Add the "done" tool + llmTools.push({ + name: 'done', + description: 'Call this tool when the task is complete.', + inputSchema: { + type: 'object', + properties: { + result: { type: 'string', description: 'The result of the task.' }, + }, + }, }); - const content = response.content; + return { + messages: [{ + role: 'user', + content: `Perform following task: ${task}. Once the task is complete, call the "done" tool.` + }], + tools: llmTools, + }; + } - const toolUseBlocks = content.filter(block => block.type === 'tool_use'); - const textBlocks = content.filter(block => block.type === 'text'); + async makeApiCall(conversation: LLMConversation): Promise { + // Convert generic messages to Claude format + const claudeMessages: Anthropic.Messages.MessageParam[] = []; - messages.push({ - role: 'assistant', - content: content - }); - - if (toolUseBlocks.length === 0) - return textBlocks.map(block => block.text).join('\n'); - - const toolResults: Anthropic.Messages.ToolResultBlockParam[] = []; - - for (const toolUse of toolUseBlocks) { - if (toolUse.name === 'done') - return JSON.stringify(toolUse.input, null, 2); - - try { - debug('tool')(toolUse.name, toolUse.input); - const response = await client.callTool({ - name: toolUse.name, - arguments: toolUse.input as any, + for (const message of conversation.messages) { + if (message.role === 'user') { + claudeMessages.push({ + role: 'user', + content: message.content }); - const responseContent = (response.content || []) as (TextContent | ImageContent)[]; - debug('tool')(responseContent); - const text = responseContent.filter(part => part.type === 'text').map(part => part.text).join('\n'); + } else if (message.role === 'assistant') { + const content: Anthropic.Messages.ContentBlock[] = []; - toolResults.push({ - type: 'tool_result', - tool_use_id: toolUse.id, - content: text, - }); - } catch (error) { - debug('tool')(error); - toolResults.push({ - type: 'tool_result', - tool_use_id: toolUse.id, - content: `Error while executing tool "${toolUse.name}": ${error instanceof Error ? error.message : String(error)}\n\nPlease try to recover and complete the task.`, - is_error: true, - }); - // Skip remaining tool calls for this iteration - for (const remainingToolUse of toolUseBlocks.slice(toolUseBlocks.indexOf(toolUse) + 1)) { - toolResults.push({ - type: 'tool_result', - tool_use_id: remainingToolUse.id, - content: `This tool call is skipped due to previous error.`, - is_error: true, + // Add text content + if (message.content) { + content.push({ + type: 'text', + text: message.content, + citations: [] + }); + } + + // Add tool calls + if (message.toolCalls) { + for (const toolCall of message.toolCalls) { + content.push({ + type: 'tool_use', + id: toolCall.id, + name: toolCall.name, + input: toolCall.arguments + }); + } + } + + claudeMessages.push({ + role: 'assistant', + content + }); + } else if (message.role === 'tool') { + // Tool results are added differently - we need to find if there's already a user message with tool results + const lastMessage = claudeMessages[claudeMessages.length - 1]; + const toolResult: Anthropic.Messages.ToolResultBlockParam = { + type: 'tool_result', + tool_use_id: message.toolCallId, + content: message.content, + is_error: message.isError, + }; + + if (lastMessage && lastMessage.role === 'user' && Array.isArray(lastMessage.content)) { + // Add to existing tool results message + (lastMessage.content as Anthropic.Messages.ToolResultBlockParam[]).push(toolResult); + } else { + // Create new tool results message + claudeMessages.push({ + role: 'user', + content: [toolResult] }); } - break; } } - // Add tool results as user message - messages.push({ - role: 'user', - content: toolResults + // Convert generic tools to Claude format + const claudeTools: Anthropic.Messages.Tool[] = conversation.tools.map(tool => ({ + name: tool.name, + description: tool.description, + input_schema: tool.inputSchema, + })); + + const response = await this.anthropic.messages.create({ + model, + max_tokens: 10000, + messages: claudeMessages, + tools: claudeTools, }); + + // Extract tool calls and add assistant message to generic conversation + const toolCalls = response.content.filter(block => block.type === 'tool_use') as Anthropic.Messages.ToolUseBlock[]; + const textContent = response.content.filter(block => block.type === 'text').map(block => (block as Anthropic.Messages.TextBlock).text).join(''); + + const llmToolCalls: LLMToolCall[] = toolCalls.map(toolCall => ({ + name: toolCall.name, + arguments: toolCall.input as any, + id: toolCall.id, + })); + + // Add assistant message to generic conversation + conversation.messages.push({ + role: 'assistant', + content: textContent, + toolCalls: llmToolCalls.length > 0 ? llmToolCalls : undefined + }); + + return llmToolCalls; } - throw new Error('Failed to perform step, max attempts reached'); -} + addToolResults( + conversation: LLMConversation, + results: Array<{ toolCallId: string; content: string; isError?: boolean }> + ): void { + for (const result of results) { + conversation.messages.push({ + role: 'tool', + toolCallId: result.toolCallId, + content: result.content, + isError: result.isError, + }); + } + } -function asClaudeDeclaration(tool: Tool): Anthropic.Messages.Tool { - return { - name: tool.name, - description: tool.description, - input_schema: tool.inputSchema, - }; + checkDoneToolCall(toolCall: LLMToolCall): string | null { + if (toolCall.name === 'done') + return (toolCall.arguments as { result: string }).result; + + return null; + } } diff --git a/src/eval/loopOpenAI.ts b/src/eval/loopOpenAI.ts index 4408b53..59f1011 100644 --- a/src/eval/loopOpenAI.ts +++ b/src/eval/loopOpenAI.ts @@ -15,91 +15,147 @@ */ import OpenAI from 'openai'; -import debug from 'debug'; - -import type { Tool, ImageContent, TextContent } from '@modelcontextprotocol/sdk/types.js'; -import type { Client } from '@modelcontextprotocol/sdk/client/index.js'; +import type { LLMDelegate, LLMConversation, LLMToolCall, LLMTool } from './loop.js'; +import type { Tool } from '@modelcontextprotocol/sdk/types.js'; const model = 'gpt-4.1'; -export async function runTask(client: Client, task: string): Promise { - const openai = new OpenAI(); - const messages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = [ - { - role: 'user', - content: `Peform following task: ${task}. Once the task is complete, call the "done" tool.` +export class OpenAIDelegate implements LLMDelegate { + private openai = new OpenAI(); + + createConversation(task: string, tools: Tool[]): LLMConversation { + const genericTools: LLMTool[] = tools.map(tool => ({ + name: tool.name, + description: tool.description || '', + inputSchema: tool.inputSchema, + })); + + // Add the "done" tool + genericTools.push({ + name: 'done', + description: 'Call this tool when the task is complete.', + inputSchema: { + type: 'object', + properties: { + result: { type: 'string', description: 'The result of the task.' }, + }, + required: ['result'], + }, + }); + + return { + messages: [{ + role: 'user', + content: `Peform following task: ${task}. Once the task is complete, call the "done" tool.` + }], + tools: genericTools, + }; + } + + async makeApiCall(conversation: LLMConversation): Promise { + // Convert generic messages to OpenAI format + const openaiMessages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = []; + + for (const message of conversation.messages) { + if (message.role === 'user') { + openaiMessages.push({ + role: 'user', + content: message.content + }); + } else if (message.role === 'assistant') { + const toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[] = []; + + if (message.toolCalls) { + for (const toolCall of message.toolCalls) { + toolCalls.push({ + id: toolCall.id, + type: 'function', + function: { + name: toolCall.name, + arguments: JSON.stringify(toolCall.arguments) + } + }); + } + } + + const assistantMessage: OpenAI.Chat.Completions.ChatCompletionAssistantMessageParam = { + role: 'assistant' + }; + + if (message.content) + assistantMessage.content = message.content; + + if (toolCalls.length > 0) + assistantMessage.tool_calls = toolCalls; + + openaiMessages.push(assistantMessage); + } else if (message.role === 'tool') { + openaiMessages.push({ + role: 'tool', + tool_call_id: message.toolCallId, + content: message.content, + }); + } } - ]; - const { tools } = await client.listTools(); + // Convert generic tools to OpenAI format + const openaiTools: OpenAI.Chat.Completions.ChatCompletionTool[] = conversation.tools.map(tool => ({ + type: 'function', + function: { + name: tool.name, + description: tool.description, + parameters: tool.inputSchema, + }, + })); - for (let iteration = 0; iteration < 5; ++iteration) { - debug('history')(messages); - - const response = await openai.chat.completions.create({ + const response = await this.openai.chat.completions.create({ model, - messages, - tools: tools.map(tool => asOpenAIDeclaration(tool)), + messages: openaiMessages, + tools: openaiTools, tool_choice: 'auto' }); const message = response.choices[0].message; - if (!message.tool_calls?.length) - return JSON.stringify(message.content, null, 2); - messages.push({ - role: 'assistant', - tool_calls: message.tool_calls + // Extract tool calls and add assistant message to generic conversation + const toolCalls = message.tool_calls || []; + const genericToolCalls: LLMToolCall[] = toolCalls.map(toolCall => { + const functionCall = toolCall.function; + return { + name: functionCall.name, + arguments: JSON.parse(functionCall.arguments), + id: toolCall.id, + }; }); - for (const toolCall of message.tool_calls) { - const functionCall = toolCall.function; + // Add assistant message to generic conversation + conversation.messages.push({ + role: 'assistant', + content: message.content || '', + toolCalls: genericToolCalls.length > 0 ? genericToolCalls : undefined + }); - if (functionCall.name === 'done') - return JSON.stringify(functionCall.arguments, null, 2); + return genericToolCalls; + } - try { - debug('tool')(functionCall.name, functionCall.arguments); - const response = await client.callTool({ - name: functionCall.name, - arguments: JSON.parse(functionCall.arguments) - }); - const content = (response.content || []) as (TextContent | ImageContent)[]; - debug('tool')(content); - const text = content.filter(part => part.type === 'text').map(part => part.text).join('\n'); - messages.push({ - role: 'tool', - tool_call_id: toolCall.id, - content: text, - }); - } catch (error) { - debug('tool')(error); - messages.push({ - role: 'tool', - tool_call_id: toolCall.id, - content: `Error while executing tool "${functionCall.name}": ${error instanceof Error ? error.message : String(error)}\n\nPlease try to recover and complete the task.`, - }); - for (const ignoredToolCall of message.tool_calls.slice(message.tool_calls.indexOf(toolCall) + 1)) { - messages.push({ - role: 'tool', - tool_call_id: ignoredToolCall.id, - content: `This tool call is skipped due to previous error.`, - }); - } - break; - } + addToolResults( + conversation: LLMConversation, + results: Array<{ toolCallId: string; content: string; isError?: boolean }> + ): void { + for (const result of results) { + conversation.messages.push({ + role: 'tool', + toolCallId: result.toolCallId, + content: result.content, + isError: result.isError, + }); } } - throw new Error('Failed to perform step, max attempts reached'); -} -function asOpenAIDeclaration(tool: Tool): OpenAI.Chat.Completions.ChatCompletionTool { - return { - type: 'function', - function: { - name: tool.name, - description: tool.description, - parameters: tool.inputSchema, - }, - }; + checkDoneToolCall(toolCall: LLMToolCall): string | null { + if (toolCall.name === 'done') + return toolCall.arguments.result; + + return null; + } } diff --git a/src/eval/main.ts b/src/eval/main.ts index 4def87e..8ad22be 100644 --- a/src/eval/main.ts +++ b/src/eval/main.ts @@ -23,14 +23,17 @@ import dotenv from 'dotenv'; import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; import { Client } from '@modelcontextprotocol/sdk/client/index.js'; import { program } from 'commander'; -import { runTask as runTaskOpenAI } from './loopOpenAI.js'; -import { runTask as runTaskClaude } from './loopClaude.js'; +import { OpenAIDelegate } from './loopOpenAI.js'; +import { ClaudeDelegate } from './loopClaude.js'; +import { runTask } from './loop.js'; + +import type { LLMDelegate } from './loop.js'; dotenv.config(); const __filename = url.fileURLToPath(import.meta.url); -async function run(runTask: (client: Client, task: string) => Promise) { +async function run(delegate: LLMDelegate) { const transport = new StdioClientTransport({ command: 'node', args: [ @@ -48,7 +51,7 @@ async function run(runTask: (client: Client, task: string) => Promise', 'model to use') .action(async options => { if (options.model === 'claude') - await run(runTaskClaude); + await run(new ClaudeDelegate()); else - await run(runTaskOpenAI); + await run(new OpenAIDelegate()); }); void program.parseAsync(process.argv);