diff --git a/src/services/llm/ai_interface.ts b/src/services/llm/ai_interface.ts index 310c73847..c539f33a8 100644 --- a/src/services/llm/ai_interface.ts +++ b/src/services/llm/ai_interface.ts @@ -69,7 +69,7 @@ export interface SemanticContextService { /** * Retrieve semantic context based on relevance to user query */ - getSemanticContext(noteId: string, userQuery: string, maxResults?: number): Promise; + getSemanticContext(noteId: string, userQuery: string, maxResults?: number, messages?: Message[]): Promise; /** * Get progressive context based on depth diff --git a/src/services/llm/context/modules/context_formatter.ts b/src/services/llm/context/modules/context_formatter.ts index f5d5ef47f..8a9819706 100644 --- a/src/services/llm/context/modules/context_formatter.ts +++ b/src/services/llm/context/modules/context_formatter.ts @@ -3,6 +3,9 @@ import log from '../../../log.js'; import { CONTEXT_PROMPTS, FORMATTING_PROMPTS } from '../../constants/llm_prompt_constants.js'; import { LLM_CONSTANTS } from '../../constants/provider_constants.js'; import type { IContextFormatter, NoteSearchResult } from '../../interfaces/context_interfaces.js'; +import modelCapabilitiesService from '../../model_capabilities_service.js'; +import { calculateAvailableContextSize } from '../../interfaces/model_capabilities.js'; +import type { Message } from '../../ai_interface.js'; // Use constants from the centralized file // const CONTEXT_WINDOW = { @@ -20,26 +23,46 @@ import type { IContextFormatter, NoteSearchResult } from '../../interfaces/conte */ export class ContextFormatter implements IContextFormatter { /** - * Build a structured context string from note sources + * Build formatted context from a list of note search results * * @param sources Array of note data with content and metadata * @param query The user's query for context * @param providerId Optional provider ID to customize formatting + * @param messages Optional conversation messages to adjust context size * @returns Formatted context string */ - async buildContextFromNotes(sources: NoteSearchResult[], query: string, providerId: string = 'default'): Promise { + async buildContextFromNotes( + sources: NoteSearchResult[], + query: string, + providerId: string = 'default', + messages: Message[] = [] + ): Promise { if (!sources || sources.length === 0) { log.info('No sources provided to context formatter'); return CONTEXT_PROMPTS.NO_NOTES_CONTEXT; } try { - // Get appropriate context size based on provider - const maxTotalLength = + // Get model name from provider + let modelName = providerId; + + // Look up model capabilities + const modelCapabilities = await modelCapabilitiesService.getModelCapabilities(modelName); + + // Calculate available context size for this conversation + const availableContextSize = calculateAvailableContextSize( + modelCapabilities, + messages, + 3 // Expected additional turns + ); + + // Use the calculated size or fall back to constants + const maxTotalLength = availableContextSize || ( providerId === 'openai' ? LLM_CONSTANTS.CONTEXT_WINDOW.OPENAI : providerId === 'anthropic' ? LLM_CONSTANTS.CONTEXT_WINDOW.ANTHROPIC : providerId === 'ollama' ? LLM_CONSTANTS.CONTEXT_WINDOW.OLLAMA : - LLM_CONSTANTS.CONTEXT_WINDOW.DEFAULT; + LLM_CONSTANTS.CONTEXT_WINDOW.DEFAULT + ); // DEBUG: Log context window size log.info(`Context window for provider ${providerId}: ${maxTotalLength} chars`); diff --git a/src/services/llm/context/modules/context_service.ts b/src/services/llm/context/modules/context_service.ts index c831af597..ac290c7ec 100644 --- a/src/services/llm/context/modules/context_service.ts +++ b/src/services/llm/context/modules/context_service.ts @@ -10,6 +10,7 @@ import { CONTEXT_PROMPTS } from '../../constants/llm_prompt_constants.js'; import becca from '../../../../becca/becca.js'; import type { NoteSearchResult } from '../../interfaces/context_interfaces.js'; import type { LLMServiceInterface } from '../../interfaces/agent_tool_interfaces.js'; +import type { Message } from '../../ai_interface.js'; /** * Main context service that integrates all context-related functionality @@ -635,14 +636,20 @@ export class ContextService { } /** - * Get semantic context for a note and query + * Get semantic context based on query * - * @param noteId - The base note ID - * @param userQuery - The user's query - * @param maxResults - Maximum number of results to include - * @returns Formatted context string + * @param noteId - Note ID to start from + * @param userQuery - User query for context + * @param maxResults - Maximum number of results + * @param messages - Optional conversation messages to adjust context size + * @returns Formatted context */ - async getSemanticContext(noteId: string, userQuery: string, maxResults: number = 5): Promise { + async getSemanticContext( + noteId: string, + userQuery: string, + maxResults: number = 5, + messages: Message[] = [] + ): Promise { if (!this.initialized) { await this.initialize(); } @@ -712,24 +719,39 @@ export class ContextService { // Get content for the top N most relevant notes const mostRelevantNotes = rankedNotes.slice(0, maxResults); - const relevantContent = await Promise.all( + + // Get relevant search results to pass to context formatter + const searchResults = await Promise.all( mostRelevantNotes.map(async note => { const content = await this.contextExtractor.getNoteContent(note.noteId); if (!content) return null; - // Format with relevance score and title - return `### ${note.title} (Relevance: ${Math.round(note.relevance * 100)}%)\n\n${content}`; + // Create a properly typed NoteSearchResult object + return { + noteId: note.noteId, + title: note.title, + content, + similarity: note.relevance + }; }) ); + // Filter out nulls and empty content + const validResults: NoteSearchResult[] = searchResults + .filter(result => result !== null && result.content && result.content.trim().length > 0) + .map(result => result as NoteSearchResult); + // If no content retrieved, return empty string - if (!relevantContent.filter(Boolean).length) { + if (validResults.length === 0) { return ''; } - return `# Relevant Context\n\nThe following notes are most relevant to your query:\n\n${ - relevantContent.filter(Boolean).join('\n\n---\n\n') - }`; + // Get the provider information for formatting + const provider = await providerManager.getPreferredEmbeddingProvider(); + const providerId = provider?.name || 'default'; + + // Format the context with the context formatter (which handles adjusting for conversation size) + return contextFormatter.buildContextFromNotes(validResults, userQuery, providerId, messages); } catch (error) { log.error(`Error getting semantic context: ${error}`); return ''; diff --git a/src/services/llm/context_service.ts b/src/services/llm/context_service.ts index c8a619421..074eb296b 100644 --- a/src/services/llm/context_service.ts +++ b/src/services/llm/context_service.ts @@ -154,10 +154,11 @@ class TriliumContextService { * @param noteId - The note ID * @param userQuery - The user's query * @param maxResults - Maximum results to include + * @param messages - Optional conversation messages to adjust context size * @returns Formatted context string */ - async getSemanticContext(noteId: string, userQuery: string, maxResults = 5): Promise { - return contextService.getSemanticContext(noteId, userQuery, maxResults); + async getSemanticContext(noteId: string, userQuery: string, maxResults = 5, messages: Message[] = []): Promise { + return contextService.getSemanticContext(noteId, userQuery, maxResults, messages); } /** diff --git a/src/services/llm/interfaces/context_interfaces.ts b/src/services/llm/interfaces/context_interfaces.ts index b82b6a77b..589d01503 100644 --- a/src/services/llm/interfaces/context_interfaces.ts +++ b/src/services/llm/interfaces/context_interfaces.ts @@ -46,7 +46,12 @@ export interface NoteSearchResult { * Interface for context formatter */ export interface IContextFormatter { - buildContextFromNotes(sources: NoteSearchResult[], query: string, providerId?: string): Promise; + buildContextFromNotes( + sources: NoteSearchResult[], + query: string, + providerId?: string, + messages?: Array<{role: string, content: string}> + ): Promise; } /** diff --git a/src/services/llm/interfaces/model_capabilities.ts b/src/services/llm/interfaces/model_capabilities.ts new file mode 100644 index 000000000..75dc4251b --- /dev/null +++ b/src/services/llm/interfaces/model_capabilities.ts @@ -0,0 +1,138 @@ +import type { Message } from '../ai_interface.js'; + +/** + * Interface for model capabilities information + */ +export interface ModelCapabilities { + contextWindowTokens: number; // Context window size in tokens + contextWindowChars: number; // Estimated context window size in characters (for planning) + maxCompletionTokens: number; // Maximum completion length + hasFunctionCalling: boolean; // Whether the model supports function calling + hasVision: boolean; // Whether the model supports image input + costPerInputToken: number; // Cost per input token (if applicable) + costPerOutputToken: number; // Cost per output token (if applicable) +} + +/** + * Default model capabilities for unknown models + */ +export const DEFAULT_MODEL_CAPABILITIES: ModelCapabilities = { + contextWindowTokens: 4096, + contextWindowChars: 16000, // ~4 chars per token estimate + maxCompletionTokens: 1024, + hasFunctionCalling: false, + hasVision: false, + costPerInputToken: 0, + costPerOutputToken: 0 +}; + +/** + * Model capabilities for common models + */ +export const MODEL_CAPABILITIES: Record> = { + // OpenAI models + 'gpt-3.5-turbo': { + contextWindowTokens: 4096, + contextWindowChars: 16000, + hasFunctionCalling: true + }, + 'gpt-3.5-turbo-16k': { + contextWindowTokens: 16384, + contextWindowChars: 65000, + hasFunctionCalling: true + }, + 'gpt-4': { + contextWindowTokens: 8192, + contextWindowChars: 32000, + hasFunctionCalling: true + }, + 'gpt-4-32k': { + contextWindowTokens: 32768, + contextWindowChars: 130000, + hasFunctionCalling: true + }, + 'gpt-4-turbo': { + contextWindowTokens: 128000, + contextWindowChars: 512000, + hasFunctionCalling: true, + hasVision: true + }, + 'gpt-4o': { + contextWindowTokens: 128000, + contextWindowChars: 512000, + hasFunctionCalling: true, + hasVision: true + }, + + // Anthropic models + 'claude-3-haiku': { + contextWindowTokens: 200000, + contextWindowChars: 800000, + hasVision: true + }, + 'claude-3-sonnet': { + contextWindowTokens: 200000, + contextWindowChars: 800000, + hasVision: true + }, + 'claude-3-opus': { + contextWindowTokens: 200000, + contextWindowChars: 800000, + hasVision: true + }, + 'claude-2': { + contextWindowTokens: 100000, + contextWindowChars: 400000 + }, + + // Ollama models (defaults, will be updated dynamically) + 'llama3': { + contextWindowTokens: 8192, + contextWindowChars: 32000 + }, + 'mistral': { + contextWindowTokens: 8192, + contextWindowChars: 32000 + }, + 'llama2': { + contextWindowTokens: 4096, + contextWindowChars: 16000 + } +}; + +/** + * Calculate available context window size for context generation + * This takes into account expected message sizes and other overhead + * + * @param model Model name + * @param messages Current conversation messages + * @param expectedTurns Number of expected additional conversation turns + * @returns Available context size in characters + */ +export function calculateAvailableContextSize( + modelCapabilities: ModelCapabilities, + messages: Message[], + expectedTurns: number = 3 +): number { + // Calculate current message token usage (rough estimate) + let currentMessageChars = 0; + for (const message of messages) { + currentMessageChars += message.content.length; + } + + // Reserve space for system prompt and overhead + const systemPromptReserve = 1000; + + // Reserve space for expected conversation turns + const turnReserve = expectedTurns * 2000; // Average 2000 chars per turn (including both user and assistant) + + // Calculate available space + const totalReserved = currentMessageChars + systemPromptReserve + turnReserve; + const availableContextSize = Math.max(0, modelCapabilities.contextWindowChars - totalReserved); + + // Use at most 70% of total context window size to be safe + const maxSafeContextSize = Math.floor(modelCapabilities.contextWindowChars * 0.7); + + // Return the smaller of available size or max safe size + return Math.min(availableContextSize, maxSafeContextSize); +} diff --git a/src/services/llm/model_capabilities_service.ts b/src/services/llm/model_capabilities_service.ts new file mode 100644 index 000000000..c327ebc9d --- /dev/null +++ b/src/services/llm/model_capabilities_service.ts @@ -0,0 +1,159 @@ +import log from '../log.js'; +import type { ModelCapabilities } from './interfaces/model_capabilities.js'; +import { MODEL_CAPABILITIES, DEFAULT_MODEL_CAPABILITIES } from './interfaces/model_capabilities.js'; +import aiServiceManager from './ai_service_manager.js'; +import { getEmbeddingProvider } from './providers/providers.js'; +import type { BaseEmbeddingProvider } from './embeddings/base_embeddings.js'; +import type { EmbeddingModelInfo } from './interfaces/embedding_interfaces.js'; + +// Define a type for embedding providers that might have the getModelInfo method +interface EmbeddingProviderWithModelInfo { + getModelInfo?: (modelName: string) => Promise; +} + +/** + * Service for fetching and caching model capabilities + */ +export class ModelCapabilitiesService { + // Cache model capabilities + private capabilitiesCache: Map = new Map(); + + constructor() { + // Initialize cache with known models + this.initializeCache(); + } + + /** + * Initialize the cache with known model capabilities + */ + private initializeCache() { + // Add all predefined model capabilities to cache + for (const [model, capabilities] of Object.entries(MODEL_CAPABILITIES)) { + this.capabilitiesCache.set(model, { + ...DEFAULT_MODEL_CAPABILITIES, + ...capabilities + }); + } + } + + /** + * Get model capabilities, fetching from provider if needed + * + * @param modelName Full model name (with or without provider prefix) + * @returns Model capabilities + */ + async getModelCapabilities(modelName: string): Promise { + // Handle provider-prefixed model names (e.g., "openai:gpt-4") + let provider = 'default'; + let baseModelName = modelName; + + if (modelName.includes(':')) { + const parts = modelName.split(':'); + provider = parts[0]; + baseModelName = parts[1]; + } + + // Check cache first + const cacheKey = baseModelName; + if (this.capabilitiesCache.has(cacheKey)) { + return this.capabilitiesCache.get(cacheKey)!; + } + + // Fetch from provider if possible + try { + // Get provider service + const providerService = aiServiceManager.getService(provider); + + if (providerService && typeof (providerService as any).getModelCapabilities === 'function') { + // If provider supports direct capability fetching, use it + const capabilities = await (providerService as any).getModelCapabilities(baseModelName); + + if (capabilities) { + // Merge with defaults and cache + const fullCapabilities = { + ...DEFAULT_MODEL_CAPABILITIES, + ...capabilities + }; + + this.capabilitiesCache.set(cacheKey, fullCapabilities); + log.info(`Fetched capabilities for ${modelName}: context window ${fullCapabilities.contextWindowTokens} tokens`); + + return fullCapabilities; + } + } + + // Try to fetch from embedding provider if available + const embeddingProvider = getEmbeddingProvider(provider); + + if (embeddingProvider) { + try { + // Cast to a type that might have getModelInfo method + const providerWithModelInfo = embeddingProvider as unknown as EmbeddingProviderWithModelInfo; + + if (providerWithModelInfo.getModelInfo) { + const modelInfo = await providerWithModelInfo.getModelInfo(baseModelName); + + if (modelInfo && modelInfo.contextWidth) { + // Convert to our capabilities format + const capabilities: ModelCapabilities = { + ...DEFAULT_MODEL_CAPABILITIES, + contextWindowTokens: modelInfo.contextWidth, + contextWindowChars: modelInfo.contextWidth * 4 // Rough estimate: 4 chars per token + }; + + this.capabilitiesCache.set(cacheKey, capabilities); + log.info(`Derived capabilities for ${modelName} from embedding provider: context window ${capabilities.contextWindowTokens} tokens`); + + return capabilities; + } + } + } catch (error) { + log.info(`Could not get model info from embedding provider for ${modelName}: ${error}`); + } + } + } catch (error) { + log.error(`Error fetching model capabilities for ${modelName}: ${error}`); + } + + // If we get here, try to find a similar model in our predefined list + for (const knownModel of Object.keys(MODEL_CAPABILITIES)) { + // Check if the model name contains this known model (e.g., "gpt-4-1106-preview" contains "gpt-4") + if (baseModelName.includes(knownModel)) { + const capabilities = { + ...DEFAULT_MODEL_CAPABILITIES, + ...MODEL_CAPABILITIES[knownModel] + }; + + this.capabilitiesCache.set(cacheKey, capabilities); + log.info(`Using similar model (${knownModel}) capabilities for ${modelName}`); + + return capabilities; + } + } + + // Fall back to defaults if nothing else works + log.info(`Using default capabilities for unknown model ${modelName}`); + this.capabilitiesCache.set(cacheKey, DEFAULT_MODEL_CAPABILITIES); + + return DEFAULT_MODEL_CAPABILITIES; + } + + /** + * Update model capabilities in the cache + * + * @param modelName Model name + * @param capabilities Capabilities to update + */ + updateModelCapabilities(modelName: string, capabilities: Partial) { + const currentCapabilities = this.capabilitiesCache.get(modelName) || DEFAULT_MODEL_CAPABILITIES; + + this.capabilitiesCache.set(modelName, { + ...currentCapabilities, + ...capabilities + }); + } +} + +// Create and export singleton instance +const modelCapabilitiesService = new ModelCapabilitiesService(); +export default modelCapabilitiesService; diff --git a/src/services/llm/pipeline/chat_pipeline.ts b/src/services/llm/pipeline/chat_pipeline.ts index f7e7c7432..312ba3711 100644 --- a/src/services/llm/pipeline/chat_pipeline.ts +++ b/src/services/llm/pipeline/chat_pipeline.ts @@ -106,7 +106,8 @@ export class ChatPipeline { // Get semantic context for regular queries const semanticContext = await this.stages.semanticContextExtraction.execute({ noteId: input.noteId, - query: input.query + query: input.query, + messages: input.messages }); context = semanticContext.context; this.updateStageMetrics('semanticContextExtraction', contextStartTime); @@ -136,10 +137,10 @@ export class ChatPipeline { const llmStartTime = Date.now(); // Setup streaming handler if streaming is enabled and callback provided - const enableStreaming = this.config.enableStreaming && + const enableStreaming = this.config.enableStreaming && modelSelection.options.stream !== false && typeof streamCallback === 'function'; - + if (enableStreaming) { // Make sure stream is enabled in options modelSelection.options.stream = true; @@ -157,10 +158,10 @@ export class ChatPipeline { await completion.response.stream(async (chunk: StreamChunk) => { // Process the chunk text const processedChunk = await this.processStreamChunk(chunk, input.options); - + // Accumulate text for final response accumulatedText += processedChunk.text; - + // Forward to callback await streamCallback!(processedChunk.text, processedChunk.done); }); @@ -182,12 +183,12 @@ export class ChatPipeline { const endTime = Date.now(); const executionTime = endTime - startTime; - + // Update overall average execution time - this.metrics.averageExecutionTime = + this.metrics.averageExecutionTime = (this.metrics.averageExecutionTime * (this.metrics.totalExecutions - 1) + executionTime) / this.metrics.totalExecutions; - + log.info(`Chat pipeline completed in ${executionTime}ms`); return finalResponse; @@ -235,12 +236,12 @@ export class ChatPipeline { */ private updateStageMetrics(stageName: string, startTime: number) { if (!this.config.enableMetrics) return; - + const executionTime = Date.now() - startTime; const metrics = this.metrics.stageMetrics[stageName]; - + metrics.totalExecutions++; - metrics.averageExecutionTime = + metrics.averageExecutionTime = (metrics.averageExecutionTime * (metrics.totalExecutions - 1) + executionTime) / metrics.totalExecutions; } @@ -258,7 +259,7 @@ export class ChatPipeline { resetMetrics(): void { this.metrics.totalExecutions = 0; this.metrics.averageExecutionTime = 0; - + Object.keys(this.metrics.stageMetrics).forEach(stageName => { this.metrics.stageMetrics[stageName] = { totalExecutions: 0, diff --git a/src/services/llm/pipeline/interfaces.ts b/src/services/llm/pipeline/interfaces.ts index 097883e55..0d85c3939 100644 --- a/src/services/llm/pipeline/interfaces.ts +++ b/src/services/llm/pipeline/interfaces.ts @@ -15,12 +15,12 @@ export interface ChatPipelineConfig { * Whether to enable streaming support */ enableStreaming: boolean; - + /** * Whether to enable performance metrics */ enableMetrics: boolean; - + /** * Maximum number of tool call iterations */ @@ -84,6 +84,7 @@ export interface SemanticContextExtractionInput extends PipelineInput { noteId: string; query: string; maxResults?: number; + messages?: Message[]; } /** diff --git a/src/services/llm/pipeline/stages/semantic_context_extraction_stage.ts b/src/services/llm/pipeline/stages/semantic_context_extraction_stage.ts index d1317a960..d74466ddc 100644 --- a/src/services/llm/pipeline/stages/semantic_context_extraction_stage.ts +++ b/src/services/llm/pipeline/stages/semantic_context_extraction_stage.ts @@ -15,11 +15,11 @@ export class SemanticContextExtractionStage extends BasePipelineStage { - const { noteId, query, maxResults = 5 } = input; + const { noteId, query, maxResults = 5, messages = [] } = input; log.info(`Extracting semantic context from note ${noteId}, query: ${query?.substring(0, 50)}...`); const contextService = aiServiceManager.getContextService(); - const context = await contextService.getSemanticContext(noteId, query, maxResults); + const context = await contextService.getSemanticContext(noteId, query, maxResults, messages); return { context }; }