dynamically adjust context window sizes based on conversation context

This commit is contained in:
perf3ct 2025-03-30 22:13:40 +00:00
parent 29845c343c
commit f2cb013e14
No known key found for this signature in database
GPG Key ID: 569C4EEC436F5232
10 changed files with 388 additions and 38 deletions

View File

@ -69,7 +69,7 @@ export interface SemanticContextService {
/**
* Retrieve semantic context based on relevance to user query
*/
getSemanticContext(noteId: string, userQuery: string, maxResults?: number): Promise<string>;
getSemanticContext(noteId: string, userQuery: string, maxResults?: number, messages?: Message[]): Promise<string>;
/**
* Get progressive context based on depth

View File

@ -3,6 +3,9 @@ import log from '../../../log.js';
import { CONTEXT_PROMPTS, FORMATTING_PROMPTS } from '../../constants/llm_prompt_constants.js';
import { LLM_CONSTANTS } from '../../constants/provider_constants.js';
import type { IContextFormatter, NoteSearchResult } from '../../interfaces/context_interfaces.js';
import modelCapabilitiesService from '../../model_capabilities_service.js';
import { calculateAvailableContextSize } from '../../interfaces/model_capabilities.js';
import type { Message } from '../../ai_interface.js';
// Use constants from the centralized file
// const CONTEXT_WINDOW = {
@ -20,26 +23,46 @@ import type { IContextFormatter, NoteSearchResult } from '../../interfaces/conte
*/
export class ContextFormatter implements IContextFormatter {
/**
* Build a structured context string from note sources
* Build formatted context from a list of note search results
*
* @param sources Array of note data with content and metadata
* @param query The user's query for context
* @param providerId Optional provider ID to customize formatting
* @param messages Optional conversation messages to adjust context size
* @returns Formatted context string
*/
async buildContextFromNotes(sources: NoteSearchResult[], query: string, providerId: string = 'default'): Promise<string> {
async buildContextFromNotes(
sources: NoteSearchResult[],
query: string,
providerId: string = 'default',
messages: Message[] = []
): Promise<string> {
if (!sources || sources.length === 0) {
log.info('No sources provided to context formatter');
return CONTEXT_PROMPTS.NO_NOTES_CONTEXT;
}
try {
// Get appropriate context size based on provider
const maxTotalLength =
// Get model name from provider
let modelName = providerId;
// Look up model capabilities
const modelCapabilities = await modelCapabilitiesService.getModelCapabilities(modelName);
// Calculate available context size for this conversation
const availableContextSize = calculateAvailableContextSize(
modelCapabilities,
messages,
3 // Expected additional turns
);
// Use the calculated size or fall back to constants
const maxTotalLength = availableContextSize || (
providerId === 'openai' ? LLM_CONSTANTS.CONTEXT_WINDOW.OPENAI :
providerId === 'anthropic' ? LLM_CONSTANTS.CONTEXT_WINDOW.ANTHROPIC :
providerId === 'ollama' ? LLM_CONSTANTS.CONTEXT_WINDOW.OLLAMA :
LLM_CONSTANTS.CONTEXT_WINDOW.DEFAULT;
LLM_CONSTANTS.CONTEXT_WINDOW.DEFAULT
);
// DEBUG: Log context window size
log.info(`Context window for provider ${providerId}: ${maxTotalLength} chars`);

View File

@ -10,6 +10,7 @@ import { CONTEXT_PROMPTS } from '../../constants/llm_prompt_constants.js';
import becca from '../../../../becca/becca.js';
import type { NoteSearchResult } from '../../interfaces/context_interfaces.js';
import type { LLMServiceInterface } from '../../interfaces/agent_tool_interfaces.js';
import type { Message } from '../../ai_interface.js';
/**
* Main context service that integrates all context-related functionality
@ -635,14 +636,20 @@ export class ContextService {
}
/**
* Get semantic context for a note and query
* Get semantic context based on query
*
* @param noteId - The base note ID
* @param userQuery - The user's query
* @param maxResults - Maximum number of results to include
* @returns Formatted context string
* @param noteId - Note ID to start from
* @param userQuery - User query for context
* @param maxResults - Maximum number of results
* @param messages - Optional conversation messages to adjust context size
* @returns Formatted context
*/
async getSemanticContext(noteId: string, userQuery: string, maxResults: number = 5): Promise<string> {
async getSemanticContext(
noteId: string,
userQuery: string,
maxResults: number = 5,
messages: Message[] = []
): Promise<string> {
if (!this.initialized) {
await this.initialize();
}
@ -712,24 +719,39 @@ export class ContextService {
// Get content for the top N most relevant notes
const mostRelevantNotes = rankedNotes.slice(0, maxResults);
const relevantContent = await Promise.all(
// Get relevant search results to pass to context formatter
const searchResults = await Promise.all(
mostRelevantNotes.map(async note => {
const content = await this.contextExtractor.getNoteContent(note.noteId);
if (!content) return null;
// Format with relevance score and title
return `### ${note.title} (Relevance: ${Math.round(note.relevance * 100)}%)\n\n${content}`;
// Create a properly typed NoteSearchResult object
return {
noteId: note.noteId,
title: note.title,
content,
similarity: note.relevance
};
})
);
// Filter out nulls and empty content
const validResults: NoteSearchResult[] = searchResults
.filter(result => result !== null && result.content && result.content.trim().length > 0)
.map(result => result as NoteSearchResult);
// If no content retrieved, return empty string
if (!relevantContent.filter(Boolean).length) {
if (validResults.length === 0) {
return '';
}
return `# Relevant Context\n\nThe following notes are most relevant to your query:\n\n${
relevantContent.filter(Boolean).join('\n\n---\n\n')
}`;
// Get the provider information for formatting
const provider = await providerManager.getPreferredEmbeddingProvider();
const providerId = provider?.name || 'default';
// Format the context with the context formatter (which handles adjusting for conversation size)
return contextFormatter.buildContextFromNotes(validResults, userQuery, providerId, messages);
} catch (error) {
log.error(`Error getting semantic context: ${error}`);
return '';

View File

@ -154,10 +154,11 @@ class TriliumContextService {
* @param noteId - The note ID
* @param userQuery - The user's query
* @param maxResults - Maximum results to include
* @param messages - Optional conversation messages to adjust context size
* @returns Formatted context string
*/
async getSemanticContext(noteId: string, userQuery: string, maxResults = 5): Promise<string> {
return contextService.getSemanticContext(noteId, userQuery, maxResults);
async getSemanticContext(noteId: string, userQuery: string, maxResults = 5, messages: Message[] = []): Promise<string> {
return contextService.getSemanticContext(noteId, userQuery, maxResults, messages);
}
/**

View File

@ -46,7 +46,12 @@ export interface NoteSearchResult {
* Interface for context formatter
*/
export interface IContextFormatter {
buildContextFromNotes(sources: NoteSearchResult[], query: string, providerId?: string): Promise<string>;
buildContextFromNotes(
sources: NoteSearchResult[],
query: string,
providerId?: string,
messages?: Array<{role: string, content: string}>
): Promise<string>;
}
/**

View File

@ -0,0 +1,138 @@
import type { Message } from '../ai_interface.js';
/**
* Interface for model capabilities information
*/
export interface ModelCapabilities {
contextWindowTokens: number; // Context window size in tokens
contextWindowChars: number; // Estimated context window size in characters (for planning)
maxCompletionTokens: number; // Maximum completion length
hasFunctionCalling: boolean; // Whether the model supports function calling
hasVision: boolean; // Whether the model supports image input
costPerInputToken: number; // Cost per input token (if applicable)
costPerOutputToken: number; // Cost per output token (if applicable)
}
/**
* Default model capabilities for unknown models
*/
export const DEFAULT_MODEL_CAPABILITIES: ModelCapabilities = {
contextWindowTokens: 4096,
contextWindowChars: 16000, // ~4 chars per token estimate
maxCompletionTokens: 1024,
hasFunctionCalling: false,
hasVision: false,
costPerInputToken: 0,
costPerOutputToken: 0
};
/**
* Model capabilities for common models
*/
export const MODEL_CAPABILITIES: Record<string, Partial<ModelCapabilities>> = {
// OpenAI models
'gpt-3.5-turbo': {
contextWindowTokens: 4096,
contextWindowChars: 16000,
hasFunctionCalling: true
},
'gpt-3.5-turbo-16k': {
contextWindowTokens: 16384,
contextWindowChars: 65000,
hasFunctionCalling: true
},
'gpt-4': {
contextWindowTokens: 8192,
contextWindowChars: 32000,
hasFunctionCalling: true
},
'gpt-4-32k': {
contextWindowTokens: 32768,
contextWindowChars: 130000,
hasFunctionCalling: true
},
'gpt-4-turbo': {
contextWindowTokens: 128000,
contextWindowChars: 512000,
hasFunctionCalling: true,
hasVision: true
},
'gpt-4o': {
contextWindowTokens: 128000,
contextWindowChars: 512000,
hasFunctionCalling: true,
hasVision: true
},
// Anthropic models
'claude-3-haiku': {
contextWindowTokens: 200000,
contextWindowChars: 800000,
hasVision: true
},
'claude-3-sonnet': {
contextWindowTokens: 200000,
contextWindowChars: 800000,
hasVision: true
},
'claude-3-opus': {
contextWindowTokens: 200000,
contextWindowChars: 800000,
hasVision: true
},
'claude-2': {
contextWindowTokens: 100000,
contextWindowChars: 400000
},
// Ollama models (defaults, will be updated dynamically)
'llama3': {
contextWindowTokens: 8192,
contextWindowChars: 32000
},
'mistral': {
contextWindowTokens: 8192,
contextWindowChars: 32000
},
'llama2': {
contextWindowTokens: 4096,
contextWindowChars: 16000
}
};
/**
* Calculate available context window size for context generation
* This takes into account expected message sizes and other overhead
*
* @param model Model name
* @param messages Current conversation messages
* @param expectedTurns Number of expected additional conversation turns
* @returns Available context size in characters
*/
export function calculateAvailableContextSize(
modelCapabilities: ModelCapabilities,
messages: Message[],
expectedTurns: number = 3
): number {
// Calculate current message token usage (rough estimate)
let currentMessageChars = 0;
for (const message of messages) {
currentMessageChars += message.content.length;
}
// Reserve space for system prompt and overhead
const systemPromptReserve = 1000;
// Reserve space for expected conversation turns
const turnReserve = expectedTurns * 2000; // Average 2000 chars per turn (including both user and assistant)
// Calculate available space
const totalReserved = currentMessageChars + systemPromptReserve + turnReserve;
const availableContextSize = Math.max(0, modelCapabilities.contextWindowChars - totalReserved);
// Use at most 70% of total context window size to be safe
const maxSafeContextSize = Math.floor(modelCapabilities.contextWindowChars * 0.7);
// Return the smaller of available size or max safe size
return Math.min(availableContextSize, maxSafeContextSize);
}

View File

@ -0,0 +1,159 @@
import log from '../log.js';
import type { ModelCapabilities } from './interfaces/model_capabilities.js';
import { MODEL_CAPABILITIES, DEFAULT_MODEL_CAPABILITIES } from './interfaces/model_capabilities.js';
import aiServiceManager from './ai_service_manager.js';
import { getEmbeddingProvider } from './providers/providers.js';
import type { BaseEmbeddingProvider } from './embeddings/base_embeddings.js';
import type { EmbeddingModelInfo } from './interfaces/embedding_interfaces.js';
// Define a type for embedding providers that might have the getModelInfo method
interface EmbeddingProviderWithModelInfo {
getModelInfo?: (modelName: string) => Promise<EmbeddingModelInfo>;
}
/**
* Service for fetching and caching model capabilities
*/
export class ModelCapabilitiesService {
// Cache model capabilities
private capabilitiesCache: Map<string, ModelCapabilities> = new Map();
constructor() {
// Initialize cache with known models
this.initializeCache();
}
/**
* Initialize the cache with known model capabilities
*/
private initializeCache() {
// Add all predefined model capabilities to cache
for (const [model, capabilities] of Object.entries(MODEL_CAPABILITIES)) {
this.capabilitiesCache.set(model, {
...DEFAULT_MODEL_CAPABILITIES,
...capabilities
});
}
}
/**
* Get model capabilities, fetching from provider if needed
*
* @param modelName Full model name (with or without provider prefix)
* @returns Model capabilities
*/
async getModelCapabilities(modelName: string): Promise<ModelCapabilities> {
// Handle provider-prefixed model names (e.g., "openai:gpt-4")
let provider = 'default';
let baseModelName = modelName;
if (modelName.includes(':')) {
const parts = modelName.split(':');
provider = parts[0];
baseModelName = parts[1];
}
// Check cache first
const cacheKey = baseModelName;
if (this.capabilitiesCache.has(cacheKey)) {
return this.capabilitiesCache.get(cacheKey)!;
}
// Fetch from provider if possible
try {
// Get provider service
const providerService = aiServiceManager.getService(provider);
if (providerService && typeof (providerService as any).getModelCapabilities === 'function') {
// If provider supports direct capability fetching, use it
const capabilities = await (providerService as any).getModelCapabilities(baseModelName);
if (capabilities) {
// Merge with defaults and cache
const fullCapabilities = {
...DEFAULT_MODEL_CAPABILITIES,
...capabilities
};
this.capabilitiesCache.set(cacheKey, fullCapabilities);
log.info(`Fetched capabilities for ${modelName}: context window ${fullCapabilities.contextWindowTokens} tokens`);
return fullCapabilities;
}
}
// Try to fetch from embedding provider if available
const embeddingProvider = getEmbeddingProvider(provider);
if (embeddingProvider) {
try {
// Cast to a type that might have getModelInfo method
const providerWithModelInfo = embeddingProvider as unknown as EmbeddingProviderWithModelInfo;
if (providerWithModelInfo.getModelInfo) {
const modelInfo = await providerWithModelInfo.getModelInfo(baseModelName);
if (modelInfo && modelInfo.contextWidth) {
// Convert to our capabilities format
const capabilities: ModelCapabilities = {
...DEFAULT_MODEL_CAPABILITIES,
contextWindowTokens: modelInfo.contextWidth,
contextWindowChars: modelInfo.contextWidth * 4 // Rough estimate: 4 chars per token
};
this.capabilitiesCache.set(cacheKey, capabilities);
log.info(`Derived capabilities for ${modelName} from embedding provider: context window ${capabilities.contextWindowTokens} tokens`);
return capabilities;
}
}
} catch (error) {
log.info(`Could not get model info from embedding provider for ${modelName}: ${error}`);
}
}
} catch (error) {
log.error(`Error fetching model capabilities for ${modelName}: ${error}`);
}
// If we get here, try to find a similar model in our predefined list
for (const knownModel of Object.keys(MODEL_CAPABILITIES)) {
// Check if the model name contains this known model (e.g., "gpt-4-1106-preview" contains "gpt-4")
if (baseModelName.includes(knownModel)) {
const capabilities = {
...DEFAULT_MODEL_CAPABILITIES,
...MODEL_CAPABILITIES[knownModel]
};
this.capabilitiesCache.set(cacheKey, capabilities);
log.info(`Using similar model (${knownModel}) capabilities for ${modelName}`);
return capabilities;
}
}
// Fall back to defaults if nothing else works
log.info(`Using default capabilities for unknown model ${modelName}`);
this.capabilitiesCache.set(cacheKey, DEFAULT_MODEL_CAPABILITIES);
return DEFAULT_MODEL_CAPABILITIES;
}
/**
* Update model capabilities in the cache
*
* @param modelName Model name
* @param capabilities Capabilities to update
*/
updateModelCapabilities(modelName: string, capabilities: Partial<ModelCapabilities>) {
const currentCapabilities = this.capabilitiesCache.get(modelName) || DEFAULT_MODEL_CAPABILITIES;
this.capabilitiesCache.set(modelName, {
...currentCapabilities,
...capabilities
});
}
}
// Create and export singleton instance
const modelCapabilitiesService = new ModelCapabilitiesService();
export default modelCapabilitiesService;

View File

@ -106,7 +106,8 @@ export class ChatPipeline {
// Get semantic context for regular queries
const semanticContext = await this.stages.semanticContextExtraction.execute({
noteId: input.noteId,
query: input.query
query: input.query,
messages: input.messages
});
context = semanticContext.context;
this.updateStageMetrics('semanticContextExtraction', contextStartTime);
@ -136,10 +137,10 @@ export class ChatPipeline {
const llmStartTime = Date.now();
// Setup streaming handler if streaming is enabled and callback provided
const enableStreaming = this.config.enableStreaming &&
const enableStreaming = this.config.enableStreaming &&
modelSelection.options.stream !== false &&
typeof streamCallback === 'function';
if (enableStreaming) {
// Make sure stream is enabled in options
modelSelection.options.stream = true;
@ -157,10 +158,10 @@ export class ChatPipeline {
await completion.response.stream(async (chunk: StreamChunk) => {
// Process the chunk text
const processedChunk = await this.processStreamChunk(chunk, input.options);
// Accumulate text for final response
accumulatedText += processedChunk.text;
// Forward to callback
await streamCallback!(processedChunk.text, processedChunk.done);
});
@ -182,12 +183,12 @@ export class ChatPipeline {
const endTime = Date.now();
const executionTime = endTime - startTime;
// Update overall average execution time
this.metrics.averageExecutionTime =
this.metrics.averageExecutionTime =
(this.metrics.averageExecutionTime * (this.metrics.totalExecutions - 1) + executionTime) /
this.metrics.totalExecutions;
log.info(`Chat pipeline completed in ${executionTime}ms`);
return finalResponse;
@ -235,12 +236,12 @@ export class ChatPipeline {
*/
private updateStageMetrics(stageName: string, startTime: number) {
if (!this.config.enableMetrics) return;
const executionTime = Date.now() - startTime;
const metrics = this.metrics.stageMetrics[stageName];
metrics.totalExecutions++;
metrics.averageExecutionTime =
metrics.averageExecutionTime =
(metrics.averageExecutionTime * (metrics.totalExecutions - 1) + executionTime) /
metrics.totalExecutions;
}
@ -258,7 +259,7 @@ export class ChatPipeline {
resetMetrics(): void {
this.metrics.totalExecutions = 0;
this.metrics.averageExecutionTime = 0;
Object.keys(this.metrics.stageMetrics).forEach(stageName => {
this.metrics.stageMetrics[stageName] = {
totalExecutions: 0,

View File

@ -15,12 +15,12 @@ export interface ChatPipelineConfig {
* Whether to enable streaming support
*/
enableStreaming: boolean;
/**
* Whether to enable performance metrics
*/
enableMetrics: boolean;
/**
* Maximum number of tool call iterations
*/
@ -84,6 +84,7 @@ export interface SemanticContextExtractionInput extends PipelineInput {
noteId: string;
query: string;
maxResults?: number;
messages?: Message[];
}
/**

View File

@ -15,11 +15,11 @@ export class SemanticContextExtractionStage extends BasePipelineStage<SemanticCo
* Extract semantic context based on a query
*/
protected async process(input: SemanticContextExtractionInput): Promise<{ context: string }> {
const { noteId, query, maxResults = 5 } = input;
const { noteId, query, maxResults = 5, messages = [] } = input;
log.info(`Extracting semantic context from note ${noteId}, query: ${query?.substring(0, 50)}...`);
const contextService = aiServiceManager.getContextService();
const context = await contextService.getSemanticContext(noteId, query, maxResults);
const context = await contextService.getSemanticContext(noteId, query, maxResults, messages);
return { context };
}