From 59a358a3ee43f5917842c9cbd7ddbf219d28a847 Mon Sep 17 00:00:00 2001 From: perf3ct Date: Wed, 9 Apr 2025 19:21:34 +0000 Subject: [PATCH] use this new providerMetadata approach --- src/services/llm/ai_interface.ts | 2 + .../pipeline/stages/llm_completion_stage.ts | 26 ++- .../pipeline/stages/model_selection_stage.ts | 150 ++++++++++++++++-- src/services/llm/rest_chat_service.ts | 3 +- 4 files changed, 164 insertions(+), 17 deletions(-) diff --git a/src/services/llm/ai_interface.ts b/src/services/llm/ai_interface.ts index 2e555c227..1bef407e4 100644 --- a/src/services/llm/ai_interface.ts +++ b/src/services/llm/ai_interface.ts @@ -1,4 +1,5 @@ import type { ToolCall } from './tools/tool_interfaces.js'; +import type { ModelMetadata } from './providers/provider_options.js'; export interface Message { role: 'user' | 'assistant' | 'system' | 'tool'; @@ -36,6 +37,7 @@ export interface ChatCompletionOptions { tools?: any[]; // Tools to provide to the LLM useAdvancedContext?: boolean; // Whether to use advanced context enrichment toolExecutionStatus?: any[]; // Status information about executed tools for feedback + providerMetadata?: ModelMetadata; // Metadata about the provider and model capabilities } export interface ChatResponse { diff --git a/src/services/llm/pipeline/stages/llm_completion_stage.ts b/src/services/llm/pipeline/stages/llm_completion_stage.ts index 8171a5a73..5efe108b0 100644 --- a/src/services/llm/pipeline/stages/llm_completion_stage.ts +++ b/src/services/llm/pipeline/stages/llm_completion_stage.ts @@ -18,15 +18,15 @@ export class LLMCompletionStage extends BasePipelineStage { const { messages, options, provider } = input; - + // Create a copy of options to avoid modifying the original const updatedOptions = { ...options }; - + // Check if tools should be enabled if (updatedOptions.enableTools !== false) { // Get all available tools from the registry const toolDefinitions = toolRegistry.getAllToolDefinitions(); - + if (toolDefinitions.length > 0) { // Enable tools and add them to the options updatedOptions.enableTools = true; @@ -35,11 +35,23 @@ export class LLMCompletionStage extends BasePipelineStage 0) { const firstProvider = providers[0]; + defaultProvider = firstProvider; // Get provider-specific default model if (firstProvider === 'openai') { const model = await options.getOption('openaiDefaultModel'); - if (model) defaultModel = `openai:${model}`; + if (model) defaultModelName = model; } else if (firstProvider === 'anthropic') { const model = await options.getOption('anthropicDefaultModel'); - if (model) defaultModel = `anthropic:${model}`; + if (model) defaultModelName = model; } else if (firstProvider === 'ollama') { const model = await options.getOption('ollamaDefaultModel'); if (model) { - defaultModel = `ollama:${model}`; + defaultModelName = model; // Enable tools for all Ollama models // The Ollama API will handle models that don't support tool calling @@ -130,9 +146,125 @@ export class ModelSelectionStage extends BasePipelineStage 10000 ? 'high' : 'medium'; } - updatedOptions.model = defaultModel; + // Set the model and add provider metadata + updatedOptions.model = defaultModelName; + this.addProviderMetadata(updatedOptions, defaultProvider, defaultModelName); - log.info(`Selected model: ${updatedOptions.model} for query complexity: ${queryComplexity}`); + log.info(`Selected model: ${defaultModelName} from provider: ${defaultProvider} for query complexity: ${queryComplexity}`); return { options: updatedOptions }; } + + /** + * Helper to parse model identifier with provider prefix + * Handles legacy format "provider:model" + */ + private parseModelIdentifier(modelId: string): { provider?: string, model: string } { + if (!modelId) return { model: '' }; + + const parts = modelId.split(':'); + if (parts.length === 1) { + // No provider prefix + return { model: modelId }; + } else { + // Extract provider and model + const provider = parts[0]; + const model = parts.slice(1).join(':'); // Handle model names that might include : + return { provider, model }; + } + } + + /** + * Add provider metadata to the options based on model name + */ + private addProviderMetadata(options: ChatCompletionOptions, provider: string, modelName: string): void { + // Check if we already have providerMetadata + if (options.providerMetadata) { + // If providerMetadata exists but not modelId, add the model name + if (!options.providerMetadata.modelId && modelName) { + options.providerMetadata.modelId = modelName; + } + return; + } + + // If no provider could be determined, try to use precedence + let selectedProvider = provider; + if (!selectedProvider) { + // List of providers in precedence order + const providerPrecedence = ['anthropic', 'openai', 'ollama']; + + // Find the first available provider + for (const p of providerPrecedence) { + if (aiServiceManager.isProviderAvailable(p)) { + selectedProvider = p; + break; + } + } + } + + // Set the provider metadata in the options + if (selectedProvider) { + // Ensure the provider is one of the valid types + const validProvider = selectedProvider as 'openai' | 'anthropic' | 'ollama' | 'local'; + + options.providerMetadata = { + provider: validProvider, + modelId: modelName + }; + + // For backward compatibility, ensure model name is set without prefix + if (options.model && options.model.includes(':')) { + options.model = modelName || options.model.split(':')[1]; + } + + log.info(`Set provider metadata: provider=${selectedProvider}, model=${modelName}`); + } + } + + /** + * Determine model based on provider precedence + */ + private determineDefaultModel(input: ModelSelectionInput): string { + const providerPrecedence = ['anthropic', 'openai', 'ollama']; + + // Use only providers that are available + const availableProviders = providerPrecedence.filter(provider => + aiServiceManager.isProviderAvailable(provider)); + + if (availableProviders.length === 0) { + throw new Error('No AI providers are available'); + } + + // Get the first available provider and its default model + const defaultProvider = availableProviders[0] as 'openai' | 'anthropic' | 'ollama' | 'local'; + let defaultModel = 'gpt-3.5-turbo'; // Default fallback + + // Set provider metadata + if (!input.options.providerMetadata) { + input.options.providerMetadata = { + provider: defaultProvider, + modelId: defaultModel + }; + } + + log.info(`Selected default model ${defaultModel} from provider ${defaultProvider}`); + return defaultModel; + } + + /** + * Get estimated context window for Ollama models + */ + private getOllamaContextWindow(model: string): number { + // Estimate based on model family + if (model.includes('llama3')) { + return 8192; + } else if (model.includes('llama2')) { + return 4096; + } else if (model.includes('mistral') || model.includes('mixtral')) { + return 8192; + } else if (model.includes('gemma')) { + return 8192; + } else { + return 4096; // Default fallback + } + } } diff --git a/src/services/llm/rest_chat_service.ts b/src/services/llm/rest_chat_service.ts index 39b962ceb..7a4845e8b 100644 --- a/src/services/llm/rest_chat_service.ts +++ b/src/services/llm/rest_chat_service.ts @@ -457,7 +457,8 @@ class RestChatService { systemPrompt: session.messages.find(m => m.role === 'system')?.content, temperature: session.metadata.temperature, maxTokens: session.metadata.maxTokens, - model: session.metadata.model + model: session.metadata.model, + stream: req.method === 'GET' ? true : undefined // Explicitly set stream: true for GET requests }, streamCallback: req.method === 'GET' ? (data, done) => { res.write(`data: ${JSON.stringify({ content: data, done })}\n\n`);