From c04e3b2c890d4125db6f7894d953bccea999b249 Mon Sep 17 00:00:00 2001 From: perf3ct Date: Mon, 14 Apr 2025 19:39:29 +0000 Subject: [PATCH] okay openai tool calling response is close to working --- src/services/llm/ai_interface.ts | 6 ++ src/services/llm/providers/openai_service.ts | 87 ++++++++++++++++---- 2 files changed, 78 insertions(+), 15 deletions(-) diff --git a/src/services/llm/ai_interface.ts b/src/services/llm/ai_interface.ts index 69eab9c3f..354656efb 100644 --- a/src/services/llm/ai_interface.ts +++ b/src/services/llm/ai_interface.ts @@ -42,6 +42,12 @@ export interface StreamChunk { * This can include thinking state, tool execution info, etc. */ raw?: any; + + /** + * Tool calls from the LLM (if any) + * These may be accumulated over multiple chunks during streaming + */ + tool_calls?: ToolCall[] | any[]; } /** diff --git a/src/services/llm/providers/openai_service.ts b/src/services/llm/providers/openai_service.ts index f537dfd8d..3933deb2f 100644 --- a/src/services/llm/providers/openai_service.ts +++ b/src/services/llm/providers/openai_service.ts @@ -32,7 +32,7 @@ export class OpenAIService extends BaseAIService { // Get provider-specific options from the central provider manager const providerOptions = getOpenAIOptions(opts); - + // Initialize the OpenAI client const client = this.getClient(providerOptions.apiKey, providerOptions.baseUrl); @@ -69,36 +69,79 @@ export class OpenAIService extends BaseAIService { // If streaming is requested if (providerOptions.stream) { params.stream = true; - + // Get stream from OpenAI SDK const stream = await client.chat.completions.create(params); - + + // Create a closure to hold accumulated tool calls + let accumulatedToolCalls: any[] = []; + // Return a response with the stream handler - return { + const response: ChatResponse = { text: '', // Initial empty text, will be populated during streaming model: params.model, provider: this.getName(), + // Add tool_calls property that will be populated during streaming + tool_calls: [], stream: async (callback) => { let completeText = ''; - + try { // Process the stream if (Symbol.asyncIterator in stream) { for await (const chunk of stream as AsyncIterable) { const content = chunk.choices[0]?.delta?.content || ''; const isDone = !!chunk.choices[0]?.finish_reason; - + + // Check for tool calls in the delta + const deltaToolCalls = chunk.choices[0]?.delta?.tool_calls; + + if (deltaToolCalls) { + // Process and accumulate tool calls from this chunk + for (const deltaToolCall of deltaToolCalls) { + const toolCallId = deltaToolCall.index; + + // Initialize or update the accumulated tool call + if (!accumulatedToolCalls[toolCallId]) { + accumulatedToolCalls[toolCallId] = { + id: deltaToolCall.id || `call_${toolCallId}`, + type: deltaToolCall.type || 'function', + function: { + name: '', + arguments: '' + } + }; + } + + // Update function name if present + if (deltaToolCall.function?.name) { + accumulatedToolCalls[toolCallId].function.name = + deltaToolCall.function.name; + } + + // Append to function arguments if present + if (deltaToolCall.function?.arguments) { + accumulatedToolCalls[toolCallId].function.arguments += + deltaToolCall.function.arguments; + } + } + + // Important: Update the response's tool_calls with accumulated tool calls + response.tool_calls = accumulatedToolCalls.filter(Boolean); + } + if (content) { completeText += content; } - - // Send the chunk to the caller with raw data + + // Send the chunk to the caller with raw data and any accumulated tool calls await callback({ text: content, done: isDone, - raw: chunk // Include the raw chunk for advanced processing + raw: chunk, + tool_calls: accumulatedToolCalls.length > 0 ? accumulatedToolCalls.filter(Boolean) : undefined }); - + if (isDone) { break; } @@ -106,14 +149,22 @@ export class OpenAIService extends BaseAIService { } else { // Fallback for non-iterable response console.warn('Stream is not iterable, falling back to non-streaming response'); - + if ('choices' in stream) { const content = stream.choices[0]?.message?.content || ''; completeText = content; + + // Check if there are tool calls in the non-stream response + const toolCalls = stream.choices[0]?.message?.tool_calls; + if (toolCalls) { + response.tool_calls = toolCalls; + } + await callback({ text: content, done: true, - raw: stream + raw: stream, + tool_calls: toolCalls }); } } @@ -121,16 +172,22 @@ export class OpenAIService extends BaseAIService { console.error('Error processing stream:', error); throw error; } - + + // Update the response's text with the complete text + response.text = completeText; + + // Return the complete text return completeText; } }; + + return response; } else { // Non-streaming response params.stream = false; - + const completion = await client.chat.completions.create(params); - + if (!('choices' in completion)) { throw new Error('Unexpected response format from OpenAI API'); }