mirror of
https://github.com/TriliumNext/Notes.git
synced 2025-07-31 20:22:27 +08:00
okay openai tool calling response is close to working
This commit is contained in:
parent
6750467edc
commit
c04e3b2c89
@ -42,6 +42,12 @@ export interface StreamChunk {
|
|||||||
* This can include thinking state, tool execution info, etc.
|
* This can include thinking state, tool execution info, etc.
|
||||||
*/
|
*/
|
||||||
raw?: any;
|
raw?: any;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tool calls from the LLM (if any)
|
||||||
|
* These may be accumulated over multiple chunks during streaming
|
||||||
|
*/
|
||||||
|
tool_calls?: ToolCall[] | any[];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -32,7 +32,7 @@ export class OpenAIService extends BaseAIService {
|
|||||||
|
|
||||||
// Get provider-specific options from the central provider manager
|
// Get provider-specific options from the central provider manager
|
||||||
const providerOptions = getOpenAIOptions(opts);
|
const providerOptions = getOpenAIOptions(opts);
|
||||||
|
|
||||||
// Initialize the OpenAI client
|
// Initialize the OpenAI client
|
||||||
const client = this.getClient(providerOptions.apiKey, providerOptions.baseUrl);
|
const client = this.getClient(providerOptions.apiKey, providerOptions.baseUrl);
|
||||||
|
|
||||||
@ -69,36 +69,79 @@ export class OpenAIService extends BaseAIService {
|
|||||||
// If streaming is requested
|
// If streaming is requested
|
||||||
if (providerOptions.stream) {
|
if (providerOptions.stream) {
|
||||||
params.stream = true;
|
params.stream = true;
|
||||||
|
|
||||||
// Get stream from OpenAI SDK
|
// Get stream from OpenAI SDK
|
||||||
const stream = await client.chat.completions.create(params);
|
const stream = await client.chat.completions.create(params);
|
||||||
|
|
||||||
|
// Create a closure to hold accumulated tool calls
|
||||||
|
let accumulatedToolCalls: any[] = [];
|
||||||
|
|
||||||
// Return a response with the stream handler
|
// Return a response with the stream handler
|
||||||
return {
|
const response: ChatResponse = {
|
||||||
text: '', // Initial empty text, will be populated during streaming
|
text: '', // Initial empty text, will be populated during streaming
|
||||||
model: params.model,
|
model: params.model,
|
||||||
provider: this.getName(),
|
provider: this.getName(),
|
||||||
|
// Add tool_calls property that will be populated during streaming
|
||||||
|
tool_calls: [],
|
||||||
stream: async (callback) => {
|
stream: async (callback) => {
|
||||||
let completeText = '';
|
let completeText = '';
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Process the stream
|
// Process the stream
|
||||||
if (Symbol.asyncIterator in stream) {
|
if (Symbol.asyncIterator in stream) {
|
||||||
for await (const chunk of stream as AsyncIterable<OpenAI.Chat.ChatCompletionChunk>) {
|
for await (const chunk of stream as AsyncIterable<OpenAI.Chat.ChatCompletionChunk>) {
|
||||||
const content = chunk.choices[0]?.delta?.content || '';
|
const content = chunk.choices[0]?.delta?.content || '';
|
||||||
const isDone = !!chunk.choices[0]?.finish_reason;
|
const isDone = !!chunk.choices[0]?.finish_reason;
|
||||||
|
|
||||||
|
// Check for tool calls in the delta
|
||||||
|
const deltaToolCalls = chunk.choices[0]?.delta?.tool_calls;
|
||||||
|
|
||||||
|
if (deltaToolCalls) {
|
||||||
|
// Process and accumulate tool calls from this chunk
|
||||||
|
for (const deltaToolCall of deltaToolCalls) {
|
||||||
|
const toolCallId = deltaToolCall.index;
|
||||||
|
|
||||||
|
// Initialize or update the accumulated tool call
|
||||||
|
if (!accumulatedToolCalls[toolCallId]) {
|
||||||
|
accumulatedToolCalls[toolCallId] = {
|
||||||
|
id: deltaToolCall.id || `call_${toolCallId}`,
|
||||||
|
type: deltaToolCall.type || 'function',
|
||||||
|
function: {
|
||||||
|
name: '',
|
||||||
|
arguments: ''
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update function name if present
|
||||||
|
if (deltaToolCall.function?.name) {
|
||||||
|
accumulatedToolCalls[toolCallId].function.name =
|
||||||
|
deltaToolCall.function.name;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append to function arguments if present
|
||||||
|
if (deltaToolCall.function?.arguments) {
|
||||||
|
accumulatedToolCalls[toolCallId].function.arguments +=
|
||||||
|
deltaToolCall.function.arguments;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Important: Update the response's tool_calls with accumulated tool calls
|
||||||
|
response.tool_calls = accumulatedToolCalls.filter(Boolean);
|
||||||
|
}
|
||||||
|
|
||||||
if (content) {
|
if (content) {
|
||||||
completeText += content;
|
completeText += content;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send the chunk to the caller with raw data
|
// Send the chunk to the caller with raw data and any accumulated tool calls
|
||||||
await callback({
|
await callback({
|
||||||
text: content,
|
text: content,
|
||||||
done: isDone,
|
done: isDone,
|
||||||
raw: chunk // Include the raw chunk for advanced processing
|
raw: chunk,
|
||||||
|
tool_calls: accumulatedToolCalls.length > 0 ? accumulatedToolCalls.filter(Boolean) : undefined
|
||||||
});
|
});
|
||||||
|
|
||||||
if (isDone) {
|
if (isDone) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -106,14 +149,22 @@ export class OpenAIService extends BaseAIService {
|
|||||||
} else {
|
} else {
|
||||||
// Fallback for non-iterable response
|
// Fallback for non-iterable response
|
||||||
console.warn('Stream is not iterable, falling back to non-streaming response');
|
console.warn('Stream is not iterable, falling back to non-streaming response');
|
||||||
|
|
||||||
if ('choices' in stream) {
|
if ('choices' in stream) {
|
||||||
const content = stream.choices[0]?.message?.content || '';
|
const content = stream.choices[0]?.message?.content || '';
|
||||||
completeText = content;
|
completeText = content;
|
||||||
|
|
||||||
|
// Check if there are tool calls in the non-stream response
|
||||||
|
const toolCalls = stream.choices[0]?.message?.tool_calls;
|
||||||
|
if (toolCalls) {
|
||||||
|
response.tool_calls = toolCalls;
|
||||||
|
}
|
||||||
|
|
||||||
await callback({
|
await callback({
|
||||||
text: content,
|
text: content,
|
||||||
done: true,
|
done: true,
|
||||||
raw: stream
|
raw: stream,
|
||||||
|
tool_calls: toolCalls
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -121,16 +172,22 @@ export class OpenAIService extends BaseAIService {
|
|||||||
console.error('Error processing stream:', error);
|
console.error('Error processing stream:', error);
|
||||||
throw error;
|
throw error;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update the response's text with the complete text
|
||||||
|
response.text = completeText;
|
||||||
|
|
||||||
|
// Return the complete text
|
||||||
return completeText;
|
return completeText;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
return response;
|
||||||
} else {
|
} else {
|
||||||
// Non-streaming response
|
// Non-streaming response
|
||||||
params.stream = false;
|
params.stream = false;
|
||||||
|
|
||||||
const completion = await client.chat.completions.create(params);
|
const completion = await client.chat.completions.create(params);
|
||||||
|
|
||||||
if (!('choices' in completion)) {
|
if (!('choices' in completion)) {
|
||||||
throw new Error('Unexpected response format from OpenAI API');
|
throw new Error('Unexpected response format from OpenAI API');
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user