From 697d348286f3e772ead40e12da7d6cf100f85030 Mon Sep 17 00:00:00 2001 From: perf3ct Date: Sun, 16 Mar 2025 18:08:50 +0000 Subject: [PATCH] set up more reasonable context window and dimension sizes --- db/migrations/0229__ai_llm_options.sql | 3 +- src/routes/api/llm.ts | 31 ++ .../llm/embeddings/base_embeddings.ts | 212 +++++++++++-- .../llm/embeddings/embeddings_interface.ts | 10 + .../llm/embeddings/providers/anthropic.ts | 184 ++++++++++-- .../llm/embeddings/providers/ollama.ts | 239 +++++++++------ .../llm/embeddings/providers/openai.ts | 282 +++++++++++++++--- 7 files changed, 787 insertions(+), 174 deletions(-) diff --git a/db/migrations/0229__ai_llm_options.sql b/db/migrations/0229__ai_llm_options.sql index 041ccea42..9d08f8bfd 100644 --- a/db/migrations/0229__ai_llm_options.sql +++ b/db/migrations/0229__ai_llm_options.sql @@ -26,4 +26,5 @@ INSERT INTO options (name, value, isSynced, utcDateModified) VALUES ('aiSystemPr INSERT INTO options (name, value, isSynced, utcDateModified) VALUES ('embeddingsDefaultProvider', 'openai', 1, strftime('%Y-%m-%dT%H:%M:%fZ', 'now')); INSERT INTO options (name, value, isSynced, utcDateModified) VALUES ('enableAutomaticIndexing', 'true', 1, strftime('%Y-%m-%dT%H:%M:%fZ', 'now')); INSERT INTO options (name, value, isSynced, utcDateModified) VALUES ('embeddingSimilarityThreshold', '0.65', 1, strftime('%Y-%m-%dT%H:%M:%fZ', 'now')); -INSERT INTO options (name, value, isSynced, utcDateModified) VALUES ('maxNotesPerLlmQuery', '10', 1, strftime('%Y-%m-%dT%H:%M:%fZ', 'now')); \ No newline at end of file +INSERT INTO options (name, value, isSynced, utcDateModified) VALUES ('maxNotesPerLlmQuery', '10', 1, strftime('%Y-%m-%dT%H:%M:%fZ', 'now')); +INSERT INTO options (name, value, isSynced, utcDateModified) VALUES ('embeddingBatchSize', '10', 1, strftime('%Y-%m-%dT%H:%M:%fZ', 'now')); \ No newline at end of file diff --git a/src/routes/api/llm.ts b/src/routes/api/llm.ts index 5a4deea85..b41809937 100644 --- a/src/routes/api/llm.ts +++ b/src/routes/api/llm.ts @@ -41,6 +41,37 @@ export const LLM_CONSTANTS = { } }, + // Model-specific embedding dimensions for Ollama models + OLLAMA_MODEL_DIMENSIONS: { + "llama3": 4096, + "llama3.1": 4096, + "mistral": 4096, + "nomic": 768, + "mxbai": 1024, + "nomic-embed-text": 768, + "mxbai-embed-large": 1024, + "default": 384 + }, + + // Model-specific context windows for Ollama models + OLLAMA_MODEL_CONTEXT_WINDOWS: { + "llama3": 8192, + "mistral": 8192, + "nomic": 32768, + "mxbai": 32768, + "nomic-embed-text": 32768, + "mxbai-embed-large": 32768, + "default": 4096 + }, + + // Batch size configuration + BATCH_SIZE: { + OPENAI: 10, // OpenAI can handle larger batches efficiently + ANTHROPIC: 5, // More conservative for Anthropic + OLLAMA: 1, // Ollama processes one at a time + DEFAULT: 5 // Conservative default + }, + // Chunking parameters CHUNKING: { DEFAULT_SIZE: 1500, diff --git a/src/services/llm/embeddings/base_embeddings.ts b/src/services/llm/embeddings/base_embeddings.ts index 29e52d1ce..4f1f09b60 100644 --- a/src/services/llm/embeddings/base_embeddings.ts +++ b/src/services/llm/embeddings/base_embeddings.ts @@ -1,22 +1,212 @@ import type { EmbeddingProvider, EmbeddingConfig, NoteEmbeddingContext } from './embeddings_interface.js'; +import log from "../../log.js"; +import { LLM_CONSTANTS } from "../../../routes/api/llm.js"; +import options from "../../options.js"; /** * Base class that implements common functionality for embedding providers */ export abstract class BaseEmbeddingProvider implements EmbeddingProvider { - abstract name: string; + name: string = "base"; protected config: EmbeddingConfig; + protected apiKey?: string; + protected baseUrl: string; + protected modelInfoCache = new Map(); constructor(config: EmbeddingConfig) { this.config = config; + this.apiKey = config.apiKey; + this.baseUrl = config.baseUrl || ""; } getConfig(): EmbeddingConfig { - return this.config; + return { ...this.config }; } + getDimension(): number { + return this.config.dimension; + } + + async initialize(): Promise { + // Default implementation does nothing + return; + } + + /** + * Generate embeddings for a single text + */ abstract generateEmbeddings(text: string): Promise; - abstract generateBatchEmbeddings(texts: string[]): Promise; + + /** + * Get the appropriate batch size for this provider + * Override in provider implementations if needed + */ + protected async getBatchSize(): Promise { + // Try to get the user-configured batch size + let configuredBatchSize: number | null = null; + + try { + const batchSizeStr = await options.getOption('embeddingBatchSize'); + if (batchSizeStr) { + configuredBatchSize = parseInt(batchSizeStr, 10); + } + } catch (error) { + log.error(`Error getting batch size from options: ${error}`); + } + + // If user has configured a specific batch size, use that + if (configuredBatchSize && !isNaN(configuredBatchSize) && configuredBatchSize > 0) { + return configuredBatchSize; + } + + // Otherwise use the provider-specific default from constants + return this.config.batchSize || + LLM_CONSTANTS.BATCH_SIZE[this.name.toUpperCase() as keyof typeof LLM_CONSTANTS.BATCH_SIZE] || + LLM_CONSTANTS.BATCH_SIZE.DEFAULT; + } + + /** + * Process a batch of texts with adaptive handling + * This method will try to process the batch and reduce batch size if encountering errors + */ + protected async processWithAdaptiveBatch( + items: T[], + processFn: (batch: T[]) => Promise, + isBatchSizeError: (error: any) => boolean + ): Promise { + const results: any[] = []; + const failures: { index: number, error: string }[] = []; + let currentBatchSize = await this.getBatchSize(); + let lastError: Error | null = null; + + // Process items in batches + for (let i = 0; i < items.length;) { + const batch = items.slice(i, i + currentBatchSize); + + try { + // Process the current batch + const batchResults = await processFn(batch); + results.push(...batchResults); + i += batch.length; + } + catch (error: any) { + lastError = error; + const errorMessage = error.message || 'Unknown error'; + + // Check if this is a batch size related error + if (isBatchSizeError(error) && currentBatchSize > 1) { + // Reduce batch size and retry + const newBatchSize = Math.max(1, Math.floor(currentBatchSize / 2)); + console.warn(`Batch size error detected, reducing batch size from ${currentBatchSize} to ${newBatchSize}: ${errorMessage}`); + currentBatchSize = newBatchSize; + } + else if (currentBatchSize === 1) { + // If we're already at batch size 1, we can't reduce further, so log the error and skip this item + log.error(`Error processing item at index ${i} with batch size 1: ${errorMessage}`); + failures.push({ index: i, error: errorMessage }); + i++; // Move to the next item + } + else { + // For other errors, retry with a smaller batch size as a precaution + const newBatchSize = Math.max(1, Math.floor(currentBatchSize / 2)); + console.warn(`Error processing batch, reducing batch size from ${currentBatchSize} to ${newBatchSize} as a precaution: ${errorMessage}`); + currentBatchSize = newBatchSize; + } + } + } + + // If all items failed and we have a last error, throw it + if (results.length === 0 && failures.length > 0 && lastError) { + throw lastError; + } + + // If some items failed but others succeeded, log the summary + if (failures.length > 0) { + console.warn(`Processed ${results.length} items successfully, but ${failures.length} items failed`); + } + + return results; + } + + /** + * Detect if an error is related to batch size limits + * Override in provider-specific implementations + */ + protected isBatchSizeError(error: any): boolean { + const errorMessage = error?.message || ''; + const batchSizeErrorPatterns = [ + 'batch size', 'too many items', 'too many inputs', + 'input too large', 'payload too large', 'context length', + 'token limit', 'rate limit', 'request too large' + ]; + + return batchSizeErrorPatterns.some(pattern => + errorMessage.toLowerCase().includes(pattern.toLowerCase()) + ); + } + + /** + * Generate embeddings for multiple texts + * Default implementation processes texts one by one + */ + async generateBatchEmbeddings(texts: string[]): Promise { + if (texts.length === 0) { + return []; + } + + try { + return await this.processWithAdaptiveBatch( + texts, + async (batch) => { + const batchResults = await Promise.all( + batch.map(text => this.generateEmbeddings(text)) + ); + return batchResults; + }, + this.isBatchSizeError + ); + } + catch (error: any) { + const errorMessage = error.message || "Unknown error"; + log.error(`Batch embedding error for provider ${this.name}: ${errorMessage}`); + throw new Error(`${this.name} batch embedding error: ${errorMessage}`); + } + } + + /** + * Generate embeddings for a note with its context + */ + async generateNoteEmbeddings(context: NoteEmbeddingContext): Promise { + const text = [context.title || "", context.content || ""].filter(Boolean).join(" "); + return this.generateEmbeddings(text); + } + + /** + * Generate embeddings for multiple notes with their contexts + */ + async generateBatchNoteEmbeddings(contexts: NoteEmbeddingContext[]): Promise { + if (contexts.length === 0) { + return []; + } + + try { + return await this.processWithAdaptiveBatch( + contexts, + async (batch) => { + const batchResults = await Promise.all( + batch.map(context => this.generateNoteEmbeddings(context)) + ); + return batchResults; + }, + this.isBatchSizeError + ); + } + catch (error: any) { + const errorMessage = error.message || "Unknown error"; + log.error(`Batch note embedding error for provider ${this.name}: ${errorMessage}`); + throw new Error(`${this.name} batch note embedding error: ${errorMessage}`); + } + } /** * Cleans and normalizes text for embeddings by removing excessive whitespace @@ -157,20 +347,4 @@ export abstract class BaseEmbeddingProvider implements EmbeddingProvider { return result; } - - /** - * Default implementation that converts note context to text and generates embeddings - */ - async generateNoteEmbeddings(context: NoteEmbeddingContext): Promise { - const text = this.generateNoteContextText(context); - return this.generateEmbeddings(text); - } - - /** - * Default implementation that processes notes in batch - */ - async generateBatchNoteEmbeddings(contexts: NoteEmbeddingContext[]): Promise { - const texts = contexts.map(ctx => this.generateNoteContextText(ctx)); - return this.generateBatchEmbeddings(texts); - } } diff --git a/src/services/llm/embeddings/embeddings_interface.ts b/src/services/llm/embeddings/embeddings_interface.ts index 7a6b11bfb..b5e512b6a 100644 --- a/src/services/llm/embeddings/embeddings_interface.ts +++ b/src/services/llm/embeddings/embeddings_interface.ts @@ -36,6 +36,14 @@ export interface NoteEmbeddingContext { templateTitles?: string[]; } +/** + * Information about an embedding model's capabilities + */ +export interface EmbeddingModelInfo { + dimension: number; + contextWindow: number; +} + /** * Configuration for how embeddings should be generated */ @@ -46,6 +54,8 @@ export interface EmbeddingConfig { normalize?: boolean; batchSize?: number; contextWindowSize?: number; + apiKey?: string; + baseUrl?: string; } /** diff --git a/src/services/llm/embeddings/providers/anthropic.ts b/src/services/llm/embeddings/providers/anthropic.ts index 3c5156f54..0634f3b9c 100644 --- a/src/services/llm/embeddings/providers/anthropic.ts +++ b/src/services/llm/embeddings/providers/anthropic.ts @@ -1,25 +1,117 @@ -import { BaseEmbeddingProvider } from "../base_embeddings.js"; -import type { EmbeddingConfig } from "../embeddings_interface.js"; import axios from "axios"; import log from "../../../log.js"; +import { BaseEmbeddingProvider } from "../base_embeddings.js"; +import type { EmbeddingConfig, EmbeddingModelInfo } from "../embeddings_interface.js"; +import { LLM_CONSTANTS } from "../../../../routes/api/llm.js"; -interface AnthropicEmbeddingConfig extends EmbeddingConfig { - apiKey: string; - baseUrl: string; -} +// Anthropic model context window sizes - as of current API version +const ANTHROPIC_MODEL_CONTEXT_WINDOWS: Record = { + "claude-3-opus-20240229": 200000, + "claude-3-sonnet-20240229": 180000, + "claude-3-haiku-20240307": 48000, + "claude-2.1": 200000, + "claude-2.0": 100000, + "claude-instant-1.2": 100000, + "default": 100000 +}; /** - * Anthropic (Claude) embedding provider implementation + * Anthropic embedding provider implementation */ export class AnthropicEmbeddingProvider extends BaseEmbeddingProvider { name = "anthropic"; - private apiKey: string; - private baseUrl: string; - constructor(config: AnthropicEmbeddingConfig) { + constructor(config: EmbeddingConfig) { super(config); - this.apiKey = config.apiKey; - this.baseUrl = config.baseUrl; + } + + /** + * Initialize the provider by detecting model capabilities + */ + async initialize(): Promise { + const modelName = this.config.model || "claude-3-haiku-20240307"; + try { + // Detect model capabilities + const modelInfo = await this.getModelInfo(modelName); + + // Update the config dimension + this.config.dimension = modelInfo.dimension; + + log.info(`Anthropic model ${modelName} initialized with dimension ${this.config.dimension} and context window ${modelInfo.contextWindow}`); + } catch (error: any) { + log.error(`Error initializing Anthropic provider: ${error.message}`); + } + } + + /** + * Try to determine Anthropic model capabilities + * Note: Anthropic doesn't have a public endpoint for model metadata, so we use a combination + * of known values and detection by test embeddings + */ + private async fetchModelCapabilities(modelName: string): Promise { + // Anthropic doesn't have a model info endpoint, but we can look up known context sizes + // and detect embedding dimensions by making a test request + + try { + // Get context window size from our local registry of known models + const modelBase = Object.keys(ANTHROPIC_MODEL_CONTEXT_WINDOWS).find( + model => modelName.startsWith(model) + ) || "default"; + + const contextWindow = ANTHROPIC_MODEL_CONTEXT_WINDOWS[modelBase]; + + // For embedding dimension, we'll return null and let getModelInfo detect it + return { + dimension: 0, // Will be detected by test embedding + contextWindow + }; + } catch (error) { + log.info(`Could not determine capabilities for Anthropic model ${modelName}: ${error}`); + return null; + } + } + + /** + * Get model information including embedding dimensions + */ + async getModelInfo(modelName: string): Promise { + // Check cache first + if (this.modelInfoCache.has(modelName)) { + return this.modelInfoCache.get(modelName); + } + + // Try to determine model capabilities + const capabilities = await this.fetchModelCapabilities(modelName); + const contextWindow = capabilities?.contextWindow || LLM_CONSTANTS.CONTEXT_WINDOW.ANTHROPIC; + + // For Anthropic, we need to detect embedding dimension with a test call + try { + // Detect dimension with a test embedding + const testEmbedding = await this.generateEmbeddings("Test"); + const dimension = testEmbedding.length; + + const modelInfo: EmbeddingModelInfo = { + dimension, + contextWindow + }; + + this.modelInfoCache.set(modelName, modelInfo); + this.config.dimension = dimension; + + log.info(`Detected Anthropic model ${modelName} with dimension ${dimension} (context: ${contextWindow})`); + return modelInfo; + } catch (error: any) { + // If detection fails, use defaults + const dimension = LLM_CONSTANTS.EMBEDDING_DIMENSIONS.ANTHROPIC.DEFAULT; + + log.info(`Using default parameters for Anthropic model ${modelName}: dimension ${dimension}, context ${contextWindow}`); + + const modelInfo: EmbeddingModelInfo = { dimension, contextWindow }; + this.modelInfoCache.set(modelName, modelInfo); + this.config.dimension = dimension; + + return modelInfo; + } } /** @@ -27,11 +119,23 @@ export class AnthropicEmbeddingProvider extends BaseEmbeddingProvider { */ async generateEmbeddings(text: string): Promise { try { + if (!text.trim()) { + return new Float32Array(this.config.dimension); + } + + // Get model info to check context window + const modelName = this.config.model || "claude-3-haiku-20240307"; + const modelInfo = await this.getModelInfo(modelName); + + // Trim text if it might exceed context window (rough character estimate) + const charLimit = modelInfo.contextWindow * 4; // Rough estimate: avg 4 chars per token + const trimmedText = text.length > charLimit ? text.substring(0, charLimit) : text; + const response = await axios.post( `${this.baseUrl}/embeddings`, { - model: this.config.model || "claude-3-haiku-20240307", - input: text, + model: modelName, + input: trimmedText, encoding_format: "float" }, { @@ -44,8 +148,7 @@ export class AnthropicEmbeddingProvider extends BaseEmbeddingProvider { ); if (response.data && response.data.embedding) { - const embedding = response.data.embedding; - return new Float32Array(embedding); + return new Float32Array(response.data.embedding); } else { throw new Error("Unexpected response structure from Anthropic API"); } @@ -56,23 +159,60 @@ export class AnthropicEmbeddingProvider extends BaseEmbeddingProvider { } } + /** + * More specific implementation of batch size error detection for Anthropic + */ + protected isBatchSizeError(error: any): boolean { + const errorMessage = error?.message || error?.response?.data?.error?.message || ''; + const anthropicBatchSizeErrorPatterns = [ + 'batch size', 'too many inputs', 'context length exceeded', + 'token limit', 'rate limit', 'limit exceeded', + 'too long', 'request too large', 'content too large' + ]; + + return anthropicBatchSizeErrorPatterns.some(pattern => + errorMessage.toLowerCase().includes(pattern.toLowerCase()) + ); + } + /** * Generate embeddings for multiple texts in a single batch * * Note: Anthropic doesn't currently support batch embedding, so we process each text individually + * but using the adaptive batch processor to handle errors and retries */ async generateBatchEmbeddings(texts: string[]): Promise { if (texts.length === 0) { return []; } - const results: Float32Array[] = []; + try { + return await this.processWithAdaptiveBatch( + texts, + async (batch) => { + const results: Float32Array[] = []; - for (const text of texts) { - const embedding = await this.generateEmbeddings(text); - results.push(embedding); + // For Anthropic, we have to process one at a time + for (const text of batch) { + // Skip empty texts + if (!text.trim()) { + results.push(new Float32Array(this.config.dimension)); + continue; + } + + const embedding = await this.generateEmbeddings(text); + results.push(embedding); + } + + return results; + }, + this.isBatchSizeError + ); + } + catch (error: any) { + const errorMessage = error.message || "Unknown error"; + log.error(`Anthropic batch embedding error: ${errorMessage}`); + throw new Error(`Anthropic batch embedding error: ${errorMessage}`); } - - return results; } } diff --git a/src/services/llm/embeddings/providers/ollama.ts b/src/services/llm/embeddings/providers/ollama.ts index aea0b5f00..43bdf6d8a 100644 --- a/src/services/llm/embeddings/providers/ollama.ts +++ b/src/services/llm/embeddings/providers/ollama.ts @@ -1,30 +1,17 @@ -import { BaseEmbeddingProvider } from "../base_embeddings.js"; -import type { EmbeddingConfig } from "../embeddings_interface.js"; import axios from "axios"; import log from "../../../log.js"; - -interface OllamaEmbeddingConfig extends EmbeddingConfig { - baseUrl: string; -} - -// Model-specific embedding dimensions -interface EmbeddingModelInfo { - dimension: number; - contextWindow: number; -} +import { BaseEmbeddingProvider } from "../base_embeddings.js"; +import type { EmbeddingConfig, EmbeddingModelInfo } from "../embeddings_interface.js"; +import { LLM_CONSTANTS } from "../../../../routes/api/llm.js"; /** * Ollama embedding provider implementation */ export class OllamaEmbeddingProvider extends BaseEmbeddingProvider { name = "ollama"; - private baseUrl: string; - // Cache for model dimensions to avoid repeated API calls - private modelInfoCache = new Map(); - constructor(config: OllamaEmbeddingConfig) { + constructor(config: EmbeddingConfig) { super(config); - this.baseUrl = config.baseUrl; } /** @@ -33,97 +20,148 @@ export class OllamaEmbeddingProvider extends BaseEmbeddingProvider { async initialize(): Promise { const modelName = this.config.model || "llama3"; try { - await this.getModelInfo(modelName); - log.info(`Ollama embedding provider initialized with model ${modelName}`); + // Detect model capabilities + const modelInfo = await this.getModelInfo(modelName); + + // Update the config dimension + this.config.dimension = modelInfo.dimension; + + log.info(`Ollama model ${modelName} initialized with dimension ${this.config.dimension} and context window ${modelInfo.contextWindow}`); } catch (error: any) { - log.error(`Failed to initialize Ollama embedding provider: ${error.message}`); - // Still continue with default dimensions + log.error(`Error initializing Ollama provider: ${error.message}`); } } /** - * Get model information including embedding dimensions + * Fetch detailed model information from Ollama API + * @param modelName The name of the model to fetch information for */ - async getModelInfo(modelName: string): Promise { - // Check cache first - if (this.modelInfoCache.has(modelName)) { - return this.modelInfoCache.get(modelName)!; - } - - // Default dimensions for common embedding models - const defaultDimensions: Record = { - "nomic-embed-text": 768, - "mxbai-embed-large": 1024, - "llama3": 4096, - "all-minilm": 384, - "default": 4096 - }; - - // Default context windows - const defaultContextWindows: Record = { - "nomic-embed-text": 8192, - "mxbai-embed-large": 8192, - "llama3": 8192, - "all-minilm": 4096, - "default": 4096 - }; - + private async fetchModelCapabilities(modelName: string): Promise { try { - // Try to detect if this is an embedding model - const testResponse = await axios.post( - `${this.baseUrl}/api/embeddings`, - { - model: modelName, - prompt: "Test" - }, + // First try the /api/show endpoint which has detailed model information + const showResponse = await axios.get( + `${this.baseUrl}/api/show`, { + params: { name: modelName }, headers: { "Content-Type": "application/json" }, timeout: 10000 } ); - let dimension = 0; - let contextWindow = 0; + if (showResponse.data && showResponse.data.parameters) { + const params = showResponse.data.parameters; + // Extract context length from parameters (different models might use different parameter names) + const contextWindow = params.context_length || + params.num_ctx || + params.context_window || + (LLM_CONSTANTS.OLLAMA_MODEL_CONTEXT_WINDOWS as Record).default; - if (testResponse.data && Array.isArray(testResponse.data.embedding)) { - dimension = testResponse.data.embedding.length; + // Some models might provide embedding dimensions + const embeddingDimension = params.embedding_length || params.dim || null; - // Set context window based on model name if we have it - const baseModelName = modelName.split(':')[0]; - contextWindow = defaultContextWindows[baseModelName] || defaultContextWindows.default; + log.info(`Fetched Ollama model info from API for ${modelName}: context window ${contextWindow}`); - log.info(`Detected Ollama model ${modelName} with dimension ${dimension}`); - } else { - throw new Error("Could not detect embedding dimensions"); + return { + dimension: embeddingDimension || 0, // We'll detect this separately if not provided + contextWindow: contextWindow + }; } + } catch (error: any) { + log.info(`Could not fetch model info from Ollama show API: ${error.message}. Will try embedding test.`); + // We'll fall back to embedding test if this fails + } + + return null; + } + + /** + * Get model information by probing the API + */ + async getModelInfo(modelName: string): Promise { + // Check cache first + if (this.modelInfoCache.has(modelName)) { + return this.modelInfoCache.get(modelName); + } + + // Try to fetch model capabilities from API + const apiModelInfo = await this.fetchModelCapabilities(modelName); + if (apiModelInfo) { + // If we have context window but no embedding dimension, we need to detect the dimension + if (apiModelInfo.contextWindow && !apiModelInfo.dimension) { + try { + // Detect dimension with a test embedding + const dimension = await this.detectEmbeddingDimension(modelName); + apiModelInfo.dimension = dimension; + } catch (error) { + // If dimension detection fails, fall back to defaults + const baseModelName = modelName.split(':')[0]; + apiModelInfo.dimension = (LLM_CONSTANTS.OLLAMA_MODEL_DIMENSIONS as Record)[baseModelName] || + (LLM_CONSTANTS.OLLAMA_MODEL_DIMENSIONS as Record).default; + } + } + + // Cache and return the API-provided info + this.modelInfoCache.set(modelName, apiModelInfo); + this.config.dimension = apiModelInfo.dimension; + return apiModelInfo; + } + + // If API info fetch fails, fall back to test embedding + try { + const dimension = await this.detectEmbeddingDimension(modelName); + const baseModelName = modelName.split(':')[0]; + const contextWindow = (LLM_CONSTANTS.OLLAMA_MODEL_CONTEXT_WINDOWS as Record)[baseModelName] || + (LLM_CONSTANTS.OLLAMA_MODEL_CONTEXT_WINDOWS as Record).default; const modelInfo: EmbeddingModelInfo = { dimension, contextWindow }; this.modelInfoCache.set(modelName, modelInfo); - - // Update the provider config dimension this.config.dimension = dimension; + log.info(`Detected Ollama model ${modelName} with dimension ${dimension} (context: ${contextWindow})`); return modelInfo; } catch (error: any) { log.error(`Error detecting Ollama model capabilities: ${error.message}`); - // If detection fails, use defaults based on model name + // If all detection fails, use defaults based on model name const baseModelName = modelName.split(':')[0]; - const dimension = defaultDimensions[baseModelName] || defaultDimensions.default; - const contextWindow = defaultContextWindows[baseModelName] || defaultContextWindows.default; + const dimension = (LLM_CONSTANTS.OLLAMA_MODEL_DIMENSIONS as Record)[baseModelName] || + (LLM_CONSTANTS.OLLAMA_MODEL_DIMENSIONS as Record).default; + const contextWindow = (LLM_CONSTANTS.OLLAMA_MODEL_CONTEXT_WINDOWS as Record)[baseModelName] || + (LLM_CONSTANTS.OLLAMA_MODEL_CONTEXT_WINDOWS as Record).default; - log.info(`Using default dimension ${dimension} for model ${modelName}`); + log.info(`Using default parameters for model ${modelName}: dimension ${dimension}, context ${contextWindow}`); const modelInfo: EmbeddingModelInfo = { dimension, contextWindow }; this.modelInfoCache.set(modelName, modelInfo); - - // Update the provider config dimension this.config.dimension = dimension; return modelInfo; } } + /** + * Detect embedding dimension by making a test API call + */ + private async detectEmbeddingDimension(modelName: string): Promise { + const testResponse = await axios.post( + `${this.baseUrl}/api/embeddings`, + { + model: modelName, + prompt: "Test" + }, + { + headers: { "Content-Type": "application/json" }, + timeout: 10000 + } + ); + + if (testResponse.data && Array.isArray(testResponse.data.embedding)) { + return testResponse.data.embedding.length; + } else { + throw new Error("Could not detect embedding dimensions"); + } + } + /** * Get the current embedding dimension */ @@ -136,6 +174,10 @@ export class OllamaEmbeddingProvider extends BaseEmbeddingProvider { */ async generateEmbeddings(text: string): Promise { try { + if (!text.trim()) { + return new Float32Array(this.config.dimension); + } + const modelName = this.config.model || "llama3"; // Ensure we have model info @@ -173,29 +215,60 @@ export class OllamaEmbeddingProvider extends BaseEmbeddingProvider { } } + /** + * More specific implementation of batch size error detection for Ollama + */ + protected isBatchSizeError(error: any): boolean { + const errorMessage = error?.message || ''; + const ollamaBatchSizeErrorPatterns = [ + 'context length', 'token limit', 'out of memory', + 'too large', 'overloaded', 'prompt too long', + 'too many tokens', 'maximum size' + ]; + + return ollamaBatchSizeErrorPatterns.some(pattern => + errorMessage.toLowerCase().includes(pattern.toLowerCase()) + ); + } + /** * Generate embeddings for multiple texts * * Note: Ollama API doesn't support batch embedding, so we process them sequentially + * but using the adaptive batch processor to handle rate limits and retries */ async generateBatchEmbeddings(texts: string[]): Promise { if (texts.length === 0) { return []; } - const results: Float32Array[] = []; + try { + return await this.processWithAdaptiveBatch( + texts, + async (batch) => { + const results: Float32Array[] = []; - for (const text of texts) { - try { - const embedding = await this.generateEmbeddings(text); - results.push(embedding); - } catch (error: any) { - const errorMessage = error.response?.data?.error?.message || error.message || "Unknown error"; - log.error(`Ollama batch embedding error: ${errorMessage}`); - throw new Error(`Ollama batch embedding error: ${errorMessage}`); - } + // For Ollama, we have to process one at a time + for (const text of batch) { + // Skip empty texts + if (!text.trim()) { + results.push(new Float32Array(this.config.dimension)); + continue; + } + + const embedding = await this.generateEmbeddings(text); + results.push(embedding); + } + + return results; + }, + this.isBatchSizeError + ); + } + catch (error: any) { + const errorMessage = error.message || "Unknown error"; + log.error(`Ollama batch embedding error: ${errorMessage}`); + throw new Error(`Ollama batch embedding error: ${errorMessage}`); } - - return results; } } diff --git a/src/services/llm/embeddings/providers/openai.ts b/src/services/llm/embeddings/providers/openai.ts index 0ad8ca51b..116fd5c9b 100644 --- a/src/services/llm/embeddings/providers/openai.ts +++ b/src/services/llm/embeddings/providers/openai.ts @@ -1,25 +1,165 @@ -import { BaseEmbeddingProvider } from "../base_embeddings.js"; -import type { EmbeddingConfig } from "../embeddings_interface.js"; import axios from "axios"; import log from "../../../log.js"; - -interface OpenAIEmbeddingConfig extends EmbeddingConfig { - apiKey: string; - baseUrl: string; -} +import { BaseEmbeddingProvider } from "../base_embeddings.js"; +import type { EmbeddingConfig, EmbeddingModelInfo } from "../embeddings_interface.js"; +import { LLM_CONSTANTS } from "../../../../routes/api/llm.js"; /** * OpenAI embedding provider implementation */ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider { name = "openai"; - private apiKey: string; - private baseUrl: string; - constructor(config: OpenAIEmbeddingConfig) { + constructor(config: EmbeddingConfig) { super(config); - this.apiKey = config.apiKey; - this.baseUrl = config.baseUrl; + } + + /** + * Initialize the provider by detecting model capabilities + */ + async initialize(): Promise { + const modelName = this.config.model || "text-embedding-3-small"; + try { + // Detect model capabilities + const modelInfo = await this.getModelInfo(modelName); + + // Update the config dimension + this.config.dimension = modelInfo.dimension; + + log.info(`OpenAI model ${modelName} initialized with dimension ${this.config.dimension} and context window ${modelInfo.contextWindow}`); + } catch (error: any) { + log.error(`Error initializing OpenAI provider: ${error.message}`); + } + } + + /** + * Fetch model information from the OpenAI API + */ + private async fetchModelCapabilities(modelName: string): Promise { + if (!this.apiKey) { + return null; + } + + try { + // First try to get model details from the models API + const response = await axios.get( + `${this.baseUrl}/models/${modelName}`, + { + headers: { + "Authorization": `Bearer ${this.apiKey}`, + "Content-Type": "application/json" + }, + timeout: 10000 + } + ); + + if (response.data) { + // Different model families may have different ways of exposing context window + let contextWindow = 0; + let dimension = 0; + + // Extract context window if available + if (response.data.context_window) { + contextWindow = response.data.context_window; + } else if (response.data.limits && response.data.limits.context_window) { + contextWindow = response.data.limits.context_window; + } else if (response.data.limits && response.data.limits.context_length) { + contextWindow = response.data.limits.context_length; + } + + // Extract embedding dimensions if available + if (response.data.dimensions) { + dimension = response.data.dimensions; + } else if (response.data.embedding_dimension) { + dimension = response.data.embedding_dimension; + } + + // If we didn't get all the info, use defaults for missing values + if (!contextWindow) { + // Set default context window based on model name patterns + if (modelName.includes('ada') || modelName.includes('embedding-ada')) { + contextWindow = LLM_CONSTANTS.CONTEXT_WINDOW.OPENAI; + } else if (modelName.includes('davinci')) { + contextWindow = 8192; + } else if (modelName.includes('embedding-3')) { + contextWindow = 8191; + } else { + contextWindow = LLM_CONSTANTS.CONTEXT_WINDOW.OPENAI; + } + } + + if (!dimension) { + // Set default dimensions based on model name patterns + if (modelName.includes('ada') || modelName.includes('embedding-ada')) { + dimension = LLM_CONSTANTS.EMBEDDING_DIMENSIONS.OPENAI.ADA; + } else if (modelName.includes('embedding-3-small')) { + dimension = 1536; + } else if (modelName.includes('embedding-3-large')) { + dimension = 3072; + } else { + dimension = LLM_CONSTANTS.EMBEDDING_DIMENSIONS.OPENAI.DEFAULT; + } + } + + log.info(`Fetched OpenAI model info for ${modelName}: context window ${contextWindow}, dimension ${dimension}`); + + return { + dimension, + contextWindow + }; + } + } catch (error: any) { + log.info(`Could not fetch model info from OpenAI API: ${error.message}. Will try embedding test.`); + } + + return null; + } + + /** + * Get model information including embedding dimensions + */ + async getModelInfo(modelName: string): Promise { + // Check cache first + if (this.modelInfoCache.has(modelName)) { + return this.modelInfoCache.get(modelName); + } + + // Try to fetch model capabilities from API + const apiModelInfo = await this.fetchModelCapabilities(modelName); + if (apiModelInfo) { + // Cache and return the API-provided info + this.modelInfoCache.set(modelName, apiModelInfo); + this.config.dimension = apiModelInfo.dimension; + return apiModelInfo; + } + + // If API info fetch fails, try to detect embedding dimension with a test call + try { + const testEmbedding = await this.generateEmbeddings("Test"); + const dimension = testEmbedding.length; + + // Use default context window + let contextWindow = LLM_CONSTANTS.CONTEXT_WINDOW.OPENAI; + + const modelInfo: EmbeddingModelInfo = { dimension, contextWindow }; + this.modelInfoCache.set(modelName, modelInfo); + this.config.dimension = dimension; + + log.info(`Detected OpenAI model ${modelName} with dimension ${dimension} (context: ${contextWindow})`); + return modelInfo; + } catch (error: any) { + // If detection fails, use defaults + const dimension = LLM_CONSTANTS.EMBEDDING_DIMENSIONS.OPENAI.DEFAULT; + const contextWindow = LLM_CONSTANTS.CONTEXT_WINDOW.OPENAI; + + log.info(`Using default parameters for OpenAI model ${modelName}: dimension ${dimension}, context ${contextWindow}`); + + const modelInfo: EmbeddingModelInfo = { dimension, contextWindow }; + this.modelInfoCache.set(modelName, modelInfo); + this.config.dimension = dimension; + + return modelInfo; + } } /** @@ -27,6 +167,10 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider { */ async generateEmbeddings(text: string): Promise { try { + if (!text.trim()) { + return new Float32Array(this.config.dimension); + } + const response = await axios.post( `${this.baseUrl}/embeddings`, { @@ -43,8 +187,7 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider { ); if (response.data && response.data.data && response.data.data[0] && response.data.data[0].embedding) { - const embedding = response.data.data[0].embedding; - return new Float32Array(embedding); + return new Float32Array(response.data.data[0].embedding); } else { throw new Error("Unexpected response structure from OpenAI API"); } @@ -55,53 +198,94 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider { } } + /** + * More specific implementation of batch size error detection for OpenAI + */ + protected isBatchSizeError(error: any): boolean { + const errorMessage = error?.message || error?.response?.data?.error?.message || ''; + const openAIBatchSizeErrorPatterns = [ + 'batch size', 'too many inputs', 'context length exceeded', + 'maximum context length', 'token limit', 'rate limit exceeded', + 'tokens in the messages', 'reduce the length', 'too long' + ]; + + return openAIBatchSizeErrorPatterns.some(pattern => + errorMessage.toLowerCase().includes(pattern.toLowerCase()) + ); + } + + /** + * Custom implementation for batched OpenAI embeddings + */ + async generateBatchEmbeddingsWithAPI(texts: string[]): Promise { + if (texts.length === 0) { + return []; + } + + const response = await axios.post( + `${this.baseUrl}/embeddings`, + { + input: texts, + model: this.config.model || "text-embedding-3-small", + encoding_format: "float" + }, + { + headers: { + "Content-Type": "application/json", + "Authorization": `Bearer ${this.apiKey}` + } + } + ); + + if (response.data && response.data.data) { + // Sort the embeddings by index to ensure they match the input order + const sortedEmbeddings = response.data.data + .sort((a: any, b: any) => a.index - b.index) + .map((item: any) => new Float32Array(item.embedding)); + + return sortedEmbeddings; + } else { + throw new Error("Unexpected response structure from OpenAI API"); + } + } + /** * Generate embeddings for multiple texts in a single batch + * OpenAI API supports batch embedding, so we implement a custom version */ async generateBatchEmbeddings(texts: string[]): Promise { if (texts.length === 0) { return []; } - const batchSize = this.config.batchSize || 10; - const results: Float32Array[] = []; + try { + return await this.processWithAdaptiveBatch( + texts, + async (batch) => { + // Filter out empty texts and use the API batch functionality + const filteredBatch = batch.filter(text => text.trim().length > 0); - // Process in batches to avoid API limits - for (let i = 0; i < texts.length; i += batchSize) { - const batch = texts.slice(i, i + batchSize); - try { - const response = await axios.post( - `${this.baseUrl}/embeddings`, - { - input: batch, - model: this.config.model || "text-embedding-3-small", - encoding_format: "float" - }, - { - headers: { - "Content-Type": "application/json", - "Authorization": `Bearer ${this.apiKey}` - } + if (filteredBatch.length === 0) { + // If all texts are empty after filtering, return empty embeddings + return batch.map(() => new Float32Array(this.config.dimension)); } - ); - if (response.data && response.data.data) { - // Sort the embeddings by index to ensure they match the input order - const sortedEmbeddings = response.data.data - .sort((a: any, b: any) => a.index - b.index) - .map((item: any) => new Float32Array(item.embedding)); + if (filteredBatch.length === 1) { + // If only one text, use the single embedding endpoint + const embedding = await this.generateEmbeddings(filteredBatch[0]); + return [embedding]; + } - results.push(...sortedEmbeddings); - } else { - throw new Error("Unexpected response structure from OpenAI API"); - } - } catch (error: any) { - const errorMessage = error.response?.data?.error?.message || error.message || "Unknown error"; - log.error(`OpenAI batch embedding error: ${errorMessage}`); - throw new Error(`OpenAI batch embedding error: ${errorMessage}`); - } + // Use the batch API endpoint + return this.generateBatchEmbeddingsWithAPI(filteredBatch); + }, + this.isBatchSizeError + ); + } + catch (error: any) { + const errorMessage = error.message || "Unknown error"; + log.error(`OpenAI batch embedding error: ${errorMessage}`); + throw new Error(`OpenAI batch embedding error: ${errorMessage}`); } - - return results; } }