chore: unify loops (#745)

This commit is contained in:
Pavel Feldman 2025-07-23 17:42:53 -07:00 committed by GitHub
parent bc120baa78
commit 31a4fb3d07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 371 additions and 155 deletions

107
src/eval/loop.ts Normal file
View File

@ -0,0 +1,107 @@
/**
* Copyright (c) Microsoft Corporation.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import debug from 'debug';
import type { Tool, ImageContent, TextContent } from '@modelcontextprotocol/sdk/types.js';
import type { Client } from '@modelcontextprotocol/sdk/client/index.js';
export type LLMToolCall = {
name: string;
arguments: any;
id: string;
};
export type LLMTool = {
name: string;
description: string;
inputSchema: any;
};
export type LLMMessage =
| { role: 'user'; content: string }
| { role: 'assistant'; content: string; toolCalls?: LLMToolCall[] }
| { role: 'tool'; toolCallId: string; content: string; isError?: boolean };
export type LLMConversation = {
messages: LLMMessage[];
tools: LLMTool[];
};
export interface LLMDelegate {
createConversation(task: string, tools: Tool[]): LLMConversation;
makeApiCall(conversation: LLMConversation): Promise<LLMToolCall[]>;
addToolResults(conversation: LLMConversation, results: Array<{ toolCallId: string; content: string; isError?: boolean }>): void;
checkDoneToolCall(toolCall: LLMToolCall): string | null;
}
export async function runTask(delegate: LLMDelegate, client: Client, task: string): Promise<string | undefined> {
const { tools } = await client.listTools();
const conversation = delegate.createConversation(task, tools);
for (let iteration = 0; iteration < 5; ++iteration) {
debug('history')('Making API call for iteration', iteration);
const toolCalls = await delegate.makeApiCall(conversation);
if (toolCalls.length === 0)
throw new Error('Call the "done" tool when the task is complete.');
const toolResults: Array<{ toolCallId: string; content: string; isError?: boolean }> = [];
for (const toolCall of toolCalls) {
// Check if this is the "done" tool
const doneResult = delegate.checkDoneToolCall(toolCall);
if (doneResult !== null)
return doneResult;
const { name, arguments: args, id } = toolCall;
try {
debug('tool')(name, args);
const response = await client.callTool({
name,
arguments: args,
});
const responseContent = (response.content || []) as (TextContent | ImageContent)[];
debug('tool')(responseContent);
const text = responseContent.filter(part => part.type === 'text').map(part => part.text).join('\n');
toolResults.push({
toolCallId: id,
content: text,
});
} catch (error) {
debug('tool')(error);
toolResults.push({
toolCallId: id,
content: `Error while executing tool "${name}": ${error instanceof Error ? error.message : String(error)}\n\nPlease try to recover and complete the task.`,
isError: true,
});
// Skip remaining tool calls for this iteration
for (const remainingToolCall of toolCalls.slice(toolCalls.indexOf(toolCall) + 1)) {
toolResults.push({
toolCallId: remainingToolCall.id,
content: `This tool call is skipped due to previous error.`,
isError: true,
});
}
break;
}
}
// Add tool results to conversation
delegate.addToolResults(conversation, toolResults);
}
throw new Error('Failed to perform step, max attempts reached');
}

View File

@ -15,105 +15,155 @@
*/ */
import Anthropic from '@anthropic-ai/sdk'; import Anthropic from '@anthropic-ai/sdk';
import debug from 'debug'; import type { LLMDelegate, LLMConversation, LLMToolCall, LLMTool } from './loop.js';
import type { Tool } from '@modelcontextprotocol/sdk/types.js';
import type { Tool, ImageContent, TextContent } from '@modelcontextprotocol/sdk/types.js';
import type { Client } from '@modelcontextprotocol/sdk/client/index.js';
const model = 'claude-sonnet-4-20250514'; const model = 'claude-sonnet-4-20250514';
export async function runTask(client: Client, task: string): Promise<string | undefined> { export class ClaudeDelegate implements LLMDelegate {
const anthropic = new Anthropic(); private anthropic = new Anthropic();
const messages: Anthropic.Messages.MessageParam[] = [];
const { tools } = await client.listTools(); createConversation(task: string, tools: Tool[]): LLMConversation {
const claudeTools = tools.map(tool => asClaudeDeclaration(tool)); const llmTools: LLMTool[] = tools.map(tool => ({
name: tool.name,
description: tool.description || '',
inputSchema: tool.inputSchema,
}));
// Add initial user message // Add the "done" tool
messages.push({ llmTools.push({
role: 'user', name: 'done',
content: `Perform following task: ${task}.` description: 'Call this tool when the task is complete.',
}); inputSchema: {
type: 'object',
for (let iteration = 0; iteration < 5; ++iteration) { properties: {
debug('history')(messages); result: { type: 'string', description: 'The result of the task.' },
},
const response = await anthropic.messages.create({ },
model,
max_tokens: 10000,
messages,
tools: claudeTools,
}); });
const content = response.content; return {
messages: [{
role: 'user',
content: `Perform following task: ${task}. Once the task is complete, call the "done" tool.`
}],
tools: llmTools,
};
}
const toolUseBlocks = content.filter(block => block.type === 'tool_use'); async makeApiCall(conversation: LLMConversation): Promise<LLMToolCall[]> {
const textBlocks = content.filter(block => block.type === 'text'); // Convert generic messages to Claude format
const claudeMessages: Anthropic.Messages.MessageParam[] = [];
messages.push({ for (const message of conversation.messages) {
role: 'assistant', if (message.role === 'user') {
content: content claudeMessages.push({
}); role: 'user',
content: message.content
if (toolUseBlocks.length === 0)
return textBlocks.map(block => block.text).join('\n');
const toolResults: Anthropic.Messages.ToolResultBlockParam[] = [];
for (const toolUse of toolUseBlocks) {
if (toolUse.name === 'done')
return JSON.stringify(toolUse.input, null, 2);
try {
debug('tool')(toolUse.name, toolUse.input);
const response = await client.callTool({
name: toolUse.name,
arguments: toolUse.input as any,
}); });
const responseContent = (response.content || []) as (TextContent | ImageContent)[]; } else if (message.role === 'assistant') {
debug('tool')(responseContent); const content: Anthropic.Messages.ContentBlock[] = [];
const text = responseContent.filter(part => part.type === 'text').map(part => part.text).join('\n');
toolResults.push({ // Add text content
type: 'tool_result', if (message.content) {
tool_use_id: toolUse.id, content.push({
content: text, type: 'text',
}); text: message.content,
} catch (error) { citations: []
debug('tool')(error); });
toolResults.push({ }
type: 'tool_result',
tool_use_id: toolUse.id, // Add tool calls
content: `Error while executing tool "${toolUse.name}": ${error instanceof Error ? error.message : String(error)}\n\nPlease try to recover and complete the task.`, if (message.toolCalls) {
is_error: true, for (const toolCall of message.toolCalls) {
}); content.push({
// Skip remaining tool calls for this iteration type: 'tool_use',
for (const remainingToolUse of toolUseBlocks.slice(toolUseBlocks.indexOf(toolUse) + 1)) { id: toolCall.id,
toolResults.push({ name: toolCall.name,
type: 'tool_result', input: toolCall.arguments
tool_use_id: remainingToolUse.id, });
content: `This tool call is skipped due to previous error.`, }
is_error: true, }
claudeMessages.push({
role: 'assistant',
content
});
} else if (message.role === 'tool') {
// Tool results are added differently - we need to find if there's already a user message with tool results
const lastMessage = claudeMessages[claudeMessages.length - 1];
const toolResult: Anthropic.Messages.ToolResultBlockParam = {
type: 'tool_result',
tool_use_id: message.toolCallId,
content: message.content,
is_error: message.isError,
};
if (lastMessage && lastMessage.role === 'user' && Array.isArray(lastMessage.content)) {
// Add to existing tool results message
(lastMessage.content as Anthropic.Messages.ToolResultBlockParam[]).push(toolResult);
} else {
// Create new tool results message
claudeMessages.push({
role: 'user',
content: [toolResult]
}); });
} }
break;
} }
} }
// Add tool results as user message // Convert generic tools to Claude format
messages.push({ const claudeTools: Anthropic.Messages.Tool[] = conversation.tools.map(tool => ({
role: 'user', name: tool.name,
content: toolResults description: tool.description,
input_schema: tool.inputSchema,
}));
const response = await this.anthropic.messages.create({
model,
max_tokens: 10000,
messages: claudeMessages,
tools: claudeTools,
}); });
// Extract tool calls and add assistant message to generic conversation
const toolCalls = response.content.filter(block => block.type === 'tool_use') as Anthropic.Messages.ToolUseBlock[];
const textContent = response.content.filter(block => block.type === 'text').map(block => (block as Anthropic.Messages.TextBlock).text).join('');
const llmToolCalls: LLMToolCall[] = toolCalls.map(toolCall => ({
name: toolCall.name,
arguments: toolCall.input as any,
id: toolCall.id,
}));
// Add assistant message to generic conversation
conversation.messages.push({
role: 'assistant',
content: textContent,
toolCalls: llmToolCalls.length > 0 ? llmToolCalls : undefined
});
return llmToolCalls;
} }
throw new Error('Failed to perform step, max attempts reached'); addToolResults(
} conversation: LLMConversation,
results: Array<{ toolCallId: string; content: string; isError?: boolean }>
): void {
for (const result of results) {
conversation.messages.push({
role: 'tool',
toolCallId: result.toolCallId,
content: result.content,
isError: result.isError,
});
}
}
function asClaudeDeclaration(tool: Tool): Anthropic.Messages.Tool { checkDoneToolCall(toolCall: LLMToolCall): string | null {
return { if (toolCall.name === 'done')
name: tool.name, return (toolCall.arguments as { result: string }).result;
description: tool.description,
input_schema: tool.inputSchema, return null;
}; }
} }

View File

@ -15,91 +15,147 @@
*/ */
import OpenAI from 'openai'; import OpenAI from 'openai';
import debug from 'debug'; import type { LLMDelegate, LLMConversation, LLMToolCall, LLMTool } from './loop.js';
import type { Tool } from '@modelcontextprotocol/sdk/types.js';
import type { Tool, ImageContent, TextContent } from '@modelcontextprotocol/sdk/types.js';
import type { Client } from '@modelcontextprotocol/sdk/client/index.js';
const model = 'gpt-4.1'; const model = 'gpt-4.1';
export async function runTask(client: Client, task: string): Promise<string | undefined> { export class OpenAIDelegate implements LLMDelegate {
const openai = new OpenAI(); private openai = new OpenAI();
const messages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = [
{ createConversation(task: string, tools: Tool[]): LLMConversation {
role: 'user', const genericTools: LLMTool[] = tools.map(tool => ({
content: `Peform following task: ${task}. Once the task is complete, call the "done" tool.` name: tool.name,
description: tool.description || '',
inputSchema: tool.inputSchema,
}));
// Add the "done" tool
genericTools.push({
name: 'done',
description: 'Call this tool when the task is complete.',
inputSchema: {
type: 'object',
properties: {
result: { type: 'string', description: 'The result of the task.' },
},
required: ['result'],
},
});
return {
messages: [{
role: 'user',
content: `Peform following task: ${task}. Once the task is complete, call the "done" tool.`
}],
tools: genericTools,
};
}
async makeApiCall(conversation: LLMConversation): Promise<LLMToolCall[]> {
// Convert generic messages to OpenAI format
const openaiMessages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = [];
for (const message of conversation.messages) {
if (message.role === 'user') {
openaiMessages.push({
role: 'user',
content: message.content
});
} else if (message.role === 'assistant') {
const toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[] = [];
if (message.toolCalls) {
for (const toolCall of message.toolCalls) {
toolCalls.push({
id: toolCall.id,
type: 'function',
function: {
name: toolCall.name,
arguments: JSON.stringify(toolCall.arguments)
}
});
}
}
const assistantMessage: OpenAI.Chat.Completions.ChatCompletionAssistantMessageParam = {
role: 'assistant'
};
if (message.content)
assistantMessage.content = message.content;
if (toolCalls.length > 0)
assistantMessage.tool_calls = toolCalls;
openaiMessages.push(assistantMessage);
} else if (message.role === 'tool') {
openaiMessages.push({
role: 'tool',
tool_call_id: message.toolCallId,
content: message.content,
});
}
} }
];
const { tools } = await client.listTools(); // Convert generic tools to OpenAI format
const openaiTools: OpenAI.Chat.Completions.ChatCompletionTool[] = conversation.tools.map(tool => ({
type: 'function',
function: {
name: tool.name,
description: tool.description,
parameters: tool.inputSchema,
},
}));
for (let iteration = 0; iteration < 5; ++iteration) { const response = await this.openai.chat.completions.create({
debug('history')(messages);
const response = await openai.chat.completions.create({
model, model,
messages, messages: openaiMessages,
tools: tools.map(tool => asOpenAIDeclaration(tool)), tools: openaiTools,
tool_choice: 'auto' tool_choice: 'auto'
}); });
const message = response.choices[0].message; const message = response.choices[0].message;
if (!message.tool_calls?.length)
return JSON.stringify(message.content, null, 2);
messages.push({ // Extract tool calls and add assistant message to generic conversation
role: 'assistant', const toolCalls = message.tool_calls || [];
tool_calls: message.tool_calls const genericToolCalls: LLMToolCall[] = toolCalls.map(toolCall => {
const functionCall = toolCall.function;
return {
name: functionCall.name,
arguments: JSON.parse(functionCall.arguments),
id: toolCall.id,
};
}); });
for (const toolCall of message.tool_calls) { // Add assistant message to generic conversation
const functionCall = toolCall.function; conversation.messages.push({
role: 'assistant',
content: message.content || '',
toolCalls: genericToolCalls.length > 0 ? genericToolCalls : undefined
});
if (functionCall.name === 'done') return genericToolCalls;
return JSON.stringify(functionCall.arguments, null, 2); }
try { addToolResults(
debug('tool')(functionCall.name, functionCall.arguments); conversation: LLMConversation,
const response = await client.callTool({ results: Array<{ toolCallId: string; content: string; isError?: boolean }>
name: functionCall.name, ): void {
arguments: JSON.parse(functionCall.arguments) for (const result of results) {
}); conversation.messages.push({
const content = (response.content || []) as (TextContent | ImageContent)[]; role: 'tool',
debug('tool')(content); toolCallId: result.toolCallId,
const text = content.filter(part => part.type === 'text').map(part => part.text).join('\n'); content: result.content,
messages.push({ isError: result.isError,
role: 'tool', });
tool_call_id: toolCall.id,
content: text,
});
} catch (error) {
debug('tool')(error);
messages.push({
role: 'tool',
tool_call_id: toolCall.id,
content: `Error while executing tool "${functionCall.name}": ${error instanceof Error ? error.message : String(error)}\n\nPlease try to recover and complete the task.`,
});
for (const ignoredToolCall of message.tool_calls.slice(message.tool_calls.indexOf(toolCall) + 1)) {
messages.push({
role: 'tool',
tool_call_id: ignoredToolCall.id,
content: `This tool call is skipped due to previous error.`,
});
}
break;
}
} }
} }
throw new Error('Failed to perform step, max attempts reached');
}
function asOpenAIDeclaration(tool: Tool): OpenAI.Chat.Completions.ChatCompletionTool { checkDoneToolCall(toolCall: LLMToolCall): string | null {
return { if (toolCall.name === 'done')
type: 'function', return toolCall.arguments.result;
function: {
name: tool.name, return null;
description: tool.description, }
parameters: tool.inputSchema,
},
};
} }

View File

@ -23,14 +23,17 @@ import dotenv from 'dotenv';
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
import { Client } from '@modelcontextprotocol/sdk/client/index.js'; import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { program } from 'commander'; import { program } from 'commander';
import { runTask as runTaskOpenAI } from './loopOpenAI.js'; import { OpenAIDelegate } from './loopOpenAI.js';
import { runTask as runTaskClaude } from './loopClaude.js'; import { ClaudeDelegate } from './loopClaude.js';
import { runTask } from './loop.js';
import type { LLMDelegate } from './loop.js';
dotenv.config(); dotenv.config();
const __filename = url.fileURLToPath(import.meta.url); const __filename = url.fileURLToPath(import.meta.url);
async function run(runTask: (client: Client, task: string) => Promise<string | undefined>) { async function run(delegate: LLMDelegate) {
const transport = new StdioClientTransport({ const transport = new StdioClientTransport({
command: 'node', command: 'node',
args: [ args: [
@ -48,7 +51,7 @@ async function run(runTask: (client: Client, task: string) => Promise<string | u
let lastResult: string | undefined; let lastResult: string | undefined;
for (const task of tasks) for (const task of tasks)
lastResult = await runTask(client, task); lastResult = await runTask(delegate, client, task);
console.log(lastResult); console.log(lastResult);
await client.close(); await client.close();
} }
@ -61,8 +64,8 @@ program
.option('--model <model>', 'model to use') .option('--model <model>', 'model to use')
.action(async options => { .action(async options => {
if (options.model === 'claude') if (options.model === 'claude')
await run(runTaskClaude); await run(new ClaudeDelegate());
else else
await run(runTaskOpenAI); await run(new OpenAIDelegate());
}); });
void program.parseAsync(process.argv); void program.parseAsync(process.argv);