rip out openai custom implementation in favor of sdk

This commit is contained in:
perf3ct 2025-04-09 21:16:29 +00:00
parent f71351db6a
commit 6fe2b87901
No known key found for this signature in database
GPG Key ID: 569C4EEC436F5232
7 changed files with 288 additions and 135 deletions

46
package-lock.json generated
View File

@ -69,6 +69,7 @@
"normalize-strings": "1.1.1", "normalize-strings": "1.1.1",
"normalize.css": "8.0.1", "normalize.css": "8.0.1",
"ollama": "0.5.14", "ollama": "0.5.14",
"openai": "4.93.0",
"rand-token": "1.0.1", "rand-token": "1.0.1",
"safe-compare": "1.1.4", "safe-compare": "1.1.4",
"sanitize-filename": "1.6.3", "sanitize-filename": "1.6.3",
@ -16035,6 +16036,51 @@
"dev": true, "dev": true,
"license": "MIT" "license": "MIT"
}, },
"node_modules/openai": {
"version": "4.93.0",
"resolved": "https://registry.npmjs.org/openai/-/openai-4.93.0.tgz",
"integrity": "sha512-2kONcISbThKLfm7T9paVzg+QCE1FOZtNMMUfXyXckUAoXRRS/mTP89JSDHPMp8uM5s0bz28RISbvQjArD6mgUQ==",
"license": "Apache-2.0",
"dependencies": {
"@types/node": "^18.11.18",
"@types/node-fetch": "^2.6.4",
"abort-controller": "^3.0.0",
"agentkeepalive": "^4.2.1",
"form-data-encoder": "1.7.2",
"formdata-node": "^4.3.2",
"node-fetch": "^2.6.7"
},
"bin": {
"openai": "bin/cli"
},
"peerDependencies": {
"ws": "^8.18.0",
"zod": "^3.23.8"
},
"peerDependenciesMeta": {
"ws": {
"optional": true
},
"zod": {
"optional": true
}
}
},
"node_modules/openai/node_modules/@types/node": {
"version": "18.19.86",
"resolved": "https://registry.npmjs.org/@types/node/-/node-18.19.86.tgz",
"integrity": "sha512-fifKayi175wLyKyc5qUfyENhQ1dCNI1UNjp653d8kuYcPQN5JhX3dGuP/XmvPTg/xRBn1VTLpbmi+H/Mr7tLfQ==",
"license": "MIT",
"dependencies": {
"undici-types": "~5.26.4"
}
},
"node_modules/openai/node_modules/undici-types": {
"version": "5.26.5",
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz",
"integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==",
"license": "MIT"
},
"node_modules/openapi-types": { "node_modules/openapi-types": {
"version": "12.1.3", "version": "12.1.3",
"resolved": "https://registry.npmjs.org/openapi-types/-/openapi-types-12.1.3.tgz", "resolved": "https://registry.npmjs.org/openapi-types/-/openapi-types-12.1.3.tgz",

View File

@ -131,6 +131,7 @@
"normalize-strings": "1.1.1", "normalize-strings": "1.1.1",
"normalize.css": "8.0.1", "normalize.css": "8.0.1",
"ollama": "0.5.14", "ollama": "0.5.14",
"openai": "4.93.0",
"rand-token": "1.0.1", "rand-token": "1.0.1",
"safe-compare": "1.1.4", "safe-compare": "1.1.4",
"sanitize-filename": "1.6.3", "sanitize-filename": "1.6.3",

View File

@ -1,7 +1,7 @@
import axios from 'axios';
import options from "../../services/options.js"; import options from "../../services/options.js";
import log from "../../services/log.js"; import log from "../../services/log.js";
import type { Request, Response } from "express"; import type { Request, Response } from "express";
import OpenAI from "openai";
/** /**
* @swagger * @swagger
@ -69,39 +69,39 @@ async function listModels(req: Request, res: Response) {
throw new Error('OpenAI API key is not configured'); throw new Error('OpenAI API key is not configured');
} }
// Call OpenAI API to get models // Initialize OpenAI client with the API key and base URL
const response = await axios.get(`${openaiBaseUrl}/models`, { const openai = new OpenAI({
headers: { apiKey,
'Content-Type': 'application/json', baseURL: openaiBaseUrl
'Authorization': `Bearer ${apiKey}`
},
timeout: 10000
}); });
// Call OpenAI API to get models using the SDK
const response = await openai.models.list();
// Filter and categorize models // Filter and categorize models
const allModels = response.data.data || []; const allModels = response.data || [];
// Separate models into chat models and embedding models // Separate models into chat models and embedding models
const chatModels = allModels const chatModels = allModels
.filter((model: any) => .filter((model) =>
// Include GPT models for chat // Include GPT models for chat
model.id.includes('gpt') || model.id.includes('gpt') ||
// Include Claude models via Azure OpenAI // Include Claude models via Azure OpenAI
model.id.includes('claude') model.id.includes('claude')
) )
.map((model: any) => ({ .map((model) => ({
id: model.id, id: model.id,
name: model.id, name: model.id,
type: 'chat' type: 'chat'
})); }));
const embeddingModels = allModels const embeddingModels = allModels
.filter((model: any) => .filter((model) =>
// Only include embedding-specific models // Only include embedding-specific models
model.id.includes('embedding') || model.id.includes('embedding') ||
model.id.includes('embed') model.id.includes('embed')
) )
.map((model: any) => ({ .map((model) => ({
id: model.id, id: model.id,
name: model.id, name: model.id,
type: 'embedding' type: 'embedding'

View File

@ -4,15 +4,30 @@ import type { EmbeddingConfig } from "../embeddings_interface.js";
import { NormalizationStatus } from "../embeddings_interface.js"; import { NormalizationStatus } from "../embeddings_interface.js";
import { LLM_CONSTANTS } from "../../constants/provider_constants.js"; import { LLM_CONSTANTS } from "../../constants/provider_constants.js";
import type { EmbeddingModelInfo } from "../../interfaces/embedding_interfaces.js"; import type { EmbeddingModelInfo } from "../../interfaces/embedding_interfaces.js";
import OpenAI from "openai";
/** /**
* OpenAI embedding provider implementation * OpenAI embedding provider implementation using the official SDK
*/ */
export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider { export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider {
name = "openai"; name = "openai";
private client: OpenAI | null = null;
constructor(config: EmbeddingConfig) { constructor(config: EmbeddingConfig) {
super(config); super(config);
this.initClient();
}
/**
* Initialize the OpenAI client
*/
private initClient() {
if (this.apiKey) {
this.client = new OpenAI({
apiKey: this.apiKey,
baseURL: this.baseUrl
});
}
} }
/** /**
@ -21,6 +36,11 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider {
async initialize(): Promise<void> { async initialize(): Promise<void> {
const modelName = this.config.model || "text-embedding-3-small"; const modelName = this.config.model || "text-embedding-3-small";
try { try {
// Initialize client if needed
if (!this.client && this.apiKey) {
this.initClient();
}
// Detect model capabilities // Detect model capabilities
const modelInfo = await this.getModelInfo(modelName); const modelInfo = await this.getModelInfo(modelName);
@ -37,46 +57,35 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider {
* Fetch model information from the OpenAI API * Fetch model information from the OpenAI API
*/ */
private async fetchModelCapabilities(modelName: string): Promise<EmbeddingModelInfo | null> { private async fetchModelCapabilities(modelName: string): Promise<EmbeddingModelInfo | null> {
if (!this.apiKey) { if (!this.client) {
return null; return null;
} }
try { try {
// First try to get model details from the models API // Get model details using the SDK
const response = await fetch(`${this.baseUrl}/models/${modelName}`, { const model = await this.client.models.retrieve(modelName);
method: 'GET',
headers: {
"Authorization": `Bearer ${this.apiKey}`,
"Content-Type": "application/json"
},
signal: AbortSignal.timeout(10000)
});
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const data = await response.json(); if (model) {
if (data) {
// Different model families may have different ways of exposing context window // Different model families may have different ways of exposing context window
let contextWindow = 0; let contextWindow = 0;
let dimension = 0; let dimension = 0;
// Extract context window if available // Extract context window if available from the response
if (data.context_window) { const modelData = model as any;
contextWindow = data.context_window;
} else if (data.limits && data.limits.context_window) { if (modelData.context_window) {
contextWindow = data.limits.context_window; contextWindow = modelData.context_window;
} else if (data.limits && data.limits.context_length) { } else if (modelData.limits && modelData.limits.context_window) {
contextWindow = data.limits.context_length; contextWindow = modelData.limits.context_window;
} else if (modelData.limits && modelData.limits.context_length) {
contextWindow = modelData.limits.context_length;
} }
// Extract embedding dimensions if available // Extract embedding dimensions if available
if (data.dimensions) { if (modelData.dimensions) {
dimension = data.dimensions; dimension = modelData.dimensions;
} else if (data.embedding_dimension) { } else if (modelData.embedding_dimension) {
dimension = data.embedding_dimension; dimension = modelData.embedding_dimension;
} }
// If we didn't get all the info, use defaults for missing values // If we didn't get all the info, use defaults for missing values
@ -188,27 +197,21 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider {
return new Float32Array(this.config.dimension); return new Float32Array(this.config.dimension);
} }
const response = await fetch(`${this.baseUrl}/embeddings`, { if (!this.client) {
method: 'POST', this.initClient();
headers: { if (!this.client) {
"Content-Type": "application/json", throw new Error("OpenAI client initialization failed");
"Authorization": `Bearer ${this.apiKey}` }
},
body: JSON.stringify({
input: text,
model: this.config.model || "text-embedding-3-small",
encoding_format: "float"
})
});
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
} }
const response = await this.client.embeddings.create({
model: this.config.model || "text-embedding-3-small",
input: text,
encoding_format: "float"
});
const data = await response.json(); if (response && response.data && response.data[0] && response.data[0].embedding) {
return new Float32Array(response.data[0].embedding);
if (data && data.data && data.data[0] && data.data[0].embedding) {
return new Float32Array(data.data[0].embedding);
} else { } else {
throw new Error("Unexpected response structure from OpenAI API"); throw new Error("Unexpected response structure from OpenAI API");
} }
@ -243,30 +246,24 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider {
return []; return [];
} }
const response = await fetch(`${this.baseUrl}/embeddings`, { if (!this.client) {
method: 'POST', this.initClient();
headers: { if (!this.client) {
"Content-Type": "application/json", throw new Error("OpenAI client initialization failed");
"Authorization": `Bearer ${this.apiKey}` }
},
body: JSON.stringify({
input: texts,
model: this.config.model || "text-embedding-3-small",
encoding_format: "float"
})
});
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
} }
const response = await this.client.embeddings.create({
model: this.config.model || "text-embedding-3-small",
input: texts,
encoding_format: "float"
});
const data = await response.json(); if (response && response.data) {
if (data && data.data) {
// Sort the embeddings by index to ensure they match the input order // Sort the embeddings by index to ensure they match the input order
const sortedEmbeddings = data.data const sortedEmbeddings = response.data
.sort((a: any, b: any) => a.index - b.index) .sort((a, b) => a.index - b.index)
.map((item: any) => new Float32Array(item.embedding)); .map(item => new Float32Array(item.embedding));
return sortedEmbeddings; return sortedEmbeddings;
} else { } else {

View File

@ -1,7 +1,6 @@
import type { Message } from "../ai_interface.js"; import type { Message } from "../ai_interface.js";
// These imports need to be added for the factory to work // These imports need to be added for the factory to work
import { OpenAIMessageFormatter } from "../formatters/openai_formatter.js"; import { OpenAIMessageFormatter } from "../formatters/openai_formatter.js";
import { AnthropicMessageFormatter } from "../formatters/anthropic_formatter.js";
import { OllamaMessageFormatter } from "../formatters/ollama_formatter.js"; import { OllamaMessageFormatter } from "../formatters/ollama_formatter.js";
/** /**
@ -76,7 +75,8 @@ export class MessageFormatterFactory {
this.formatters[providerKey] = new OpenAIMessageFormatter(); this.formatters[providerKey] = new OpenAIMessageFormatter();
break; break;
case 'anthropic': case 'anthropic':
this.formatters[providerKey] = new AnthropicMessageFormatter(); console.warn('Anthropic formatter not available, using OpenAI formatter as fallback');
this.formatters[providerKey] = new OpenAIMessageFormatter();
break; break;
case 'ollama': case 'ollama':
this.formatters[providerKey] = new OllamaMessageFormatter(); this.formatters[providerKey] = new OllamaMessageFormatter();

View File

@ -1,11 +1,12 @@
import options from '../../options.js'; import options from '../../options.js';
import { BaseAIService } from '../base_ai_service.js'; import { BaseAIService } from '../base_ai_service.js';
import type { ChatCompletionOptions, ChatResponse, Message } from '../ai_interface.js'; import type { ChatCompletionOptions, ChatResponse, Message } from '../ai_interface.js';
import { PROVIDER_CONSTANTS } from '../constants/provider_constants.js';
import type { OpenAIOptions } from './provider_options.js';
import { getOpenAIOptions } from './providers.js'; import { getOpenAIOptions } from './providers.js';
import OpenAI from 'openai';
export class OpenAIService extends BaseAIService { export class OpenAIService extends BaseAIService {
private openai: OpenAI | null = null;
constructor() { constructor() {
super('OpenAI'); super('OpenAI');
} }
@ -14,6 +15,16 @@ export class OpenAIService extends BaseAIService {
return super.isAvailable() && !!options.getOption('openaiApiKey'); return super.isAvailable() && !!options.getOption('openaiApiKey');
} }
private getClient(apiKey: string, baseUrl?: string): OpenAI {
if (!this.openai) {
this.openai = new OpenAI({
apiKey,
baseURL: baseUrl
});
}
return this.openai;
}
async generateChatCompletion(messages: Message[], opts: ChatCompletionOptions = {}): Promise<ChatResponse> { async generateChatCompletion(messages: Message[], opts: ChatCompletionOptions = {}): Promise<ChatResponse> {
if (!this.isAvailable()) { if (!this.isAvailable()) {
throw new Error('OpenAI service is not available. Check API key and AI settings.'); throw new Error('OpenAI service is not available. Check API key and AI settings.');
@ -21,6 +32,9 @@ export class OpenAIService extends BaseAIService {
// Get provider-specific options from the central provider manager // Get provider-specific options from the central provider manager
const providerOptions = getOpenAIOptions(opts); const providerOptions = getOpenAIOptions(opts);
// Initialize the OpenAI client
const client = this.getClient(providerOptions.apiKey, providerOptions.baseUrl);
const systemPrompt = this.getSystemPrompt(providerOptions.systemPrompt || options.getOption('aiSystemPrompt')); const systemPrompt = this.getSystemPrompt(providerOptions.systemPrompt || options.getOption('aiSystemPrompt'));
@ -31,20 +45,10 @@ export class OpenAIService extends BaseAIService {
: [{ role: 'system', content: systemPrompt }, ...messages]; : [{ role: 'system', content: systemPrompt }, ...messages];
try { try {
// Fix endpoint construction - ensure we don't double up on /v1 // Create params object for the OpenAI SDK
const normalizedBaseUrl = providerOptions.baseUrl.replace(/\/+$/, ''); const params: OpenAI.Chat.ChatCompletionCreateParams = {
const endpoint = normalizedBaseUrl.includes('/v1')
? `${normalizedBaseUrl}/chat/completions`
: `${normalizedBaseUrl}/v1/chat/completions`;
// Create request body directly from provider options
const requestBody: any = {
model: providerOptions.model, model: providerOptions.model,
messages: messagesWithSystem, messages: messagesWithSystem as OpenAI.Chat.ChatCompletionMessageParam[],
};
// Extract API parameters from provider options
const apiParams = {
temperature: providerOptions.temperature, temperature: providerOptions.temperature,
max_tokens: providerOptions.max_tokens, max_tokens: providerOptions.max_tokens,
stream: providerOptions.stream, stream: providerOptions.stream,
@ -53,51 +57,138 @@ export class OpenAIService extends BaseAIService {
presence_penalty: providerOptions.presence_penalty presence_penalty: providerOptions.presence_penalty
}; };
// Merge API parameters, filtering out undefined values
Object.entries(apiParams).forEach(([key, value]) => {
if (value !== undefined) {
requestBody[key] = value;
}
});
// Add tools if enabled // Add tools if enabled
if (providerOptions.enableTools && providerOptions.tools && providerOptions.tools.length > 0) { if (providerOptions.enableTools && providerOptions.tools && providerOptions.tools.length > 0) {
requestBody.tools = providerOptions.tools; params.tools = providerOptions.tools as OpenAI.Chat.ChatCompletionTool[];
} }
if (providerOptions.tool_choice) { if (providerOptions.tool_choice) {
requestBody.tool_choice = providerOptions.tool_choice; params.tool_choice = providerOptions.tool_choice as OpenAI.Chat.ChatCompletionToolChoiceOption;
} }
const response = await fetch(endpoint, { // If streaming is requested
method: 'POST', if (providerOptions.stream) {
headers: { params.stream = true;
'Content-Type': 'application/json',
'Authorization': `Bearer ${providerOptions.apiKey}` const stream = await client.chat.completions.create(params);
}, let fullText = '';
body: JSON.stringify(requestBody)
}); // If a direct callback is provided, use it
if (providerOptions.streamCallback) {
// Process the stream with the callback
try {
// The stream is an AsyncIterable
if (Symbol.asyncIterator in stream) {
for await (const chunk of stream as AsyncIterable<OpenAI.Chat.ChatCompletionChunk>) {
const content = chunk.choices[0]?.delta?.content || '';
if (content) {
fullText += content;
await providerOptions.streamCallback(content, false, chunk);
}
// If this is the last chunk
if (chunk.choices[0]?.finish_reason) {
await providerOptions.streamCallback('', true, chunk);
}
}
} else {
console.error('Stream is not iterable, falling back to non-streaming response');
// If we get a non-streaming response somehow
if ('choices' in stream) {
const content = stream.choices[0]?.message?.content || '';
fullText = content;
if (providerOptions.streamCallback) {
await providerOptions.streamCallback(content, true, stream);
}
}
}
} catch (error) {
console.error('Error processing stream:', error);
throw error;
}
return {
text: fullText,
model: params.model,
provider: this.getName(),
usage: {} // Usage stats aren't available with streaming
};
} else {
// Use the more flexible stream interface
return {
text: '', // Initial empty text, will be filled by stream processing
model: params.model,
provider: this.getName(),
usage: {}, // Usage stats aren't available with streaming
stream: async (callback) => {
let completeText = '';
try {
// The stream is an AsyncIterable
if (Symbol.asyncIterator in stream) {
for await (const chunk of stream as AsyncIterable<OpenAI.Chat.ChatCompletionChunk>) {
const content = chunk.choices[0]?.delta?.content || '';
const isDone = !!chunk.choices[0]?.finish_reason;
if (content) {
completeText += content;
}
// Call the provided callback with the StreamChunk interface
await callback({
text: content,
done: isDone
});
if (isDone) {
break;
}
}
} else {
console.warn('Stream is not iterable, falling back to non-streaming response');
// If we get a non-streaming response somehow
if ('choices' in stream) {
const content = stream.choices[0]?.message?.content || '';
completeText = content;
await callback({
text: content,
done: true
});
}
}
} catch (error) {
console.error('Error processing stream:', error);
throw error;
}
return completeText;
}
};
}
} else {
// Non-streaming response
params.stream = false;
const completion = await client.chat.completions.create(params);
if (!('choices' in completion)) {
throw new Error('Unexpected response format from OpenAI API');
}
if (!response.ok) { return {
const errorBody = await response.text(); text: completion.choices[0].message.content || '',
throw new Error(`OpenAI API error: ${response.status} ${response.statusText} - ${errorBody}`); model: completion.model,
provider: this.getName(),
usage: {
promptTokens: completion.usage?.prompt_tokens,
completionTokens: completion.usage?.completion_tokens,
totalTokens: completion.usage?.total_tokens
},
tool_calls: completion.choices[0].message.tool_calls
};
} }
const data = await response.json();
return {
text: data.choices[0].message.content,
model: data.model,
provider: this.getName(),
usage: {
promptTokens: data.usage?.prompt_tokens,
completionTokens: data.usage?.completion_tokens,
totalTokens: data.usage?.total_tokens
},
tool_calls: data.choices[0].message.tool_calls
};
} catch (error) { } catch (error) {
console.error('OpenAI service error:', error); console.error('OpenAI service error:', error);
throw error; throw error;

View File

@ -53,6 +53,8 @@ export interface OpenAIOptions extends ProviderConfig {
// Internal control flags (not sent directly to API) // Internal control flags (not sent directly to API)
enableTools?: boolean; enableTools?: boolean;
// Streaming callback handler
streamCallback?: (text: string, isDone: boolean, originalChunk?: any) => Promise<void> | void;
} }
/** /**
@ -76,6 +78,8 @@ export interface AnthropicOptions extends ProviderConfig {
// Internal parameters (not sent directly to API) // Internal parameters (not sent directly to API)
formattedMessages?: { messages: any[], system: string }; formattedMessages?: { messages: any[], system: string };
// Streaming callback handler
streamCallback?: (text: string, isDone: boolean, originalChunk?: any) => Promise<void> | void;
} }
/** /**
@ -105,6 +109,8 @@ export interface OllamaOptions extends ProviderConfig {
preserveSystemPrompt?: boolean; preserveSystemPrompt?: boolean;
expectsJsonResponse?: boolean; expectsJsonResponse?: boolean;
toolExecutionStatus?: any[]; toolExecutionStatus?: any[];
// Streaming callback handler
streamCallback?: (text: string, isDone: boolean, originalChunk?: any) => Promise<void> | void;
} }
/** /**
@ -134,6 +140,10 @@ export function createOpenAIOptions(
// Internal configuration // Internal configuration
systemPrompt: opts.systemPrompt, systemPrompt: opts.systemPrompt,
enableTools: opts.enableTools, enableTools: opts.enableTools,
// Pass through streaming callback
streamCallback: opts.streamCallback,
// Include provider metadata
providerMetadata: opts.providerMetadata,
}; };
} }
@ -164,6 +174,10 @@ export function createAnthropicOptions(
// Internal configuration // Internal configuration
systemPrompt: opts.systemPrompt, systemPrompt: opts.systemPrompt,
// Pass through streaming callback
streamCallback: opts.streamCallback,
// Include provider metadata
providerMetadata: opts.providerMetadata,
}; };
} }
@ -198,5 +212,9 @@ export function createOllamaOptions(
preserveSystemPrompt: opts.preserveSystemPrompt, preserveSystemPrompt: opts.preserveSystemPrompt,
expectsJsonResponse: opts.expectsJsonResponse, expectsJsonResponse: opts.expectsJsonResponse,
toolExecutionStatus: opts.toolExecutionStatus, toolExecutionStatus: opts.toolExecutionStatus,
// Pass through streaming callback
streamCallback: opts.streamCallback,
// Include provider metadata
providerMetadata: opts.providerMetadata,
}; };
} }