mirror of
https://github.com/TriliumNext/Notes.git
synced 2025-07-29 02:52:27 +08:00
put more websocket logic into the stream handler
This commit is contained in:
parent
d1edf59f97
commit
c9bb0fb219
@ -8,7 +8,13 @@ import type { OllamaOptions } from './provider_options.js';
|
||||
import { getOllamaOptions } from './providers.js';
|
||||
import { Ollama, type ChatRequest, type ChatResponse as OllamaChatResponse } from 'ollama';
|
||||
import options from '../../options.js';
|
||||
import { StreamProcessor, createStreamHandler } from './stream_handler.js';
|
||||
import {
|
||||
StreamProcessor,
|
||||
createStreamHandler,
|
||||
performProviderHealthCheck,
|
||||
processProviderStream,
|
||||
extractStreamStats
|
||||
} from './stream_handler.js';
|
||||
|
||||
// Add an interface for tool execution feedback status
|
||||
interface ToolExecutionStatus {
|
||||
@ -256,7 +262,7 @@ export class OllamaService extends BaseAIService {
|
||||
/**
|
||||
* Handle streaming response from Ollama
|
||||
*
|
||||
* Simplified implementation that leverages the Ollama SDK's streaming capabilities
|
||||
* Uses reusable stream handling utilities for processing
|
||||
*/
|
||||
private async handleStreamingResponse(
|
||||
client: Ollama,
|
||||
@ -269,6 +275,49 @@ export class OllamaService extends BaseAIService {
|
||||
// Log detailed information about the streaming setup
|
||||
log.info(`Ollama streaming details: model=${providerOptions.model}, streamCallback=${opts.streamCallback ? 'provided' : 'not provided'}`);
|
||||
|
||||
// Create streaming request
|
||||
const streamingRequest = {
|
||||
...requestOptions,
|
||||
stream: true as const
|
||||
};
|
||||
|
||||
// Handle direct streamCallback if provided
|
||||
if (opts.streamCallback) {
|
||||
try {
|
||||
// Perform health check before streaming
|
||||
await performProviderHealthCheck(
|
||||
async () => await client.list(),
|
||||
this.getName()
|
||||
);
|
||||
|
||||
log.info(`Making Ollama streaming request after successful health check`);
|
||||
// Get the stream iterator
|
||||
const streamIterator = await client.chat(streamingRequest);
|
||||
|
||||
// Process the stream with our reusable utility
|
||||
const streamResult = await processProviderStream(
|
||||
streamIterator,
|
||||
{
|
||||
providerName: this.getName(),
|
||||
modelName: providerOptions.model
|
||||
},
|
||||
opts.streamCallback
|
||||
);
|
||||
|
||||
// Create the final response after streaming is complete
|
||||
return {
|
||||
text: streamResult.completeText,
|
||||
model: providerOptions.model,
|
||||
provider: this.getName(),
|
||||
tool_calls: this.transformToolCalls(streamResult.toolCalls),
|
||||
usage: extractStreamStats(streamResult.finalChunk, this.getName())
|
||||
};
|
||||
} catch (error) {
|
||||
log.error(`Error in Ollama streaming with callback: ${error}`);
|
||||
log.error(`Error details: ${error instanceof Error ? error.stack : 'No stack trace available'}`);
|
||||
throw error;
|
||||
}
|
||||
} else {
|
||||
// Create a stream handler using our reusable StreamProcessor
|
||||
const streamHandler = createStreamHandler(
|
||||
{
|
||||
@ -282,23 +331,11 @@ export class OllamaService extends BaseAIService {
|
||||
let chunkCount = 0;
|
||||
|
||||
try {
|
||||
// Create streaming request
|
||||
const streamingRequest = {
|
||||
...requestOptions,
|
||||
stream: true as const
|
||||
};
|
||||
|
||||
log.info(`Creating Ollama streaming request with options: model=${streamingRequest.model}, stream=${streamingRequest.stream}, tools=${streamingRequest.tools ? streamingRequest.tools.length : 0}`);
|
||||
|
||||
// Perform health check
|
||||
try {
|
||||
log.info(`Performing Ollama health check...`);
|
||||
const healthCheck = await client.list();
|
||||
log.info(`Ollama health check successful. Available models: ${healthCheck.models.map(m => m.name).join(', ')}`);
|
||||
} catch (healthError) {
|
||||
log.error(`Ollama health check failed: ${healthError instanceof Error ? healthError.message : String(healthError)}`);
|
||||
throw new Error(`Unable to connect to Ollama server: ${healthError instanceof Error ? healthError.message : String(healthError)}`);
|
||||
}
|
||||
await performProviderHealthCheck(
|
||||
async () => await client.list(),
|
||||
this.getName()
|
||||
);
|
||||
|
||||
// Get the stream iterator
|
||||
log.info(`Getting stream iterator from Ollama`);
|
||||
@ -350,157 +387,15 @@ export class OllamaService extends BaseAIService {
|
||||
}
|
||||
);
|
||||
|
||||
// Handle direct streamCallback if provided
|
||||
if (opts.streamCallback) {
|
||||
let completeText = '';
|
||||
let responseToolCalls: any[] = [];
|
||||
let finalChunk: OllamaChatResponse | null = null;
|
||||
let chunkCount = 0;
|
||||
|
||||
try {
|
||||
// Create streaming request
|
||||
const streamingRequest = {
|
||||
...requestOptions,
|
||||
stream: true as const
|
||||
};
|
||||
|
||||
log.info(`Starting Ollama direct streamCallback processing with model ${providerOptions.model}`);
|
||||
|
||||
// Get the async iterator
|
||||
log.info(`Calling Ollama chat API for direct streaming`);
|
||||
let streamIterator;
|
||||
try {
|
||||
log.info(`About to call client.chat with streaming request to ${options.getOption('ollamaBaseUrl')}`);
|
||||
log.info(`Model: ${streamingRequest.model}, Stream: ${streamingRequest.stream}`);
|
||||
log.info(`Messages count: ${streamingRequest.messages.length}`);
|
||||
log.info(`First message: role=${streamingRequest.messages[0].role}, content preview=${streamingRequest.messages[0].content?.substring(0, 50) || 'empty'}`);
|
||||
|
||||
// Perform health check before streaming
|
||||
try {
|
||||
log.info(`Performing Ollama health check before direct streaming...`);
|
||||
const healthCheck = await client.list();
|
||||
log.info(`Ollama health check successful. Available models: ${healthCheck.models.map(m => m.name).join(', ')}`);
|
||||
} catch (healthError) {
|
||||
log.error(`Ollama health check failed: ${healthError instanceof Error ? healthError.message : String(healthError)}`);
|
||||
log.error(`This indicates a connection issue to the Ollama server at ${options.getOption('ollamaBaseUrl')}`);
|
||||
throw new Error(`Unable to connect to Ollama server: ${healthError instanceof Error ? healthError.message : String(healthError)}`);
|
||||
}
|
||||
|
||||
// Proceed with streaming after successful health check
|
||||
log.info(`Making Ollama streaming request after successful health check`);
|
||||
streamIterator = await client.chat(streamingRequest);
|
||||
|
||||
log.info(`Successfully obtained Ollama stream iterator for direct callback`);
|
||||
|
||||
// Check if the stream iterator is valid
|
||||
if (!streamIterator || typeof streamIterator[Symbol.asyncIterator] !== 'function') {
|
||||
log.error(`Invalid stream iterator returned from Ollama: ${JSON.stringify(streamIterator)}`);
|
||||
throw new Error('Invalid stream iterator returned from Ollama');
|
||||
}
|
||||
|
||||
log.info(`Stream iterator is valid, beginning processing`);
|
||||
} catch (error) {
|
||||
log.error(`Error getting stream iterator from Ollama: ${error instanceof Error ? error.message : String(error)}`);
|
||||
log.error(`Error stack: ${error instanceof Error ? error.stack : 'No stack trace'}`);
|
||||
throw error;
|
||||
}
|
||||
|
||||
// Process each chunk
|
||||
try {
|
||||
log.info(`Starting to iterate through stream chunks`);
|
||||
for await (const chunk of streamIterator) {
|
||||
chunkCount++;
|
||||
finalChunk = chunk;
|
||||
|
||||
// Process chunk with StreamProcessor
|
||||
const result = await StreamProcessor.processChunk(
|
||||
chunk,
|
||||
completeText,
|
||||
chunkCount,
|
||||
{ providerName: this.getName(), modelName: providerOptions.model }
|
||||
);
|
||||
|
||||
completeText = result.completeText;
|
||||
|
||||
// Extract tool calls
|
||||
const toolCalls = StreamProcessor.extractToolCalls(chunk);
|
||||
if (toolCalls.length > 0) {
|
||||
responseToolCalls = toolCalls;
|
||||
}
|
||||
|
||||
// Call the callback with the current chunk content
|
||||
if (opts.streamCallback) {
|
||||
// For chunks with content, send the content directly
|
||||
if (chunk.message?.content) {
|
||||
log.info(`Sending direct chunk #${chunkCount} with content: "${chunk.message.content.substring(0, 50)}${chunk.message.content.length > 50 ? '...' : ''}"`);
|
||||
|
||||
await StreamProcessor.sendChunkToCallback(
|
||||
opts.streamCallback,
|
||||
chunk.message.content,
|
||||
!!chunk.done, // Mark as done if done flag is set
|
||||
chunk,
|
||||
chunkCount
|
||||
);
|
||||
} else if (chunk.done) {
|
||||
// Send empty done message for final chunk with no content
|
||||
await StreamProcessor.sendChunkToCallback(
|
||||
opts.streamCallback,
|
||||
'',
|
||||
true,
|
||||
chunk,
|
||||
chunkCount
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// If this is the done chunk, log it
|
||||
if (chunk.done && !result.logged) {
|
||||
log.info(`Reached final direct chunk (done=true) after ${chunkCount} chunks, total content length: ${completeText.length}`);
|
||||
}
|
||||
}
|
||||
|
||||
// Send one final callback with done=true after all chunks have been processed
|
||||
// Only send this if the last chunk didn't already have done=true
|
||||
if (opts.streamCallback && (!finalChunk || !finalChunk.done)) {
|
||||
log.info(`Sending explicit final callback with done=true flag after all chunks processed`);
|
||||
await StreamProcessor.sendFinalCallback(opts.streamCallback, completeText);
|
||||
}
|
||||
|
||||
log.info(`Completed direct streaming from Ollama: processed ${chunkCount} chunks, final content: ${completeText.length} chars`);
|
||||
} catch (iterationError) {
|
||||
log.error(`Error iterating through Ollama stream chunks: ${iterationError instanceof Error ? iterationError.message : String(iterationError)}`);
|
||||
log.error(`Iteration error stack: ${iterationError instanceof Error ? iterationError.stack : 'No stack trace'}`);
|
||||
throw iterationError;
|
||||
}
|
||||
|
||||
// Create the final response after streaming is complete
|
||||
return StreamProcessor.createFinalResponse(
|
||||
completeText,
|
||||
providerOptions.model,
|
||||
this.getName(),
|
||||
this.transformToolCalls(responseToolCalls),
|
||||
{
|
||||
promptTokens: finalChunk?.prompt_eval_count || 0,
|
||||
completionTokens: finalChunk?.eval_count || 0,
|
||||
totalTokens: (finalChunk?.prompt_eval_count || 0) + (finalChunk?.eval_count || 0)
|
||||
}
|
||||
);
|
||||
} catch (error) {
|
||||
log.error(`Error in Ollama streaming with callback: ${error}`);
|
||||
log.error(`Error details: ${error instanceof Error ? error.stack : 'No stack trace available'}`);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
// Return a response object with the stream handler
|
||||
return {
|
||||
text: '', // Initial text is empty, will be populated during streaming
|
||||
model: providerOptions.model,
|
||||
provider: this.getName(),
|
||||
stream: streamHandler as (callback: (chunk: StreamChunk) => Promise<void> | void) => Promise<string>
|
||||
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Transform Ollama tool calls to the standard format expected by the pipeline
|
||||
|
@ -210,3 +210,194 @@ export function createStreamHandler(
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Interface for provider-specific stream options
|
||||
*/
|
||||
export interface ProviderStreamOptions {
|
||||
providerName: string;
|
||||
modelName: string;
|
||||
apiConfig?: any;
|
||||
}
|
||||
|
||||
/**
|
||||
* Interface for streaming response stats
|
||||
*/
|
||||
export interface StreamStats {
|
||||
promptTokens?: number;
|
||||
completionTokens?: number;
|
||||
totalTokens?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform a health check against an API endpoint
|
||||
* @param checkFn Function that performs the actual health check API call
|
||||
* @param providerName Name of the provider for logging
|
||||
* @returns Promise resolving to true if healthy, or throwing an error if not
|
||||
*/
|
||||
export async function performProviderHealthCheck(
|
||||
checkFn: () => Promise<any>,
|
||||
providerName: string
|
||||
): Promise<boolean> {
|
||||
try {
|
||||
log.info(`Performing ${providerName} health check...`);
|
||||
const healthResponse = await checkFn();
|
||||
log.info(`${providerName} health check successful`);
|
||||
return true;
|
||||
} catch (healthError) {
|
||||
log.error(`${providerName} health check failed: ${healthError instanceof Error ? healthError.message : String(healthError)}`);
|
||||
throw new Error(`Unable to connect to ${providerName} server: ${healthError instanceof Error ? healthError.message : String(healthError)}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process a stream from an LLM provider using a callback-based approach
|
||||
* @param streamIterator Async iterator returned from the provider's API
|
||||
* @param options Provider information and configuration
|
||||
* @param streamCallback Optional callback function for streaming updates
|
||||
* @returns Promise resolving to the complete response including text and tool calls
|
||||
*/
|
||||
export async function processProviderStream(
|
||||
streamIterator: AsyncIterable<any>,
|
||||
options: ProviderStreamOptions,
|
||||
streamCallback?: (text: string, done: boolean, chunk?: any) => Promise<void> | void
|
||||
): Promise<{
|
||||
completeText: string;
|
||||
toolCalls: any[];
|
||||
finalChunk: any | null;
|
||||
chunkCount: number;
|
||||
}> {
|
||||
let completeText = '';
|
||||
let responseToolCalls: any[] = [];
|
||||
let finalChunk: any | null = null;
|
||||
let chunkCount = 0;
|
||||
|
||||
try {
|
||||
log.info(`Starting ${options.providerName} stream processing with model ${options.modelName}`);
|
||||
|
||||
// Validate stream iterator
|
||||
if (!streamIterator || typeof streamIterator[Symbol.asyncIterator] !== 'function') {
|
||||
log.error(`Invalid stream iterator returned from ${options.providerName}`);
|
||||
throw new Error(`Invalid stream iterator returned from ${options.providerName}`);
|
||||
}
|
||||
|
||||
// Process each chunk
|
||||
for await (const chunk of streamIterator) {
|
||||
chunkCount++;
|
||||
finalChunk = chunk;
|
||||
|
||||
// Process chunk with StreamProcessor
|
||||
const result = await StreamProcessor.processChunk(
|
||||
chunk,
|
||||
completeText,
|
||||
chunkCount,
|
||||
{ providerName: options.providerName, modelName: options.modelName }
|
||||
);
|
||||
|
||||
completeText = result.completeText;
|
||||
|
||||
// Extract tool calls
|
||||
const toolCalls = StreamProcessor.extractToolCalls(chunk);
|
||||
if (toolCalls.length > 0) {
|
||||
responseToolCalls = toolCalls;
|
||||
}
|
||||
|
||||
// Call the callback with the current chunk content if provided
|
||||
if (streamCallback) {
|
||||
// For chunks with content, send the content directly
|
||||
const contentProperty = getChunkContentProperty(chunk);
|
||||
if (contentProperty) {
|
||||
await StreamProcessor.sendChunkToCallback(
|
||||
streamCallback,
|
||||
contentProperty,
|
||||
!!chunk.done, // Mark as done if done flag is set
|
||||
chunk,
|
||||
chunkCount
|
||||
);
|
||||
} else if (chunk.done) {
|
||||
// Send empty done message for final chunk with no content
|
||||
await StreamProcessor.sendChunkToCallback(
|
||||
streamCallback,
|
||||
'',
|
||||
true,
|
||||
chunk,
|
||||
chunkCount
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Log final chunk
|
||||
if (chunk.done && !result.logged) {
|
||||
log.info(`Reached final chunk (done=true) after ${chunkCount} chunks, total content length: ${completeText.length}`);
|
||||
}
|
||||
}
|
||||
|
||||
// Send one final callback with done=true if the last chunk didn't have done=true
|
||||
if (streamCallback && (!finalChunk || !finalChunk.done)) {
|
||||
log.info(`Sending explicit final callback with done=true flag after all chunks processed`);
|
||||
await StreamProcessor.sendFinalCallback(streamCallback, completeText);
|
||||
}
|
||||
|
||||
log.info(`Completed ${options.providerName} streaming: processed ${chunkCount} chunks, final content: ${completeText.length} chars`);
|
||||
|
||||
return {
|
||||
completeText,
|
||||
toolCalls: responseToolCalls,
|
||||
finalChunk,
|
||||
chunkCount
|
||||
};
|
||||
} catch (error) {
|
||||
log.error(`Error in ${options.providerName} stream processing: ${error instanceof Error ? error.message : String(error)}`);
|
||||
log.error(`Error details: ${error instanceof Error ? error.stack : 'No stack trace available'}`);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to extract content from a chunk based on provider's response format
|
||||
* Different providers may have different chunk structures
|
||||
*/
|
||||
function getChunkContentProperty(chunk: any): string | null {
|
||||
// Check common content locations in different provider responses
|
||||
if (chunk.message?.content) {
|
||||
return chunk.message.content;
|
||||
}
|
||||
if (chunk.content) {
|
||||
return chunk.content;
|
||||
}
|
||||
if (chunk.choices?.[0]?.delta?.content) {
|
||||
return chunk.choices[0].delta.content;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract usage statistics from the final chunk based on provider format
|
||||
*/
|
||||
export function extractStreamStats(finalChunk: any | null, providerName: string): StreamStats {
|
||||
// Handle provider-specific response formats
|
||||
if (!finalChunk) {
|
||||
return { promptTokens: 0, completionTokens: 0, totalTokens: 0 };
|
||||
}
|
||||
|
||||
// Ollama format
|
||||
if (finalChunk.prompt_eval_count !== undefined && finalChunk.eval_count !== undefined) {
|
||||
return {
|
||||
promptTokens: finalChunk.prompt_eval_count || 0,
|
||||
completionTokens: finalChunk.eval_count || 0,
|
||||
totalTokens: (finalChunk.prompt_eval_count || 0) + (finalChunk.eval_count || 0)
|
||||
};
|
||||
}
|
||||
|
||||
// OpenAI-like format
|
||||
if (finalChunk.usage) {
|
||||
return {
|
||||
promptTokens: finalChunk.usage.prompt_tokens || 0,
|
||||
completionTokens: finalChunk.usage.completion_tokens || 0,
|
||||
totalTokens: finalChunk.usage.total_tokens || 0
|
||||
};
|
||||
}
|
||||
|
||||
log.info(`No standard token usage found in ${providerName} final chunk`);
|
||||
return { promptTokens: 0, completionTokens: 0, totalTokens: 0 };
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user