put more websocket logic into the stream handler

This commit is contained in:
perf3ct 2025-04-13 19:44:04 +00:00
parent d1edf59f97
commit c9bb0fb219
No known key found for this signature in database
GPG Key ID: 569C4EEC436F5232
2 changed files with 312 additions and 226 deletions

View File

@ -8,7 +8,13 @@ import type { OllamaOptions } from './provider_options.js';
import { getOllamaOptions } from './providers.js';
import { Ollama, type ChatRequest, type ChatResponse as OllamaChatResponse } from 'ollama';
import options from '../../options.js';
import { StreamProcessor, createStreamHandler } from './stream_handler.js';
import {
StreamProcessor,
createStreamHandler,
performProviderHealthCheck,
processProviderStream,
extractStreamStats
} from './stream_handler.js';
// Add an interface for tool execution feedback status
interface ToolExecutionStatus {
@ -256,7 +262,7 @@ export class OllamaService extends BaseAIService {
/**
* Handle streaming response from Ollama
*
* Simplified implementation that leverages the Ollama SDK's streaming capabilities
* Uses reusable stream handling utilities for processing
*/
private async handleStreamingResponse(
client: Ollama,
@ -269,6 +275,49 @@ export class OllamaService extends BaseAIService {
// Log detailed information about the streaming setup
log.info(`Ollama streaming details: model=${providerOptions.model}, streamCallback=${opts.streamCallback ? 'provided' : 'not provided'}`);
// Create streaming request
const streamingRequest = {
...requestOptions,
stream: true as const
};
// Handle direct streamCallback if provided
if (opts.streamCallback) {
try {
// Perform health check before streaming
await performProviderHealthCheck(
async () => await client.list(),
this.getName()
);
log.info(`Making Ollama streaming request after successful health check`);
// Get the stream iterator
const streamIterator = await client.chat(streamingRequest);
// Process the stream with our reusable utility
const streamResult = await processProviderStream(
streamIterator,
{
providerName: this.getName(),
modelName: providerOptions.model
},
opts.streamCallback
);
// Create the final response after streaming is complete
return {
text: streamResult.completeText,
model: providerOptions.model,
provider: this.getName(),
tool_calls: this.transformToolCalls(streamResult.toolCalls),
usage: extractStreamStats(streamResult.finalChunk, this.getName())
};
} catch (error) {
log.error(`Error in Ollama streaming with callback: ${error}`);
log.error(`Error details: ${error instanceof Error ? error.stack : 'No stack trace available'}`);
throw error;
}
} else {
// Create a stream handler using our reusable StreamProcessor
const streamHandler = createStreamHandler(
{
@ -282,23 +331,11 @@ export class OllamaService extends BaseAIService {
let chunkCount = 0;
try {
// Create streaming request
const streamingRequest = {
...requestOptions,
stream: true as const
};
log.info(`Creating Ollama streaming request with options: model=${streamingRequest.model}, stream=${streamingRequest.stream}, tools=${streamingRequest.tools ? streamingRequest.tools.length : 0}`);
// Perform health check
try {
log.info(`Performing Ollama health check...`);
const healthCheck = await client.list();
log.info(`Ollama health check successful. Available models: ${healthCheck.models.map(m => m.name).join(', ')}`);
} catch (healthError) {
log.error(`Ollama health check failed: ${healthError instanceof Error ? healthError.message : String(healthError)}`);
throw new Error(`Unable to connect to Ollama server: ${healthError instanceof Error ? healthError.message : String(healthError)}`);
}
await performProviderHealthCheck(
async () => await client.list(),
this.getName()
);
// Get the stream iterator
log.info(`Getting stream iterator from Ollama`);
@ -350,157 +387,15 @@ export class OllamaService extends BaseAIService {
}
);
// Handle direct streamCallback if provided
if (opts.streamCallback) {
let completeText = '';
let responseToolCalls: any[] = [];
let finalChunk: OllamaChatResponse | null = null;
let chunkCount = 0;
try {
// Create streaming request
const streamingRequest = {
...requestOptions,
stream: true as const
};
log.info(`Starting Ollama direct streamCallback processing with model ${providerOptions.model}`);
// Get the async iterator
log.info(`Calling Ollama chat API for direct streaming`);
let streamIterator;
try {
log.info(`About to call client.chat with streaming request to ${options.getOption('ollamaBaseUrl')}`);
log.info(`Model: ${streamingRequest.model}, Stream: ${streamingRequest.stream}`);
log.info(`Messages count: ${streamingRequest.messages.length}`);
log.info(`First message: role=${streamingRequest.messages[0].role}, content preview=${streamingRequest.messages[0].content?.substring(0, 50) || 'empty'}`);
// Perform health check before streaming
try {
log.info(`Performing Ollama health check before direct streaming...`);
const healthCheck = await client.list();
log.info(`Ollama health check successful. Available models: ${healthCheck.models.map(m => m.name).join(', ')}`);
} catch (healthError) {
log.error(`Ollama health check failed: ${healthError instanceof Error ? healthError.message : String(healthError)}`);
log.error(`This indicates a connection issue to the Ollama server at ${options.getOption('ollamaBaseUrl')}`);
throw new Error(`Unable to connect to Ollama server: ${healthError instanceof Error ? healthError.message : String(healthError)}`);
}
// Proceed with streaming after successful health check
log.info(`Making Ollama streaming request after successful health check`);
streamIterator = await client.chat(streamingRequest);
log.info(`Successfully obtained Ollama stream iterator for direct callback`);
// Check if the stream iterator is valid
if (!streamIterator || typeof streamIterator[Symbol.asyncIterator] !== 'function') {
log.error(`Invalid stream iterator returned from Ollama: ${JSON.stringify(streamIterator)}`);
throw new Error('Invalid stream iterator returned from Ollama');
}
log.info(`Stream iterator is valid, beginning processing`);
} catch (error) {
log.error(`Error getting stream iterator from Ollama: ${error instanceof Error ? error.message : String(error)}`);
log.error(`Error stack: ${error instanceof Error ? error.stack : 'No stack trace'}`);
throw error;
}
// Process each chunk
try {
log.info(`Starting to iterate through stream chunks`);
for await (const chunk of streamIterator) {
chunkCount++;
finalChunk = chunk;
// Process chunk with StreamProcessor
const result = await StreamProcessor.processChunk(
chunk,
completeText,
chunkCount,
{ providerName: this.getName(), modelName: providerOptions.model }
);
completeText = result.completeText;
// Extract tool calls
const toolCalls = StreamProcessor.extractToolCalls(chunk);
if (toolCalls.length > 0) {
responseToolCalls = toolCalls;
}
// Call the callback with the current chunk content
if (opts.streamCallback) {
// For chunks with content, send the content directly
if (chunk.message?.content) {
log.info(`Sending direct chunk #${chunkCount} with content: "${chunk.message.content.substring(0, 50)}${chunk.message.content.length > 50 ? '...' : ''}"`);
await StreamProcessor.sendChunkToCallback(
opts.streamCallback,
chunk.message.content,
!!chunk.done, // Mark as done if done flag is set
chunk,
chunkCount
);
} else if (chunk.done) {
// Send empty done message for final chunk with no content
await StreamProcessor.sendChunkToCallback(
opts.streamCallback,
'',
true,
chunk,
chunkCount
);
}
}
// If this is the done chunk, log it
if (chunk.done && !result.logged) {
log.info(`Reached final direct chunk (done=true) after ${chunkCount} chunks, total content length: ${completeText.length}`);
}
}
// Send one final callback with done=true after all chunks have been processed
// Only send this if the last chunk didn't already have done=true
if (opts.streamCallback && (!finalChunk || !finalChunk.done)) {
log.info(`Sending explicit final callback with done=true flag after all chunks processed`);
await StreamProcessor.sendFinalCallback(opts.streamCallback, completeText);
}
log.info(`Completed direct streaming from Ollama: processed ${chunkCount} chunks, final content: ${completeText.length} chars`);
} catch (iterationError) {
log.error(`Error iterating through Ollama stream chunks: ${iterationError instanceof Error ? iterationError.message : String(iterationError)}`);
log.error(`Iteration error stack: ${iterationError instanceof Error ? iterationError.stack : 'No stack trace'}`);
throw iterationError;
}
// Create the final response after streaming is complete
return StreamProcessor.createFinalResponse(
completeText,
providerOptions.model,
this.getName(),
this.transformToolCalls(responseToolCalls),
{
promptTokens: finalChunk?.prompt_eval_count || 0,
completionTokens: finalChunk?.eval_count || 0,
totalTokens: (finalChunk?.prompt_eval_count || 0) + (finalChunk?.eval_count || 0)
}
);
} catch (error) {
log.error(`Error in Ollama streaming with callback: ${error}`);
log.error(`Error details: ${error instanceof Error ? error.stack : 'No stack trace available'}`);
throw error;
}
}
// Return a response object with the stream handler
return {
text: '', // Initial text is empty, will be populated during streaming
model: providerOptions.model,
provider: this.getName(),
stream: streamHandler as (callback: (chunk: StreamChunk) => Promise<void> | void) => Promise<string>
};
}
}
/**
* Transform Ollama tool calls to the standard format expected by the pipeline

View File

@ -210,3 +210,194 @@ export function createStreamHandler(
}
};
}
/**
* Interface for provider-specific stream options
*/
export interface ProviderStreamOptions {
providerName: string;
modelName: string;
apiConfig?: any;
}
/**
* Interface for streaming response stats
*/
export interface StreamStats {
promptTokens?: number;
completionTokens?: number;
totalTokens?: number;
}
/**
* Perform a health check against an API endpoint
* @param checkFn Function that performs the actual health check API call
* @param providerName Name of the provider for logging
* @returns Promise resolving to true if healthy, or throwing an error if not
*/
export async function performProviderHealthCheck(
checkFn: () => Promise<any>,
providerName: string
): Promise<boolean> {
try {
log.info(`Performing ${providerName} health check...`);
const healthResponse = await checkFn();
log.info(`${providerName} health check successful`);
return true;
} catch (healthError) {
log.error(`${providerName} health check failed: ${healthError instanceof Error ? healthError.message : String(healthError)}`);
throw new Error(`Unable to connect to ${providerName} server: ${healthError instanceof Error ? healthError.message : String(healthError)}`);
}
}
/**
* Process a stream from an LLM provider using a callback-based approach
* @param streamIterator Async iterator returned from the provider's API
* @param options Provider information and configuration
* @param streamCallback Optional callback function for streaming updates
* @returns Promise resolving to the complete response including text and tool calls
*/
export async function processProviderStream(
streamIterator: AsyncIterable<any>,
options: ProviderStreamOptions,
streamCallback?: (text: string, done: boolean, chunk?: any) => Promise<void> | void
): Promise<{
completeText: string;
toolCalls: any[];
finalChunk: any | null;
chunkCount: number;
}> {
let completeText = '';
let responseToolCalls: any[] = [];
let finalChunk: any | null = null;
let chunkCount = 0;
try {
log.info(`Starting ${options.providerName} stream processing with model ${options.modelName}`);
// Validate stream iterator
if (!streamIterator || typeof streamIterator[Symbol.asyncIterator] !== 'function') {
log.error(`Invalid stream iterator returned from ${options.providerName}`);
throw new Error(`Invalid stream iterator returned from ${options.providerName}`);
}
// Process each chunk
for await (const chunk of streamIterator) {
chunkCount++;
finalChunk = chunk;
// Process chunk with StreamProcessor
const result = await StreamProcessor.processChunk(
chunk,
completeText,
chunkCount,
{ providerName: options.providerName, modelName: options.modelName }
);
completeText = result.completeText;
// Extract tool calls
const toolCalls = StreamProcessor.extractToolCalls(chunk);
if (toolCalls.length > 0) {
responseToolCalls = toolCalls;
}
// Call the callback with the current chunk content if provided
if (streamCallback) {
// For chunks with content, send the content directly
const contentProperty = getChunkContentProperty(chunk);
if (contentProperty) {
await StreamProcessor.sendChunkToCallback(
streamCallback,
contentProperty,
!!chunk.done, // Mark as done if done flag is set
chunk,
chunkCount
);
} else if (chunk.done) {
// Send empty done message for final chunk with no content
await StreamProcessor.sendChunkToCallback(
streamCallback,
'',
true,
chunk,
chunkCount
);
}
}
// Log final chunk
if (chunk.done && !result.logged) {
log.info(`Reached final chunk (done=true) after ${chunkCount} chunks, total content length: ${completeText.length}`);
}
}
// Send one final callback with done=true if the last chunk didn't have done=true
if (streamCallback && (!finalChunk || !finalChunk.done)) {
log.info(`Sending explicit final callback with done=true flag after all chunks processed`);
await StreamProcessor.sendFinalCallback(streamCallback, completeText);
}
log.info(`Completed ${options.providerName} streaming: processed ${chunkCount} chunks, final content: ${completeText.length} chars`);
return {
completeText,
toolCalls: responseToolCalls,
finalChunk,
chunkCount
};
} catch (error) {
log.error(`Error in ${options.providerName} stream processing: ${error instanceof Error ? error.message : String(error)}`);
log.error(`Error details: ${error instanceof Error ? error.stack : 'No stack trace available'}`);
throw error;
}
}
/**
* Helper function to extract content from a chunk based on provider's response format
* Different providers may have different chunk structures
*/
function getChunkContentProperty(chunk: any): string | null {
// Check common content locations in different provider responses
if (chunk.message?.content) {
return chunk.message.content;
}
if (chunk.content) {
return chunk.content;
}
if (chunk.choices?.[0]?.delta?.content) {
return chunk.choices[0].delta.content;
}
return null;
}
/**
* Extract usage statistics from the final chunk based on provider format
*/
export function extractStreamStats(finalChunk: any | null, providerName: string): StreamStats {
// Handle provider-specific response formats
if (!finalChunk) {
return { promptTokens: 0, completionTokens: 0, totalTokens: 0 };
}
// Ollama format
if (finalChunk.prompt_eval_count !== undefined && finalChunk.eval_count !== undefined) {
return {
promptTokens: finalChunk.prompt_eval_count || 0,
completionTokens: finalChunk.eval_count || 0,
totalTokens: (finalChunk.prompt_eval_count || 0) + (finalChunk.eval_count || 0)
};
}
// OpenAI-like format
if (finalChunk.usage) {
return {
promptTokens: finalChunk.usage.prompt_tokens || 0,
completionTokens: finalChunk.usage.completion_tokens || 0,
totalTokens: finalChunk.usage.total_tokens || 0
};
}
log.info(`No standard token usage found in ${providerName} final chunk`);
return { promptTokens: 0, completionTokens: 0, totalTokens: 0 };
}