mirror of
https://github.com/TriliumNext/Notes.git
synced 2025-09-02 21:42:15 +08:00
set up more reasonable context window and dimension sizes
This commit is contained in:
parent
572a03a3f7
commit
697d348286
@ -27,3 +27,4 @@ INSERT INTO options (name, value, isSynced, utcDateModified) VALUES ('embeddings
|
|||||||
INSERT INTO options (name, value, isSynced, utcDateModified) VALUES ('enableAutomaticIndexing', 'true', 1, strftime('%Y-%m-%dT%H:%M:%fZ', 'now'));
|
INSERT INTO options (name, value, isSynced, utcDateModified) VALUES ('enableAutomaticIndexing', 'true', 1, strftime('%Y-%m-%dT%H:%M:%fZ', 'now'));
|
||||||
INSERT INTO options (name, value, isSynced, utcDateModified) VALUES ('embeddingSimilarityThreshold', '0.65', 1, strftime('%Y-%m-%dT%H:%M:%fZ', 'now'));
|
INSERT INTO options (name, value, isSynced, utcDateModified) VALUES ('embeddingSimilarityThreshold', '0.65', 1, strftime('%Y-%m-%dT%H:%M:%fZ', 'now'));
|
||||||
INSERT INTO options (name, value, isSynced, utcDateModified) VALUES ('maxNotesPerLlmQuery', '10', 1, strftime('%Y-%m-%dT%H:%M:%fZ', 'now'));
|
INSERT INTO options (name, value, isSynced, utcDateModified) VALUES ('maxNotesPerLlmQuery', '10', 1, strftime('%Y-%m-%dT%H:%M:%fZ', 'now'));
|
||||||
|
INSERT INTO options (name, value, isSynced, utcDateModified) VALUES ('embeddingBatchSize', '10', 1, strftime('%Y-%m-%dT%H:%M:%fZ', 'now'));
|
@ -41,6 +41,37 @@ export const LLM_CONSTANTS = {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
// Model-specific embedding dimensions for Ollama models
|
||||||
|
OLLAMA_MODEL_DIMENSIONS: {
|
||||||
|
"llama3": 4096,
|
||||||
|
"llama3.1": 4096,
|
||||||
|
"mistral": 4096,
|
||||||
|
"nomic": 768,
|
||||||
|
"mxbai": 1024,
|
||||||
|
"nomic-embed-text": 768,
|
||||||
|
"mxbai-embed-large": 1024,
|
||||||
|
"default": 384
|
||||||
|
},
|
||||||
|
|
||||||
|
// Model-specific context windows for Ollama models
|
||||||
|
OLLAMA_MODEL_CONTEXT_WINDOWS: {
|
||||||
|
"llama3": 8192,
|
||||||
|
"mistral": 8192,
|
||||||
|
"nomic": 32768,
|
||||||
|
"mxbai": 32768,
|
||||||
|
"nomic-embed-text": 32768,
|
||||||
|
"mxbai-embed-large": 32768,
|
||||||
|
"default": 4096
|
||||||
|
},
|
||||||
|
|
||||||
|
// Batch size configuration
|
||||||
|
BATCH_SIZE: {
|
||||||
|
OPENAI: 10, // OpenAI can handle larger batches efficiently
|
||||||
|
ANTHROPIC: 5, // More conservative for Anthropic
|
||||||
|
OLLAMA: 1, // Ollama processes one at a time
|
||||||
|
DEFAULT: 5 // Conservative default
|
||||||
|
},
|
||||||
|
|
||||||
// Chunking parameters
|
// Chunking parameters
|
||||||
CHUNKING: {
|
CHUNKING: {
|
||||||
DEFAULT_SIZE: 1500,
|
DEFAULT_SIZE: 1500,
|
||||||
|
@ -1,22 +1,212 @@
|
|||||||
import type { EmbeddingProvider, EmbeddingConfig, NoteEmbeddingContext } from './embeddings_interface.js';
|
import type { EmbeddingProvider, EmbeddingConfig, NoteEmbeddingContext } from './embeddings_interface.js';
|
||||||
|
import log from "../../log.js";
|
||||||
|
import { LLM_CONSTANTS } from "../../../routes/api/llm.js";
|
||||||
|
import options from "../../options.js";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base class that implements common functionality for embedding providers
|
* Base class that implements common functionality for embedding providers
|
||||||
*/
|
*/
|
||||||
export abstract class BaseEmbeddingProvider implements EmbeddingProvider {
|
export abstract class BaseEmbeddingProvider implements EmbeddingProvider {
|
||||||
abstract name: string;
|
name: string = "base";
|
||||||
protected config: EmbeddingConfig;
|
protected config: EmbeddingConfig;
|
||||||
|
protected apiKey?: string;
|
||||||
|
protected baseUrl: string;
|
||||||
|
protected modelInfoCache = new Map<string, any>();
|
||||||
|
|
||||||
constructor(config: EmbeddingConfig) {
|
constructor(config: EmbeddingConfig) {
|
||||||
this.config = config;
|
this.config = config;
|
||||||
|
this.apiKey = config.apiKey;
|
||||||
|
this.baseUrl = config.baseUrl || "";
|
||||||
}
|
}
|
||||||
|
|
||||||
getConfig(): EmbeddingConfig {
|
getConfig(): EmbeddingConfig {
|
||||||
return this.config;
|
return { ...this.config };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getDimension(): number {
|
||||||
|
return this.config.dimension;
|
||||||
|
}
|
||||||
|
|
||||||
|
async initialize(): Promise<void> {
|
||||||
|
// Default implementation does nothing
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate embeddings for a single text
|
||||||
|
*/
|
||||||
abstract generateEmbeddings(text: string): Promise<Float32Array>;
|
abstract generateEmbeddings(text: string): Promise<Float32Array>;
|
||||||
abstract generateBatchEmbeddings(texts: string[]): Promise<Float32Array[]>;
|
|
||||||
|
/**
|
||||||
|
* Get the appropriate batch size for this provider
|
||||||
|
* Override in provider implementations if needed
|
||||||
|
*/
|
||||||
|
protected async getBatchSize(): Promise<number> {
|
||||||
|
// Try to get the user-configured batch size
|
||||||
|
let configuredBatchSize: number | null = null;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const batchSizeStr = await options.getOption('embeddingBatchSize');
|
||||||
|
if (batchSizeStr) {
|
||||||
|
configuredBatchSize = parseInt(batchSizeStr, 10);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
log.error(`Error getting batch size from options: ${error}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If user has configured a specific batch size, use that
|
||||||
|
if (configuredBatchSize && !isNaN(configuredBatchSize) && configuredBatchSize > 0) {
|
||||||
|
return configuredBatchSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise use the provider-specific default from constants
|
||||||
|
return this.config.batchSize ||
|
||||||
|
LLM_CONSTANTS.BATCH_SIZE[this.name.toUpperCase() as keyof typeof LLM_CONSTANTS.BATCH_SIZE] ||
|
||||||
|
LLM_CONSTANTS.BATCH_SIZE.DEFAULT;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Process a batch of texts with adaptive handling
|
||||||
|
* This method will try to process the batch and reduce batch size if encountering errors
|
||||||
|
*/
|
||||||
|
protected async processWithAdaptiveBatch<T>(
|
||||||
|
items: T[],
|
||||||
|
processFn: (batch: T[]) => Promise<any[]>,
|
||||||
|
isBatchSizeError: (error: any) => boolean
|
||||||
|
): Promise<any[]> {
|
||||||
|
const results: any[] = [];
|
||||||
|
const failures: { index: number, error: string }[] = [];
|
||||||
|
let currentBatchSize = await this.getBatchSize();
|
||||||
|
let lastError: Error | null = null;
|
||||||
|
|
||||||
|
// Process items in batches
|
||||||
|
for (let i = 0; i < items.length;) {
|
||||||
|
const batch = items.slice(i, i + currentBatchSize);
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Process the current batch
|
||||||
|
const batchResults = await processFn(batch);
|
||||||
|
results.push(...batchResults);
|
||||||
|
i += batch.length;
|
||||||
|
}
|
||||||
|
catch (error: any) {
|
||||||
|
lastError = error;
|
||||||
|
const errorMessage = error.message || 'Unknown error';
|
||||||
|
|
||||||
|
// Check if this is a batch size related error
|
||||||
|
if (isBatchSizeError(error) && currentBatchSize > 1) {
|
||||||
|
// Reduce batch size and retry
|
||||||
|
const newBatchSize = Math.max(1, Math.floor(currentBatchSize / 2));
|
||||||
|
console.warn(`Batch size error detected, reducing batch size from ${currentBatchSize} to ${newBatchSize}: ${errorMessage}`);
|
||||||
|
currentBatchSize = newBatchSize;
|
||||||
|
}
|
||||||
|
else if (currentBatchSize === 1) {
|
||||||
|
// If we're already at batch size 1, we can't reduce further, so log the error and skip this item
|
||||||
|
log.error(`Error processing item at index ${i} with batch size 1: ${errorMessage}`);
|
||||||
|
failures.push({ index: i, error: errorMessage });
|
||||||
|
i++; // Move to the next item
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// For other errors, retry with a smaller batch size as a precaution
|
||||||
|
const newBatchSize = Math.max(1, Math.floor(currentBatchSize / 2));
|
||||||
|
console.warn(`Error processing batch, reducing batch size from ${currentBatchSize} to ${newBatchSize} as a precaution: ${errorMessage}`);
|
||||||
|
currentBatchSize = newBatchSize;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If all items failed and we have a last error, throw it
|
||||||
|
if (results.length === 0 && failures.length > 0 && lastError) {
|
||||||
|
throw lastError;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If some items failed but others succeeded, log the summary
|
||||||
|
if (failures.length > 0) {
|
||||||
|
console.warn(`Processed ${results.length} items successfully, but ${failures.length} items failed`);
|
||||||
|
}
|
||||||
|
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Detect if an error is related to batch size limits
|
||||||
|
* Override in provider-specific implementations
|
||||||
|
*/
|
||||||
|
protected isBatchSizeError(error: any): boolean {
|
||||||
|
const errorMessage = error?.message || '';
|
||||||
|
const batchSizeErrorPatterns = [
|
||||||
|
'batch size', 'too many items', 'too many inputs',
|
||||||
|
'input too large', 'payload too large', 'context length',
|
||||||
|
'token limit', 'rate limit', 'request too large'
|
||||||
|
];
|
||||||
|
|
||||||
|
return batchSizeErrorPatterns.some(pattern =>
|
||||||
|
errorMessage.toLowerCase().includes(pattern.toLowerCase())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate embeddings for multiple texts
|
||||||
|
* Default implementation processes texts one by one
|
||||||
|
*/
|
||||||
|
async generateBatchEmbeddings(texts: string[]): Promise<Float32Array[]> {
|
||||||
|
if (texts.length === 0) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
return await this.processWithAdaptiveBatch(
|
||||||
|
texts,
|
||||||
|
async (batch) => {
|
||||||
|
const batchResults = await Promise.all(
|
||||||
|
batch.map(text => this.generateEmbeddings(text))
|
||||||
|
);
|
||||||
|
return batchResults;
|
||||||
|
},
|
||||||
|
this.isBatchSizeError
|
||||||
|
);
|
||||||
|
}
|
||||||
|
catch (error: any) {
|
||||||
|
const errorMessage = error.message || "Unknown error";
|
||||||
|
log.error(`Batch embedding error for provider ${this.name}: ${errorMessage}`);
|
||||||
|
throw new Error(`${this.name} batch embedding error: ${errorMessage}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate embeddings for a note with its context
|
||||||
|
*/
|
||||||
|
async generateNoteEmbeddings(context: NoteEmbeddingContext): Promise<Float32Array> {
|
||||||
|
const text = [context.title || "", context.content || ""].filter(Boolean).join(" ");
|
||||||
|
return this.generateEmbeddings(text);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate embeddings for multiple notes with their contexts
|
||||||
|
*/
|
||||||
|
async generateBatchNoteEmbeddings(contexts: NoteEmbeddingContext[]): Promise<Float32Array[]> {
|
||||||
|
if (contexts.length === 0) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
return await this.processWithAdaptiveBatch(
|
||||||
|
contexts,
|
||||||
|
async (batch) => {
|
||||||
|
const batchResults = await Promise.all(
|
||||||
|
batch.map(context => this.generateNoteEmbeddings(context))
|
||||||
|
);
|
||||||
|
return batchResults;
|
||||||
|
},
|
||||||
|
this.isBatchSizeError
|
||||||
|
);
|
||||||
|
}
|
||||||
|
catch (error: any) {
|
||||||
|
const errorMessage = error.message || "Unknown error";
|
||||||
|
log.error(`Batch note embedding error for provider ${this.name}: ${errorMessage}`);
|
||||||
|
throw new Error(`${this.name} batch note embedding error: ${errorMessage}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Cleans and normalizes text for embeddings by removing excessive whitespace
|
* Cleans and normalizes text for embeddings by removing excessive whitespace
|
||||||
@ -157,20 +347,4 @@ export abstract class BaseEmbeddingProvider implements EmbeddingProvider {
|
|||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Default implementation that converts note context to text and generates embeddings
|
|
||||||
*/
|
|
||||||
async generateNoteEmbeddings(context: NoteEmbeddingContext): Promise<Float32Array> {
|
|
||||||
const text = this.generateNoteContextText(context);
|
|
||||||
return this.generateEmbeddings(text);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Default implementation that processes notes in batch
|
|
||||||
*/
|
|
||||||
async generateBatchNoteEmbeddings(contexts: NoteEmbeddingContext[]): Promise<Float32Array[]> {
|
|
||||||
const texts = contexts.map(ctx => this.generateNoteContextText(ctx));
|
|
||||||
return this.generateBatchEmbeddings(texts);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -36,6 +36,14 @@ export interface NoteEmbeddingContext {
|
|||||||
templateTitles?: string[];
|
templateTitles?: string[];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Information about an embedding model's capabilities
|
||||||
|
*/
|
||||||
|
export interface EmbeddingModelInfo {
|
||||||
|
dimension: number;
|
||||||
|
contextWindow: number;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Configuration for how embeddings should be generated
|
* Configuration for how embeddings should be generated
|
||||||
*/
|
*/
|
||||||
@ -46,6 +54,8 @@ export interface EmbeddingConfig {
|
|||||||
normalize?: boolean;
|
normalize?: boolean;
|
||||||
batchSize?: number;
|
batchSize?: number;
|
||||||
contextWindowSize?: number;
|
contextWindowSize?: number;
|
||||||
|
apiKey?: string;
|
||||||
|
baseUrl?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1,25 +1,117 @@
|
|||||||
import { BaseEmbeddingProvider } from "../base_embeddings.js";
|
|
||||||
import type { EmbeddingConfig } from "../embeddings_interface.js";
|
|
||||||
import axios from "axios";
|
import axios from "axios";
|
||||||
import log from "../../../log.js";
|
import log from "../../../log.js";
|
||||||
|
import { BaseEmbeddingProvider } from "../base_embeddings.js";
|
||||||
|
import type { EmbeddingConfig, EmbeddingModelInfo } from "../embeddings_interface.js";
|
||||||
|
import { LLM_CONSTANTS } from "../../../../routes/api/llm.js";
|
||||||
|
|
||||||
interface AnthropicEmbeddingConfig extends EmbeddingConfig {
|
// Anthropic model context window sizes - as of current API version
|
||||||
apiKey: string;
|
const ANTHROPIC_MODEL_CONTEXT_WINDOWS: Record<string, number> = {
|
||||||
baseUrl: string;
|
"claude-3-opus-20240229": 200000,
|
||||||
}
|
"claude-3-sonnet-20240229": 180000,
|
||||||
|
"claude-3-haiku-20240307": 48000,
|
||||||
|
"claude-2.1": 200000,
|
||||||
|
"claude-2.0": 100000,
|
||||||
|
"claude-instant-1.2": 100000,
|
||||||
|
"default": 100000
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Anthropic (Claude) embedding provider implementation
|
* Anthropic embedding provider implementation
|
||||||
*/
|
*/
|
||||||
export class AnthropicEmbeddingProvider extends BaseEmbeddingProvider {
|
export class AnthropicEmbeddingProvider extends BaseEmbeddingProvider {
|
||||||
name = "anthropic";
|
name = "anthropic";
|
||||||
private apiKey: string;
|
|
||||||
private baseUrl: string;
|
|
||||||
|
|
||||||
constructor(config: AnthropicEmbeddingConfig) {
|
constructor(config: EmbeddingConfig) {
|
||||||
super(config);
|
super(config);
|
||||||
this.apiKey = config.apiKey;
|
}
|
||||||
this.baseUrl = config.baseUrl;
|
|
||||||
|
/**
|
||||||
|
* Initialize the provider by detecting model capabilities
|
||||||
|
*/
|
||||||
|
async initialize(): Promise<void> {
|
||||||
|
const modelName = this.config.model || "claude-3-haiku-20240307";
|
||||||
|
try {
|
||||||
|
// Detect model capabilities
|
||||||
|
const modelInfo = await this.getModelInfo(modelName);
|
||||||
|
|
||||||
|
// Update the config dimension
|
||||||
|
this.config.dimension = modelInfo.dimension;
|
||||||
|
|
||||||
|
log.info(`Anthropic model ${modelName} initialized with dimension ${this.config.dimension} and context window ${modelInfo.contextWindow}`);
|
||||||
|
} catch (error: any) {
|
||||||
|
log.error(`Error initializing Anthropic provider: ${error.message}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Try to determine Anthropic model capabilities
|
||||||
|
* Note: Anthropic doesn't have a public endpoint for model metadata, so we use a combination
|
||||||
|
* of known values and detection by test embeddings
|
||||||
|
*/
|
||||||
|
private async fetchModelCapabilities(modelName: string): Promise<EmbeddingModelInfo | null> {
|
||||||
|
// Anthropic doesn't have a model info endpoint, but we can look up known context sizes
|
||||||
|
// and detect embedding dimensions by making a test request
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Get context window size from our local registry of known models
|
||||||
|
const modelBase = Object.keys(ANTHROPIC_MODEL_CONTEXT_WINDOWS).find(
|
||||||
|
model => modelName.startsWith(model)
|
||||||
|
) || "default";
|
||||||
|
|
||||||
|
const contextWindow = ANTHROPIC_MODEL_CONTEXT_WINDOWS[modelBase];
|
||||||
|
|
||||||
|
// For embedding dimension, we'll return null and let getModelInfo detect it
|
||||||
|
return {
|
||||||
|
dimension: 0, // Will be detected by test embedding
|
||||||
|
contextWindow
|
||||||
|
};
|
||||||
|
} catch (error) {
|
||||||
|
log.info(`Could not determine capabilities for Anthropic model ${modelName}: ${error}`);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get model information including embedding dimensions
|
||||||
|
*/
|
||||||
|
async getModelInfo(modelName: string): Promise<EmbeddingModelInfo> {
|
||||||
|
// Check cache first
|
||||||
|
if (this.modelInfoCache.has(modelName)) {
|
||||||
|
return this.modelInfoCache.get(modelName);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to determine model capabilities
|
||||||
|
const capabilities = await this.fetchModelCapabilities(modelName);
|
||||||
|
const contextWindow = capabilities?.contextWindow || LLM_CONSTANTS.CONTEXT_WINDOW.ANTHROPIC;
|
||||||
|
|
||||||
|
// For Anthropic, we need to detect embedding dimension with a test call
|
||||||
|
try {
|
||||||
|
// Detect dimension with a test embedding
|
||||||
|
const testEmbedding = await this.generateEmbeddings("Test");
|
||||||
|
const dimension = testEmbedding.length;
|
||||||
|
|
||||||
|
const modelInfo: EmbeddingModelInfo = {
|
||||||
|
dimension,
|
||||||
|
contextWindow
|
||||||
|
};
|
||||||
|
|
||||||
|
this.modelInfoCache.set(modelName, modelInfo);
|
||||||
|
this.config.dimension = dimension;
|
||||||
|
|
||||||
|
log.info(`Detected Anthropic model ${modelName} with dimension ${dimension} (context: ${contextWindow})`);
|
||||||
|
return modelInfo;
|
||||||
|
} catch (error: any) {
|
||||||
|
// If detection fails, use defaults
|
||||||
|
const dimension = LLM_CONSTANTS.EMBEDDING_DIMENSIONS.ANTHROPIC.DEFAULT;
|
||||||
|
|
||||||
|
log.info(`Using default parameters for Anthropic model ${modelName}: dimension ${dimension}, context ${contextWindow}`);
|
||||||
|
|
||||||
|
const modelInfo: EmbeddingModelInfo = { dimension, contextWindow };
|
||||||
|
this.modelInfoCache.set(modelName, modelInfo);
|
||||||
|
this.config.dimension = dimension;
|
||||||
|
|
||||||
|
return modelInfo;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -27,11 +119,23 @@ export class AnthropicEmbeddingProvider extends BaseEmbeddingProvider {
|
|||||||
*/
|
*/
|
||||||
async generateEmbeddings(text: string): Promise<Float32Array> {
|
async generateEmbeddings(text: string): Promise<Float32Array> {
|
||||||
try {
|
try {
|
||||||
|
if (!text.trim()) {
|
||||||
|
return new Float32Array(this.config.dimension);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get model info to check context window
|
||||||
|
const modelName = this.config.model || "claude-3-haiku-20240307";
|
||||||
|
const modelInfo = await this.getModelInfo(modelName);
|
||||||
|
|
||||||
|
// Trim text if it might exceed context window (rough character estimate)
|
||||||
|
const charLimit = modelInfo.contextWindow * 4; // Rough estimate: avg 4 chars per token
|
||||||
|
const trimmedText = text.length > charLimit ? text.substring(0, charLimit) : text;
|
||||||
|
|
||||||
const response = await axios.post(
|
const response = await axios.post(
|
||||||
`${this.baseUrl}/embeddings`,
|
`${this.baseUrl}/embeddings`,
|
||||||
{
|
{
|
||||||
model: this.config.model || "claude-3-haiku-20240307",
|
model: modelName,
|
||||||
input: text,
|
input: trimmedText,
|
||||||
encoding_format: "float"
|
encoding_format: "float"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -44,8 +148,7 @@ export class AnthropicEmbeddingProvider extends BaseEmbeddingProvider {
|
|||||||
);
|
);
|
||||||
|
|
||||||
if (response.data && response.data.embedding) {
|
if (response.data && response.data.embedding) {
|
||||||
const embedding = response.data.embedding;
|
return new Float32Array(response.data.embedding);
|
||||||
return new Float32Array(embedding);
|
|
||||||
} else {
|
} else {
|
||||||
throw new Error("Unexpected response structure from Anthropic API");
|
throw new Error("Unexpected response structure from Anthropic API");
|
||||||
}
|
}
|
||||||
@ -56,23 +159,60 @@ export class AnthropicEmbeddingProvider extends BaseEmbeddingProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* More specific implementation of batch size error detection for Anthropic
|
||||||
|
*/
|
||||||
|
protected isBatchSizeError(error: any): boolean {
|
||||||
|
const errorMessage = error?.message || error?.response?.data?.error?.message || '';
|
||||||
|
const anthropicBatchSizeErrorPatterns = [
|
||||||
|
'batch size', 'too many inputs', 'context length exceeded',
|
||||||
|
'token limit', 'rate limit', 'limit exceeded',
|
||||||
|
'too long', 'request too large', 'content too large'
|
||||||
|
];
|
||||||
|
|
||||||
|
return anthropicBatchSizeErrorPatterns.some(pattern =>
|
||||||
|
errorMessage.toLowerCase().includes(pattern.toLowerCase())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate embeddings for multiple texts in a single batch
|
* Generate embeddings for multiple texts in a single batch
|
||||||
*
|
*
|
||||||
* Note: Anthropic doesn't currently support batch embedding, so we process each text individually
|
* Note: Anthropic doesn't currently support batch embedding, so we process each text individually
|
||||||
|
* but using the adaptive batch processor to handle errors and retries
|
||||||
*/
|
*/
|
||||||
async generateBatchEmbeddings(texts: string[]): Promise<Float32Array[]> {
|
async generateBatchEmbeddings(texts: string[]): Promise<Float32Array[]> {
|
||||||
if (texts.length === 0) {
|
if (texts.length === 0) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
return await this.processWithAdaptiveBatch(
|
||||||
|
texts,
|
||||||
|
async (batch) => {
|
||||||
const results: Float32Array[] = [];
|
const results: Float32Array[] = [];
|
||||||
|
|
||||||
for (const text of texts) {
|
// For Anthropic, we have to process one at a time
|
||||||
|
for (const text of batch) {
|
||||||
|
// Skip empty texts
|
||||||
|
if (!text.trim()) {
|
||||||
|
results.push(new Float32Array(this.config.dimension));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
const embedding = await this.generateEmbeddings(text);
|
const embedding = await this.generateEmbeddings(text);
|
||||||
results.push(embedding);
|
results.push(embedding);
|
||||||
}
|
}
|
||||||
|
|
||||||
return results;
|
return results;
|
||||||
|
},
|
||||||
|
this.isBatchSizeError
|
||||||
|
);
|
||||||
|
}
|
||||||
|
catch (error: any) {
|
||||||
|
const errorMessage = error.message || "Unknown error";
|
||||||
|
log.error(`Anthropic batch embedding error: ${errorMessage}`);
|
||||||
|
throw new Error(`Anthropic batch embedding error: ${errorMessage}`);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,30 +1,17 @@
|
|||||||
import { BaseEmbeddingProvider } from "../base_embeddings.js";
|
|
||||||
import type { EmbeddingConfig } from "../embeddings_interface.js";
|
|
||||||
import axios from "axios";
|
import axios from "axios";
|
||||||
import log from "../../../log.js";
|
import log from "../../../log.js";
|
||||||
|
import { BaseEmbeddingProvider } from "../base_embeddings.js";
|
||||||
interface OllamaEmbeddingConfig extends EmbeddingConfig {
|
import type { EmbeddingConfig, EmbeddingModelInfo } from "../embeddings_interface.js";
|
||||||
baseUrl: string;
|
import { LLM_CONSTANTS } from "../../../../routes/api/llm.js";
|
||||||
}
|
|
||||||
|
|
||||||
// Model-specific embedding dimensions
|
|
||||||
interface EmbeddingModelInfo {
|
|
||||||
dimension: number;
|
|
||||||
contextWindow: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Ollama embedding provider implementation
|
* Ollama embedding provider implementation
|
||||||
*/
|
*/
|
||||||
export class OllamaEmbeddingProvider extends BaseEmbeddingProvider {
|
export class OllamaEmbeddingProvider extends BaseEmbeddingProvider {
|
||||||
name = "ollama";
|
name = "ollama";
|
||||||
private baseUrl: string;
|
|
||||||
// Cache for model dimensions to avoid repeated API calls
|
|
||||||
private modelInfoCache = new Map<string, EmbeddingModelInfo>();
|
|
||||||
|
|
||||||
constructor(config: OllamaEmbeddingConfig) {
|
constructor(config: EmbeddingConfig) {
|
||||||
super(config);
|
super(config);
|
||||||
this.baseUrl = config.baseUrl;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -33,43 +20,129 @@ export class OllamaEmbeddingProvider extends BaseEmbeddingProvider {
|
|||||||
async initialize(): Promise<void> {
|
async initialize(): Promise<void> {
|
||||||
const modelName = this.config.model || "llama3";
|
const modelName = this.config.model || "llama3";
|
||||||
try {
|
try {
|
||||||
await this.getModelInfo(modelName);
|
// Detect model capabilities
|
||||||
log.info(`Ollama embedding provider initialized with model ${modelName}`);
|
const modelInfo = await this.getModelInfo(modelName);
|
||||||
|
|
||||||
|
// Update the config dimension
|
||||||
|
this.config.dimension = modelInfo.dimension;
|
||||||
|
|
||||||
|
log.info(`Ollama model ${modelName} initialized with dimension ${this.config.dimension} and context window ${modelInfo.contextWindow}`);
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
log.error(`Failed to initialize Ollama embedding provider: ${error.message}`);
|
log.error(`Error initializing Ollama provider: ${error.message}`);
|
||||||
// Still continue with default dimensions
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get model information including embedding dimensions
|
* Fetch detailed model information from Ollama API
|
||||||
|
* @param modelName The name of the model to fetch information for
|
||||||
|
*/
|
||||||
|
private async fetchModelCapabilities(modelName: string): Promise<EmbeddingModelInfo | null> {
|
||||||
|
try {
|
||||||
|
// First try the /api/show endpoint which has detailed model information
|
||||||
|
const showResponse = await axios.get(
|
||||||
|
`${this.baseUrl}/api/show`,
|
||||||
|
{
|
||||||
|
params: { name: modelName },
|
||||||
|
headers: { "Content-Type": "application/json" },
|
||||||
|
timeout: 10000
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
if (showResponse.data && showResponse.data.parameters) {
|
||||||
|
const params = showResponse.data.parameters;
|
||||||
|
// Extract context length from parameters (different models might use different parameter names)
|
||||||
|
const contextWindow = params.context_length ||
|
||||||
|
params.num_ctx ||
|
||||||
|
params.context_window ||
|
||||||
|
(LLM_CONSTANTS.OLLAMA_MODEL_CONTEXT_WINDOWS as Record<string, number>).default;
|
||||||
|
|
||||||
|
// Some models might provide embedding dimensions
|
||||||
|
const embeddingDimension = params.embedding_length || params.dim || null;
|
||||||
|
|
||||||
|
log.info(`Fetched Ollama model info from API for ${modelName}: context window ${contextWindow}`);
|
||||||
|
|
||||||
|
return {
|
||||||
|
dimension: embeddingDimension || 0, // We'll detect this separately if not provided
|
||||||
|
contextWindow: contextWindow
|
||||||
|
};
|
||||||
|
}
|
||||||
|
} catch (error: any) {
|
||||||
|
log.info(`Could not fetch model info from Ollama show API: ${error.message}. Will try embedding test.`);
|
||||||
|
// We'll fall back to embedding test if this fails
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get model information by probing the API
|
||||||
*/
|
*/
|
||||||
async getModelInfo(modelName: string): Promise<EmbeddingModelInfo> {
|
async getModelInfo(modelName: string): Promise<EmbeddingModelInfo> {
|
||||||
// Check cache first
|
// Check cache first
|
||||||
if (this.modelInfoCache.has(modelName)) {
|
if (this.modelInfoCache.has(modelName)) {
|
||||||
return this.modelInfoCache.get(modelName)!;
|
return this.modelInfoCache.get(modelName);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Default dimensions for common embedding models
|
// Try to fetch model capabilities from API
|
||||||
const defaultDimensions: Record<string, number> = {
|
const apiModelInfo = await this.fetchModelCapabilities(modelName);
|
||||||
"nomic-embed-text": 768,
|
if (apiModelInfo) {
|
||||||
"mxbai-embed-large": 1024,
|
// If we have context window but no embedding dimension, we need to detect the dimension
|
||||||
"llama3": 4096,
|
if (apiModelInfo.contextWindow && !apiModelInfo.dimension) {
|
||||||
"all-minilm": 384,
|
|
||||||
"default": 4096
|
|
||||||
};
|
|
||||||
|
|
||||||
// Default context windows
|
|
||||||
const defaultContextWindows: Record<string, number> = {
|
|
||||||
"nomic-embed-text": 8192,
|
|
||||||
"mxbai-embed-large": 8192,
|
|
||||||
"llama3": 8192,
|
|
||||||
"all-minilm": 4096,
|
|
||||||
"default": 4096
|
|
||||||
};
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Try to detect if this is an embedding model
|
// Detect dimension with a test embedding
|
||||||
|
const dimension = await this.detectEmbeddingDimension(modelName);
|
||||||
|
apiModelInfo.dimension = dimension;
|
||||||
|
} catch (error) {
|
||||||
|
// If dimension detection fails, fall back to defaults
|
||||||
|
const baseModelName = modelName.split(':')[0];
|
||||||
|
apiModelInfo.dimension = (LLM_CONSTANTS.OLLAMA_MODEL_DIMENSIONS as Record<string, number>)[baseModelName] ||
|
||||||
|
(LLM_CONSTANTS.OLLAMA_MODEL_DIMENSIONS as Record<string, number>).default;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cache and return the API-provided info
|
||||||
|
this.modelInfoCache.set(modelName, apiModelInfo);
|
||||||
|
this.config.dimension = apiModelInfo.dimension;
|
||||||
|
return apiModelInfo;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If API info fetch fails, fall back to test embedding
|
||||||
|
try {
|
||||||
|
const dimension = await this.detectEmbeddingDimension(modelName);
|
||||||
|
const baseModelName = modelName.split(':')[0];
|
||||||
|
const contextWindow = (LLM_CONSTANTS.OLLAMA_MODEL_CONTEXT_WINDOWS as Record<string, number>)[baseModelName] ||
|
||||||
|
(LLM_CONSTANTS.OLLAMA_MODEL_CONTEXT_WINDOWS as Record<string, number>).default;
|
||||||
|
|
||||||
|
const modelInfo: EmbeddingModelInfo = { dimension, contextWindow };
|
||||||
|
this.modelInfoCache.set(modelName, modelInfo);
|
||||||
|
this.config.dimension = dimension;
|
||||||
|
|
||||||
|
log.info(`Detected Ollama model ${modelName} with dimension ${dimension} (context: ${contextWindow})`);
|
||||||
|
return modelInfo;
|
||||||
|
} catch (error: any) {
|
||||||
|
log.error(`Error detecting Ollama model capabilities: ${error.message}`);
|
||||||
|
|
||||||
|
// If all detection fails, use defaults based on model name
|
||||||
|
const baseModelName = modelName.split(':')[0];
|
||||||
|
const dimension = (LLM_CONSTANTS.OLLAMA_MODEL_DIMENSIONS as Record<string, number>)[baseModelName] ||
|
||||||
|
(LLM_CONSTANTS.OLLAMA_MODEL_DIMENSIONS as Record<string, number>).default;
|
||||||
|
const contextWindow = (LLM_CONSTANTS.OLLAMA_MODEL_CONTEXT_WINDOWS as Record<string, number>)[baseModelName] ||
|
||||||
|
(LLM_CONSTANTS.OLLAMA_MODEL_CONTEXT_WINDOWS as Record<string, number>).default;
|
||||||
|
|
||||||
|
log.info(`Using default parameters for model ${modelName}: dimension ${dimension}, context ${contextWindow}`);
|
||||||
|
|
||||||
|
const modelInfo: EmbeddingModelInfo = { dimension, contextWindow };
|
||||||
|
this.modelInfoCache.set(modelName, modelInfo);
|
||||||
|
this.config.dimension = dimension;
|
||||||
|
|
||||||
|
return modelInfo;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Detect embedding dimension by making a test API call
|
||||||
|
*/
|
||||||
|
private async detectEmbeddingDimension(modelName: string): Promise<number> {
|
||||||
const testResponse = await axios.post(
|
const testResponse = await axios.post(
|
||||||
`${this.baseUrl}/api/embeddings`,
|
`${this.baseUrl}/api/embeddings`,
|
||||||
{
|
{
|
||||||
@ -82,46 +155,11 @@ export class OllamaEmbeddingProvider extends BaseEmbeddingProvider {
|
|||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
let dimension = 0;
|
|
||||||
let contextWindow = 0;
|
|
||||||
|
|
||||||
if (testResponse.data && Array.isArray(testResponse.data.embedding)) {
|
if (testResponse.data && Array.isArray(testResponse.data.embedding)) {
|
||||||
dimension = testResponse.data.embedding.length;
|
return testResponse.data.embedding.length;
|
||||||
|
|
||||||
// Set context window based on model name if we have it
|
|
||||||
const baseModelName = modelName.split(':')[0];
|
|
||||||
contextWindow = defaultContextWindows[baseModelName] || defaultContextWindows.default;
|
|
||||||
|
|
||||||
log.info(`Detected Ollama model ${modelName} with dimension ${dimension}`);
|
|
||||||
} else {
|
} else {
|
||||||
throw new Error("Could not detect embedding dimensions");
|
throw new Error("Could not detect embedding dimensions");
|
||||||
}
|
}
|
||||||
|
|
||||||
const modelInfo: EmbeddingModelInfo = { dimension, contextWindow };
|
|
||||||
this.modelInfoCache.set(modelName, modelInfo);
|
|
||||||
|
|
||||||
// Update the provider config dimension
|
|
||||||
this.config.dimension = dimension;
|
|
||||||
|
|
||||||
return modelInfo;
|
|
||||||
} catch (error: any) {
|
|
||||||
log.error(`Error detecting Ollama model capabilities: ${error.message}`);
|
|
||||||
|
|
||||||
// If detection fails, use defaults based on model name
|
|
||||||
const baseModelName = modelName.split(':')[0];
|
|
||||||
const dimension = defaultDimensions[baseModelName] || defaultDimensions.default;
|
|
||||||
const contextWindow = defaultContextWindows[baseModelName] || defaultContextWindows.default;
|
|
||||||
|
|
||||||
log.info(`Using default dimension ${dimension} for model ${modelName}`);
|
|
||||||
|
|
||||||
const modelInfo: EmbeddingModelInfo = { dimension, contextWindow };
|
|
||||||
this.modelInfoCache.set(modelName, modelInfo);
|
|
||||||
|
|
||||||
// Update the provider config dimension
|
|
||||||
this.config.dimension = dimension;
|
|
||||||
|
|
||||||
return modelInfo;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -136,6 +174,10 @@ export class OllamaEmbeddingProvider extends BaseEmbeddingProvider {
|
|||||||
*/
|
*/
|
||||||
async generateEmbeddings(text: string): Promise<Float32Array> {
|
async generateEmbeddings(text: string): Promise<Float32Array> {
|
||||||
try {
|
try {
|
||||||
|
if (!text.trim()) {
|
||||||
|
return new Float32Array(this.config.dimension);
|
||||||
|
}
|
||||||
|
|
||||||
const modelName = this.config.model || "llama3";
|
const modelName = this.config.model || "llama3";
|
||||||
|
|
||||||
// Ensure we have model info
|
// Ensure we have model info
|
||||||
@ -173,29 +215,60 @@ export class OllamaEmbeddingProvider extends BaseEmbeddingProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* More specific implementation of batch size error detection for Ollama
|
||||||
|
*/
|
||||||
|
protected isBatchSizeError(error: any): boolean {
|
||||||
|
const errorMessage = error?.message || '';
|
||||||
|
const ollamaBatchSizeErrorPatterns = [
|
||||||
|
'context length', 'token limit', 'out of memory',
|
||||||
|
'too large', 'overloaded', 'prompt too long',
|
||||||
|
'too many tokens', 'maximum size'
|
||||||
|
];
|
||||||
|
|
||||||
|
return ollamaBatchSizeErrorPatterns.some(pattern =>
|
||||||
|
errorMessage.toLowerCase().includes(pattern.toLowerCase())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate embeddings for multiple texts
|
* Generate embeddings for multiple texts
|
||||||
*
|
*
|
||||||
* Note: Ollama API doesn't support batch embedding, so we process them sequentially
|
* Note: Ollama API doesn't support batch embedding, so we process them sequentially
|
||||||
|
* but using the adaptive batch processor to handle rate limits and retries
|
||||||
*/
|
*/
|
||||||
async generateBatchEmbeddings(texts: string[]): Promise<Float32Array[]> {
|
async generateBatchEmbeddings(texts: string[]): Promise<Float32Array[]> {
|
||||||
if (texts.length === 0) {
|
if (texts.length === 0) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
return await this.processWithAdaptiveBatch(
|
||||||
|
texts,
|
||||||
|
async (batch) => {
|
||||||
const results: Float32Array[] = [];
|
const results: Float32Array[] = [];
|
||||||
|
|
||||||
for (const text of texts) {
|
// For Ollama, we have to process one at a time
|
||||||
try {
|
for (const text of batch) {
|
||||||
|
// Skip empty texts
|
||||||
|
if (!text.trim()) {
|
||||||
|
results.push(new Float32Array(this.config.dimension));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
const embedding = await this.generateEmbeddings(text);
|
const embedding = await this.generateEmbeddings(text);
|
||||||
results.push(embedding);
|
results.push(embedding);
|
||||||
} catch (error: any) {
|
}
|
||||||
const errorMessage = error.response?.data?.error?.message || error.message || "Unknown error";
|
|
||||||
|
return results;
|
||||||
|
},
|
||||||
|
this.isBatchSizeError
|
||||||
|
);
|
||||||
|
}
|
||||||
|
catch (error: any) {
|
||||||
|
const errorMessage = error.message || "Unknown error";
|
||||||
log.error(`Ollama batch embedding error: ${errorMessage}`);
|
log.error(`Ollama batch embedding error: ${errorMessage}`);
|
||||||
throw new Error(`Ollama batch embedding error: ${errorMessage}`);
|
throw new Error(`Ollama batch embedding error: ${errorMessage}`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return results;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -1,25 +1,165 @@
|
|||||||
import { BaseEmbeddingProvider } from "../base_embeddings.js";
|
|
||||||
import type { EmbeddingConfig } from "../embeddings_interface.js";
|
|
||||||
import axios from "axios";
|
import axios from "axios";
|
||||||
import log from "../../../log.js";
|
import log from "../../../log.js";
|
||||||
|
import { BaseEmbeddingProvider } from "../base_embeddings.js";
|
||||||
interface OpenAIEmbeddingConfig extends EmbeddingConfig {
|
import type { EmbeddingConfig, EmbeddingModelInfo } from "../embeddings_interface.js";
|
||||||
apiKey: string;
|
import { LLM_CONSTANTS } from "../../../../routes/api/llm.js";
|
||||||
baseUrl: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* OpenAI embedding provider implementation
|
* OpenAI embedding provider implementation
|
||||||
*/
|
*/
|
||||||
export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider {
|
export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider {
|
||||||
name = "openai";
|
name = "openai";
|
||||||
private apiKey: string;
|
|
||||||
private baseUrl: string;
|
|
||||||
|
|
||||||
constructor(config: OpenAIEmbeddingConfig) {
|
constructor(config: EmbeddingConfig) {
|
||||||
super(config);
|
super(config);
|
||||||
this.apiKey = config.apiKey;
|
}
|
||||||
this.baseUrl = config.baseUrl;
|
|
||||||
|
/**
|
||||||
|
* Initialize the provider by detecting model capabilities
|
||||||
|
*/
|
||||||
|
async initialize(): Promise<void> {
|
||||||
|
const modelName = this.config.model || "text-embedding-3-small";
|
||||||
|
try {
|
||||||
|
// Detect model capabilities
|
||||||
|
const modelInfo = await this.getModelInfo(modelName);
|
||||||
|
|
||||||
|
// Update the config dimension
|
||||||
|
this.config.dimension = modelInfo.dimension;
|
||||||
|
|
||||||
|
log.info(`OpenAI model ${modelName} initialized with dimension ${this.config.dimension} and context window ${modelInfo.contextWindow}`);
|
||||||
|
} catch (error: any) {
|
||||||
|
log.error(`Error initializing OpenAI provider: ${error.message}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fetch model information from the OpenAI API
|
||||||
|
*/
|
||||||
|
private async fetchModelCapabilities(modelName: string): Promise<EmbeddingModelInfo | null> {
|
||||||
|
if (!this.apiKey) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
// First try to get model details from the models API
|
||||||
|
const response = await axios.get(
|
||||||
|
`${this.baseUrl}/models/${modelName}`,
|
||||||
|
{
|
||||||
|
headers: {
|
||||||
|
"Authorization": `Bearer ${this.apiKey}`,
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
},
|
||||||
|
timeout: 10000
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
if (response.data) {
|
||||||
|
// Different model families may have different ways of exposing context window
|
||||||
|
let contextWindow = 0;
|
||||||
|
let dimension = 0;
|
||||||
|
|
||||||
|
// Extract context window if available
|
||||||
|
if (response.data.context_window) {
|
||||||
|
contextWindow = response.data.context_window;
|
||||||
|
} else if (response.data.limits && response.data.limits.context_window) {
|
||||||
|
contextWindow = response.data.limits.context_window;
|
||||||
|
} else if (response.data.limits && response.data.limits.context_length) {
|
||||||
|
contextWindow = response.data.limits.context_length;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract embedding dimensions if available
|
||||||
|
if (response.data.dimensions) {
|
||||||
|
dimension = response.data.dimensions;
|
||||||
|
} else if (response.data.embedding_dimension) {
|
||||||
|
dimension = response.data.embedding_dimension;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we didn't get all the info, use defaults for missing values
|
||||||
|
if (!contextWindow) {
|
||||||
|
// Set default context window based on model name patterns
|
||||||
|
if (modelName.includes('ada') || modelName.includes('embedding-ada')) {
|
||||||
|
contextWindow = LLM_CONSTANTS.CONTEXT_WINDOW.OPENAI;
|
||||||
|
} else if (modelName.includes('davinci')) {
|
||||||
|
contextWindow = 8192;
|
||||||
|
} else if (modelName.includes('embedding-3')) {
|
||||||
|
contextWindow = 8191;
|
||||||
|
} else {
|
||||||
|
contextWindow = LLM_CONSTANTS.CONTEXT_WINDOW.OPENAI;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!dimension) {
|
||||||
|
// Set default dimensions based on model name patterns
|
||||||
|
if (modelName.includes('ada') || modelName.includes('embedding-ada')) {
|
||||||
|
dimension = LLM_CONSTANTS.EMBEDDING_DIMENSIONS.OPENAI.ADA;
|
||||||
|
} else if (modelName.includes('embedding-3-small')) {
|
||||||
|
dimension = 1536;
|
||||||
|
} else if (modelName.includes('embedding-3-large')) {
|
||||||
|
dimension = 3072;
|
||||||
|
} else {
|
||||||
|
dimension = LLM_CONSTANTS.EMBEDDING_DIMENSIONS.OPENAI.DEFAULT;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.info(`Fetched OpenAI model info for ${modelName}: context window ${contextWindow}, dimension ${dimension}`);
|
||||||
|
|
||||||
|
return {
|
||||||
|
dimension,
|
||||||
|
contextWindow
|
||||||
|
};
|
||||||
|
}
|
||||||
|
} catch (error: any) {
|
||||||
|
log.info(`Could not fetch model info from OpenAI API: ${error.message}. Will try embedding test.`);
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get model information including embedding dimensions
|
||||||
|
*/
|
||||||
|
async getModelInfo(modelName: string): Promise<EmbeddingModelInfo> {
|
||||||
|
// Check cache first
|
||||||
|
if (this.modelInfoCache.has(modelName)) {
|
||||||
|
return this.modelInfoCache.get(modelName);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to fetch model capabilities from API
|
||||||
|
const apiModelInfo = await this.fetchModelCapabilities(modelName);
|
||||||
|
if (apiModelInfo) {
|
||||||
|
// Cache and return the API-provided info
|
||||||
|
this.modelInfoCache.set(modelName, apiModelInfo);
|
||||||
|
this.config.dimension = apiModelInfo.dimension;
|
||||||
|
return apiModelInfo;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If API info fetch fails, try to detect embedding dimension with a test call
|
||||||
|
try {
|
||||||
|
const testEmbedding = await this.generateEmbeddings("Test");
|
||||||
|
const dimension = testEmbedding.length;
|
||||||
|
|
||||||
|
// Use default context window
|
||||||
|
let contextWindow = LLM_CONSTANTS.CONTEXT_WINDOW.OPENAI;
|
||||||
|
|
||||||
|
const modelInfo: EmbeddingModelInfo = { dimension, contextWindow };
|
||||||
|
this.modelInfoCache.set(modelName, modelInfo);
|
||||||
|
this.config.dimension = dimension;
|
||||||
|
|
||||||
|
log.info(`Detected OpenAI model ${modelName} with dimension ${dimension} (context: ${contextWindow})`);
|
||||||
|
return modelInfo;
|
||||||
|
} catch (error: any) {
|
||||||
|
// If detection fails, use defaults
|
||||||
|
const dimension = LLM_CONSTANTS.EMBEDDING_DIMENSIONS.OPENAI.DEFAULT;
|
||||||
|
const contextWindow = LLM_CONSTANTS.CONTEXT_WINDOW.OPENAI;
|
||||||
|
|
||||||
|
log.info(`Using default parameters for OpenAI model ${modelName}: dimension ${dimension}, context ${contextWindow}`);
|
||||||
|
|
||||||
|
const modelInfo: EmbeddingModelInfo = { dimension, contextWindow };
|
||||||
|
this.modelInfoCache.set(modelName, modelInfo);
|
||||||
|
this.config.dimension = dimension;
|
||||||
|
|
||||||
|
return modelInfo;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -27,6 +167,10 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider {
|
|||||||
*/
|
*/
|
||||||
async generateEmbeddings(text: string): Promise<Float32Array> {
|
async generateEmbeddings(text: string): Promise<Float32Array> {
|
||||||
try {
|
try {
|
||||||
|
if (!text.trim()) {
|
||||||
|
return new Float32Array(this.config.dimension);
|
||||||
|
}
|
||||||
|
|
||||||
const response = await axios.post(
|
const response = await axios.post(
|
||||||
`${this.baseUrl}/embeddings`,
|
`${this.baseUrl}/embeddings`,
|
||||||
{
|
{
|
||||||
@ -43,8 +187,7 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider {
|
|||||||
);
|
);
|
||||||
|
|
||||||
if (response.data && response.data.data && response.data.data[0] && response.data.data[0].embedding) {
|
if (response.data && response.data.data && response.data.data[0] && response.data.data[0].embedding) {
|
||||||
const embedding = response.data.data[0].embedding;
|
return new Float32Array(response.data.data[0].embedding);
|
||||||
return new Float32Array(embedding);
|
|
||||||
} else {
|
} else {
|
||||||
throw new Error("Unexpected response structure from OpenAI API");
|
throw new Error("Unexpected response structure from OpenAI API");
|
||||||
}
|
}
|
||||||
@ -56,24 +199,33 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate embeddings for multiple texts in a single batch
|
* More specific implementation of batch size error detection for OpenAI
|
||||||
*/
|
*/
|
||||||
async generateBatchEmbeddings(texts: string[]): Promise<Float32Array[]> {
|
protected isBatchSizeError(error: any): boolean {
|
||||||
|
const errorMessage = error?.message || error?.response?.data?.error?.message || '';
|
||||||
|
const openAIBatchSizeErrorPatterns = [
|
||||||
|
'batch size', 'too many inputs', 'context length exceeded',
|
||||||
|
'maximum context length', 'token limit', 'rate limit exceeded',
|
||||||
|
'tokens in the messages', 'reduce the length', 'too long'
|
||||||
|
];
|
||||||
|
|
||||||
|
return openAIBatchSizeErrorPatterns.some(pattern =>
|
||||||
|
errorMessage.toLowerCase().includes(pattern.toLowerCase())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Custom implementation for batched OpenAI embeddings
|
||||||
|
*/
|
||||||
|
async generateBatchEmbeddingsWithAPI(texts: string[]): Promise<Float32Array[]> {
|
||||||
if (texts.length === 0) {
|
if (texts.length === 0) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
const batchSize = this.config.batchSize || 10;
|
|
||||||
const results: Float32Array[] = [];
|
|
||||||
|
|
||||||
// Process in batches to avoid API limits
|
|
||||||
for (let i = 0; i < texts.length; i += batchSize) {
|
|
||||||
const batch = texts.slice(i, i + batchSize);
|
|
||||||
try {
|
|
||||||
const response = await axios.post(
|
const response = await axios.post(
|
||||||
`${this.baseUrl}/embeddings`,
|
`${this.baseUrl}/embeddings`,
|
||||||
{
|
{
|
||||||
input: batch,
|
input: texts,
|
||||||
model: this.config.model || "text-embedding-3-small",
|
model: this.config.model || "text-embedding-3-small",
|
||||||
encoding_format: "float"
|
encoding_format: "float"
|
||||||
},
|
},
|
||||||
@ -91,17 +243,49 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider {
|
|||||||
.sort((a: any, b: any) => a.index - b.index)
|
.sort((a: any, b: any) => a.index - b.index)
|
||||||
.map((item: any) => new Float32Array(item.embedding));
|
.map((item: any) => new Float32Array(item.embedding));
|
||||||
|
|
||||||
results.push(...sortedEmbeddings);
|
return sortedEmbeddings;
|
||||||
} else {
|
} else {
|
||||||
throw new Error("Unexpected response structure from OpenAI API");
|
throw new Error("Unexpected response structure from OpenAI API");
|
||||||
}
|
}
|
||||||
} catch (error: any) {
|
}
|
||||||
const errorMessage = error.response?.data?.error?.message || error.message || "Unknown error";
|
|
||||||
|
/**
|
||||||
|
* Generate embeddings for multiple texts in a single batch
|
||||||
|
* OpenAI API supports batch embedding, so we implement a custom version
|
||||||
|
*/
|
||||||
|
async generateBatchEmbeddings(texts: string[]): Promise<Float32Array[]> {
|
||||||
|
if (texts.length === 0) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
return await this.processWithAdaptiveBatch(
|
||||||
|
texts,
|
||||||
|
async (batch) => {
|
||||||
|
// Filter out empty texts and use the API batch functionality
|
||||||
|
const filteredBatch = batch.filter(text => text.trim().length > 0);
|
||||||
|
|
||||||
|
if (filteredBatch.length === 0) {
|
||||||
|
// If all texts are empty after filtering, return empty embeddings
|
||||||
|
return batch.map(() => new Float32Array(this.config.dimension));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (filteredBatch.length === 1) {
|
||||||
|
// If only one text, use the single embedding endpoint
|
||||||
|
const embedding = await this.generateEmbeddings(filteredBatch[0]);
|
||||||
|
return [embedding];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the batch API endpoint
|
||||||
|
return this.generateBatchEmbeddingsWithAPI(filteredBatch);
|
||||||
|
},
|
||||||
|
this.isBatchSizeError
|
||||||
|
);
|
||||||
|
}
|
||||||
|
catch (error: any) {
|
||||||
|
const errorMessage = error.message || "Unknown error";
|
||||||
log.error(`OpenAI batch embedding error: ${errorMessage}`);
|
log.error(`OpenAI batch embedding error: ${errorMessage}`);
|
||||||
throw new Error(`OpenAI batch embedding error: ${errorMessage}`);
|
throw new Error(`OpenAI batch embedding error: ${errorMessage}`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return results;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user