diff --git a/src/services/llm/context/services/vector_search_service.ts b/src/services/llm/context/services/vector_search_service.ts index 0d350fa6c..11e50a833 100644 --- a/src/services/llm/context/services/vector_search_service.ts +++ b/src/services/llm/context/services/vector_search_service.ts @@ -376,6 +376,76 @@ export class VectorSearchService { return ''; } } + + /** + * Find notes that are semantically relevant to multiple queries + * Combines results from multiple queries, deduplicates them, and returns the most relevant ones + * + * @param queries - Array of search queries + * @param contextNoteId - Optional note ID to restrict search to a branch + * @param options - Search options including result limit and summarization preference + * @returns Array of relevant notes with similarity scores, deduplicated and sorted + */ + async findRelevantNotesMultiQuery( + queries: string[], + contextNoteId: string | null = null, + options: VectorSearchOptions = {} + ): Promise { + if (!queries || queries.length === 0) { + log.info('No queries provided to findRelevantNotesMultiQuery'); + return []; + } + + log.info(`VectorSearchService: Finding relevant notes for ${queries.length} queries`); + log.info(`Multi-query parameters: contextNoteId=${contextNoteId || 'global'}, queries=${JSON.stringify(queries.map(q => q.substring(0, 20) + '...'))}`); + + try { + // Create a Map to deduplicate results across queries + const allResults = new Map(); + + // For each query, adjust maxResults to avoid getting too many total results + const adjustedMaxResults = options.maxResults ? + Math.ceil(options.maxResults / queries.length) : + Math.ceil(SEARCH_CONSTANTS.VECTOR_SEARCH.DEFAULT_MAX_RESULTS / queries.length); + + // Search for each query and combine results + for (const query of queries) { + try { + const queryOptions = { + ...options, + maxResults: adjustedMaxResults, + useEnhancedQueries: false // We're already using enhanced queries + }; + + const results = await this.findRelevantNotes(query, contextNoteId, queryOptions); + + // Merge results, keeping the highest similarity score for duplicates + for (const note of results) { + if (!allResults.has(note.noteId) || + (allResults.has(note.noteId) && note.similarity > (allResults.get(note.noteId)?.similarity || 0))) { + allResults.set(note.noteId, note); + } + } + + log.info(`Found ${results.length} results for query: "${query.substring(0, 30)}..."`); + } catch (error) { + log.error(`Error searching for query "${query}": ${error}`); + } + } + + // Convert map to array and sort by similarity + const combinedResults = Array.from(allResults.values()) + .sort((a, b) => b.similarity - a.similarity) + .slice(0, options.maxResults || SEARCH_CONSTANTS.VECTOR_SEARCH.DEFAULT_MAX_RESULTS); + + log.info(`VectorSearchService: Found ${combinedResults.length} total deduplicated results across ${queries.length} queries`); + + return combinedResults; + } catch (error) { + log.error(`Error in findRelevantNotesMultiQuery: ${error}`); + return []; + } + } } // Export a singleton instance diff --git a/src/services/llm/pipeline/chat_pipeline.ts b/src/services/llm/pipeline/chat_pipeline.ts index 888511f92..947a562e6 100644 --- a/src/services/llm/pipeline/chat_pipeline.ts +++ b/src/services/llm/pipeline/chat_pipeline.ts @@ -202,21 +202,23 @@ export class ChatPipeline { const vectorSearchStartTime = Date.now(); log.info(`========== STAGE 3: VECTOR SEARCH ==========`); log.info('Using VectorSearchStage pipeline component to find relevant notes'); + log.info(`Searching with ${searchQueries.length} queries from decomposition`); + // Use the vectorSearchStage with multiple queries const vectorSearchResult = await this.stages.vectorSearch.execute({ - query: userQuery, + query: userQuery, // Original query as fallback + queries: searchQueries, // All decomposed queries noteId: input.noteId || 'global', options: { maxResults: SEARCH_CONSTANTS.CONTEXT.MAX_SIMILAR_NOTES, - useEnhancedQueries: true, + useEnhancedQueries: false, // We're already using enhanced queries from decomposition threshold: SEARCH_CONSTANTS.VECTOR_SEARCH.DEFAULT_THRESHOLD, llmService: llmService || undefined } }); this.updateStageMetrics('vectorSearch', vectorSearchStartTime); - - log.info(`Vector search found ${vectorSearchResult.searchResults.length} relevant notes`); + log.info(`Vector search found ${vectorSearchResult.searchResults.length} relevant notes across ${searchQueries.length} queries`); // Extract context from search results log.info(`========== SEMANTIC CONTEXT EXTRACTION ==========`); diff --git a/src/services/llm/pipeline/stages/vector_search_stage.ts b/src/services/llm/pipeline/stages/vector_search_stage.ts index 9766d55e3..306b5079d 100644 --- a/src/services/llm/pipeline/stages/vector_search_stage.ts +++ b/src/services/llm/pipeline/stages/vector_search_stage.ts @@ -13,6 +13,7 @@ import { SEARCH_CONSTANTS } from '../../constants/search_constants.js'; export interface VectorSearchInput { query: string; + queries?: string[]; noteId?: string; options?: { maxResults?: number; @@ -42,6 +43,7 @@ export class VectorSearchStage { async execute(input: VectorSearchInput): Promise { const { query, + queries = [], noteId = 'global', options = {} } = input; @@ -53,6 +55,42 @@ export class VectorSearchStage { llmService = undefined } = options; + // If queries array is provided, use multi-query search + if (queries && queries.length > 0) { + log.info(`VectorSearchStage: Searching with ${queries.length} queries`); + log.info(`Parameters: noteId=${noteId}, maxResults=${maxResults}, threshold=${threshold}`); + + try { + // Use the new multi-query method + const searchResults = await vectorSearchService.findRelevantNotesMultiQuery( + queries, + noteId === 'global' ? null : noteId, + { + maxResults, + threshold, + llmService: llmService || null + } + ); + + log.info(`VectorSearchStage: Found ${searchResults.length} relevant notes from multi-query search`); + + return { + searchResults, + originalQuery: query, + noteId + }; + } catch (error) { + log.error(`Error in vector search stage multi-query: ${error}`); + // Return empty results on error + return { + searchResults: [], + originalQuery: query, + noteId + }; + } + } + + // Fallback to single query search log.info(`VectorSearchStage: Searching for "${query.substring(0, 50)}..."`); log.info(`Parameters: noteId=${noteId}, maxResults=${maxResults}, threshold=${threshold}`); @@ -64,7 +102,7 @@ export class VectorSearchStage { { maxResults, threshold, - llmService + llmService: llmService || null } );