diff --git a/package-lock.json b/package-lock.json index 9e0539ff3..aa1e86826 100644 --- a/package-lock.json +++ b/package-lock.json @@ -69,6 +69,7 @@ "normalize-strings": "1.1.1", "normalize.css": "8.0.1", "ollama": "0.5.14", + "openai": "4.93.0", "rand-token": "1.0.1", "safe-compare": "1.1.4", "sanitize-filename": "1.6.3", @@ -16035,6 +16036,51 @@ "dev": true, "license": "MIT" }, + "node_modules/openai": { + "version": "4.93.0", + "resolved": "https://registry.npmjs.org/openai/-/openai-4.93.0.tgz", + "integrity": "sha512-2kONcISbThKLfm7T9paVzg+QCE1FOZtNMMUfXyXckUAoXRRS/mTP89JSDHPMp8uM5s0bz28RISbvQjArD6mgUQ==", + "license": "Apache-2.0", + "dependencies": { + "@types/node": "^18.11.18", + "@types/node-fetch": "^2.6.4", + "abort-controller": "^3.0.0", + "agentkeepalive": "^4.2.1", + "form-data-encoder": "1.7.2", + "formdata-node": "^4.3.2", + "node-fetch": "^2.6.7" + }, + "bin": { + "openai": "bin/cli" + }, + "peerDependencies": { + "ws": "^8.18.0", + "zod": "^3.23.8" + }, + "peerDependenciesMeta": { + "ws": { + "optional": true + }, + "zod": { + "optional": true + } + } + }, + "node_modules/openai/node_modules/@types/node": { + "version": "18.19.86", + "resolved": "https://registry.npmjs.org/@types/node/-/node-18.19.86.tgz", + "integrity": "sha512-fifKayi175wLyKyc5qUfyENhQ1dCNI1UNjp653d8kuYcPQN5JhX3dGuP/XmvPTg/xRBn1VTLpbmi+H/Mr7tLfQ==", + "license": "MIT", + "dependencies": { + "undici-types": "~5.26.4" + } + }, + "node_modules/openai/node_modules/undici-types": { + "version": "5.26.5", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz", + "integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==", + "license": "MIT" + }, "node_modules/openapi-types": { "version": "12.1.3", "resolved": "https://registry.npmjs.org/openapi-types/-/openapi-types-12.1.3.tgz", diff --git a/package.json b/package.json index 7d86e2cc6..8100c770f 100644 --- a/package.json +++ b/package.json @@ -131,6 +131,7 @@ "normalize-strings": "1.1.1", "normalize.css": "8.0.1", "ollama": "0.5.14", + "openai": "4.93.0", "rand-token": "1.0.1", "safe-compare": "1.1.4", "sanitize-filename": "1.6.3", diff --git a/src/routes/api/openai.ts b/src/routes/api/openai.ts index 220264bfe..c78f183cd 100644 --- a/src/routes/api/openai.ts +++ b/src/routes/api/openai.ts @@ -1,7 +1,7 @@ -import axios from 'axios'; import options from "../../services/options.js"; import log from "../../services/log.js"; import type { Request, Response } from "express"; +import OpenAI from "openai"; /** * @swagger @@ -69,39 +69,39 @@ async function listModels(req: Request, res: Response) { throw new Error('OpenAI API key is not configured'); } - // Call OpenAI API to get models - const response = await axios.get(`${openaiBaseUrl}/models`, { - headers: { - 'Content-Type': 'application/json', - 'Authorization': `Bearer ${apiKey}` - }, - timeout: 10000 + // Initialize OpenAI client with the API key and base URL + const openai = new OpenAI({ + apiKey, + baseURL: openaiBaseUrl }); + // Call OpenAI API to get models using the SDK + const response = await openai.models.list(); + // Filter and categorize models - const allModels = response.data.data || []; + const allModels = response.data || []; // Separate models into chat models and embedding models const chatModels = allModels - .filter((model: any) => + .filter((model) => // Include GPT models for chat model.id.includes('gpt') || // Include Claude models via Azure OpenAI model.id.includes('claude') ) - .map((model: any) => ({ + .map((model) => ({ id: model.id, name: model.id, type: 'chat' })); const embeddingModels = allModels - .filter((model: any) => + .filter((model) => // Only include embedding-specific models model.id.includes('embedding') || model.id.includes('embed') ) - .map((model: any) => ({ + .map((model) => ({ id: model.id, name: model.id, type: 'embedding' diff --git a/src/services/llm/embeddings/providers/openai.ts b/src/services/llm/embeddings/providers/openai.ts index 69ed111c1..c48c0bf44 100644 --- a/src/services/llm/embeddings/providers/openai.ts +++ b/src/services/llm/embeddings/providers/openai.ts @@ -4,15 +4,30 @@ import type { EmbeddingConfig } from "../embeddings_interface.js"; import { NormalizationStatus } from "../embeddings_interface.js"; import { LLM_CONSTANTS } from "../../constants/provider_constants.js"; import type { EmbeddingModelInfo } from "../../interfaces/embedding_interfaces.js"; +import OpenAI from "openai"; /** - * OpenAI embedding provider implementation + * OpenAI embedding provider implementation using the official SDK */ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider { name = "openai"; + private client: OpenAI | null = null; constructor(config: EmbeddingConfig) { super(config); + this.initClient(); + } + + /** + * Initialize the OpenAI client + */ + private initClient() { + if (this.apiKey) { + this.client = new OpenAI({ + apiKey: this.apiKey, + baseURL: this.baseUrl + }); + } } /** @@ -21,6 +36,11 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider { async initialize(): Promise { const modelName = this.config.model || "text-embedding-3-small"; try { + // Initialize client if needed + if (!this.client && this.apiKey) { + this.initClient(); + } + // Detect model capabilities const modelInfo = await this.getModelInfo(modelName); @@ -37,46 +57,35 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider { * Fetch model information from the OpenAI API */ private async fetchModelCapabilities(modelName: string): Promise { - if (!this.apiKey) { + if (!this.client) { return null; } try { - // First try to get model details from the models API - const response = await fetch(`${this.baseUrl}/models/${modelName}`, { - method: 'GET', - headers: { - "Authorization": `Bearer ${this.apiKey}`, - "Content-Type": "application/json" - }, - signal: AbortSignal.timeout(10000) - }); - - if (!response.ok) { - throw new Error(`HTTP error! status: ${response.status}`); - } + // Get model details using the SDK + const model = await this.client.models.retrieve(modelName); - const data = await response.json(); - - if (data) { + if (model) { // Different model families may have different ways of exposing context window let contextWindow = 0; let dimension = 0; - // Extract context window if available - if (data.context_window) { - contextWindow = data.context_window; - } else if (data.limits && data.limits.context_window) { - contextWindow = data.limits.context_window; - } else if (data.limits && data.limits.context_length) { - contextWindow = data.limits.context_length; + // Extract context window if available from the response + const modelData = model as any; + + if (modelData.context_window) { + contextWindow = modelData.context_window; + } else if (modelData.limits && modelData.limits.context_window) { + contextWindow = modelData.limits.context_window; + } else if (modelData.limits && modelData.limits.context_length) { + contextWindow = modelData.limits.context_length; } // Extract embedding dimensions if available - if (data.dimensions) { - dimension = data.dimensions; - } else if (data.embedding_dimension) { - dimension = data.embedding_dimension; + if (modelData.dimensions) { + dimension = modelData.dimensions; + } else if (modelData.embedding_dimension) { + dimension = modelData.embedding_dimension; } // If we didn't get all the info, use defaults for missing values @@ -188,27 +197,21 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider { return new Float32Array(this.config.dimension); } - const response = await fetch(`${this.baseUrl}/embeddings`, { - method: 'POST', - headers: { - "Content-Type": "application/json", - "Authorization": `Bearer ${this.apiKey}` - }, - body: JSON.stringify({ - input: text, - model: this.config.model || "text-embedding-3-small", - encoding_format: "float" - }) - }); - - if (!response.ok) { - throw new Error(`HTTP error! status: ${response.status}`); + if (!this.client) { + this.initClient(); + if (!this.client) { + throw new Error("OpenAI client initialization failed"); + } } + + const response = await this.client.embeddings.create({ + model: this.config.model || "text-embedding-3-small", + input: text, + encoding_format: "float" + }); - const data = await response.json(); - - if (data && data.data && data.data[0] && data.data[0].embedding) { - return new Float32Array(data.data[0].embedding); + if (response && response.data && response.data[0] && response.data[0].embedding) { + return new Float32Array(response.data[0].embedding); } else { throw new Error("Unexpected response structure from OpenAI API"); } @@ -243,30 +246,24 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider { return []; } - const response = await fetch(`${this.baseUrl}/embeddings`, { - method: 'POST', - headers: { - "Content-Type": "application/json", - "Authorization": `Bearer ${this.apiKey}` - }, - body: JSON.stringify({ - input: texts, - model: this.config.model || "text-embedding-3-small", - encoding_format: "float" - }) - }); - - if (!response.ok) { - throw new Error(`HTTP error! status: ${response.status}`); + if (!this.client) { + this.initClient(); + if (!this.client) { + throw new Error("OpenAI client initialization failed"); + } } + + const response = await this.client.embeddings.create({ + model: this.config.model || "text-embedding-3-small", + input: texts, + encoding_format: "float" + }); - const data = await response.json(); - - if (data && data.data) { + if (response && response.data) { // Sort the embeddings by index to ensure they match the input order - const sortedEmbeddings = data.data - .sort((a: any, b: any) => a.index - b.index) - .map((item: any) => new Float32Array(item.embedding)); + const sortedEmbeddings = response.data + .sort((a, b) => a.index - b.index) + .map(item => new Float32Array(item.embedding)); return sortedEmbeddings; } else { diff --git a/src/services/llm/interfaces/message_formatter.ts b/src/services/llm/interfaces/message_formatter.ts index 634c70cd8..3ec387d0a 100644 --- a/src/services/llm/interfaces/message_formatter.ts +++ b/src/services/llm/interfaces/message_formatter.ts @@ -1,7 +1,6 @@ import type { Message } from "../ai_interface.js"; // These imports need to be added for the factory to work import { OpenAIMessageFormatter } from "../formatters/openai_formatter.js"; -import { AnthropicMessageFormatter } from "../formatters/anthropic_formatter.js"; import { OllamaMessageFormatter } from "../formatters/ollama_formatter.js"; /** @@ -76,7 +75,8 @@ export class MessageFormatterFactory { this.formatters[providerKey] = new OpenAIMessageFormatter(); break; case 'anthropic': - this.formatters[providerKey] = new AnthropicMessageFormatter(); + console.warn('Anthropic formatter not available, using OpenAI formatter as fallback'); + this.formatters[providerKey] = new OpenAIMessageFormatter(); break; case 'ollama': this.formatters[providerKey] = new OllamaMessageFormatter(); diff --git a/src/services/llm/providers/openai_service.ts b/src/services/llm/providers/openai_service.ts index 35e159a2f..15e68ca1d 100644 --- a/src/services/llm/providers/openai_service.ts +++ b/src/services/llm/providers/openai_service.ts @@ -1,11 +1,12 @@ import options from '../../options.js'; import { BaseAIService } from '../base_ai_service.js'; import type { ChatCompletionOptions, ChatResponse, Message } from '../ai_interface.js'; -import { PROVIDER_CONSTANTS } from '../constants/provider_constants.js'; -import type { OpenAIOptions } from './provider_options.js'; import { getOpenAIOptions } from './providers.js'; +import OpenAI from 'openai'; export class OpenAIService extends BaseAIService { + private openai: OpenAI | null = null; + constructor() { super('OpenAI'); } @@ -14,6 +15,16 @@ export class OpenAIService extends BaseAIService { return super.isAvailable() && !!options.getOption('openaiApiKey'); } + private getClient(apiKey: string, baseUrl?: string): OpenAI { + if (!this.openai) { + this.openai = new OpenAI({ + apiKey, + baseURL: baseUrl + }); + } + return this.openai; + } + async generateChatCompletion(messages: Message[], opts: ChatCompletionOptions = {}): Promise { if (!this.isAvailable()) { throw new Error('OpenAI service is not available. Check API key and AI settings.'); @@ -21,6 +32,9 @@ export class OpenAIService extends BaseAIService { // Get provider-specific options from the central provider manager const providerOptions = getOpenAIOptions(opts); + + // Initialize the OpenAI client + const client = this.getClient(providerOptions.apiKey, providerOptions.baseUrl); const systemPrompt = this.getSystemPrompt(providerOptions.systemPrompt || options.getOption('aiSystemPrompt')); @@ -31,20 +45,10 @@ export class OpenAIService extends BaseAIService { : [{ role: 'system', content: systemPrompt }, ...messages]; try { - // Fix endpoint construction - ensure we don't double up on /v1 - const normalizedBaseUrl = providerOptions.baseUrl.replace(/\/+$/, ''); - const endpoint = normalizedBaseUrl.includes('/v1') - ? `${normalizedBaseUrl}/chat/completions` - : `${normalizedBaseUrl}/v1/chat/completions`; - - // Create request body directly from provider options - const requestBody: any = { + // Create params object for the OpenAI SDK + const params: OpenAI.Chat.ChatCompletionCreateParams = { model: providerOptions.model, - messages: messagesWithSystem, - }; - - // Extract API parameters from provider options - const apiParams = { + messages: messagesWithSystem as OpenAI.Chat.ChatCompletionMessageParam[], temperature: providerOptions.temperature, max_tokens: providerOptions.max_tokens, stream: providerOptions.stream, @@ -53,51 +57,138 @@ export class OpenAIService extends BaseAIService { presence_penalty: providerOptions.presence_penalty }; - - - // Merge API parameters, filtering out undefined values - Object.entries(apiParams).forEach(([key, value]) => { - if (value !== undefined) { - requestBody[key] = value; - } - }); - // Add tools if enabled if (providerOptions.enableTools && providerOptions.tools && providerOptions.tools.length > 0) { - requestBody.tools = providerOptions.tools; + params.tools = providerOptions.tools as OpenAI.Chat.ChatCompletionTool[]; } if (providerOptions.tool_choice) { - requestBody.tool_choice = providerOptions.tool_choice; + params.tool_choice = providerOptions.tool_choice as OpenAI.Chat.ChatCompletionToolChoiceOption; } - const response = await fetch(endpoint, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - 'Authorization': `Bearer ${providerOptions.apiKey}` - }, - body: JSON.stringify(requestBody) - }); + // If streaming is requested + if (providerOptions.stream) { + params.stream = true; + + const stream = await client.chat.completions.create(params); + let fullText = ''; + + // If a direct callback is provided, use it + if (providerOptions.streamCallback) { + // Process the stream with the callback + try { + // The stream is an AsyncIterable + if (Symbol.asyncIterator in stream) { + for await (const chunk of stream as AsyncIterable) { + const content = chunk.choices[0]?.delta?.content || ''; + if (content) { + fullText += content; + await providerOptions.streamCallback(content, false, chunk); + } + + // If this is the last chunk + if (chunk.choices[0]?.finish_reason) { + await providerOptions.streamCallback('', true, chunk); + } + } + } else { + console.error('Stream is not iterable, falling back to non-streaming response'); + + // If we get a non-streaming response somehow + if ('choices' in stream) { + const content = stream.choices[0]?.message?.content || ''; + fullText = content; + if (providerOptions.streamCallback) { + await providerOptions.streamCallback(content, true, stream); + } + } + } + } catch (error) { + console.error('Error processing stream:', error); + throw error; + } + + return { + text: fullText, + model: params.model, + provider: this.getName(), + usage: {} // Usage stats aren't available with streaming + }; + } else { + // Use the more flexible stream interface + return { + text: '', // Initial empty text, will be filled by stream processing + model: params.model, + provider: this.getName(), + usage: {}, // Usage stats aren't available with streaming + stream: async (callback) => { + let completeText = ''; + + try { + // The stream is an AsyncIterable + if (Symbol.asyncIterator in stream) { + for await (const chunk of stream as AsyncIterable) { + const content = chunk.choices[0]?.delta?.content || ''; + const isDone = !!chunk.choices[0]?.finish_reason; + + if (content) { + completeText += content; + } + + // Call the provided callback with the StreamChunk interface + await callback({ + text: content, + done: isDone + }); + + if (isDone) { + break; + } + } + } else { + console.warn('Stream is not iterable, falling back to non-streaming response'); + + // If we get a non-streaming response somehow + if ('choices' in stream) { + const content = stream.choices[0]?.message?.content || ''; + completeText = content; + await callback({ + text: content, + done: true + }); + } + } + } catch (error) { + console.error('Error processing stream:', error); + throw error; + } + + return completeText; + } + }; + } + } else { + // Non-streaming response + params.stream = false; + + const completion = await client.chat.completions.create(params); + + if (!('choices' in completion)) { + throw new Error('Unexpected response format from OpenAI API'); + } - if (!response.ok) { - const errorBody = await response.text(); - throw new Error(`OpenAI API error: ${response.status} ${response.statusText} - ${errorBody}`); + return { + text: completion.choices[0].message.content || '', + model: completion.model, + provider: this.getName(), + usage: { + promptTokens: completion.usage?.prompt_tokens, + completionTokens: completion.usage?.completion_tokens, + totalTokens: completion.usage?.total_tokens + }, + tool_calls: completion.choices[0].message.tool_calls + }; } - - const data = await response.json(); - - return { - text: data.choices[0].message.content, - model: data.model, - provider: this.getName(), - usage: { - promptTokens: data.usage?.prompt_tokens, - completionTokens: data.usage?.completion_tokens, - totalTokens: data.usage?.total_tokens - }, - tool_calls: data.choices[0].message.tool_calls - }; } catch (error) { console.error('OpenAI service error:', error); throw error; diff --git a/src/services/llm/providers/provider_options.ts b/src/services/llm/providers/provider_options.ts index c765c2738..b83c6f99a 100644 --- a/src/services/llm/providers/provider_options.ts +++ b/src/services/llm/providers/provider_options.ts @@ -53,6 +53,8 @@ export interface OpenAIOptions extends ProviderConfig { // Internal control flags (not sent directly to API) enableTools?: boolean; + // Streaming callback handler + streamCallback?: (text: string, isDone: boolean, originalChunk?: any) => Promise | void; } /** @@ -76,6 +78,8 @@ export interface AnthropicOptions extends ProviderConfig { // Internal parameters (not sent directly to API) formattedMessages?: { messages: any[], system: string }; + // Streaming callback handler + streamCallback?: (text: string, isDone: boolean, originalChunk?: any) => Promise | void; } /** @@ -105,6 +109,8 @@ export interface OllamaOptions extends ProviderConfig { preserveSystemPrompt?: boolean; expectsJsonResponse?: boolean; toolExecutionStatus?: any[]; + // Streaming callback handler + streamCallback?: (text: string, isDone: boolean, originalChunk?: any) => Promise | void; } /** @@ -134,6 +140,10 @@ export function createOpenAIOptions( // Internal configuration systemPrompt: opts.systemPrompt, enableTools: opts.enableTools, + // Pass through streaming callback + streamCallback: opts.streamCallback, + // Include provider metadata + providerMetadata: opts.providerMetadata, }; } @@ -164,6 +174,10 @@ export function createAnthropicOptions( // Internal configuration systemPrompt: opts.systemPrompt, + // Pass through streaming callback + streamCallback: opts.streamCallback, + // Include provider metadata + providerMetadata: opts.providerMetadata, }; } @@ -198,5 +212,9 @@ export function createOllamaOptions( preserveSystemPrompt: opts.preserveSystemPrompt, expectsJsonResponse: opts.expectsJsonResponse, toolExecutionStatus: opts.toolExecutionStatus, + // Pass through streaming callback + streamCallback: opts.streamCallback, + // Include provider metadata + providerMetadata: opts.providerMetadata, }; }