From 2899707e64d6a627b1f7c347e901a6ed381df81e Mon Sep 17 00:00:00 2001 From: perf3ct Date: Fri, 28 Mar 2025 21:47:28 +0000 Subject: [PATCH] Better use of interfaces, reducing useage of "any" --- src/routes/api/llm.ts | 17 ++- .../agent_tools/contextual_thinking_tool.ts | 2 +- src/services/llm/context/content_chunking.ts | 14 +- .../llm/context/modules/context_service.ts | 62 +++++--- .../llm/context/modules/query_enhancer.ts | 17 ++- src/services/llm/context_service.ts | 3 +- .../llm/embeddings/base_embeddings.ts | 144 ++++++++++++++---- .../llm/interfaces/context_interfaces.ts | 2 +- .../llm/interfaces/embedding_interfaces.ts | 108 +++++++++++++ .../llm/interfaces/error_interfaces.ts | 78 ++++++++++ .../llm/providers/anthropic_service.ts | 39 +++-- src/services/llm/providers/ollama_service.ts | 52 +++---- 12 files changed, 427 insertions(+), 111 deletions(-) create mode 100644 src/services/llm/interfaces/embedding_interfaces.ts create mode 100644 src/services/llm/interfaces/error_interfaces.ts diff --git a/src/routes/api/llm.ts b/src/routes/api/llm.ts index 3bec82c44..693f5e839 100644 --- a/src/routes/api/llm.ts +++ b/src/routes/api/llm.ts @@ -931,7 +931,16 @@ async function sendMessage(req: Request, res: Response) { // Get the generated context const context = results.context; - sourceNotes = results.notes; + // Convert from NoteSearchResult to NoteSource + sourceNotes = results.sources.map(source => ({ + noteId: source.noteId, + title: source.title, + content: source.content || undefined, // Convert null to undefined + similarity: source.similarity + })); + + // Build context from relevant notes + const contextFromNotes = buildContextFromNotes(sourceNotes, messageContent); // Add system message with the context const contextMessage: Message = { @@ -1063,8 +1072,7 @@ async function sendMessage(req: Request, res: Response) { sources: sourceNotes.map(note => ({ noteId: note.noteId, title: note.title, - similarity: note.similarity, - branchId: note.branchId + similarity: note.similarity })) }; } @@ -1198,8 +1206,7 @@ async function sendMessage(req: Request, res: Response) { sources: sourceNotes.map(note => ({ noteId: note.noteId, title: note.title, - similarity: note.similarity, - branchId: note.branchId + similarity: note.similarity })) }; } diff --git a/src/services/llm/agent_tools/contextual_thinking_tool.ts b/src/services/llm/agent_tools/contextual_thinking_tool.ts index 113f0e165..de48778e6 100644 --- a/src/services/llm/agent_tools/contextual_thinking_tool.ts +++ b/src/services/llm/agent_tools/contextual_thinking_tool.ts @@ -28,7 +28,7 @@ export interface ThinkingStep { sources?: string[]; parentId?: string; children?: string[]; - metadata?: Record; + metadata?: Record; } /** diff --git a/src/services/llm/context/content_chunking.ts b/src/services/llm/context/content_chunking.ts index 8727001cc..c083dac48 100644 --- a/src/services/llm/context/content_chunking.ts +++ b/src/services/llm/context/content_chunking.ts @@ -12,7 +12,7 @@ export interface ContentChunk { noteId?: string; title?: string; path?: string; - metadata?: Record; + metadata?: Record; } /** @@ -43,7 +43,7 @@ export interface ChunkOptions { /** * Additional information to include in chunk metadata */ - metadata?: Record; + metadata?: Record; } /** @@ -52,7 +52,7 @@ export interface ChunkOptions { async function getDefaultChunkOptions(): Promise> { // Import constants dynamically to avoid circular dependencies const { LLM_CONSTANTS } = await import('../../../routes/api/llm.js'); - + return { maxChunkSize: LLM_CONSTANTS.CHUNKING.DEFAULT_SIZE, overlapSize: LLM_CONSTANTS.CHUNKING.DEFAULT_OVERLAP, @@ -293,3 +293,11 @@ export async function semanticChunking( return chunks; } + +export interface NoteChunk { + noteId: string; + title: string; + content: string; + type?: string; + metadata?: Record; +} diff --git a/src/services/llm/context/modules/context_service.ts b/src/services/llm/context/modules/context_service.ts index 87d97c19a..c831af597 100644 --- a/src/services/llm/context/modules/context_service.ts +++ b/src/services/llm/context/modules/context_service.ts @@ -8,6 +8,8 @@ import aiServiceManager from '../../ai_service_manager.js'; import { ContextExtractor } from '../index.js'; import { CONTEXT_PROMPTS } from '../../constants/llm_prompt_constants.js'; import becca from '../../../../becca/becca.js'; +import type { NoteSearchResult } from '../../interfaces/context_interfaces.js'; +import type { LLMServiceInterface } from '../../interfaces/agent_tool_interfaces.js'; /** * Main context service that integrates all context-related functionality @@ -73,10 +75,10 @@ export class ContextService { */ async processQuery( userQuestion: string, - llmService: any, + llmService: LLMServiceInterface, contextNoteId: string | null = null, showThinking: boolean = false - ) { + ): Promise<{ context: string; sources: NoteSearchResult[]; thinking?: string }> { log.info(`Processing query with: question="${userQuestion.substring(0, 50)}...", noteId=${contextNoteId}, showThinking=${showThinking}`); if (!this.initialized) { @@ -87,8 +89,8 @@ export class ContextService { // Return a fallback response if initialization fails return { context: CONTEXT_PROMPTS.NO_NOTES_CONTEXT, - notes: [], - queries: [userQuestion] + sources: [], + thinking: undefined }; } } @@ -105,10 +107,10 @@ export class ContextService { log.info(`Generated search queries: ${JSON.stringify(searchQueries)}`); // Step 2: Find relevant notes using multi-query approach - let relevantNotes: any[] = []; + let relevantNotes: NoteSearchResult[] = []; try { // Find notes for each query and combine results - const allResults: Map = new Map(); + const allResults: Map = new Map(); for (const query of searchQueries) { const results = await semanticSearch.findRelevantNotes( @@ -124,7 +126,7 @@ export class ContextService { } else { // If note already exists, update similarity to max of both values const existing = allResults.get(result.noteId); - if (result.similarity > existing.similarity) { + if (existing && result.similarity > existing.similarity) { existing.similarity = result.similarity; allResults.set(result.noteId, existing); } @@ -186,15 +188,15 @@ export class ContextService { return { context: enhancedContext, - notes: relevantNotes, - queries: searchQueries + sources: relevantNotes, + thinking: showThinking ? this.summarizeContextStructure(enhancedContext) : undefined }; } catch (error) { log.error(`Error processing query: ${error}`); return { context: CONTEXT_PROMPTS.NO_NOTES_CONTEXT, - notes: [], - queries: [userQuestion] + sources: [], + thinking: undefined }; } } @@ -212,7 +214,7 @@ export class ContextService { noteId: string, query: string, showThinking: boolean = false, - relevantNotes: Array = [] + relevantNotes: NoteSearchResult[] = [] ): Promise { try { log.info(`Building enhanced agent tools context for query: "${query.substring(0, 50)}...", noteId=${noteId}, showThinking=${showThinking}`); @@ -391,7 +393,7 @@ export class ContextService { // Combine the notes from both searches - the initial relevantNotes and from vector search // Start with a Map to deduplicate by noteId - const allNotes = new Map(); + const allNotes = new Map(); // Add notes from the initial search in processQuery (relevantNotes parameter) if (relevantNotes && relevantNotes.length > 0) { @@ -409,7 +411,10 @@ export class ContextService { log.info(`Adding ${vectorSearchNotes.length} notes from vector search to combined results`); for (const note of vectorSearchNotes) { // If note already exists, keep the one with higher similarity - if (!allNotes.has(note.noteId) || note.similarity > allNotes.get(note.noteId).similarity) { + const existing = allNotes.get(note.noteId); + if (existing && note.similarity > existing.similarity) { + existing.similarity = note.similarity; + } else { allNotes.set(note.noteId, note); } } @@ -831,7 +836,7 @@ export class ContextService { } // Get embeddings for the query and all chunks - const queryEmbedding = await provider.createEmbedding(query); + const queryEmbedding = await provider.generateEmbeddings(query); // Process chunks in smaller batches to avoid overwhelming the provider const batchSize = 5; @@ -840,7 +845,7 @@ export class ContextService { for (let i = 0; i < chunks.length; i += batchSize) { const batch = chunks.slice(i, i + batchSize); const batchEmbeddings = await Promise.all( - batch.map(chunk => provider.createEmbedding(chunk)) + batch.map(chunk => provider.generateEmbeddings(chunk)) ); chunkEmbeddings.push(...batchEmbeddings); } @@ -848,7 +853,8 @@ export class ContextService { // Calculate similarity between query and each chunk const similarities: Array<{index: number, similarity: number, content: string}> = chunkEmbeddings.map((embedding, index) => { - const similarity = provider.calculateSimilarity(queryEmbedding, embedding); + // Calculate cosine similarity manually since the method doesn't exist + const similarity = this.calculateCosineSimilarity(queryEmbedding, embedding); return { index, similarity, content: chunks[index] }; }); @@ -891,6 +897,28 @@ export class ContextService { return content.substring(0, maxChars) + '...'; } } + + /** + * Calculate cosine similarity between two vectors + * @param vec1 - First vector + * @param vec2 - Second vector + * @returns Cosine similarity between the two vectors + */ + private calculateCosineSimilarity(vec1: number[], vec2: number[]): number { + let dotProduct = 0; + let norm1 = 0; + let norm2 = 0; + + for (let i = 0; i < vec1.length; i++) { + dotProduct += vec1[i] * vec2[i]; + norm1 += vec1[i] * vec1[i]; + norm2 += vec2[i] * vec2[i]; + } + + const magnitude = Math.sqrt(norm1) * Math.sqrt(norm2); + if (magnitude === 0) return 0; + return dotProduct / magnitude; + } } // Export singleton instance diff --git a/src/services/llm/context/modules/query_enhancer.ts b/src/services/llm/context/modules/query_enhancer.ts index 72a2a6639..56453675c 100644 --- a/src/services/llm/context/modules/query_enhancer.ts +++ b/src/services/llm/context/modules/query_enhancer.ts @@ -2,11 +2,13 @@ import log from '../../../log.js'; import cacheManager from './cache_manager.js'; import type { Message } from '../../ai_interface.js'; import { CONTEXT_PROMPTS } from '../../constants/llm_prompt_constants.js'; +import type { LLMServiceInterface } from '../../interfaces/agent_tool_interfaces.js'; +import type { IQueryEnhancer } from '../../interfaces/context_interfaces.js'; /** * Provides utilities for enhancing queries and generating search queries */ -export class QueryEnhancer { +export class QueryEnhancer implements IQueryEnhancer { // Use the centralized query enhancer prompt private metaPrompt = CONTEXT_PROMPTS.QUERY_ENHANCER; @@ -17,11 +19,15 @@ export class QueryEnhancer { * @param llmService - The LLM service to use for generating queries * @returns Array of search queries */ - async generateSearchQueries(userQuestion: string, llmService: any): Promise { + async generateSearchQueries(userQuestion: string, llmService: LLMServiceInterface): Promise { + if (!userQuestion || userQuestion.trim() === '') { + return []; // Return empty array for empty input + } + try { - // Check cache first - const cached = cacheManager.getQueryResults(`searchQueries:${userQuestion}`); - if (cached) { + // Check cache with proper type checking + const cached = cacheManager.getQueryResults(`searchQueries:${userQuestion}`); + if (cached && Array.isArray(cached)) { return cached; } @@ -120,7 +126,6 @@ export class QueryEnhancer { } catch (error: unknown) { const errorMessage = error instanceof Error ? error.message : String(error); log.error(`Error generating search queries: ${errorMessage}`); - // Fallback to just using the original question return [userQuestion]; } } diff --git a/src/services/llm/context_service.ts b/src/services/llm/context_service.ts index 65b7d175f..4370cf731 100644 --- a/src/services/llm/context_service.ts +++ b/src/services/llm/context_service.ts @@ -8,6 +8,7 @@ import log from '../log.js'; import contextService from './context/modules/context_service.js'; import { ContextExtractor } from './context/index.js'; +import type { NoteSearchResult } from './interfaces/context_interfaces.js'; /** * Main Context Service for Trilium Notes @@ -84,7 +85,7 @@ class TriliumContextService { * @param query - The original user query * @returns Formatted context string */ - async buildContextFromNotes(sources: any[], query: string): Promise { + async buildContextFromNotes(sources: NoteSearchResult[], query: string): Promise { const provider = await (await import('./context/modules/provider_manager.js')).default.getPreferredEmbeddingProvider(); const providerId = provider?.name || 'default'; return (await import('./context/modules/context_formatter.js')).default.buildContextFromNotes(sources, query, providerId); diff --git a/src/services/llm/embeddings/base_embeddings.ts b/src/services/llm/embeddings/base_embeddings.ts index 3e1c81143..b8c17220f 100644 --- a/src/services/llm/embeddings/base_embeddings.ts +++ b/src/services/llm/embeddings/base_embeddings.ts @@ -1,23 +1,48 @@ -import type { EmbeddingProvider, EmbeddingConfig, NoteEmbeddingContext } from './embeddings_interface.js'; import { NormalizationStatus } from './embeddings_interface.js'; +import type { NoteEmbeddingContext } from './embeddings_interface.js'; import log from "../../log.js"; import { LLM_CONSTANTS } from "../../../routes/api/llm.js"; import options from "../../options.js"; +import { isBatchSizeError as checkBatchSizeError } from '../interfaces/error_interfaces.js'; +import type { EmbeddingModelInfo } from '../interfaces/embedding_interfaces.js'; + +export interface EmbeddingConfig { + model: string; + dimension: number; + type: string; + apiKey?: string; + baseUrl?: string; + batchSize?: number; + contextWidth?: number; + normalizationStatus?: NormalizationStatus; +} /** - * Base class that implements common functionality for embedding providers + * Base class for embedding providers that implements common functionality */ -export abstract class BaseEmbeddingProvider implements EmbeddingProvider { - name: string = "base"; - protected config: EmbeddingConfig; +export abstract class BaseEmbeddingProvider { + protected model: string; + protected dimension: number; + protected type: string; + protected maxBatchSize: number = 100; protected apiKey?: string; protected baseUrl: string; - protected modelInfoCache = new Map(); + protected name: string = 'base'; + protected modelInfoCache = new Map(); + protected config: EmbeddingConfig; constructor(config: EmbeddingConfig) { - this.config = config; + this.model = config.model; + this.dimension = config.dimension; + this.type = config.type; this.apiKey = config.apiKey; - this.baseUrl = config.baseUrl || ""; + this.baseUrl = config.baseUrl || ''; + this.config = config; + + // If batch size is specified, use it as maxBatchSize + if (config.batchSize) { + this.maxBatchSize = config.batchSize; + } } getConfig(): EmbeddingConfig { @@ -79,12 +104,12 @@ export abstract class BaseEmbeddingProvider implements EmbeddingProvider { * Process a batch of texts with adaptive handling * This method will try to process the batch and reduce batch size if encountering errors */ - protected async processWithAdaptiveBatch( + protected async processWithAdaptiveBatch( items: T[], - processFn: (batch: T[]) => Promise, - isBatchSizeError: (error: any) => boolean - ): Promise { - const results: any[] = []; + processFn: (batch: T[]) => Promise, + isBatchSizeError: (error: unknown) => boolean + ): Promise { + const results: R[] = []; const failures: { index: number, error: string }[] = []; let currentBatchSize = await this.getBatchSize(); let lastError: Error | null = null; @@ -99,9 +124,9 @@ export abstract class BaseEmbeddingProvider implements EmbeddingProvider { results.push(...batchResults); i += batch.length; } - catch (error: any) { - lastError = error; - const errorMessage = error.message || 'Unknown error'; + catch (error) { + lastError = error as Error; + const errorMessage = (lastError as Error).message || 'Unknown error'; // Check if this is a batch size related error if (isBatchSizeError(error) && currentBatchSize > 1) { @@ -142,17 +167,8 @@ export abstract class BaseEmbeddingProvider implements EmbeddingProvider { * Detect if an error is related to batch size limits * Override in provider-specific implementations */ - protected isBatchSizeError(error: any): boolean { - const errorMessage = error?.message || ''; - const batchSizeErrorPatterns = [ - 'batch size', 'too many items', 'too many inputs', - 'input too large', 'payload too large', 'context length', - 'token limit', 'rate limit', 'request too large' - ]; - - return batchSizeErrorPatterns.some(pattern => - errorMessage.toLowerCase().includes(pattern.toLowerCase()) - ); + protected isBatchSizeError(error: unknown): boolean { + return checkBatchSizeError(error); } /** @@ -173,11 +189,11 @@ export abstract class BaseEmbeddingProvider implements EmbeddingProvider { ); return batchResults; }, - this.isBatchSizeError + this.isBatchSizeError.bind(this) ); } - catch (error: any) { - const errorMessage = error.message || "Unknown error"; + catch (error) { + const errorMessage = (error as Error).message || "Unknown error"; log.error(`Batch embedding error for provider ${this.name}: ${errorMessage}`); throw new Error(`${this.name} batch embedding error: ${errorMessage}`); } @@ -208,11 +224,11 @@ export abstract class BaseEmbeddingProvider implements EmbeddingProvider { ); return batchResults; }, - this.isBatchSizeError + this.isBatchSizeError.bind(this) ); } - catch (error: any) { - const errorMessage = error.message || "Unknown error"; + catch (error) { + const errorMessage = (error as Error).message || "Unknown error"; log.error(`Batch note embedding error for provider ${this.name}: ${errorMessage}`); throw new Error(`${this.name} batch note embedding error: ${errorMessage}`); } @@ -357,4 +373,66 @@ export abstract class BaseEmbeddingProvider implements EmbeddingProvider { return result; } + + /** + * Process a batch of items with automatic retries and batch size adjustment + */ + protected async processBatchWithRetries( + items: T[], + processFn: (batch: T[]) => Promise, + isBatchSizeError: (error: unknown) => boolean = this.isBatchSizeError.bind(this) + ): Promise { + const results: Float32Array[] = []; + const failures: { index: number, error: string }[] = []; + let currentBatchSize = await this.getBatchSize(); + let lastError: Error | null = null; + + // Process items in batches + for (let i = 0; i < items.length;) { + const batch = items.slice(i, i + currentBatchSize); + + try { + // Process the current batch + const batchResults = await processFn(batch); + results.push(...batchResults); + i += batch.length; + } + catch (error) { + lastError = error as Error; + const errorMessage = lastError.message || 'Unknown error'; + + // Check if this is a batch size related error + if (isBatchSizeError(error) && currentBatchSize > 1) { + // Reduce batch size and retry + const newBatchSize = Math.max(1, Math.floor(currentBatchSize / 2)); + console.warn(`Batch size error detected, reducing batch size from ${currentBatchSize} to ${newBatchSize}: ${errorMessage}`); + currentBatchSize = newBatchSize; + } + else if (currentBatchSize === 1) { + // If we're already at batch size 1, we can't reduce further, so log the error and skip this item + console.error(`Error processing item at index ${i} with batch size 1: ${errorMessage}`); + failures.push({ index: i, error: errorMessage }); + i++; // Move to the next item + } + else { + // For other errors, retry with a smaller batch size as a precaution + const newBatchSize = Math.max(1, Math.floor(currentBatchSize / 2)); + console.warn(`Error processing batch, reducing batch size from ${currentBatchSize} to ${newBatchSize} as a precaution: ${errorMessage}`); + currentBatchSize = newBatchSize; + } + } + } + + // If all items failed and we have a last error, throw it + if (results.length === 0 && failures.length > 0 && lastError) { + throw lastError; + } + + // If some items failed but others succeeded, log the summary + if (failures.length > 0) { + console.warn(`Processed ${results.length} items successfully, but ${failures.length} items failed`); + } + + return results; + } } diff --git a/src/services/llm/interfaces/context_interfaces.ts b/src/services/llm/interfaces/context_interfaces.ts index 9d6fdef49..b82b6a77b 100644 --- a/src/services/llm/interfaces/context_interfaces.ts +++ b/src/services/llm/interfaces/context_interfaces.ts @@ -32,7 +32,7 @@ export interface ICacheManager { export interface NoteSearchResult { noteId: string; title: string; - content?: string; + content?: string | null; type?: string; mime?: string; similarity: number; diff --git a/src/services/llm/interfaces/embedding_interfaces.ts b/src/services/llm/interfaces/embedding_interfaces.ts new file mode 100644 index 000000000..eaecd4cf9 --- /dev/null +++ b/src/services/llm/interfaces/embedding_interfaces.ts @@ -0,0 +1,108 @@ +/** + * Interface for embedding provider configuration + */ +export interface EmbeddingProviderConfig { + name: string; + model: string; + dimension: number; + type: 'float32' | 'int8' | 'uint8' | 'float16'; + enabled?: boolean; + priority?: number; + baseUrl?: string; + apiKey?: string; + contextWidth?: number; + batchSize?: number; +} + +/** + * Interface for embedding model information + */ +export interface EmbeddingModelInfo { + name: string; + dimension: number; + contextWidth?: number; + maxBatchSize?: number; + tokenizer?: string; + type: 'float32' | 'int8' | 'uint8' | 'float16'; +} + +/** + * Interface for embedding provider + */ +export interface EmbeddingProvider { + getName(): string; + getModel(): string; + getDimension(): number; + getType(): 'float32' | 'int8' | 'uint8' | 'float16'; + isEnabled(): boolean; + getPriority(): number; + getMaxBatchSize(): number; + generateEmbedding(text: string): Promise; + generateBatchEmbeddings(texts: string[]): Promise; + initialize(): Promise; +} + +/** + * Interface for embedding process result + */ +export interface EmbeddingProcessResult { + noteId: string; + title: string; + success: boolean; + message?: string; + error?: Error; + chunks?: number; +} + +/** + * Interface for embedding queue item + */ +export interface EmbeddingQueueItem { + id: number; + noteId: string; + status: 'pending' | 'processing' | 'completed' | 'failed' | 'retrying'; + provider: string; + model: string; + dimension: number; + type: string; + attempts: number; + lastAttempt: string | null; + dateCreated: string; + dateCompleted: string | null; + error: string | null; + chunks: number; +} + +/** + * Interface for embedding batch processing + */ +export interface EmbeddingBatch { + texts: string[]; + noteIds: string[]; + indexes: number[]; +} + +/** + * Interface for embedding search result + */ +export interface EmbeddingSearchResult { + noteId: string; + similarity: number; + title?: string; + content?: string; + parentId?: string; + parentTitle?: string; + dateCreated?: string; + dateModified?: string; +} + +/** + * Interface for embedding chunk + */ +export interface EmbeddingChunk { + id: number; + noteId: string; + content: string; + embedding: Float32Array | Int8Array | Uint8Array; + metadata?: Record; +} diff --git a/src/services/llm/interfaces/error_interfaces.ts b/src/services/llm/interfaces/error_interfaces.ts new file mode 100644 index 000000000..542b497f5 --- /dev/null +++ b/src/services/llm/interfaces/error_interfaces.ts @@ -0,0 +1,78 @@ +/** + * Standard error interface for LLM services + */ +export interface LLMServiceError extends Error { + message: string; + name: string; + code?: string; + status?: number; + cause?: unknown; + stack?: string; +} + +/** + * Provider-specific error interface for OpenAI + */ +export interface OpenAIError extends LLMServiceError { + status: number; + headers?: Record; + type?: string; + code?: string; + param?: string; +} + +/** + * Provider-specific error interface for Anthropic + */ +export interface AnthropicError extends LLMServiceError { + status: number; + type?: string; +} + +/** + * Provider-specific error interface for Ollama + */ +export interface OllamaError extends LLMServiceError { + code?: string; +} + +/** + * Embedding-specific error interface + */ +export interface EmbeddingError extends LLMServiceError { + provider: string; + model?: string; + batchSize?: number; + isRetryable: boolean; +} + +/** + * Guard function to check if an error is a specific type of error + */ +export function isLLMServiceError(error: unknown): error is LLMServiceError { + return ( + typeof error === 'object' && + error !== null && + 'message' in error && + typeof (error as LLMServiceError).message === 'string' + ); +} + +/** + * Guard function to check if an error is a batch size error + */ +export function isBatchSizeError(error: unknown): boolean { + if (!isLLMServiceError(error)) { + return false; + } + + const errorMessage = error.message.toLowerCase(); + return ( + errorMessage.includes('batch size') || + errorMessage.includes('too many items') || + errorMessage.includes('too many inputs') || + errorMessage.includes('context length') || + errorMessage.includes('token limit') || + (error.code !== undefined && ['context_length_exceeded', 'token_limit_exceeded'].includes(error.code)) + ); +} diff --git a/src/services/llm/providers/anthropic_service.ts b/src/services/llm/providers/anthropic_service.ts index 83d2f9895..409987ad1 100644 --- a/src/services/llm/providers/anthropic_service.ts +++ b/src/services/llm/providers/anthropic_service.ts @@ -3,6 +3,11 @@ import { BaseAIService } from '../base_ai_service.js'; import type { ChatCompletionOptions, ChatResponse, Message } from '../ai_interface.js'; import { PROVIDER_CONSTANTS } from '../constants/provider_constants.js'; +interface AnthropicMessage { + role: string; + content: string; +} + export class AnthropicService extends BaseAIService { // Map of simplified model names to full model names with versions private static MODEL_MAPPING: Record = { @@ -87,25 +92,31 @@ export class AnthropicService extends BaseAIService { } } - private formatMessages(messages: Message[], systemPrompt: string): { messages: any[], system: string } { - // Extract system messages - const systemMessages = messages.filter(m => m.role === 'system'); - const nonSystemMessages = messages.filter(m => m.role !== 'system'); + /** + * Format messages for the Anthropic API + */ + private formatMessages(messages: Message[], systemPrompt: string): { messages: AnthropicMessage[], system: string } { + const formattedMessages: AnthropicMessage[] = []; - // Combine all system messages with our default - const combinedSystemPrompt = [systemPrompt] - .concat(systemMessages.map(m => m.content)) - .join('\n\n'); + // Extract the system message if present + let sysPrompt = systemPrompt; - // Format remaining messages for Anthropic's API - const formattedMessages = nonSystemMessages.map(m => ({ - role: m.role === 'user' ? 'user' : 'assistant', - content: m.content - })); + // Process each message + for (const msg of messages) { + if (msg.role === 'system') { + // Anthropic handles system messages separately + sysPrompt = msg.content; + } else { + formattedMessages.push({ + role: msg.role, + content: msg.content + }); + } + } return { messages: formattedMessages, - system: combinedSystemPrompt + system: sysPrompt }; } } diff --git a/src/services/llm/providers/ollama_service.ts b/src/services/llm/providers/ollama_service.ts index 6a18fd51e..61d40db39 100644 --- a/src/services/llm/providers/ollama_service.ts +++ b/src/services/llm/providers/ollama_service.ts @@ -3,6 +3,11 @@ import { BaseAIService } from '../base_ai_service.js'; import type { ChatCompletionOptions, ChatResponse, Message } from '../ai_interface.js'; import { PROVIDER_CONSTANTS } from '../constants/provider_constants.js'; +interface OllamaMessage { + role: string; + content: string; +} + export class OllamaService extends BaseAIService { constructor() { super('Ollama'); @@ -282,42 +287,29 @@ export class OllamaService extends BaseAIService { } } - private formatMessages(messages: Message[], systemPrompt: string): any[] { - console.log("Input messages for formatting:", JSON.stringify(messages, null, 2)); + /** + * Format messages for the Ollama API + */ + private formatMessages(messages: Message[], systemPrompt: string): OllamaMessage[] { + const formattedMessages: OllamaMessage[] = []; - // Check if there are any messages with empty content - const emptyMessages = messages.filter(msg => !msg.content || msg.content === "Empty message"); - if (emptyMessages.length > 0) { - console.warn("Found messages with empty content:", emptyMessages); - } - - // Add system message if it doesn't exist - const hasSystemMessage = messages.some(m => m.role === 'system'); - let resultMessages = [...messages]; - - if (!hasSystemMessage && systemPrompt) { - resultMessages.unshift({ + // Add system message if provided + if (systemPrompt) { + formattedMessages.push({ role: 'system', content: systemPrompt }); } - // Validate each message has content - resultMessages = resultMessages.map(msg => { - // Ensure each message has a valid content - if (!msg.content || typeof msg.content !== 'string') { - console.warn(`Message with role ${msg.role} has invalid content:`, msg.content); - return { - ...msg, - content: msg.content || "Empty message" - }; - } - return msg; - }); + // Add all messages + for (const msg of messages) { + // Ollama's API accepts 'user', 'assistant', and 'system' roles + formattedMessages.push({ + role: msg.role, + content: msg.content + }); + } - console.log("Formatted messages for Ollama:", JSON.stringify(resultMessages, null, 2)); - - // Ollama uses the same format as OpenAI for messages - return resultMessages; + return formattedMessages; } }