diff --git a/src/public/app/widgets/llm_chat_panel.ts b/src/public/app/widgets/llm_chat_panel.ts
index 84ab68f77..68d4bec61 100644
--- a/src/public/app/widgets/llm_chat_panel.ts
+++ b/src/public/app/widgets/llm_chat_panel.ts
@@ -24,6 +24,7 @@ export default class LlmChatPanel extends BasicWidget {
private chatContainer!: HTMLElement;
private loadingIndicator!: HTMLElement;
private sourcesList!: HTMLElement;
+ private useAdvancedContextCheckbox!: HTMLInputElement;
private sessionId: string | null = null;
private currentNoteId: string | null = null;
@@ -45,15 +46,29 @@ export default class LlmChatPanel extends BasicWidget {
-
`);
@@ -66,6 +81,7 @@ export default class LlmChatPanel extends BasicWidget {
this.chatContainer = element.querySelector('.note-context-chat-container') as HTMLElement;
this.loadingIndicator = element.querySelector('.loading-indicator') as HTMLElement;
this.sourcesList = element.querySelector('.sources-list') as HTMLElement;
+ this.useAdvancedContextCheckbox = element.querySelector('.use-advanced-context-checkbox') as HTMLInputElement;
this.initializeEventListeners();
@@ -109,47 +125,67 @@ export default class LlmChatPanel extends BasicWidget {
return;
}
+ this.addMessageToChat('user', content);
+ this.noteContextChatInput.value = '';
this.showLoadingIndicator();
+ this.hideSources();
try {
- // Add user message to chat
- this.addMessageToChat('user', content);
- this.noteContextChatInput.value = '';
+ const useAdvancedContext = this.useAdvancedContextCheckbox.checked;
- // Get AI settings
- const useRAG = true; // Always use RAG for this widget
+ // Setup streaming
+ const source = new EventSource(`./api/llm/messages?sessionId=${this.sessionId}&format=stream`);
+ let assistantResponse = '';
- // Send message to server
- const response = await server.post('llm/sessions/' + this.sessionId + '/messages', {
- sessionId: this.sessionId,
- content: content,
- options: {
- useRAG: useRAG
+ // Handle streaming response
+ source.onmessage = (event) => {
+ if (event.data === '[DONE]') {
+ // Stream completed
+ source.close();
+ this.hideLoadingIndicator();
+ return;
}
+
+ try {
+ const data = JSON.parse(event.data);
+ if (data.content) {
+ assistantResponse += data.content;
+ // Update the UI with the accumulated response
+ const assistantElement = this.noteContextChatMessages.querySelector('.assistant-message:last-child .message-content');
+ if (assistantElement) {
+ assistantElement.innerHTML = this.formatMarkdown(assistantResponse);
+ } else {
+ this.addMessageToChat('assistant', assistantResponse);
+ }
+ // Scroll to the bottom
+ this.chatContainer.scrollTop = this.chatContainer.scrollHeight;
+ }
+ } catch (e) {
+ console.error('Error parsing SSE message:', e);
+ }
+ };
+
+ source.onerror = () => {
+ source.close();
+ this.hideLoadingIndicator();
+ toastService.showError('Error connecting to the LLM service. Please try again.');
+ };
+
+ // Send the actual message
+ const response = await server.post('llm/messages', {
+ sessionId: this.sessionId,
+ content,
+ contextNoteId: this.currentNoteId,
+ useAdvancedContext
});
- // Get the assistant's message (last one)
- if (response?.messages?.length) {
- const messages = response.messages;
- const lastMessage = messages[messages.length - 1];
-
- if (lastMessage && lastMessage.role === 'assistant') {
- this.addMessageToChat('assistant', lastMessage.content);
- }
- }
-
- // Display sources if available
- if (response?.sources?.length) {
+ // Handle sources if returned in non-streaming response
+ if (response && response.sources && response.sources.length > 0) {
this.showSources(response.sources);
- } else {
- this.hideSources();
}
-
} catch (error) {
- console.error('Failed to send message:', error);
- toastService.showError('Failed to send message to AI');
- } finally {
this.hideLoadingIndicator();
+ toastService.showError('Error sending message: ' + (error as Error).message);
}
}
@@ -243,4 +279,17 @@ export default class LlmChatPanel extends BasicWidget {
}
});
}
+
+ /**
+ * Format markdown content for display
+ */
+ private formatMarkdown(content: string): string {
+ // Simple markdown formatting - could be replaced with a proper markdown library
+ return content
+ .replace(/\*\*(.*?)\*\*/g, '$1')
+ .replace(/\*(.*?)\*/g, '$1')
+ .replace(/`(.*?)`/g, '$1
')
+ .replace(/\n/g, '
')
+ .replace(/```(.*?)```/gs, '$1
');
+ }
}
diff --git a/src/public/translations/en/translation.json b/src/public/translations/en/translation.json
index 25ab6263d..71fe62a08 100644
--- a/src/public/translations/en/translation.json
+++ b/src/public/translations/en/translation.json
@@ -1755,5 +1755,11 @@
"content_language": {
"title": "Content languages",
"description": "Select one or more languages that should appear in the language selection in the Basic Properties section of a read-only or editable text note. This will allow features such as spell-checking or right-to-left support."
+ },
+ "ai": {
+ "sources": "Sources",
+ "enter_message": "Enter your message...",
+ "use_advanced_context": "Use Advanced Context",
+ "advanced_context_helps": "Helps with large knowledge bases and limited context windows"
}
}
diff --git a/src/routes/api/llm.ts b/src/routes/api/llm.ts
index c987a84b0..ab1a43127 100644
--- a/src/routes/api/llm.ts
+++ b/src/routes/api/llm.ts
@@ -9,6 +9,8 @@ import providerManager from "../../services/llm/embeddings/providers.js";
import type { Message, ChatCompletionOptions } from "../../services/llm/ai_interface.js";
// Import this way to prevent immediate instantiation
import * as aiServiceManagerModule from "../../services/llm/ai_service_manager.js";
+import triliumContextService from "../../services/llm/trilium_context_service.js";
+import sql from "../../services/sql.js";
// Define basic interfaces
interface ChatMessage {
@@ -290,132 +292,126 @@ async function deleteSession(req: Request, res: Response) {
}
/**
- * Find relevant notes using vector embeddings
+ * Find relevant notes based on search query
*/
-async function findRelevantNotes(query: string, contextNoteId: string | null = null, limit = 5): Promise {
+async function findRelevantNotes(content: string, contextNoteId: string | null = null, limit = 5): Promise {
try {
- // Only proceed if database is initialized
+ // If database is not initialized, we can't do this
if (!isDatabaseInitialized()) {
- log.info('Database not initialized, skipping vector search');
- return [{
- noteId: "root",
- title: "Database not initialized yet",
- content: "Please wait for database initialization to complete."
- }];
+ return [];
}
- // Get the default embedding provider
- let providerId;
- try {
- // @ts-ignore - embeddingsDefaultProvider exists but might not be in the TypeScript definitions
- providerId = await options.getOption('embeddingsDefaultProvider') || 'openai';
- } catch (error) {
- log.info('Could not get default embedding provider, using mock data');
- return [{
- noteId: "root",
- title: "Embeddings not configured",
- content: "Embedding provider not available"
- }];
+ // Check if embeddings are available
+ const enabledProviders = await providerManager.getEnabledEmbeddingProviders();
+ if (enabledProviders.length === 0) {
+ log.info("No embedding providers available, can't find relevant notes");
+ return [];
}
- const provider = providerManager.getEmbeddingProvider(providerId);
-
- if (!provider) {
- log.info(`Embedding provider ${providerId} not found, using mock data`);
- return [{
- noteId: "root",
- title: "Embeddings not available",
- content: "No embedding provider available"
- }];
+ // If content is too short, don't bother
+ if (content.length < 3) {
+ return [];
}
- // Generate embedding for the query
- const embedding = await provider.generateEmbeddings(query);
+ // Get the embedding for the query
+ const provider = enabledProviders[0];
+ const embedding = await provider.generateEmbeddings(content);
- // Find similar notes
- const modelId = 'default'; // Use default model for the provider
- const similarNotes = await vectorStore.findSimilarNotes(
- embedding, providerId, modelId, limit, 0.6 // Lower threshold to find more results
- );
-
- // If a context note was provided, check if we should include its children
+ let results;
if (contextNoteId) {
- const contextNote = becca.getNote(contextNoteId);
- if (contextNote) {
- const childNotes = contextNote.getChildNotes();
- if (childNotes.length > 0) {
- // Add relevant children that weren't already included
- const childIds = new Set(childNotes.map(note => note.noteId));
- const existingIds = new Set(similarNotes.map(note => note.noteId));
+ // For branch context, get notes specifically from that branch
- // Find children that aren't already in the similar notes
- const missingChildIds = Array.from(childIds).filter(id => !existingIds.has(id));
+ // TODO: This is a simplified implementation - we need to
+ // properly get all notes in the subtree starting from contextNoteId
- // Add up to 3 children that weren't already included
- for (const noteId of missingChildIds.slice(0, 3)) {
- similarNotes.push({
+ // For now, just get direct children of the context note
+ const contextNote = becca.notes[contextNoteId];
+ if (!contextNote) {
+ return [];
+ }
+
+ const childBranches = await sql.getRows(`
+ SELECT branches.* FROM branches
+ WHERE branches.parentNoteId = ?
+ AND branches.isDeleted = 0
+ `, [contextNoteId]);
+
+ const childNoteIds = childBranches.map((branch: any) => branch.noteId);
+
+ // Include the context note itself
+ childNoteIds.push(contextNoteId);
+
+ // Find similar notes in this context
+ results = [];
+
+ for (const noteId of childNoteIds) {
+ const noteEmbedding = await vectorStore.getEmbeddingForNote(
+ noteId,
+ provider.name,
+ provider.getConfig().model
+ );
+
+ if (noteEmbedding) {
+ const similarity = vectorStore.cosineSimilarity(
+ embedding,
+ noteEmbedding.embedding
+ );
+
+ if (similarity > 0.65) {
+ results.push({
noteId,
- similarity: 0.75 // Fixed similarity score for context children
+ similarity
});
}
}
}
+
+ // Sort by similarity
+ results.sort((a, b) => b.similarity - a.similarity);
+ results = results.slice(0, limit);
+ } else {
+ // General search across all notes
+ results = await vectorStore.findSimilarNotes(
+ embedding,
+ provider.name,
+ provider.getConfig().model,
+ limit
+ );
}
- // Get note content for context
- return await Promise.all(similarNotes.map(async ({ noteId, similarity }) => {
- const note = becca.getNote(noteId);
- if (!note) {
- return {
- noteId,
- title: "Unknown Note",
- similarity
- };
+ // Format the results
+ const sources: NoteSource[] = [];
+
+ for (const result of results) {
+ const note = becca.notes[result.noteId];
+ if (!note) continue;
+
+ let noteContent: string | undefined = undefined;
+ if (note.type === 'text') {
+ const content = note.getContent();
+ // Handle both string and Buffer types
+ noteContent = typeof content === 'string' ? content :
+ content instanceof Buffer ? content.toString('utf8') : undefined;
}
- // Get note content
- let content = '';
- try {
- // @ts-ignore - Content can be string or Buffer
- const noteContent = await note.getContent();
- content = typeof noteContent === 'string' ? noteContent : noteContent.toString('utf8');
-
- // Truncate content if it's too long (for performance)
- if (content.length > 2000) {
- content = content.substring(0, 2000) + "...";
- }
- } catch (e) {
- log.error(`Error getting content for note ${noteId}: ${e}`);
- }
-
- // Get a branch ID for navigation
- let branchId;
- try {
- const branches = note.getBranches();
- if (branches.length > 0) {
- branchId = branches[0].branchId;
- }
- } catch (e) {
- log.error(`Error getting branch for note ${noteId}: ${e}`);
- }
-
- return {
- noteId,
+ sources.push({
+ noteId: result.noteId,
title: note.title,
- content,
- similarity,
- branchId
- };
- }));
- } catch (error) {
- log.error(`Error finding relevant notes: ${error}`);
- // Return empty array on error
+ content: noteContent,
+ similarity: result.similarity,
+ branchId: note.getBranches()[0]?.branchId
+ });
+ }
+
+ return sources;
+ } catch (error: any) {
+ log.error(`Error finding relevant notes: ${error.message}`);
return [];
}
}
/**
- * Build a context string from relevant notes
+ * Build context from notes
*/
function buildContextFromNotes(sources: NoteSource[], query: string): string {
console.log("Building context from notes with query:", query);
@@ -449,265 +445,237 @@ Now, based on the above notes, please answer: ${query}`;
}
/**
- * Send a message to an LLM chat session and get a response
+ * Send a message to the AI
*/
async function sendMessage(req: Request, res: Response) {
try {
- const { sessionId, content, temperature, maxTokens, provider, model } = req.body;
-
- console.log("Received message request:", {
- sessionId,
- contentLength: content ? content.length : 0,
- contentPreview: content ? content.substring(0, 50) + (content.length > 50 ? '...' : '') : 'undefined',
- temperature,
- maxTokens,
- provider,
- model
- });
-
- if (!sessionId) {
- throw new Error('Session ID is required');
- }
+ // Extract the content from the request body
+ const { content, sessionId, useAdvancedContext = false } = req.body || {};
+ // Validate the content
if (!content || typeof content !== 'string' || content.trim().length === 0) {
throw new Error('Content cannot be empty');
}
- // Check if streaming is requested
- const wantsStream = (req.headers as any)['accept']?.includes('text/event-stream');
+ // Get or create the session
+ let session: ChatSession;
- // If client wants streaming, set up SSE response
- if (wantsStream) {
- res.setHeader('Content-Type', 'text/event-stream');
- res.setHeader('Cache-Control', 'no-cache');
- res.setHeader('Connection', 'keep-alive');
-
- // Get chat session
- let session = sessions.get(sessionId);
- if (!session) {
- const newSession = await createSession(req, res);
- if (!newSession) {
- throw new Error('Failed to create session');
- }
- // Add required properties to match ChatSession interface
- session = {
- ...newSession,
- messages: [],
- lastActive: new Date(),
- metadata: {}
- };
- sessions.set(sessionId, session);
+ if (sessionId && sessions.has(sessionId)) {
+ session = sessions.get(sessionId)!;
+ session.lastActive = new Date();
+ } else {
+ const result = await createSession(req, res);
+ if (!result?.id) {
+ throw new Error('Failed to create a new session');
}
+ session = sessions.get(result.id)!;
+ }
- // Add user message to session
- const userMessage: ChatMessage = {
- role: 'user',
- content: content,
- timestamp: new Date()
+ // Check if AI services are available
+ if (!safelyUseAIManager()) {
+ throw new Error('AI services are not available');
+ }
+
+ // Get the AI service manager
+ const aiServiceManager = aiServiceManagerModule.default.getInstance();
+ // Get the default service - just use the first available one
+ const availableProviders = aiServiceManager.getAvailableProviders();
+ let service = null;
+
+ if (availableProviders.length > 0) {
+ // Use the first available provider
+ const providerName = availableProviders[0];
+ // We know the manager has a 'services' property from our code inspection,
+ // but TypeScript doesn't know that from the interface.
+ // This is a workaround to access it
+ service = (aiServiceManager as any).services[providerName];
+ }
+
+ if (!service) {
+ throw new Error('No AI service is available');
+ }
+
+ // Create user message
+ const userMessage: Message = {
+ role: 'user',
+ content
+ };
+
+ // Add message to session
+ session.messages.push({
+ role: 'user',
+ content,
+ timestamp: new Date()
+ });
+
+ // Log a preview of the message
+ log.info(`Processing LLM message: "${content.substring(0, 50)}${content.length > 50 ? '...' : ''}"`);
+
+ // Information to return to the client
+ let aiResponse = '';
+ let sourceNotes: NoteSource[] = [];
+
+ // If Advanced Context is enabled, we use the improved method
+ if (useAdvancedContext) {
+ // Use the Trilium-specific approach
+ const contextNoteId = session.noteContext || null;
+ const results = await triliumContextService.processQuery(content, service, contextNoteId);
+
+ // Get the generated context
+ const context = results.context;
+ sourceNotes = results.notes;
+
+ // Add system message with the context
+ const contextMessage: Message = {
+ role: 'system',
+ content: context
};
- console.log("Created user message:", {
- role: userMessage.role,
- contentLength: userMessage.content?.length || 0,
- contentPreview: userMessage.content?.substring(0, 50) + (userMessage.content?.length > 50 ? '...' : '') || 'undefined'
- });
- session.messages.push(userMessage);
- // Get context for query
- const sources = await findRelevantNotes(content, session.noteContext || null);
-
- // Format messages for AI with proper type casting
+ // Format all messages for the AI
const aiMessages: Message[] = [
- { role: 'system', content: 'You are a helpful assistant for Trilium Notes. When providing answers, use only the context provided in the notes. If the information is not in the notes, say so.' },
- { role: 'user', content: buildContextFromNotes(sources, content) }
+ contextMessage,
+ ...session.messages.slice(-10).map(msg => ({
+ role: msg.role,
+ content: msg.content
+ }))
];
- // Ensure we're not sending empty content
- console.log("Final message content length:", aiMessages[1].content.length);
- console.log("Final message content preview:", aiMessages[1].content.substring(0, 100));
-
- try {
- // Send initial SSE message with session info
- const sourcesForResponse = sources.map(({ noteId, title, similarity, branchId }) => ({
- noteId,
- title,
- similarity: similarity ? Math.round(similarity * 100) / 100 : undefined,
- branchId
- }));
-
- res.write(`data: ${JSON.stringify({
- type: 'init',
- session: {
- id: sessionId,
- messages: session.messages.slice(0, -1), // Don't include the new message yet
- sources: sourcesForResponse
- }
- })}\n\n`);
-
- // Get AI response with streaming enabled
- const aiResponse = await aiServiceManagerModule.default.generateChatCompletion(aiMessages, {
- temperature,
- maxTokens,
- model: provider ? `${provider}:${model}` : model,
- stream: true
- });
-
- if (aiResponse.stream) {
- // Create an empty assistant message
- const assistantMessage: ChatMessage = {
- role: 'assistant',
- content: '',
- timestamp: new Date()
- };
- session.messages.push(assistantMessage);
-
- // Stream the response chunks
- await aiResponse.stream(async (chunk) => {
- if (chunk.text) {
- // Update the message content
- assistantMessage.content += chunk.text;
-
- // Send chunk to client
- res.write(`data: ${JSON.stringify({
- type: 'chunk',
- text: chunk.text,
- done: chunk.done
- })}\n\n`);
- }
-
- if (chunk.done) {
- // Send final message with complete response
- res.write(`data: ${JSON.stringify({
- type: 'done',
- session: {
- id: sessionId,
- messages: session.messages,
- sources: sourcesForResponse
- }
- })}\n\n`);
-
- res.end();
- }
- });
-
- return; // Early return for streaming
- } else {
- // Fallback for non-streaming response
- const assistantMessage: ChatMessage = {
- role: 'assistant',
- content: aiResponse.text,
- timestamp: new Date()
- };
- session.messages.push(assistantMessage);
-
- // Send complete response
- res.write(`data: ${JSON.stringify({
- type: 'done',
- session: {
- id: sessionId,
- messages: session.messages,
- sources: sourcesForResponse
- }
- })}\n\n`);
-
- res.end();
- return;
- }
- } catch (error: any) {
- // Send error in streaming format
- res.write(`data: ${JSON.stringify({
- type: 'error',
- error: `AI service error: ${error.message}`
- })}\n\n`);
-
- res.end();
- return;
- }
- }
-
- // Non-streaming API continues with normal JSON response...
-
- // Get chat session
- let session = sessions.get(sessionId);
- if (!session) {
- const newSession = await createSession(req, res);
- if (!newSession) {
- throw new Error('Failed to create session');
- }
- // Add required properties to match ChatSession interface
- session = {
- ...newSession,
- messages: [],
- lastActive: new Date(),
- metadata: {}
+ // Configure chat options from session metadata
+ const chatOptions: ChatCompletionOptions = {
+ temperature: session.metadata.temperature || 0.7,
+ maxTokens: session.metadata.maxTokens,
+ model: session.metadata.model
+ // 'provider' property has been removed as it's not in the ChatCompletionOptions type
};
- sessions.set(sessionId, session);
+
+ // Get streaming response if requested
+ const acceptHeader = req.get('Accept');
+ if (acceptHeader && acceptHeader.includes('text/event-stream')) {
+ res.setHeader('Content-Type', 'text/event-stream');
+ res.setHeader('Cache-Control', 'no-cache');
+ res.setHeader('Connection', 'keep-alive');
+
+ let messageContent = '';
+
+ // Stream the response
+ await service.sendChatCompletion(
+ aiMessages,
+ chatOptions,
+ (chunk: string) => {
+ messageContent += chunk;
+ res.write(`data: ${JSON.stringify({ content: chunk })}\n\n`);
+ }
+ );
+
+ // Close the stream
+ res.write('data: [DONE]\n\n');
+ res.end();
+
+ // Store the full response
+ aiResponse = messageContent;
+ } else {
+ // Non-streaming approach
+ aiResponse = await service.sendChatCompletion(aiMessages, chatOptions);
+ }
+ } else {
+ // Original approach - find relevant notes through direct embedding comparison
+ const relevantNotes = await findRelevantNotes(
+ content,
+ session.noteContext || null,
+ 5
+ );
+
+ sourceNotes = relevantNotes;
+
+ // Build context from relevant notes
+ const context = buildContextFromNotes(relevantNotes, content);
+
+ // Add system message with the context
+ const contextMessage: Message = {
+ role: 'system',
+ content: context
+ };
+
+ // Format all messages for the AI
+ const aiMessages: Message[] = [
+ contextMessage,
+ ...session.messages.slice(-10).map(msg => ({
+ role: msg.role,
+ content: msg.content
+ }))
+ ];
+
+ // Configure chat options from session metadata
+ const chatOptions: ChatCompletionOptions = {
+ temperature: session.metadata.temperature || 0.7,
+ maxTokens: session.metadata.maxTokens,
+ model: session.metadata.model
+ // 'provider' property has been removed as it's not in the ChatCompletionOptions type
+ };
+
+ // Get streaming response if requested
+ const acceptHeader = req.get('Accept');
+ if (acceptHeader && acceptHeader.includes('text/event-stream')) {
+ res.setHeader('Content-Type', 'text/event-stream');
+ res.setHeader('Cache-Control', 'no-cache');
+ res.setHeader('Connection', 'keep-alive');
+
+ let messageContent = '';
+
+ // Stream the response
+ await service.sendChatCompletion(
+ aiMessages,
+ chatOptions,
+ (chunk: string) => {
+ messageContent += chunk;
+ res.write(`data: ${JSON.stringify({ content: chunk })}\n\n`);
+ }
+ );
+
+ // Close the stream
+ res.write('data: [DONE]\n\n');
+ res.end();
+
+ // Store the full response
+ aiResponse = messageContent;
+ } else {
+ // Non-streaming approach
+ aiResponse = await service.sendChatCompletion(aiMessages, chatOptions);
+ }
}
- // Add user message to session
- const userMessage: ChatMessage = {
- role: 'user',
- content: content,
- timestamp: new Date()
- };
- console.log("Created user message:", {
- role: userMessage.role,
- contentLength: userMessage.content?.length || 0,
- contentPreview: userMessage.content?.substring(0, 50) + (userMessage.content?.length > 50 ? '...' : '') || 'undefined'
- });
- session.messages.push(userMessage);
-
- // Get context for query
- const sources = await findRelevantNotes(content, session.noteContext || null);
-
- // Format messages for AI with proper type casting
- const aiMessages: Message[] = [
- { role: 'system', content: 'You are a helpful assistant for Trilium Notes. When providing answers, use only the context provided in the notes. If the information is not in the notes, say so.' },
- { role: 'user', content: buildContextFromNotes(sources, content) }
- ];
-
- // Ensure we're not sending empty content
- console.log("Final message content length:", aiMessages[1].content.length);
- console.log("Final message content preview:", aiMessages[1].content.substring(0, 100));
-
- try {
- // Get AI response using the safe accessor methods
- const aiResponse = await aiServiceManagerModule.default.generateChatCompletion(aiMessages, {
- temperature,
- maxTokens,
- model: provider ? `${provider}:${model}` : model,
- stream: false
+ // Only store the assistant's message if we're not streaming (otherwise we already did)
+ const acceptHeader = req.get('Accept');
+ if (!acceptHeader || !acceptHeader.includes('text/event-stream')) {
+ // Store the assistant's response in the session
+ session.messages.push({
+ role: 'assistant',
+ content: aiResponse,
+ timestamp: new Date()
});
- // Add assistant message to session
- const assistantMessage: ChatMessage = {
- role: 'assistant',
- content: aiResponse.text,
- timestamp: new Date()
- };
- session.messages.push(assistantMessage);
-
- // Format sources for the response (without content to reduce payload size)
- const sourcesForResponse = sources.map(({ noteId, title, similarity, branchId }) => ({
- noteId,
- title,
- similarity: similarity ? Math.round(similarity * 100) / 100 : undefined,
- branchId
- }));
-
+ // Return the response
return {
- id: sessionId,
- messages: session.messages,
- sources: sourcesForResponse,
- provider: aiResponse.provider,
- model: aiResponse.model
+ content: aiResponse,
+ sources: sourceNotes.map(note => ({
+ noteId: note.noteId,
+ title: note.title,
+ similarity: note.similarity,
+ branchId: note.branchId
+ }))
};
- } catch (error: any) {
- log.error(`AI service error: ${error.message}`);
- throw new Error(`AI service error: ${error.message}`);
+ } else {
+ // For streaming responses, we've already sent the data
+ // But we still need to add the message to the session
+ session.messages.push({
+ role: 'assistant',
+ content: aiResponse,
+ timestamp: new Date()
+ });
}
} catch (error: any) {
- log.error(`Error sending message: ${error.message}`);
- throw error;
+ log.error(`Error sending message to LLM: ${error.message}`);
+ throw new Error(`Failed to send message: ${error.message}`);
}
}
diff --git a/src/services/llm/trilium_context_service.ts b/src/services/llm/trilium_context_service.ts
new file mode 100644
index 000000000..8ff4fb2d6
--- /dev/null
+++ b/src/services/llm/trilium_context_service.ts
@@ -0,0 +1,410 @@
+import becca from "../../becca/becca.js";
+import vectorStore from "./embeddings/vector_store.js";
+import providerManager from "./embeddings/providers.js";
+import options from "../options.js";
+import log from "../log.js";
+import type { Message } from "./ai_interface.js";
+import { cosineSimilarity } from "./embeddings/vector_store.js";
+
+/**
+ * TriliumContextService provides intelligent context management for working with large knowledge bases
+ * through limited context window LLMs like Ollama.
+ *
+ * It creates a "meta-prompting" approach where the first LLM call is used
+ * to determine what information might be needed to answer the query,
+ * then only the relevant context is loaded, before making the final
+ * response.
+ */
+class TriliumContextService {
+ private initialized = false;
+ private initPromise: Promise | null = null;
+ private provider: any = null;
+
+ // Cache for recently used context to avoid repeated embedding lookups
+ private recentQueriesCache = new Map();
+
+ // Configuration
+ private cacheExpiryMs = 5 * 60 * 1000; // 5 minutes
+ private metaPrompt = `You are an AI assistant that decides what information needs to be retrieved from a knowledge base to answer the user's question.
+Given the user's question, generate 3-5 specific search queries that would help find relevant information.
+Each query should be focused on a different aspect of the question.
+Format your answer as a JSON array of strings, with each string being a search query.
+Example: ["exact topic mentioned", "related concept 1", "related concept 2"]`;
+
+ constructor() {
+ this.setupCacheCleanup();
+ }
+
+ /**
+ * Initialize the service
+ */
+ async initialize() {
+ if (this.initialized) return;
+
+ // Use a promise to prevent multiple simultaneous initializations
+ if (this.initPromise) return this.initPromise;
+
+ this.initPromise = (async () => {
+ try {
+ const providerId = await options.getOption('embeddingsDefaultProvider') || 'ollama';
+ this.provider = providerManager.getEmbeddingProvider(providerId);
+
+ if (!this.provider) {
+ throw new Error(`Embedding provider ${providerId} not found`);
+ }
+
+ this.initialized = true;
+ log.info(`Trilium context service initialized with provider: ${providerId}`);
+ } catch (error: unknown) {
+ const errorMessage = error instanceof Error ? error.message : String(error);
+ log.error(`Failed to initialize Trilium context service: ${errorMessage}`);
+ throw error;
+ } finally {
+ this.initPromise = null;
+ }
+ })();
+
+ return this.initPromise;
+ }
+
+ /**
+ * Set up periodic cache cleanup
+ */
+ private setupCacheCleanup() {
+ setInterval(() => {
+ const now = Date.now();
+ for (const [key, data] of this.recentQueriesCache.entries()) {
+ if (now - data.timestamp > this.cacheExpiryMs) {
+ this.recentQueriesCache.delete(key);
+ }
+ }
+ }, 60000); // Run cleanup every minute
+ }
+
+ /**
+ * Generate search queries to find relevant information for the user question
+ * @param userQuestion - The user's question
+ * @param llmService - The LLM service to use for generating queries
+ * @returns Array of search queries
+ */
+ async generateSearchQueries(userQuestion: string, llmService: any): Promise {
+ try {
+ const messages: Message[] = [
+ { role: "system", content: this.metaPrompt },
+ { role: "user", content: userQuestion }
+ ];
+
+ const options = {
+ temperature: 0.3,
+ maxTokens: 300
+ };
+
+ // Get the response from the LLM
+ const response = await llmService.sendTextCompletion(messages, options);
+
+ try {
+ // Parse the JSON response
+ const jsonStr = response.trim().replace(/```json|```/g, '').trim();
+ const queries = JSON.parse(jsonStr);
+
+ if (Array.isArray(queries) && queries.length > 0) {
+ return queries;
+ } else {
+ throw new Error("Invalid response format");
+ }
+ } catch (parseError) {
+ // Fallback: if JSON parsing fails, try to extract queries line by line
+ const lines = response.split('\n')
+ .map((line: string) => line.trim())
+ .filter((line: string) => line.length > 0 && !line.startsWith('```'));
+
+ if (lines.length > 0) {
+ return lines.map((line: string) => line.replace(/^["'\d\.\-\s]+/, '').trim());
+ }
+
+ // If all else fails, just use the original question
+ return [userQuestion];
+ }
+ } catch (error: unknown) {
+ const errorMessage = error instanceof Error ? error.message : String(error);
+ log.error(`Error generating search queries: ${errorMessage}`);
+ // Fallback to just using the original question
+ return [userQuestion];
+ }
+ }
+
+ /**
+ * Find relevant notes using multiple search queries
+ * @param queries - Array of search queries
+ * @param contextNoteId - Optional note ID to restrict search to a branch
+ * @param limit - Max notes to return
+ * @returns Array of relevant notes
+ */
+ async findRelevantNotesMultiQuery(
+ queries: string[],
+ contextNoteId: string | null = null,
+ limit = 10
+ ): Promise {
+ if (!this.initialized) {
+ await this.initialize();
+ }
+
+ try {
+ // Cache key combining all queries
+ const cacheKey = JSON.stringify({ queries, contextNoteId, limit });
+
+ // Check if we have a recent cache hit
+ const cached = this.recentQueriesCache.get(cacheKey);
+ if (cached) {
+ return cached.relevantNotes;
+ }
+
+ // Array to store all results with their similarity scores
+ const allResults: {
+ noteId: string,
+ title: string,
+ content: string | null,
+ similarity: number,
+ branchId?: string
+ }[] = [];
+
+ // Set to keep track of note IDs we've seen to avoid duplicates
+ const seenNoteIds = new Set();
+
+ // Process each query
+ for (const query of queries) {
+ // Get embeddings for this query
+ const queryEmbedding = await this.provider.getEmbedding(query);
+
+ // Find notes similar to this query
+ let results;
+ if (contextNoteId) {
+ // Find within a specific context/branch
+ results = await this.findNotesInBranch(
+ queryEmbedding,
+ contextNoteId,
+ Math.min(limit, 5) // Limit per query
+ );
+ } else {
+ // Search all notes
+ results = await vectorStore.findSimilarNotes(
+ queryEmbedding,
+ this.provider.id,
+ this.provider.modelId,
+ Math.min(limit, 5), // Limit per query
+ 0.5 // Lower threshold to get more diverse results
+ );
+ }
+
+ // Process results
+ for (const result of results) {
+ if (!seenNoteIds.has(result.noteId)) {
+ seenNoteIds.add(result.noteId);
+
+ // Get the note from Becca
+ const note = becca.notes[result.noteId];
+ if (!note) continue;
+
+ // Add to our results
+ allResults.push({
+ noteId: result.noteId,
+ title: note.title,
+ content: note.type === 'text' ? note.getContent() as string : null,
+ similarity: result.similarity,
+ branchId: note.getBranches()[0]?.branchId
+ });
+ }
+ }
+ }
+
+ // Sort by similarity and take the top 'limit' results
+ const sortedResults = allResults
+ .sort((a, b) => b.similarity - a.similarity)
+ .slice(0, limit);
+
+ // Cache the results
+ this.recentQueriesCache.set(cacheKey, {
+ timestamp: Date.now(),
+ relevantNotes: sortedResults
+ });
+
+ return sortedResults;
+ } catch (error: unknown) {
+ const errorMessage = error instanceof Error ? error.message : String(error);
+ log.error(`Error finding relevant notes: ${errorMessage}`);
+ return [];
+ }
+ }
+
+ /**
+ * Find notes in a specific branch/context
+ * @param embedding - Query embedding
+ * @param contextNoteId - Note ID to restrict search to
+ * @param limit - Max notes to return
+ * @returns Array of relevant notes
+ */
+ private async findNotesInBranch(
+ embedding: Float32Array,
+ contextNoteId: string,
+ limit = 5
+ ): Promise<{noteId: string, similarity: number}[]> {
+ try {
+ // Get the subtree note IDs
+ const subtreeNoteIds = await this.getSubtreeNoteIds(contextNoteId);
+
+ if (subtreeNoteIds.length === 0) {
+ return [];
+ }
+
+ // Get all embeddings for these notes using vectorStore instead of direct SQL
+ const similarities: {noteId: string, similarity: number}[] = [];
+
+ for (const noteId of subtreeNoteIds) {
+ const noteEmbedding = await vectorStore.getEmbeddingForNote(
+ noteId,
+ this.provider.id,
+ this.provider.modelId
+ );
+
+ if (noteEmbedding) {
+ const similarity = cosineSimilarity(embedding, noteEmbedding.embedding);
+ if (similarity > 0.5) { // Apply similarity threshold
+ similarities.push({
+ noteId,
+ similarity
+ });
+ }
+ }
+ }
+
+ // Sort by similarity and return top results
+ return similarities
+ .sort((a, b) => b.similarity - a.similarity)
+ .slice(0, limit);
+ } catch (error: unknown) {
+ const errorMessage = error instanceof Error ? error.message : String(error);
+ log.error(`Error finding notes in branch: ${errorMessage}`);
+ return [];
+ }
+ }
+
+ /**
+ * Get all note IDs in a subtree (including the root note)
+ * @param rootNoteId - Root note ID
+ * @returns Array of note IDs
+ */
+ private async getSubtreeNoteIds(rootNoteId: string): Promise {
+ const note = becca.notes[rootNoteId];
+ if (!note) {
+ return [];
+ }
+
+ // Use becca to walk the note tree instead of direct SQL
+ const noteIds = new Set([rootNoteId]);
+
+ // Helper function to collect all children
+ const collectChildNotes = (noteId: string) => {
+ // Use becca.getNote(noteId).getChildNotes() to get child notes
+ const parentNote = becca.notes[noteId];
+ if (!parentNote) return;
+
+ // Get all branches where this note is the parent
+ for (const branch of Object.values(becca.branches)) {
+ if (branch.parentNoteId === noteId && !branch.isDeleted) {
+ const childNoteId = branch.noteId;
+ if (!noteIds.has(childNoteId)) {
+ noteIds.add(childNoteId);
+ // Recursively collect children of this child
+ collectChildNotes(childNoteId);
+ }
+ }
+ }
+ };
+
+ // Start collecting from the root
+ collectChildNotes(rootNoteId);
+
+ return Array.from(noteIds);
+ }
+
+ /**
+ * Build a context string from relevant notes
+ * @param sources - Array of notes
+ * @param query - Original user query
+ * @returns Formatted context string
+ */
+ buildContextFromNotes(sources: any[], query: string): string {
+ if (!sources || sources.length === 0) {
+ return "";
+ }
+
+ let context = `The following are relevant notes from your knowledge base that may help answer the query: "${query}"\n\n`;
+
+ sources.forEach((source, index) => {
+ context += `--- NOTE ${index + 1}: ${source.title} ---\n`;
+
+ if (source.content) {
+ // Truncate content if it's too long
+ const maxContentLength = 1000;
+ let content = source.content;
+
+ if (content.length > maxContentLength) {
+ content = content.substring(0, maxContentLength) + " [content truncated due to length]";
+ }
+
+ context += `${content}\n`;
+ } else {
+ context += "[This note doesn't contain textual content]\n";
+ }
+
+ context += "\n";
+ });
+
+ context += "--- END OF NOTES ---\n\n";
+ context += "Please use the information above to help answer the query. If the information doesn't contain what you need, just say so and use your general knowledge instead.";
+
+ return context;
+ }
+
+ /**
+ * Process a user query with the Trilium-specific approach:
+ * 1. Generate search queries from the original question
+ * 2. Find relevant notes using those queries
+ * 3. Build a context string from the relevant notes
+ *
+ * @param userQuestion - The user's original question
+ * @param llmService - The LLM service to use
+ * @param contextNoteId - Optional note ID to restrict search to
+ * @returns Object with context and notes
+ */
+ async processQuery(userQuestion: string, llmService: any, contextNoteId: string | null = null) {
+ if (!this.initialized) {
+ await this.initialize();
+ }
+
+ // Step 1: Generate search queries
+ const searchQueries = await this.generateSearchQueries(userQuestion, llmService);
+ log.info(`Generated search queries: ${JSON.stringify(searchQueries)}`);
+
+ // Step 2: Find relevant notes using those queries
+ const relevantNotes = await this.findRelevantNotesMultiQuery(
+ searchQueries,
+ contextNoteId,
+ 8 // Get more notes since we're using multiple queries
+ );
+
+ // Step 3: Build context from the notes
+ const context = this.buildContextFromNotes(relevantNotes, userQuestion);
+
+ return {
+ context,
+ notes: relevantNotes,
+ queries: searchQueries
+ };
+ }
+}
+
+export default new TriliumContextService();