From 5ad730c15335e7331af094c348847987ba75d57c Mon Sep 17 00:00:00 2001 From: perf3ct Date: Mon, 17 Mar 2025 21:36:14 +0000 Subject: [PATCH] openai finally works, respect embedding precedence --- src/services/llm/embeddings/storage.ts | 119 ++++++++++++++++++++----- src/services/llm/index_service.ts | 52 ++++++++++- 2 files changed, 145 insertions(+), 26 deletions(-) diff --git a/src/services/llm/embeddings/storage.ts b/src/services/llm/embeddings/storage.ts index 7192e38d6..db939990c 100644 --- a/src/services/llm/embeddings/storage.ts +++ b/src/services/llm/embeddings/storage.ts @@ -150,18 +150,23 @@ export async function findSimilarNotes( providerId: string; modelId: string; count: number; + dimension: number; } - // Get all available embedding providers and models + // Get all available embedding providers and models with dimensions const availableEmbeddings = await sql.getRows(` - SELECT DISTINCT providerId, modelId, COUNT(*) as count + SELECT DISTINCT providerId, modelId, COUNT(*) as count, dimension FROM note_embeddings GROUP BY providerId, modelId ORDER BY count DESC` ) as EmbeddingMetadata[]; if (availableEmbeddings.length > 0) { - log.info(`Available embeddings: ${JSON.stringify(availableEmbeddings)}`); + log.info(`Available embeddings: ${JSON.stringify(availableEmbeddings.map(e => ({ + providerId: e.providerId, + modelId: e.modelId, + count: e.count + })))}`); // Import the AIServiceManager to get provider precedence const { default: aiManager } = await import('../ai_service_manager.js'); @@ -210,24 +215,81 @@ export async function findSimilarNotes( const providerEmbeddings = availableEmbeddings.filter(e => e.providerId === provider); if (providerEmbeddings.length > 0) { - // Use the model with the most embeddings - const bestModel = providerEmbeddings.sort((a, b) => b.count - a.count)[0]; + // Find models that match the current embedding's dimensions + const dimensionMatchingModels = providerEmbeddings.filter(e => e.dimension === embedding.length); - log.info(`Trying fallback provider: ${provider}, model: ${bestModel.modelId}`); + // If we have models with matching dimensions, use the one with most embeddings + if (dimensionMatchingModels.length > 0) { + const bestModel = dimensionMatchingModels.sort((a, b) => b.count - a.count)[0]; + log.info(`Found fallback provider with matching dimensions (${embedding.length}): ${provider}, model: ${bestModel.modelId}`); - // Recursive call with the new provider/model, but disable further fallbacks - return findSimilarNotes( - embedding, - provider, - bestModel.modelId, - limit, - threshold, - false // Prevent infinite recursion - ); + // Recursive call with the new provider/model, but disable further fallbacks + return findSimilarNotes( + embedding, + provider, + bestModel.modelId, + limit, + threshold, + false // Prevent infinite recursion + ); + } else { + // We need to regenerate embeddings with the new provider + log.info(`No models with matching dimensions found for ${provider}. Available models: ${JSON.stringify( + providerEmbeddings.map(e => ({ model: e.modelId, dimension: e.dimension })) + )}`); + + try { + // Import provider manager to get a provider instance + const { default: providerManager } = await import('./providers.js'); + const providerInstance = providerManager.getEmbeddingProvider(provider); + + if (providerInstance) { + // Use the model with the most embeddings + const bestModel = providerEmbeddings.sort((a, b) => b.count - a.count)[0]; + // Configure the model by setting it in the config + try { + // Access the config safely through the getConfig method + const config = providerInstance.getConfig(); + config.model = bestModel.modelId; + + log.info(`Trying to convert query to ${provider}/${bestModel.modelId} embedding format (dimension: ${bestModel.dimension})`); + + // Get the original query from the embedding cache if possible, or use a placeholder + // This is a hack - ideally we'd pass the query text through the whole chain + const originalQuery = "query"; // This is a placeholder, we'd need the original query text + + // Generate a new embedding with the fallback provider + const newEmbedding = await providerInstance.generateEmbeddings(originalQuery); + + log.info(`Successfully generated new embedding with provider ${provider}/${bestModel.modelId} (dimension: ${newEmbedding.length})`); + + // Now try finding similar notes with the new embedding + return findSimilarNotes( + newEmbedding, + provider, + bestModel.modelId, + limit, + threshold, + false // Prevent infinite recursion + ); + } catch (configErr: any) { + log.error(`Error configuring provider ${provider}: ${configErr.message}`); + } + } + } catch (err: any) { + log.error(`Error converting embedding format: ${err.message}`); + } + } } } - log.info(`No suitable fallback providers found. Available embeddings: ${JSON.stringify(availableEmbeddings)}`); + log.error(`No suitable fallback providers found with compatible dimensions. Current embedding dimension: ${embedding.length}`); + log.info(`Available embeddings: ${JSON.stringify(availableEmbeddings.map(e => ({ + providerId: e.providerId, + modelId: e.modelId, + dimension: e.dimension, + count: e.count + })))}`); } else { log.info(`No embeddings found in the database at all. You need to generate embeddings first.`); } @@ -240,14 +302,27 @@ export async function findSimilarNotes( } // Calculate similarity for each embedding - const similarities = rows.map(row => { + const similarities = []; + for (const row of rows) { const rowData = row as any; const rowEmbedding = bufferToEmbedding(rowData.embedding, rowData.dimension); - return { - noteId: rowData.noteId, - similarity: cosineSimilarity(embedding, rowEmbedding) - }; - }); + + // Check if dimensions match before calculating similarity + if (rowEmbedding.length !== embedding.length) { + log.info(`Skipping embedding ${rowData.embedId} - dimension mismatch: ${rowEmbedding.length} vs ${embedding.length}`); + continue; + } + + try { + const similarity = cosineSimilarity(embedding, rowEmbedding); + similarities.push({ + noteId: rowData.noteId, + similarity + }); + } catch (err: any) { + log.error(`Error calculating similarity for note ${rowData.noteId}: ${err.message}`); + } + } // Filter by threshold and sort by similarity (highest first) const results = similarities diff --git a/src/services/llm/index_service.ts b/src/services/llm/index_service.ts index 64e2b3665..34c055fdd 100644 --- a/src/services/llm/index_service.ts +++ b/src/services/llm/index_service.ts @@ -490,16 +490,61 @@ class IndexService { } try { + // Get all enabled embedding providers const providers = await providerManager.getEnabledEmbeddingProviders(); if (!providers || providers.length === 0) { throw new Error("No embedding providers available"); } - // Use the first enabled provider - const provider = providers[0]; + // Get the embedding provider precedence + const options = (await import('../options.js')).default; + let preferredProviders: string[] = []; + + const embeddingPrecedence = await options.getOption('embeddingProviderPrecedence'); + let provider; + + if (embeddingPrecedence) { + // Parse the precedence string + if (embeddingPrecedence.startsWith('[') && embeddingPrecedence.endsWith(']')) { + preferredProviders = JSON.parse(embeddingPrecedence); + } else if (typeof embeddingPrecedence === 'string') { + if (embeddingPrecedence.includes(',')) { + preferredProviders = embeddingPrecedence.split(',').map(p => p.trim()); + } else { + preferredProviders = [embeddingPrecedence]; + } + } + + // Find first enabled provider by precedence order + for (const providerName of preferredProviders) { + const matchedProvider = providers.find(p => p.name === providerName); + if (matchedProvider) { + provider = matchedProvider; + break; + } + } + + // If no match found, use first available + if (!provider && providers.length > 0) { + provider = providers[0]; + } + } else { + // Default to first available provider + provider = providers[0]; + } + + if (!provider) { + throw new Error("No suitable embedding provider found"); + } + + log.info(`Searching with embedding provider: ${provider.name}, model: ${provider.getConfig().model}`); // Generate embedding for the query const embedding = await provider.generateEmbeddings(query); + log.info(`Generated embedding for query: "${query}" (${embedding.length} dimensions)`); + + // Get Note IDs to search, optionally filtered by branch + let similarNotes = []; // Check if we need to restrict search to a specific branch if (contextNoteId) { @@ -525,7 +570,6 @@ class IndexService { collectNoteIds(contextNoteId); // Get embeddings for all notes in the branch - const similarNotes = []; const config = provider.getConfig(); for (const noteId of branchNoteIds) { @@ -557,7 +601,7 @@ class IndexService { } else { // Search across all notes const config = provider.getConfig(); - const similarNotes = await vectorStore.findSimilarNotes( + similarNotes = await vectorStore.findSimilarNotes( embedding, provider.name, config.model,