diff --git a/src/services/llm/embeddings/providers/ollama.ts b/src/services/llm/embeddings/providers/ollama.ts index 94da3d122..ee9254f70 100644 --- a/src/services/llm/embeddings/providers/ollama.ts +++ b/src/services/llm/embeddings/providers/ollama.ts @@ -1,9 +1,10 @@ import axios from "axios"; import log from "../../../log.js"; import { BaseEmbeddingProvider } from "../base_embeddings.js"; -import type { EmbeddingConfig, EmbeddingModelInfo } from "../embeddings_interface.js"; +import type { EmbeddingConfig } from "../embeddings_interface.js"; import { NormalizationStatus } from "../embeddings_interface.js"; import { LLM_CONSTANTS } from "../../constants/provider_constants.js"; +import type { EmbeddingModelInfo } from "../../interfaces/embedding_interfaces.js"; /** * Ollama embedding provider implementation @@ -27,7 +28,7 @@ export class OllamaEmbeddingProvider extends BaseEmbeddingProvider { // Update the config dimension this.config.dimension = modelInfo.dimension; - log.info(`Ollama model ${modelName} initialized with dimension ${this.config.dimension} and context window ${modelInfo.contextWindow}`); + log.info(`Ollama model ${modelName} initialized with dimension ${this.config.dimension} and context window ${modelInfo.contextWidth}`); } catch (error: any) { log.error(`Error initializing Ollama provider: ${error.message}`); } @@ -63,9 +64,10 @@ export class OllamaEmbeddingProvider extends BaseEmbeddingProvider { log.info(`Fetched Ollama model info from API for ${modelName}: context window ${contextWindow}`); return { + name: modelName, dimension: embeddingDimension || 0, // We'll detect this separately if not provided - contextWindow: contextWindow, - guaranteesNormalization: false // Ollama models don't guarantee normalized embeddings + contextWidth: contextWindow, + type: 'float32' }; } } catch (error: any) { @@ -82,14 +84,14 @@ export class OllamaEmbeddingProvider extends BaseEmbeddingProvider { async getModelInfo(modelName: string): Promise { // Check cache first if (this.modelInfoCache.has(modelName)) { - return this.modelInfoCache.get(modelName); + return this.modelInfoCache.get(modelName)!; } // Try to fetch model capabilities from API const apiModelInfo = await this.fetchModelCapabilities(modelName); if (apiModelInfo) { // If we have context window but no embedding dimension, we need to detect the dimension - if (apiModelInfo.contextWindow && !apiModelInfo.dimension) { + if (apiModelInfo.contextWidth && !apiModelInfo.dimension) { try { // Detect dimension with a test embedding const dimension = await this.detectEmbeddingDimension(modelName); @@ -116,9 +118,10 @@ export class OllamaEmbeddingProvider extends BaseEmbeddingProvider { (LLM_CONSTANTS.OLLAMA_MODEL_CONTEXT_WINDOWS as Record).default; const modelInfo: EmbeddingModelInfo = { + name: modelName, dimension, - contextWindow, - guaranteesNormalization: false // Ollama models don't guarantee normalized embeddings + contextWidth: contextWindow, + type: 'float32' }; this.modelInfoCache.set(modelName, modelInfo); this.config.dimension = dimension; @@ -138,9 +141,10 @@ export class OllamaEmbeddingProvider extends BaseEmbeddingProvider { log.info(`Using default parameters for model ${modelName}: dimension ${dimension}, context ${contextWindow}`); const modelInfo: EmbeddingModelInfo = { + name: modelName, dimension, - contextWindow, - guaranteesNormalization: false // Ollama models don't guarantee normalized embeddings + contextWidth: contextWindow, + type: 'float32' }; this.modelInfoCache.set(modelName, modelInfo); this.config.dimension = dimension; @@ -202,7 +206,7 @@ export class OllamaEmbeddingProvider extends BaseEmbeddingProvider { // Trim text if it might exceed context window (rough character estimate) // This is a simplistic approach - ideally we'd count tokens properly - const charLimit = modelInfo.contextWindow * 4; // Rough estimate: avg 4 chars per token + const charLimit = (modelInfo.contextWidth || 4096) * 4; // Rough estimate: avg 4 chars per token const trimmedText = text.length > charLimit ? text.substring(0, charLimit) : text; const response = await axios.post( diff --git a/src/services/llm/embeddings/providers/openai.ts b/src/services/llm/embeddings/providers/openai.ts index 902ff474d..5a76e2032 100644 --- a/src/services/llm/embeddings/providers/openai.ts +++ b/src/services/llm/embeddings/providers/openai.ts @@ -1,9 +1,10 @@ import axios from "axios"; import log from "../../../log.js"; import { BaseEmbeddingProvider } from "../base_embeddings.js"; -import type { EmbeddingConfig, EmbeddingModelInfo } from "../embeddings_interface.js"; +import type { EmbeddingConfig } from "../embeddings_interface.js"; import { NormalizationStatus } from "../embeddings_interface.js"; import { LLM_CONSTANTS } from "../../constants/provider_constants.js"; +import type { EmbeddingModelInfo } from "../../interfaces/embedding_interfaces.js"; /** * OpenAI embedding provider implementation @@ -27,7 +28,7 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider { // Update the config dimension this.config.dimension = modelInfo.dimension; - log.info(`OpenAI model ${modelName} initialized with dimension ${this.config.dimension} and context window ${modelInfo.contextWindow}`); + log.info(`OpenAI model ${modelName} initialized with dimension ${this.config.dimension} and context window ${modelInfo.contextWidth}`); } catch (error: any) { log.error(`Error initializing OpenAI provider: ${error.message}`); } @@ -105,9 +106,10 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider { log.info(`Fetched OpenAI model info for ${modelName}: context window ${contextWindow}, dimension ${dimension}`); return { + name: modelName, dimension, - contextWindow, - guaranteesNormalization: true // OpenAI embeddings are normalized to unit length + contextWidth: contextWindow, + type: 'float32' }; } } catch (error: any) { @@ -123,7 +125,7 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider { async getModelInfo(modelName: string): Promise { // Check cache first if (this.modelInfoCache.has(modelName)) { - return this.modelInfoCache.get(modelName); + return this.modelInfoCache.get(modelName)!; } // Try to fetch model capabilities from API @@ -144,9 +146,10 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider { let contextWindow = LLM_CONSTANTS.CONTEXT_WINDOW.OPENAI; const modelInfo: EmbeddingModelInfo = { + name: modelName, dimension, - contextWindow, - guaranteesNormalization: true // OpenAI embeddings are normalized to unit length + contextWidth: contextWindow, + type: 'float32' }; this.modelInfoCache.set(modelName, modelInfo); this.config.dimension = dimension; @@ -161,9 +164,10 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider { log.info(`Using default parameters for OpenAI model ${modelName}: dimension ${dimension}, context ${contextWindow}`); const modelInfo: EmbeddingModelInfo = { + name: modelName, dimension, - contextWindow, - guaranteesNormalization: true // OpenAI embeddings are normalized to unit length + contextWidth: contextWindow, + type: 'float32' }; this.modelInfoCache.set(modelName, modelInfo); this.config.dimension = dimension;