mirror of
https://github.com/TriliumNext/Notes.git
synced 2025-07-29 19:12:27 +08:00
dynamically adjust context window sizes based on conversation context
This commit is contained in:
parent
29845c343c
commit
f2cb013e14
@ -69,7 +69,7 @@ export interface SemanticContextService {
|
|||||||
/**
|
/**
|
||||||
* Retrieve semantic context based on relevance to user query
|
* Retrieve semantic context based on relevance to user query
|
||||||
*/
|
*/
|
||||||
getSemanticContext(noteId: string, userQuery: string, maxResults?: number): Promise<string>;
|
getSemanticContext(noteId: string, userQuery: string, maxResults?: number, messages?: Message[]): Promise<string>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get progressive context based on depth
|
* Get progressive context based on depth
|
||||||
|
@ -3,6 +3,9 @@ import log from '../../../log.js';
|
|||||||
import { CONTEXT_PROMPTS, FORMATTING_PROMPTS } from '../../constants/llm_prompt_constants.js';
|
import { CONTEXT_PROMPTS, FORMATTING_PROMPTS } from '../../constants/llm_prompt_constants.js';
|
||||||
import { LLM_CONSTANTS } from '../../constants/provider_constants.js';
|
import { LLM_CONSTANTS } from '../../constants/provider_constants.js';
|
||||||
import type { IContextFormatter, NoteSearchResult } from '../../interfaces/context_interfaces.js';
|
import type { IContextFormatter, NoteSearchResult } from '../../interfaces/context_interfaces.js';
|
||||||
|
import modelCapabilitiesService from '../../model_capabilities_service.js';
|
||||||
|
import { calculateAvailableContextSize } from '../../interfaces/model_capabilities.js';
|
||||||
|
import type { Message } from '../../ai_interface.js';
|
||||||
|
|
||||||
// Use constants from the centralized file
|
// Use constants from the centralized file
|
||||||
// const CONTEXT_WINDOW = {
|
// const CONTEXT_WINDOW = {
|
||||||
@ -20,26 +23,46 @@ import type { IContextFormatter, NoteSearchResult } from '../../interfaces/conte
|
|||||||
*/
|
*/
|
||||||
export class ContextFormatter implements IContextFormatter {
|
export class ContextFormatter implements IContextFormatter {
|
||||||
/**
|
/**
|
||||||
* Build a structured context string from note sources
|
* Build formatted context from a list of note search results
|
||||||
*
|
*
|
||||||
* @param sources Array of note data with content and metadata
|
* @param sources Array of note data with content and metadata
|
||||||
* @param query The user's query for context
|
* @param query The user's query for context
|
||||||
* @param providerId Optional provider ID to customize formatting
|
* @param providerId Optional provider ID to customize formatting
|
||||||
|
* @param messages Optional conversation messages to adjust context size
|
||||||
* @returns Formatted context string
|
* @returns Formatted context string
|
||||||
*/
|
*/
|
||||||
async buildContextFromNotes(sources: NoteSearchResult[], query: string, providerId: string = 'default'): Promise<string> {
|
async buildContextFromNotes(
|
||||||
|
sources: NoteSearchResult[],
|
||||||
|
query: string,
|
||||||
|
providerId: string = 'default',
|
||||||
|
messages: Message[] = []
|
||||||
|
): Promise<string> {
|
||||||
if (!sources || sources.length === 0) {
|
if (!sources || sources.length === 0) {
|
||||||
log.info('No sources provided to context formatter');
|
log.info('No sources provided to context formatter');
|
||||||
return CONTEXT_PROMPTS.NO_NOTES_CONTEXT;
|
return CONTEXT_PROMPTS.NO_NOTES_CONTEXT;
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Get appropriate context size based on provider
|
// Get model name from provider
|
||||||
const maxTotalLength =
|
let modelName = providerId;
|
||||||
|
|
||||||
|
// Look up model capabilities
|
||||||
|
const modelCapabilities = await modelCapabilitiesService.getModelCapabilities(modelName);
|
||||||
|
|
||||||
|
// Calculate available context size for this conversation
|
||||||
|
const availableContextSize = calculateAvailableContextSize(
|
||||||
|
modelCapabilities,
|
||||||
|
messages,
|
||||||
|
3 // Expected additional turns
|
||||||
|
);
|
||||||
|
|
||||||
|
// Use the calculated size or fall back to constants
|
||||||
|
const maxTotalLength = availableContextSize || (
|
||||||
providerId === 'openai' ? LLM_CONSTANTS.CONTEXT_WINDOW.OPENAI :
|
providerId === 'openai' ? LLM_CONSTANTS.CONTEXT_WINDOW.OPENAI :
|
||||||
providerId === 'anthropic' ? LLM_CONSTANTS.CONTEXT_WINDOW.ANTHROPIC :
|
providerId === 'anthropic' ? LLM_CONSTANTS.CONTEXT_WINDOW.ANTHROPIC :
|
||||||
providerId === 'ollama' ? LLM_CONSTANTS.CONTEXT_WINDOW.OLLAMA :
|
providerId === 'ollama' ? LLM_CONSTANTS.CONTEXT_WINDOW.OLLAMA :
|
||||||
LLM_CONSTANTS.CONTEXT_WINDOW.DEFAULT;
|
LLM_CONSTANTS.CONTEXT_WINDOW.DEFAULT
|
||||||
|
);
|
||||||
|
|
||||||
// DEBUG: Log context window size
|
// DEBUG: Log context window size
|
||||||
log.info(`Context window for provider ${providerId}: ${maxTotalLength} chars`);
|
log.info(`Context window for provider ${providerId}: ${maxTotalLength} chars`);
|
||||||
|
@ -10,6 +10,7 @@ import { CONTEXT_PROMPTS } from '../../constants/llm_prompt_constants.js';
|
|||||||
import becca from '../../../../becca/becca.js';
|
import becca from '../../../../becca/becca.js';
|
||||||
import type { NoteSearchResult } from '../../interfaces/context_interfaces.js';
|
import type { NoteSearchResult } from '../../interfaces/context_interfaces.js';
|
||||||
import type { LLMServiceInterface } from '../../interfaces/agent_tool_interfaces.js';
|
import type { LLMServiceInterface } from '../../interfaces/agent_tool_interfaces.js';
|
||||||
|
import type { Message } from '../../ai_interface.js';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Main context service that integrates all context-related functionality
|
* Main context service that integrates all context-related functionality
|
||||||
@ -635,14 +636,20 @@ export class ContextService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get semantic context for a note and query
|
* Get semantic context based on query
|
||||||
*
|
*
|
||||||
* @param noteId - The base note ID
|
* @param noteId - Note ID to start from
|
||||||
* @param userQuery - The user's query
|
* @param userQuery - User query for context
|
||||||
* @param maxResults - Maximum number of results to include
|
* @param maxResults - Maximum number of results
|
||||||
* @returns Formatted context string
|
* @param messages - Optional conversation messages to adjust context size
|
||||||
|
* @returns Formatted context
|
||||||
*/
|
*/
|
||||||
async getSemanticContext(noteId: string, userQuery: string, maxResults: number = 5): Promise<string> {
|
async getSemanticContext(
|
||||||
|
noteId: string,
|
||||||
|
userQuery: string,
|
||||||
|
maxResults: number = 5,
|
||||||
|
messages: Message[] = []
|
||||||
|
): Promise<string> {
|
||||||
if (!this.initialized) {
|
if (!this.initialized) {
|
||||||
await this.initialize();
|
await this.initialize();
|
||||||
}
|
}
|
||||||
@ -712,24 +719,39 @@ export class ContextService {
|
|||||||
|
|
||||||
// Get content for the top N most relevant notes
|
// Get content for the top N most relevant notes
|
||||||
const mostRelevantNotes = rankedNotes.slice(0, maxResults);
|
const mostRelevantNotes = rankedNotes.slice(0, maxResults);
|
||||||
const relevantContent = await Promise.all(
|
|
||||||
|
// Get relevant search results to pass to context formatter
|
||||||
|
const searchResults = await Promise.all(
|
||||||
mostRelevantNotes.map(async note => {
|
mostRelevantNotes.map(async note => {
|
||||||
const content = await this.contextExtractor.getNoteContent(note.noteId);
|
const content = await this.contextExtractor.getNoteContent(note.noteId);
|
||||||
if (!content) return null;
|
if (!content) return null;
|
||||||
|
|
||||||
// Format with relevance score and title
|
// Create a properly typed NoteSearchResult object
|
||||||
return `### ${note.title} (Relevance: ${Math.round(note.relevance * 100)}%)\n\n${content}`;
|
return {
|
||||||
|
noteId: note.noteId,
|
||||||
|
title: note.title,
|
||||||
|
content,
|
||||||
|
similarity: note.relevance
|
||||||
|
};
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Filter out nulls and empty content
|
||||||
|
const validResults: NoteSearchResult[] = searchResults
|
||||||
|
.filter(result => result !== null && result.content && result.content.trim().length > 0)
|
||||||
|
.map(result => result as NoteSearchResult);
|
||||||
|
|
||||||
// If no content retrieved, return empty string
|
// If no content retrieved, return empty string
|
||||||
if (!relevantContent.filter(Boolean).length) {
|
if (validResults.length === 0) {
|
||||||
return '';
|
return '';
|
||||||
}
|
}
|
||||||
|
|
||||||
return `# Relevant Context\n\nThe following notes are most relevant to your query:\n\n${
|
// Get the provider information for formatting
|
||||||
relevantContent.filter(Boolean).join('\n\n---\n\n')
|
const provider = await providerManager.getPreferredEmbeddingProvider();
|
||||||
}`;
|
const providerId = provider?.name || 'default';
|
||||||
|
|
||||||
|
// Format the context with the context formatter (which handles adjusting for conversation size)
|
||||||
|
return contextFormatter.buildContextFromNotes(validResults, userQuery, providerId, messages);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
log.error(`Error getting semantic context: ${error}`);
|
log.error(`Error getting semantic context: ${error}`);
|
||||||
return '';
|
return '';
|
||||||
|
@ -154,10 +154,11 @@ class TriliumContextService {
|
|||||||
* @param noteId - The note ID
|
* @param noteId - The note ID
|
||||||
* @param userQuery - The user's query
|
* @param userQuery - The user's query
|
||||||
* @param maxResults - Maximum results to include
|
* @param maxResults - Maximum results to include
|
||||||
|
* @param messages - Optional conversation messages to adjust context size
|
||||||
* @returns Formatted context string
|
* @returns Formatted context string
|
||||||
*/
|
*/
|
||||||
async getSemanticContext(noteId: string, userQuery: string, maxResults = 5): Promise<string> {
|
async getSemanticContext(noteId: string, userQuery: string, maxResults = 5, messages: Message[] = []): Promise<string> {
|
||||||
return contextService.getSemanticContext(noteId, userQuery, maxResults);
|
return contextService.getSemanticContext(noteId, userQuery, maxResults, messages);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -46,7 +46,12 @@ export interface NoteSearchResult {
|
|||||||
* Interface for context formatter
|
* Interface for context formatter
|
||||||
*/
|
*/
|
||||||
export interface IContextFormatter {
|
export interface IContextFormatter {
|
||||||
buildContextFromNotes(sources: NoteSearchResult[], query: string, providerId?: string): Promise<string>;
|
buildContextFromNotes(
|
||||||
|
sources: NoteSearchResult[],
|
||||||
|
query: string,
|
||||||
|
providerId?: string,
|
||||||
|
messages?: Array<{role: string, content: string}>
|
||||||
|
): Promise<string>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
138
src/services/llm/interfaces/model_capabilities.ts
Normal file
138
src/services/llm/interfaces/model_capabilities.ts
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
import type { Message } from '../ai_interface.js';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interface for model capabilities information
|
||||||
|
*/
|
||||||
|
export interface ModelCapabilities {
|
||||||
|
contextWindowTokens: number; // Context window size in tokens
|
||||||
|
contextWindowChars: number; // Estimated context window size in characters (for planning)
|
||||||
|
maxCompletionTokens: number; // Maximum completion length
|
||||||
|
hasFunctionCalling: boolean; // Whether the model supports function calling
|
||||||
|
hasVision: boolean; // Whether the model supports image input
|
||||||
|
costPerInputToken: number; // Cost per input token (if applicable)
|
||||||
|
costPerOutputToken: number; // Cost per output token (if applicable)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Default model capabilities for unknown models
|
||||||
|
*/
|
||||||
|
export const DEFAULT_MODEL_CAPABILITIES: ModelCapabilities = {
|
||||||
|
contextWindowTokens: 4096,
|
||||||
|
contextWindowChars: 16000, // ~4 chars per token estimate
|
||||||
|
maxCompletionTokens: 1024,
|
||||||
|
hasFunctionCalling: false,
|
||||||
|
hasVision: false,
|
||||||
|
costPerInputToken: 0,
|
||||||
|
costPerOutputToken: 0
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Model capabilities for common models
|
||||||
|
*/
|
||||||
|
export const MODEL_CAPABILITIES: Record<string, Partial<ModelCapabilities>> = {
|
||||||
|
// OpenAI models
|
||||||
|
'gpt-3.5-turbo': {
|
||||||
|
contextWindowTokens: 4096,
|
||||||
|
contextWindowChars: 16000,
|
||||||
|
hasFunctionCalling: true
|
||||||
|
},
|
||||||
|
'gpt-3.5-turbo-16k': {
|
||||||
|
contextWindowTokens: 16384,
|
||||||
|
contextWindowChars: 65000,
|
||||||
|
hasFunctionCalling: true
|
||||||
|
},
|
||||||
|
'gpt-4': {
|
||||||
|
contextWindowTokens: 8192,
|
||||||
|
contextWindowChars: 32000,
|
||||||
|
hasFunctionCalling: true
|
||||||
|
},
|
||||||
|
'gpt-4-32k': {
|
||||||
|
contextWindowTokens: 32768,
|
||||||
|
contextWindowChars: 130000,
|
||||||
|
hasFunctionCalling: true
|
||||||
|
},
|
||||||
|
'gpt-4-turbo': {
|
||||||
|
contextWindowTokens: 128000,
|
||||||
|
contextWindowChars: 512000,
|
||||||
|
hasFunctionCalling: true,
|
||||||
|
hasVision: true
|
||||||
|
},
|
||||||
|
'gpt-4o': {
|
||||||
|
contextWindowTokens: 128000,
|
||||||
|
contextWindowChars: 512000,
|
||||||
|
hasFunctionCalling: true,
|
||||||
|
hasVision: true
|
||||||
|
},
|
||||||
|
|
||||||
|
// Anthropic models
|
||||||
|
'claude-3-haiku': {
|
||||||
|
contextWindowTokens: 200000,
|
||||||
|
contextWindowChars: 800000,
|
||||||
|
hasVision: true
|
||||||
|
},
|
||||||
|
'claude-3-sonnet': {
|
||||||
|
contextWindowTokens: 200000,
|
||||||
|
contextWindowChars: 800000,
|
||||||
|
hasVision: true
|
||||||
|
},
|
||||||
|
'claude-3-opus': {
|
||||||
|
contextWindowTokens: 200000,
|
||||||
|
contextWindowChars: 800000,
|
||||||
|
hasVision: true
|
||||||
|
},
|
||||||
|
'claude-2': {
|
||||||
|
contextWindowTokens: 100000,
|
||||||
|
contextWindowChars: 400000
|
||||||
|
},
|
||||||
|
|
||||||
|
// Ollama models (defaults, will be updated dynamically)
|
||||||
|
'llama3': {
|
||||||
|
contextWindowTokens: 8192,
|
||||||
|
contextWindowChars: 32000
|
||||||
|
},
|
||||||
|
'mistral': {
|
||||||
|
contextWindowTokens: 8192,
|
||||||
|
contextWindowChars: 32000
|
||||||
|
},
|
||||||
|
'llama2': {
|
||||||
|
contextWindowTokens: 4096,
|
||||||
|
contextWindowChars: 16000
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate available context window size for context generation
|
||||||
|
* This takes into account expected message sizes and other overhead
|
||||||
|
*
|
||||||
|
* @param model Model name
|
||||||
|
* @param messages Current conversation messages
|
||||||
|
* @param expectedTurns Number of expected additional conversation turns
|
||||||
|
* @returns Available context size in characters
|
||||||
|
*/
|
||||||
|
export function calculateAvailableContextSize(
|
||||||
|
modelCapabilities: ModelCapabilities,
|
||||||
|
messages: Message[],
|
||||||
|
expectedTurns: number = 3
|
||||||
|
): number {
|
||||||
|
// Calculate current message token usage (rough estimate)
|
||||||
|
let currentMessageChars = 0;
|
||||||
|
for (const message of messages) {
|
||||||
|
currentMessageChars += message.content.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reserve space for system prompt and overhead
|
||||||
|
const systemPromptReserve = 1000;
|
||||||
|
|
||||||
|
// Reserve space for expected conversation turns
|
||||||
|
const turnReserve = expectedTurns * 2000; // Average 2000 chars per turn (including both user and assistant)
|
||||||
|
|
||||||
|
// Calculate available space
|
||||||
|
const totalReserved = currentMessageChars + systemPromptReserve + turnReserve;
|
||||||
|
const availableContextSize = Math.max(0, modelCapabilities.contextWindowChars - totalReserved);
|
||||||
|
|
||||||
|
// Use at most 70% of total context window size to be safe
|
||||||
|
const maxSafeContextSize = Math.floor(modelCapabilities.contextWindowChars * 0.7);
|
||||||
|
|
||||||
|
// Return the smaller of available size or max safe size
|
||||||
|
return Math.min(availableContextSize, maxSafeContextSize);
|
||||||
|
}
|
159
src/services/llm/model_capabilities_service.ts
Normal file
159
src/services/llm/model_capabilities_service.ts
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
import log from '../log.js';
|
||||||
|
import type { ModelCapabilities } from './interfaces/model_capabilities.js';
|
||||||
|
import { MODEL_CAPABILITIES, DEFAULT_MODEL_CAPABILITIES } from './interfaces/model_capabilities.js';
|
||||||
|
import aiServiceManager from './ai_service_manager.js';
|
||||||
|
import { getEmbeddingProvider } from './providers/providers.js';
|
||||||
|
import type { BaseEmbeddingProvider } from './embeddings/base_embeddings.js';
|
||||||
|
import type { EmbeddingModelInfo } from './interfaces/embedding_interfaces.js';
|
||||||
|
|
||||||
|
// Define a type for embedding providers that might have the getModelInfo method
|
||||||
|
interface EmbeddingProviderWithModelInfo {
|
||||||
|
getModelInfo?: (modelName: string) => Promise<EmbeddingModelInfo>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Service for fetching and caching model capabilities
|
||||||
|
*/
|
||||||
|
export class ModelCapabilitiesService {
|
||||||
|
// Cache model capabilities
|
||||||
|
private capabilitiesCache: Map<string, ModelCapabilities> = new Map();
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
// Initialize cache with known models
|
||||||
|
this.initializeCache();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialize the cache with known model capabilities
|
||||||
|
*/
|
||||||
|
private initializeCache() {
|
||||||
|
// Add all predefined model capabilities to cache
|
||||||
|
for (const [model, capabilities] of Object.entries(MODEL_CAPABILITIES)) {
|
||||||
|
this.capabilitiesCache.set(model, {
|
||||||
|
...DEFAULT_MODEL_CAPABILITIES,
|
||||||
|
...capabilities
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get model capabilities, fetching from provider if needed
|
||||||
|
*
|
||||||
|
* @param modelName Full model name (with or without provider prefix)
|
||||||
|
* @returns Model capabilities
|
||||||
|
*/
|
||||||
|
async getModelCapabilities(modelName: string): Promise<ModelCapabilities> {
|
||||||
|
// Handle provider-prefixed model names (e.g., "openai:gpt-4")
|
||||||
|
let provider = 'default';
|
||||||
|
let baseModelName = modelName;
|
||||||
|
|
||||||
|
if (modelName.includes(':')) {
|
||||||
|
const parts = modelName.split(':');
|
||||||
|
provider = parts[0];
|
||||||
|
baseModelName = parts[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check cache first
|
||||||
|
const cacheKey = baseModelName;
|
||||||
|
if (this.capabilitiesCache.has(cacheKey)) {
|
||||||
|
return this.capabilitiesCache.get(cacheKey)!;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch from provider if possible
|
||||||
|
try {
|
||||||
|
// Get provider service
|
||||||
|
const providerService = aiServiceManager.getService(provider);
|
||||||
|
|
||||||
|
if (providerService && typeof (providerService as any).getModelCapabilities === 'function') {
|
||||||
|
// If provider supports direct capability fetching, use it
|
||||||
|
const capabilities = await (providerService as any).getModelCapabilities(baseModelName);
|
||||||
|
|
||||||
|
if (capabilities) {
|
||||||
|
// Merge with defaults and cache
|
||||||
|
const fullCapabilities = {
|
||||||
|
...DEFAULT_MODEL_CAPABILITIES,
|
||||||
|
...capabilities
|
||||||
|
};
|
||||||
|
|
||||||
|
this.capabilitiesCache.set(cacheKey, fullCapabilities);
|
||||||
|
log.info(`Fetched capabilities for ${modelName}: context window ${fullCapabilities.contextWindowTokens} tokens`);
|
||||||
|
|
||||||
|
return fullCapabilities;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to fetch from embedding provider if available
|
||||||
|
const embeddingProvider = getEmbeddingProvider(provider);
|
||||||
|
|
||||||
|
if (embeddingProvider) {
|
||||||
|
try {
|
||||||
|
// Cast to a type that might have getModelInfo method
|
||||||
|
const providerWithModelInfo = embeddingProvider as unknown as EmbeddingProviderWithModelInfo;
|
||||||
|
|
||||||
|
if (providerWithModelInfo.getModelInfo) {
|
||||||
|
const modelInfo = await providerWithModelInfo.getModelInfo(baseModelName);
|
||||||
|
|
||||||
|
if (modelInfo && modelInfo.contextWidth) {
|
||||||
|
// Convert to our capabilities format
|
||||||
|
const capabilities: ModelCapabilities = {
|
||||||
|
...DEFAULT_MODEL_CAPABILITIES,
|
||||||
|
contextWindowTokens: modelInfo.contextWidth,
|
||||||
|
contextWindowChars: modelInfo.contextWidth * 4 // Rough estimate: 4 chars per token
|
||||||
|
};
|
||||||
|
|
||||||
|
this.capabilitiesCache.set(cacheKey, capabilities);
|
||||||
|
log.info(`Derived capabilities for ${modelName} from embedding provider: context window ${capabilities.contextWindowTokens} tokens`);
|
||||||
|
|
||||||
|
return capabilities;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
log.info(`Could not get model info from embedding provider for ${modelName}: ${error}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
log.error(`Error fetching model capabilities for ${modelName}: ${error}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we get here, try to find a similar model in our predefined list
|
||||||
|
for (const knownModel of Object.keys(MODEL_CAPABILITIES)) {
|
||||||
|
// Check if the model name contains this known model (e.g., "gpt-4-1106-preview" contains "gpt-4")
|
||||||
|
if (baseModelName.includes(knownModel)) {
|
||||||
|
const capabilities = {
|
||||||
|
...DEFAULT_MODEL_CAPABILITIES,
|
||||||
|
...MODEL_CAPABILITIES[knownModel]
|
||||||
|
};
|
||||||
|
|
||||||
|
this.capabilitiesCache.set(cacheKey, capabilities);
|
||||||
|
log.info(`Using similar model (${knownModel}) capabilities for ${modelName}`);
|
||||||
|
|
||||||
|
return capabilities;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to defaults if nothing else works
|
||||||
|
log.info(`Using default capabilities for unknown model ${modelName}`);
|
||||||
|
this.capabilitiesCache.set(cacheKey, DEFAULT_MODEL_CAPABILITIES);
|
||||||
|
|
||||||
|
return DEFAULT_MODEL_CAPABILITIES;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update model capabilities in the cache
|
||||||
|
*
|
||||||
|
* @param modelName Model name
|
||||||
|
* @param capabilities Capabilities to update
|
||||||
|
*/
|
||||||
|
updateModelCapabilities(modelName: string, capabilities: Partial<ModelCapabilities>) {
|
||||||
|
const currentCapabilities = this.capabilitiesCache.get(modelName) || DEFAULT_MODEL_CAPABILITIES;
|
||||||
|
|
||||||
|
this.capabilitiesCache.set(modelName, {
|
||||||
|
...currentCapabilities,
|
||||||
|
...capabilities
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and export singleton instance
|
||||||
|
const modelCapabilitiesService = new ModelCapabilitiesService();
|
||||||
|
export default modelCapabilitiesService;
|
@ -106,7 +106,8 @@ export class ChatPipeline {
|
|||||||
// Get semantic context for regular queries
|
// Get semantic context for regular queries
|
||||||
const semanticContext = await this.stages.semanticContextExtraction.execute({
|
const semanticContext = await this.stages.semanticContextExtraction.execute({
|
||||||
noteId: input.noteId,
|
noteId: input.noteId,
|
||||||
query: input.query
|
query: input.query,
|
||||||
|
messages: input.messages
|
||||||
});
|
});
|
||||||
context = semanticContext.context;
|
context = semanticContext.context;
|
||||||
this.updateStageMetrics('semanticContextExtraction', contextStartTime);
|
this.updateStageMetrics('semanticContextExtraction', contextStartTime);
|
||||||
@ -136,10 +137,10 @@ export class ChatPipeline {
|
|||||||
const llmStartTime = Date.now();
|
const llmStartTime = Date.now();
|
||||||
|
|
||||||
// Setup streaming handler if streaming is enabled and callback provided
|
// Setup streaming handler if streaming is enabled and callback provided
|
||||||
const enableStreaming = this.config.enableStreaming &&
|
const enableStreaming = this.config.enableStreaming &&
|
||||||
modelSelection.options.stream !== false &&
|
modelSelection.options.stream !== false &&
|
||||||
typeof streamCallback === 'function';
|
typeof streamCallback === 'function';
|
||||||
|
|
||||||
if (enableStreaming) {
|
if (enableStreaming) {
|
||||||
// Make sure stream is enabled in options
|
// Make sure stream is enabled in options
|
||||||
modelSelection.options.stream = true;
|
modelSelection.options.stream = true;
|
||||||
@ -157,10 +158,10 @@ export class ChatPipeline {
|
|||||||
await completion.response.stream(async (chunk: StreamChunk) => {
|
await completion.response.stream(async (chunk: StreamChunk) => {
|
||||||
// Process the chunk text
|
// Process the chunk text
|
||||||
const processedChunk = await this.processStreamChunk(chunk, input.options);
|
const processedChunk = await this.processStreamChunk(chunk, input.options);
|
||||||
|
|
||||||
// Accumulate text for final response
|
// Accumulate text for final response
|
||||||
accumulatedText += processedChunk.text;
|
accumulatedText += processedChunk.text;
|
||||||
|
|
||||||
// Forward to callback
|
// Forward to callback
|
||||||
await streamCallback!(processedChunk.text, processedChunk.done);
|
await streamCallback!(processedChunk.text, processedChunk.done);
|
||||||
});
|
});
|
||||||
@ -182,12 +183,12 @@ export class ChatPipeline {
|
|||||||
|
|
||||||
const endTime = Date.now();
|
const endTime = Date.now();
|
||||||
const executionTime = endTime - startTime;
|
const executionTime = endTime - startTime;
|
||||||
|
|
||||||
// Update overall average execution time
|
// Update overall average execution time
|
||||||
this.metrics.averageExecutionTime =
|
this.metrics.averageExecutionTime =
|
||||||
(this.metrics.averageExecutionTime * (this.metrics.totalExecutions - 1) + executionTime) /
|
(this.metrics.averageExecutionTime * (this.metrics.totalExecutions - 1) + executionTime) /
|
||||||
this.metrics.totalExecutions;
|
this.metrics.totalExecutions;
|
||||||
|
|
||||||
log.info(`Chat pipeline completed in ${executionTime}ms`);
|
log.info(`Chat pipeline completed in ${executionTime}ms`);
|
||||||
|
|
||||||
return finalResponse;
|
return finalResponse;
|
||||||
@ -235,12 +236,12 @@ export class ChatPipeline {
|
|||||||
*/
|
*/
|
||||||
private updateStageMetrics(stageName: string, startTime: number) {
|
private updateStageMetrics(stageName: string, startTime: number) {
|
||||||
if (!this.config.enableMetrics) return;
|
if (!this.config.enableMetrics) return;
|
||||||
|
|
||||||
const executionTime = Date.now() - startTime;
|
const executionTime = Date.now() - startTime;
|
||||||
const metrics = this.metrics.stageMetrics[stageName];
|
const metrics = this.metrics.stageMetrics[stageName];
|
||||||
|
|
||||||
metrics.totalExecutions++;
|
metrics.totalExecutions++;
|
||||||
metrics.averageExecutionTime =
|
metrics.averageExecutionTime =
|
||||||
(metrics.averageExecutionTime * (metrics.totalExecutions - 1) + executionTime) /
|
(metrics.averageExecutionTime * (metrics.totalExecutions - 1) + executionTime) /
|
||||||
metrics.totalExecutions;
|
metrics.totalExecutions;
|
||||||
}
|
}
|
||||||
@ -258,7 +259,7 @@ export class ChatPipeline {
|
|||||||
resetMetrics(): void {
|
resetMetrics(): void {
|
||||||
this.metrics.totalExecutions = 0;
|
this.metrics.totalExecutions = 0;
|
||||||
this.metrics.averageExecutionTime = 0;
|
this.metrics.averageExecutionTime = 0;
|
||||||
|
|
||||||
Object.keys(this.metrics.stageMetrics).forEach(stageName => {
|
Object.keys(this.metrics.stageMetrics).forEach(stageName => {
|
||||||
this.metrics.stageMetrics[stageName] = {
|
this.metrics.stageMetrics[stageName] = {
|
||||||
totalExecutions: 0,
|
totalExecutions: 0,
|
||||||
|
@ -15,12 +15,12 @@ export interface ChatPipelineConfig {
|
|||||||
* Whether to enable streaming support
|
* Whether to enable streaming support
|
||||||
*/
|
*/
|
||||||
enableStreaming: boolean;
|
enableStreaming: boolean;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Whether to enable performance metrics
|
* Whether to enable performance metrics
|
||||||
*/
|
*/
|
||||||
enableMetrics: boolean;
|
enableMetrics: boolean;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Maximum number of tool call iterations
|
* Maximum number of tool call iterations
|
||||||
*/
|
*/
|
||||||
@ -84,6 +84,7 @@ export interface SemanticContextExtractionInput extends PipelineInput {
|
|||||||
noteId: string;
|
noteId: string;
|
||||||
query: string;
|
query: string;
|
||||||
maxResults?: number;
|
maxResults?: number;
|
||||||
|
messages?: Message[];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -15,11 +15,11 @@ export class SemanticContextExtractionStage extends BasePipelineStage<SemanticCo
|
|||||||
* Extract semantic context based on a query
|
* Extract semantic context based on a query
|
||||||
*/
|
*/
|
||||||
protected async process(input: SemanticContextExtractionInput): Promise<{ context: string }> {
|
protected async process(input: SemanticContextExtractionInput): Promise<{ context: string }> {
|
||||||
const { noteId, query, maxResults = 5 } = input;
|
const { noteId, query, maxResults = 5, messages = [] } = input;
|
||||||
log.info(`Extracting semantic context from note ${noteId}, query: ${query?.substring(0, 50)}...`);
|
log.info(`Extracting semantic context from note ${noteId}, query: ${query?.substring(0, 50)}...`);
|
||||||
|
|
||||||
const contextService = aiServiceManager.getContextService();
|
const contextService = aiServiceManager.getContextService();
|
||||||
const context = await contextService.getSemanticContext(noteId, query, maxResults);
|
const context = await contextService.getSemanticContext(noteId, query, maxResults, messages);
|
||||||
|
|
||||||
return { context };
|
return { context };
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user