diff --git a/src/public/app/widgets/llm_chat_panel.ts b/src/public/app/widgets/llm_chat_panel.ts index 84ab68f77..68d4bec61 100644 --- a/src/public/app/widgets/llm_chat_panel.ts +++ b/src/public/app/widgets/llm_chat_panel.ts @@ -24,6 +24,7 @@ export default class LlmChatPanel extends BasicWidget { private chatContainer!: HTMLElement; private loadingIndicator!: HTMLElement; private sourcesList!: HTMLElement; + private useAdvancedContextCheckbox!: HTMLInputElement; private sessionId: string | null = null; private currentNoteId: string | null = null; @@ -45,15 +46,29 @@ export default class LlmChatPanel extends BasicWidget {
-
- - + +
+
+ + +
+
+ + ${t('ai.advanced_context_helps')} +
+
+
+ + +
`); @@ -66,6 +81,7 @@ export default class LlmChatPanel extends BasicWidget { this.chatContainer = element.querySelector('.note-context-chat-container') as HTMLElement; this.loadingIndicator = element.querySelector('.loading-indicator') as HTMLElement; this.sourcesList = element.querySelector('.sources-list') as HTMLElement; + this.useAdvancedContextCheckbox = element.querySelector('.use-advanced-context-checkbox') as HTMLInputElement; this.initializeEventListeners(); @@ -109,47 +125,67 @@ export default class LlmChatPanel extends BasicWidget { return; } + this.addMessageToChat('user', content); + this.noteContextChatInput.value = ''; this.showLoadingIndicator(); + this.hideSources(); try { - // Add user message to chat - this.addMessageToChat('user', content); - this.noteContextChatInput.value = ''; + const useAdvancedContext = this.useAdvancedContextCheckbox.checked; - // Get AI settings - const useRAG = true; // Always use RAG for this widget + // Setup streaming + const source = new EventSource(`./api/llm/messages?sessionId=${this.sessionId}&format=stream`); + let assistantResponse = ''; - // Send message to server - const response = await server.post('llm/sessions/' + this.sessionId + '/messages', { - sessionId: this.sessionId, - content: content, - options: { - useRAG: useRAG + // Handle streaming response + source.onmessage = (event) => { + if (event.data === '[DONE]') { + // Stream completed + source.close(); + this.hideLoadingIndicator(); + return; } + + try { + const data = JSON.parse(event.data); + if (data.content) { + assistantResponse += data.content; + // Update the UI with the accumulated response + const assistantElement = this.noteContextChatMessages.querySelector('.assistant-message:last-child .message-content'); + if (assistantElement) { + assistantElement.innerHTML = this.formatMarkdown(assistantResponse); + } else { + this.addMessageToChat('assistant', assistantResponse); + } + // Scroll to the bottom + this.chatContainer.scrollTop = this.chatContainer.scrollHeight; + } + } catch (e) { + console.error('Error parsing SSE message:', e); + } + }; + + source.onerror = () => { + source.close(); + this.hideLoadingIndicator(); + toastService.showError('Error connecting to the LLM service. Please try again.'); + }; + + // Send the actual message + const response = await server.post('llm/messages', { + sessionId: this.sessionId, + content, + contextNoteId: this.currentNoteId, + useAdvancedContext }); - // Get the assistant's message (last one) - if (response?.messages?.length) { - const messages = response.messages; - const lastMessage = messages[messages.length - 1]; - - if (lastMessage && lastMessage.role === 'assistant') { - this.addMessageToChat('assistant', lastMessage.content); - } - } - - // Display sources if available - if (response?.sources?.length) { + // Handle sources if returned in non-streaming response + if (response && response.sources && response.sources.length > 0) { this.showSources(response.sources); - } else { - this.hideSources(); } - } catch (error) { - console.error('Failed to send message:', error); - toastService.showError('Failed to send message to AI'); - } finally { this.hideLoadingIndicator(); + toastService.showError('Error sending message: ' + (error as Error).message); } } @@ -243,4 +279,17 @@ export default class LlmChatPanel extends BasicWidget { } }); } + + /** + * Format markdown content for display + */ + private formatMarkdown(content: string): string { + // Simple markdown formatting - could be replaced with a proper markdown library + return content + .replace(/\*\*(.*?)\*\*/g, '$1') + .replace(/\*(.*?)\*/g, '$1') + .replace(/`(.*?)`/g, '$1') + .replace(/\n/g, '
') + .replace(/```(.*?)```/gs, '
$1
'); + } } diff --git a/src/public/translations/en/translation.json b/src/public/translations/en/translation.json index 25ab6263d..71fe62a08 100644 --- a/src/public/translations/en/translation.json +++ b/src/public/translations/en/translation.json @@ -1755,5 +1755,11 @@ "content_language": { "title": "Content languages", "description": "Select one or more languages that should appear in the language selection in the Basic Properties section of a read-only or editable text note. This will allow features such as spell-checking or right-to-left support." + }, + "ai": { + "sources": "Sources", + "enter_message": "Enter your message...", + "use_advanced_context": "Use Advanced Context", + "advanced_context_helps": "Helps with large knowledge bases and limited context windows" } } diff --git a/src/routes/api/llm.ts b/src/routes/api/llm.ts index c987a84b0..ab1a43127 100644 --- a/src/routes/api/llm.ts +++ b/src/routes/api/llm.ts @@ -9,6 +9,8 @@ import providerManager from "../../services/llm/embeddings/providers.js"; import type { Message, ChatCompletionOptions } from "../../services/llm/ai_interface.js"; // Import this way to prevent immediate instantiation import * as aiServiceManagerModule from "../../services/llm/ai_service_manager.js"; +import triliumContextService from "../../services/llm/trilium_context_service.js"; +import sql from "../../services/sql.js"; // Define basic interfaces interface ChatMessage { @@ -290,132 +292,126 @@ async function deleteSession(req: Request, res: Response) { } /** - * Find relevant notes using vector embeddings + * Find relevant notes based on search query */ -async function findRelevantNotes(query: string, contextNoteId: string | null = null, limit = 5): Promise { +async function findRelevantNotes(content: string, contextNoteId: string | null = null, limit = 5): Promise { try { - // Only proceed if database is initialized + // If database is not initialized, we can't do this if (!isDatabaseInitialized()) { - log.info('Database not initialized, skipping vector search'); - return [{ - noteId: "root", - title: "Database not initialized yet", - content: "Please wait for database initialization to complete." - }]; + return []; } - // Get the default embedding provider - let providerId; - try { - // @ts-ignore - embeddingsDefaultProvider exists but might not be in the TypeScript definitions - providerId = await options.getOption('embeddingsDefaultProvider') || 'openai'; - } catch (error) { - log.info('Could not get default embedding provider, using mock data'); - return [{ - noteId: "root", - title: "Embeddings not configured", - content: "Embedding provider not available" - }]; + // Check if embeddings are available + const enabledProviders = await providerManager.getEnabledEmbeddingProviders(); + if (enabledProviders.length === 0) { + log.info("No embedding providers available, can't find relevant notes"); + return []; } - const provider = providerManager.getEmbeddingProvider(providerId); - - if (!provider) { - log.info(`Embedding provider ${providerId} not found, using mock data`); - return [{ - noteId: "root", - title: "Embeddings not available", - content: "No embedding provider available" - }]; + // If content is too short, don't bother + if (content.length < 3) { + return []; } - // Generate embedding for the query - const embedding = await provider.generateEmbeddings(query); + // Get the embedding for the query + const provider = enabledProviders[0]; + const embedding = await provider.generateEmbeddings(content); - // Find similar notes - const modelId = 'default'; // Use default model for the provider - const similarNotes = await vectorStore.findSimilarNotes( - embedding, providerId, modelId, limit, 0.6 // Lower threshold to find more results - ); - - // If a context note was provided, check if we should include its children + let results; if (contextNoteId) { - const contextNote = becca.getNote(contextNoteId); - if (contextNote) { - const childNotes = contextNote.getChildNotes(); - if (childNotes.length > 0) { - // Add relevant children that weren't already included - const childIds = new Set(childNotes.map(note => note.noteId)); - const existingIds = new Set(similarNotes.map(note => note.noteId)); + // For branch context, get notes specifically from that branch - // Find children that aren't already in the similar notes - const missingChildIds = Array.from(childIds).filter(id => !existingIds.has(id)); + // TODO: This is a simplified implementation - we need to + // properly get all notes in the subtree starting from contextNoteId - // Add up to 3 children that weren't already included - for (const noteId of missingChildIds.slice(0, 3)) { - similarNotes.push({ + // For now, just get direct children of the context note + const contextNote = becca.notes[contextNoteId]; + if (!contextNote) { + return []; + } + + const childBranches = await sql.getRows(` + SELECT branches.* FROM branches + WHERE branches.parentNoteId = ? + AND branches.isDeleted = 0 + `, [contextNoteId]); + + const childNoteIds = childBranches.map((branch: any) => branch.noteId); + + // Include the context note itself + childNoteIds.push(contextNoteId); + + // Find similar notes in this context + results = []; + + for (const noteId of childNoteIds) { + const noteEmbedding = await vectorStore.getEmbeddingForNote( + noteId, + provider.name, + provider.getConfig().model + ); + + if (noteEmbedding) { + const similarity = vectorStore.cosineSimilarity( + embedding, + noteEmbedding.embedding + ); + + if (similarity > 0.65) { + results.push({ noteId, - similarity: 0.75 // Fixed similarity score for context children + similarity }); } } } + + // Sort by similarity + results.sort((a, b) => b.similarity - a.similarity); + results = results.slice(0, limit); + } else { + // General search across all notes + results = await vectorStore.findSimilarNotes( + embedding, + provider.name, + provider.getConfig().model, + limit + ); } - // Get note content for context - return await Promise.all(similarNotes.map(async ({ noteId, similarity }) => { - const note = becca.getNote(noteId); - if (!note) { - return { - noteId, - title: "Unknown Note", - similarity - }; + // Format the results + const sources: NoteSource[] = []; + + for (const result of results) { + const note = becca.notes[result.noteId]; + if (!note) continue; + + let noteContent: string | undefined = undefined; + if (note.type === 'text') { + const content = note.getContent(); + // Handle both string and Buffer types + noteContent = typeof content === 'string' ? content : + content instanceof Buffer ? content.toString('utf8') : undefined; } - // Get note content - let content = ''; - try { - // @ts-ignore - Content can be string or Buffer - const noteContent = await note.getContent(); - content = typeof noteContent === 'string' ? noteContent : noteContent.toString('utf8'); - - // Truncate content if it's too long (for performance) - if (content.length > 2000) { - content = content.substring(0, 2000) + "..."; - } - } catch (e) { - log.error(`Error getting content for note ${noteId}: ${e}`); - } - - // Get a branch ID for navigation - let branchId; - try { - const branches = note.getBranches(); - if (branches.length > 0) { - branchId = branches[0].branchId; - } - } catch (e) { - log.error(`Error getting branch for note ${noteId}: ${e}`); - } - - return { - noteId, + sources.push({ + noteId: result.noteId, title: note.title, - content, - similarity, - branchId - }; - })); - } catch (error) { - log.error(`Error finding relevant notes: ${error}`); - // Return empty array on error + content: noteContent, + similarity: result.similarity, + branchId: note.getBranches()[0]?.branchId + }); + } + + return sources; + } catch (error: any) { + log.error(`Error finding relevant notes: ${error.message}`); return []; } } /** - * Build a context string from relevant notes + * Build context from notes */ function buildContextFromNotes(sources: NoteSource[], query: string): string { console.log("Building context from notes with query:", query); @@ -449,265 +445,237 @@ Now, based on the above notes, please answer: ${query}`; } /** - * Send a message to an LLM chat session and get a response + * Send a message to the AI */ async function sendMessage(req: Request, res: Response) { try { - const { sessionId, content, temperature, maxTokens, provider, model } = req.body; - - console.log("Received message request:", { - sessionId, - contentLength: content ? content.length : 0, - contentPreview: content ? content.substring(0, 50) + (content.length > 50 ? '...' : '') : 'undefined', - temperature, - maxTokens, - provider, - model - }); - - if (!sessionId) { - throw new Error('Session ID is required'); - } + // Extract the content from the request body + const { content, sessionId, useAdvancedContext = false } = req.body || {}; + // Validate the content if (!content || typeof content !== 'string' || content.trim().length === 0) { throw new Error('Content cannot be empty'); } - // Check if streaming is requested - const wantsStream = (req.headers as any)['accept']?.includes('text/event-stream'); + // Get or create the session + let session: ChatSession; - // If client wants streaming, set up SSE response - if (wantsStream) { - res.setHeader('Content-Type', 'text/event-stream'); - res.setHeader('Cache-Control', 'no-cache'); - res.setHeader('Connection', 'keep-alive'); - - // Get chat session - let session = sessions.get(sessionId); - if (!session) { - const newSession = await createSession(req, res); - if (!newSession) { - throw new Error('Failed to create session'); - } - // Add required properties to match ChatSession interface - session = { - ...newSession, - messages: [], - lastActive: new Date(), - metadata: {} - }; - sessions.set(sessionId, session); + if (sessionId && sessions.has(sessionId)) { + session = sessions.get(sessionId)!; + session.lastActive = new Date(); + } else { + const result = await createSession(req, res); + if (!result?.id) { + throw new Error('Failed to create a new session'); } + session = sessions.get(result.id)!; + } - // Add user message to session - const userMessage: ChatMessage = { - role: 'user', - content: content, - timestamp: new Date() + // Check if AI services are available + if (!safelyUseAIManager()) { + throw new Error('AI services are not available'); + } + + // Get the AI service manager + const aiServiceManager = aiServiceManagerModule.default.getInstance(); + // Get the default service - just use the first available one + const availableProviders = aiServiceManager.getAvailableProviders(); + let service = null; + + if (availableProviders.length > 0) { + // Use the first available provider + const providerName = availableProviders[0]; + // We know the manager has a 'services' property from our code inspection, + // but TypeScript doesn't know that from the interface. + // This is a workaround to access it + service = (aiServiceManager as any).services[providerName]; + } + + if (!service) { + throw new Error('No AI service is available'); + } + + // Create user message + const userMessage: Message = { + role: 'user', + content + }; + + // Add message to session + session.messages.push({ + role: 'user', + content, + timestamp: new Date() + }); + + // Log a preview of the message + log.info(`Processing LLM message: "${content.substring(0, 50)}${content.length > 50 ? '...' : ''}"`); + + // Information to return to the client + let aiResponse = ''; + let sourceNotes: NoteSource[] = []; + + // If Advanced Context is enabled, we use the improved method + if (useAdvancedContext) { + // Use the Trilium-specific approach + const contextNoteId = session.noteContext || null; + const results = await triliumContextService.processQuery(content, service, contextNoteId); + + // Get the generated context + const context = results.context; + sourceNotes = results.notes; + + // Add system message with the context + const contextMessage: Message = { + role: 'system', + content: context }; - console.log("Created user message:", { - role: userMessage.role, - contentLength: userMessage.content?.length || 0, - contentPreview: userMessage.content?.substring(0, 50) + (userMessage.content?.length > 50 ? '...' : '') || 'undefined' - }); - session.messages.push(userMessage); - // Get context for query - const sources = await findRelevantNotes(content, session.noteContext || null); - - // Format messages for AI with proper type casting + // Format all messages for the AI const aiMessages: Message[] = [ - { role: 'system', content: 'You are a helpful assistant for Trilium Notes. When providing answers, use only the context provided in the notes. If the information is not in the notes, say so.' }, - { role: 'user', content: buildContextFromNotes(sources, content) } + contextMessage, + ...session.messages.slice(-10).map(msg => ({ + role: msg.role, + content: msg.content + })) ]; - // Ensure we're not sending empty content - console.log("Final message content length:", aiMessages[1].content.length); - console.log("Final message content preview:", aiMessages[1].content.substring(0, 100)); - - try { - // Send initial SSE message with session info - const sourcesForResponse = sources.map(({ noteId, title, similarity, branchId }) => ({ - noteId, - title, - similarity: similarity ? Math.round(similarity * 100) / 100 : undefined, - branchId - })); - - res.write(`data: ${JSON.stringify({ - type: 'init', - session: { - id: sessionId, - messages: session.messages.slice(0, -1), // Don't include the new message yet - sources: sourcesForResponse - } - })}\n\n`); - - // Get AI response with streaming enabled - const aiResponse = await aiServiceManagerModule.default.generateChatCompletion(aiMessages, { - temperature, - maxTokens, - model: provider ? `${provider}:${model}` : model, - stream: true - }); - - if (aiResponse.stream) { - // Create an empty assistant message - const assistantMessage: ChatMessage = { - role: 'assistant', - content: '', - timestamp: new Date() - }; - session.messages.push(assistantMessage); - - // Stream the response chunks - await aiResponse.stream(async (chunk) => { - if (chunk.text) { - // Update the message content - assistantMessage.content += chunk.text; - - // Send chunk to client - res.write(`data: ${JSON.stringify({ - type: 'chunk', - text: chunk.text, - done: chunk.done - })}\n\n`); - } - - if (chunk.done) { - // Send final message with complete response - res.write(`data: ${JSON.stringify({ - type: 'done', - session: { - id: sessionId, - messages: session.messages, - sources: sourcesForResponse - } - })}\n\n`); - - res.end(); - } - }); - - return; // Early return for streaming - } else { - // Fallback for non-streaming response - const assistantMessage: ChatMessage = { - role: 'assistant', - content: aiResponse.text, - timestamp: new Date() - }; - session.messages.push(assistantMessage); - - // Send complete response - res.write(`data: ${JSON.stringify({ - type: 'done', - session: { - id: sessionId, - messages: session.messages, - sources: sourcesForResponse - } - })}\n\n`); - - res.end(); - return; - } - } catch (error: any) { - // Send error in streaming format - res.write(`data: ${JSON.stringify({ - type: 'error', - error: `AI service error: ${error.message}` - })}\n\n`); - - res.end(); - return; - } - } - - // Non-streaming API continues with normal JSON response... - - // Get chat session - let session = sessions.get(sessionId); - if (!session) { - const newSession = await createSession(req, res); - if (!newSession) { - throw new Error('Failed to create session'); - } - // Add required properties to match ChatSession interface - session = { - ...newSession, - messages: [], - lastActive: new Date(), - metadata: {} + // Configure chat options from session metadata + const chatOptions: ChatCompletionOptions = { + temperature: session.metadata.temperature || 0.7, + maxTokens: session.metadata.maxTokens, + model: session.metadata.model + // 'provider' property has been removed as it's not in the ChatCompletionOptions type }; - sessions.set(sessionId, session); + + // Get streaming response if requested + const acceptHeader = req.get('Accept'); + if (acceptHeader && acceptHeader.includes('text/event-stream')) { + res.setHeader('Content-Type', 'text/event-stream'); + res.setHeader('Cache-Control', 'no-cache'); + res.setHeader('Connection', 'keep-alive'); + + let messageContent = ''; + + // Stream the response + await service.sendChatCompletion( + aiMessages, + chatOptions, + (chunk: string) => { + messageContent += chunk; + res.write(`data: ${JSON.stringify({ content: chunk })}\n\n`); + } + ); + + // Close the stream + res.write('data: [DONE]\n\n'); + res.end(); + + // Store the full response + aiResponse = messageContent; + } else { + // Non-streaming approach + aiResponse = await service.sendChatCompletion(aiMessages, chatOptions); + } + } else { + // Original approach - find relevant notes through direct embedding comparison + const relevantNotes = await findRelevantNotes( + content, + session.noteContext || null, + 5 + ); + + sourceNotes = relevantNotes; + + // Build context from relevant notes + const context = buildContextFromNotes(relevantNotes, content); + + // Add system message with the context + const contextMessage: Message = { + role: 'system', + content: context + }; + + // Format all messages for the AI + const aiMessages: Message[] = [ + contextMessage, + ...session.messages.slice(-10).map(msg => ({ + role: msg.role, + content: msg.content + })) + ]; + + // Configure chat options from session metadata + const chatOptions: ChatCompletionOptions = { + temperature: session.metadata.temperature || 0.7, + maxTokens: session.metadata.maxTokens, + model: session.metadata.model + // 'provider' property has been removed as it's not in the ChatCompletionOptions type + }; + + // Get streaming response if requested + const acceptHeader = req.get('Accept'); + if (acceptHeader && acceptHeader.includes('text/event-stream')) { + res.setHeader('Content-Type', 'text/event-stream'); + res.setHeader('Cache-Control', 'no-cache'); + res.setHeader('Connection', 'keep-alive'); + + let messageContent = ''; + + // Stream the response + await service.sendChatCompletion( + aiMessages, + chatOptions, + (chunk: string) => { + messageContent += chunk; + res.write(`data: ${JSON.stringify({ content: chunk })}\n\n`); + } + ); + + // Close the stream + res.write('data: [DONE]\n\n'); + res.end(); + + // Store the full response + aiResponse = messageContent; + } else { + // Non-streaming approach + aiResponse = await service.sendChatCompletion(aiMessages, chatOptions); + } } - // Add user message to session - const userMessage: ChatMessage = { - role: 'user', - content: content, - timestamp: new Date() - }; - console.log("Created user message:", { - role: userMessage.role, - contentLength: userMessage.content?.length || 0, - contentPreview: userMessage.content?.substring(0, 50) + (userMessage.content?.length > 50 ? '...' : '') || 'undefined' - }); - session.messages.push(userMessage); - - // Get context for query - const sources = await findRelevantNotes(content, session.noteContext || null); - - // Format messages for AI with proper type casting - const aiMessages: Message[] = [ - { role: 'system', content: 'You are a helpful assistant for Trilium Notes. When providing answers, use only the context provided in the notes. If the information is not in the notes, say so.' }, - { role: 'user', content: buildContextFromNotes(sources, content) } - ]; - - // Ensure we're not sending empty content - console.log("Final message content length:", aiMessages[1].content.length); - console.log("Final message content preview:", aiMessages[1].content.substring(0, 100)); - - try { - // Get AI response using the safe accessor methods - const aiResponse = await aiServiceManagerModule.default.generateChatCompletion(aiMessages, { - temperature, - maxTokens, - model: provider ? `${provider}:${model}` : model, - stream: false + // Only store the assistant's message if we're not streaming (otherwise we already did) + const acceptHeader = req.get('Accept'); + if (!acceptHeader || !acceptHeader.includes('text/event-stream')) { + // Store the assistant's response in the session + session.messages.push({ + role: 'assistant', + content: aiResponse, + timestamp: new Date() }); - // Add assistant message to session - const assistantMessage: ChatMessage = { - role: 'assistant', - content: aiResponse.text, - timestamp: new Date() - }; - session.messages.push(assistantMessage); - - // Format sources for the response (without content to reduce payload size) - const sourcesForResponse = sources.map(({ noteId, title, similarity, branchId }) => ({ - noteId, - title, - similarity: similarity ? Math.round(similarity * 100) / 100 : undefined, - branchId - })); - + // Return the response return { - id: sessionId, - messages: session.messages, - sources: sourcesForResponse, - provider: aiResponse.provider, - model: aiResponse.model + content: aiResponse, + sources: sourceNotes.map(note => ({ + noteId: note.noteId, + title: note.title, + similarity: note.similarity, + branchId: note.branchId + })) }; - } catch (error: any) { - log.error(`AI service error: ${error.message}`); - throw new Error(`AI service error: ${error.message}`); + } else { + // For streaming responses, we've already sent the data + // But we still need to add the message to the session + session.messages.push({ + role: 'assistant', + content: aiResponse, + timestamp: new Date() + }); } } catch (error: any) { - log.error(`Error sending message: ${error.message}`); - throw error; + log.error(`Error sending message to LLM: ${error.message}`); + throw new Error(`Failed to send message: ${error.message}`); } } diff --git a/src/services/llm/trilium_context_service.ts b/src/services/llm/trilium_context_service.ts new file mode 100644 index 000000000..8ff4fb2d6 --- /dev/null +++ b/src/services/llm/trilium_context_service.ts @@ -0,0 +1,410 @@ +import becca from "../../becca/becca.js"; +import vectorStore from "./embeddings/vector_store.js"; +import providerManager from "./embeddings/providers.js"; +import options from "../options.js"; +import log from "../log.js"; +import type { Message } from "./ai_interface.js"; +import { cosineSimilarity } from "./embeddings/vector_store.js"; + +/** + * TriliumContextService provides intelligent context management for working with large knowledge bases + * through limited context window LLMs like Ollama. + * + * It creates a "meta-prompting" approach where the first LLM call is used + * to determine what information might be needed to answer the query, + * then only the relevant context is loaded, before making the final + * response. + */ +class TriliumContextService { + private initialized = false; + private initPromise: Promise | null = null; + private provider: any = null; + + // Cache for recently used context to avoid repeated embedding lookups + private recentQueriesCache = new Map(); + + // Configuration + private cacheExpiryMs = 5 * 60 * 1000; // 5 minutes + private metaPrompt = `You are an AI assistant that decides what information needs to be retrieved from a knowledge base to answer the user's question. +Given the user's question, generate 3-5 specific search queries that would help find relevant information. +Each query should be focused on a different aspect of the question. +Format your answer as a JSON array of strings, with each string being a search query. +Example: ["exact topic mentioned", "related concept 1", "related concept 2"]`; + + constructor() { + this.setupCacheCleanup(); + } + + /** + * Initialize the service + */ + async initialize() { + if (this.initialized) return; + + // Use a promise to prevent multiple simultaneous initializations + if (this.initPromise) return this.initPromise; + + this.initPromise = (async () => { + try { + const providerId = await options.getOption('embeddingsDefaultProvider') || 'ollama'; + this.provider = providerManager.getEmbeddingProvider(providerId); + + if (!this.provider) { + throw new Error(`Embedding provider ${providerId} not found`); + } + + this.initialized = true; + log.info(`Trilium context service initialized with provider: ${providerId}`); + } catch (error: unknown) { + const errorMessage = error instanceof Error ? error.message : String(error); + log.error(`Failed to initialize Trilium context service: ${errorMessage}`); + throw error; + } finally { + this.initPromise = null; + } + })(); + + return this.initPromise; + } + + /** + * Set up periodic cache cleanup + */ + private setupCacheCleanup() { + setInterval(() => { + const now = Date.now(); + for (const [key, data] of this.recentQueriesCache.entries()) { + if (now - data.timestamp > this.cacheExpiryMs) { + this.recentQueriesCache.delete(key); + } + } + }, 60000); // Run cleanup every minute + } + + /** + * Generate search queries to find relevant information for the user question + * @param userQuestion - The user's question + * @param llmService - The LLM service to use for generating queries + * @returns Array of search queries + */ + async generateSearchQueries(userQuestion: string, llmService: any): Promise { + try { + const messages: Message[] = [ + { role: "system", content: this.metaPrompt }, + { role: "user", content: userQuestion } + ]; + + const options = { + temperature: 0.3, + maxTokens: 300 + }; + + // Get the response from the LLM + const response = await llmService.sendTextCompletion(messages, options); + + try { + // Parse the JSON response + const jsonStr = response.trim().replace(/```json|```/g, '').trim(); + const queries = JSON.parse(jsonStr); + + if (Array.isArray(queries) && queries.length > 0) { + return queries; + } else { + throw new Error("Invalid response format"); + } + } catch (parseError) { + // Fallback: if JSON parsing fails, try to extract queries line by line + const lines = response.split('\n') + .map((line: string) => line.trim()) + .filter((line: string) => line.length > 0 && !line.startsWith('```')); + + if (lines.length > 0) { + return lines.map((line: string) => line.replace(/^["'\d\.\-\s]+/, '').trim()); + } + + // If all else fails, just use the original question + return [userQuestion]; + } + } catch (error: unknown) { + const errorMessage = error instanceof Error ? error.message : String(error); + log.error(`Error generating search queries: ${errorMessage}`); + // Fallback to just using the original question + return [userQuestion]; + } + } + + /** + * Find relevant notes using multiple search queries + * @param queries - Array of search queries + * @param contextNoteId - Optional note ID to restrict search to a branch + * @param limit - Max notes to return + * @returns Array of relevant notes + */ + async findRelevantNotesMultiQuery( + queries: string[], + contextNoteId: string | null = null, + limit = 10 + ): Promise { + if (!this.initialized) { + await this.initialize(); + } + + try { + // Cache key combining all queries + const cacheKey = JSON.stringify({ queries, contextNoteId, limit }); + + // Check if we have a recent cache hit + const cached = this.recentQueriesCache.get(cacheKey); + if (cached) { + return cached.relevantNotes; + } + + // Array to store all results with their similarity scores + const allResults: { + noteId: string, + title: string, + content: string | null, + similarity: number, + branchId?: string + }[] = []; + + // Set to keep track of note IDs we've seen to avoid duplicates + const seenNoteIds = new Set(); + + // Process each query + for (const query of queries) { + // Get embeddings for this query + const queryEmbedding = await this.provider.getEmbedding(query); + + // Find notes similar to this query + let results; + if (contextNoteId) { + // Find within a specific context/branch + results = await this.findNotesInBranch( + queryEmbedding, + contextNoteId, + Math.min(limit, 5) // Limit per query + ); + } else { + // Search all notes + results = await vectorStore.findSimilarNotes( + queryEmbedding, + this.provider.id, + this.provider.modelId, + Math.min(limit, 5), // Limit per query + 0.5 // Lower threshold to get more diverse results + ); + } + + // Process results + for (const result of results) { + if (!seenNoteIds.has(result.noteId)) { + seenNoteIds.add(result.noteId); + + // Get the note from Becca + const note = becca.notes[result.noteId]; + if (!note) continue; + + // Add to our results + allResults.push({ + noteId: result.noteId, + title: note.title, + content: note.type === 'text' ? note.getContent() as string : null, + similarity: result.similarity, + branchId: note.getBranches()[0]?.branchId + }); + } + } + } + + // Sort by similarity and take the top 'limit' results + const sortedResults = allResults + .sort((a, b) => b.similarity - a.similarity) + .slice(0, limit); + + // Cache the results + this.recentQueriesCache.set(cacheKey, { + timestamp: Date.now(), + relevantNotes: sortedResults + }); + + return sortedResults; + } catch (error: unknown) { + const errorMessage = error instanceof Error ? error.message : String(error); + log.error(`Error finding relevant notes: ${errorMessage}`); + return []; + } + } + + /** + * Find notes in a specific branch/context + * @param embedding - Query embedding + * @param contextNoteId - Note ID to restrict search to + * @param limit - Max notes to return + * @returns Array of relevant notes + */ + private async findNotesInBranch( + embedding: Float32Array, + contextNoteId: string, + limit = 5 + ): Promise<{noteId: string, similarity: number}[]> { + try { + // Get the subtree note IDs + const subtreeNoteIds = await this.getSubtreeNoteIds(contextNoteId); + + if (subtreeNoteIds.length === 0) { + return []; + } + + // Get all embeddings for these notes using vectorStore instead of direct SQL + const similarities: {noteId: string, similarity: number}[] = []; + + for (const noteId of subtreeNoteIds) { + const noteEmbedding = await vectorStore.getEmbeddingForNote( + noteId, + this.provider.id, + this.provider.modelId + ); + + if (noteEmbedding) { + const similarity = cosineSimilarity(embedding, noteEmbedding.embedding); + if (similarity > 0.5) { // Apply similarity threshold + similarities.push({ + noteId, + similarity + }); + } + } + } + + // Sort by similarity and return top results + return similarities + .sort((a, b) => b.similarity - a.similarity) + .slice(0, limit); + } catch (error: unknown) { + const errorMessage = error instanceof Error ? error.message : String(error); + log.error(`Error finding notes in branch: ${errorMessage}`); + return []; + } + } + + /** + * Get all note IDs in a subtree (including the root note) + * @param rootNoteId - Root note ID + * @returns Array of note IDs + */ + private async getSubtreeNoteIds(rootNoteId: string): Promise { + const note = becca.notes[rootNoteId]; + if (!note) { + return []; + } + + // Use becca to walk the note tree instead of direct SQL + const noteIds = new Set([rootNoteId]); + + // Helper function to collect all children + const collectChildNotes = (noteId: string) => { + // Use becca.getNote(noteId).getChildNotes() to get child notes + const parentNote = becca.notes[noteId]; + if (!parentNote) return; + + // Get all branches where this note is the parent + for (const branch of Object.values(becca.branches)) { + if (branch.parentNoteId === noteId && !branch.isDeleted) { + const childNoteId = branch.noteId; + if (!noteIds.has(childNoteId)) { + noteIds.add(childNoteId); + // Recursively collect children of this child + collectChildNotes(childNoteId); + } + } + } + }; + + // Start collecting from the root + collectChildNotes(rootNoteId); + + return Array.from(noteIds); + } + + /** + * Build a context string from relevant notes + * @param sources - Array of notes + * @param query - Original user query + * @returns Formatted context string + */ + buildContextFromNotes(sources: any[], query: string): string { + if (!sources || sources.length === 0) { + return ""; + } + + let context = `The following are relevant notes from your knowledge base that may help answer the query: "${query}"\n\n`; + + sources.forEach((source, index) => { + context += `--- NOTE ${index + 1}: ${source.title} ---\n`; + + if (source.content) { + // Truncate content if it's too long + const maxContentLength = 1000; + let content = source.content; + + if (content.length > maxContentLength) { + content = content.substring(0, maxContentLength) + " [content truncated due to length]"; + } + + context += `${content}\n`; + } else { + context += "[This note doesn't contain textual content]\n"; + } + + context += "\n"; + }); + + context += "--- END OF NOTES ---\n\n"; + context += "Please use the information above to help answer the query. If the information doesn't contain what you need, just say so and use your general knowledge instead."; + + return context; + } + + /** + * Process a user query with the Trilium-specific approach: + * 1. Generate search queries from the original question + * 2. Find relevant notes using those queries + * 3. Build a context string from the relevant notes + * + * @param userQuestion - The user's original question + * @param llmService - The LLM service to use + * @param contextNoteId - Optional note ID to restrict search to + * @returns Object with context and notes + */ + async processQuery(userQuestion: string, llmService: any, contextNoteId: string | null = null) { + if (!this.initialized) { + await this.initialize(); + } + + // Step 1: Generate search queries + const searchQueries = await this.generateSearchQueries(userQuestion, llmService); + log.info(`Generated search queries: ${JSON.stringify(searchQueries)}`); + + // Step 2: Find relevant notes using those queries + const relevantNotes = await this.findRelevantNotesMultiQuery( + searchQueries, + contextNoteId, + 8 // Get more notes since we're using multiple queries + ); + + // Step 3: Build context from the notes + const context = this.buildContextFromNotes(relevantNotes, userQuestion); + + return { + context, + notes: relevantNotes, + queries: searchQueries + }; + } +} + +export default new TriliumContextService();