mirror of
https://github.com/TriliumNext/Notes.git
synced 2025-07-27 18:12:29 +08:00
rip out openai custom implementation in favor of sdk
This commit is contained in:
parent
f71351db6a
commit
6fe2b87901
46
package-lock.json
generated
46
package-lock.json
generated
@ -69,6 +69,7 @@
|
|||||||
"normalize-strings": "1.1.1",
|
"normalize-strings": "1.1.1",
|
||||||
"normalize.css": "8.0.1",
|
"normalize.css": "8.0.1",
|
||||||
"ollama": "0.5.14",
|
"ollama": "0.5.14",
|
||||||
|
"openai": "4.93.0",
|
||||||
"rand-token": "1.0.1",
|
"rand-token": "1.0.1",
|
||||||
"safe-compare": "1.1.4",
|
"safe-compare": "1.1.4",
|
||||||
"sanitize-filename": "1.6.3",
|
"sanitize-filename": "1.6.3",
|
||||||
@ -16035,6 +16036,51 @@
|
|||||||
"dev": true,
|
"dev": true,
|
||||||
"license": "MIT"
|
"license": "MIT"
|
||||||
},
|
},
|
||||||
|
"node_modules/openai": {
|
||||||
|
"version": "4.93.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/openai/-/openai-4.93.0.tgz",
|
||||||
|
"integrity": "sha512-2kONcISbThKLfm7T9paVzg+QCE1FOZtNMMUfXyXckUAoXRRS/mTP89JSDHPMp8uM5s0bz28RISbvQjArD6mgUQ==",
|
||||||
|
"license": "Apache-2.0",
|
||||||
|
"dependencies": {
|
||||||
|
"@types/node": "^18.11.18",
|
||||||
|
"@types/node-fetch": "^2.6.4",
|
||||||
|
"abort-controller": "^3.0.0",
|
||||||
|
"agentkeepalive": "^4.2.1",
|
||||||
|
"form-data-encoder": "1.7.2",
|
||||||
|
"formdata-node": "^4.3.2",
|
||||||
|
"node-fetch": "^2.6.7"
|
||||||
|
},
|
||||||
|
"bin": {
|
||||||
|
"openai": "bin/cli"
|
||||||
|
},
|
||||||
|
"peerDependencies": {
|
||||||
|
"ws": "^8.18.0",
|
||||||
|
"zod": "^3.23.8"
|
||||||
|
},
|
||||||
|
"peerDependenciesMeta": {
|
||||||
|
"ws": {
|
||||||
|
"optional": true
|
||||||
|
},
|
||||||
|
"zod": {
|
||||||
|
"optional": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/openai/node_modules/@types/node": {
|
||||||
|
"version": "18.19.86",
|
||||||
|
"resolved": "https://registry.npmjs.org/@types/node/-/node-18.19.86.tgz",
|
||||||
|
"integrity": "sha512-fifKayi175wLyKyc5qUfyENhQ1dCNI1UNjp653d8kuYcPQN5JhX3dGuP/XmvPTg/xRBn1VTLpbmi+H/Mr7tLfQ==",
|
||||||
|
"license": "MIT",
|
||||||
|
"dependencies": {
|
||||||
|
"undici-types": "~5.26.4"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/openai/node_modules/undici-types": {
|
||||||
|
"version": "5.26.5",
|
||||||
|
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz",
|
||||||
|
"integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==",
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
"node_modules/openapi-types": {
|
"node_modules/openapi-types": {
|
||||||
"version": "12.1.3",
|
"version": "12.1.3",
|
||||||
"resolved": "https://registry.npmjs.org/openapi-types/-/openapi-types-12.1.3.tgz",
|
"resolved": "https://registry.npmjs.org/openapi-types/-/openapi-types-12.1.3.tgz",
|
||||||
|
@ -131,6 +131,7 @@
|
|||||||
"normalize-strings": "1.1.1",
|
"normalize-strings": "1.1.1",
|
||||||
"normalize.css": "8.0.1",
|
"normalize.css": "8.0.1",
|
||||||
"ollama": "0.5.14",
|
"ollama": "0.5.14",
|
||||||
|
"openai": "4.93.0",
|
||||||
"rand-token": "1.0.1",
|
"rand-token": "1.0.1",
|
||||||
"safe-compare": "1.1.4",
|
"safe-compare": "1.1.4",
|
||||||
"sanitize-filename": "1.6.3",
|
"sanitize-filename": "1.6.3",
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import axios from 'axios';
|
|
||||||
import options from "../../services/options.js";
|
import options from "../../services/options.js";
|
||||||
import log from "../../services/log.js";
|
import log from "../../services/log.js";
|
||||||
import type { Request, Response } from "express";
|
import type { Request, Response } from "express";
|
||||||
|
import OpenAI from "openai";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @swagger
|
* @swagger
|
||||||
@ -69,39 +69,39 @@ async function listModels(req: Request, res: Response) {
|
|||||||
throw new Error('OpenAI API key is not configured');
|
throw new Error('OpenAI API key is not configured');
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call OpenAI API to get models
|
// Initialize OpenAI client with the API key and base URL
|
||||||
const response = await axios.get(`${openaiBaseUrl}/models`, {
|
const openai = new OpenAI({
|
||||||
headers: {
|
apiKey,
|
||||||
'Content-Type': 'application/json',
|
baseURL: openaiBaseUrl
|
||||||
'Authorization': `Bearer ${apiKey}`
|
|
||||||
},
|
|
||||||
timeout: 10000
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Call OpenAI API to get models using the SDK
|
||||||
|
const response = await openai.models.list();
|
||||||
|
|
||||||
// Filter and categorize models
|
// Filter and categorize models
|
||||||
const allModels = response.data.data || [];
|
const allModels = response.data || [];
|
||||||
|
|
||||||
// Separate models into chat models and embedding models
|
// Separate models into chat models and embedding models
|
||||||
const chatModels = allModels
|
const chatModels = allModels
|
||||||
.filter((model: any) =>
|
.filter((model) =>
|
||||||
// Include GPT models for chat
|
// Include GPT models for chat
|
||||||
model.id.includes('gpt') ||
|
model.id.includes('gpt') ||
|
||||||
// Include Claude models via Azure OpenAI
|
// Include Claude models via Azure OpenAI
|
||||||
model.id.includes('claude')
|
model.id.includes('claude')
|
||||||
)
|
)
|
||||||
.map((model: any) => ({
|
.map((model) => ({
|
||||||
id: model.id,
|
id: model.id,
|
||||||
name: model.id,
|
name: model.id,
|
||||||
type: 'chat'
|
type: 'chat'
|
||||||
}));
|
}));
|
||||||
|
|
||||||
const embeddingModels = allModels
|
const embeddingModels = allModels
|
||||||
.filter((model: any) =>
|
.filter((model) =>
|
||||||
// Only include embedding-specific models
|
// Only include embedding-specific models
|
||||||
model.id.includes('embedding') ||
|
model.id.includes('embedding') ||
|
||||||
model.id.includes('embed')
|
model.id.includes('embed')
|
||||||
)
|
)
|
||||||
.map((model: any) => ({
|
.map((model) => ({
|
||||||
id: model.id,
|
id: model.id,
|
||||||
name: model.id,
|
name: model.id,
|
||||||
type: 'embedding'
|
type: 'embedding'
|
||||||
|
@ -4,15 +4,30 @@ import type { EmbeddingConfig } from "../embeddings_interface.js";
|
|||||||
import { NormalizationStatus } from "../embeddings_interface.js";
|
import { NormalizationStatus } from "../embeddings_interface.js";
|
||||||
import { LLM_CONSTANTS } from "../../constants/provider_constants.js";
|
import { LLM_CONSTANTS } from "../../constants/provider_constants.js";
|
||||||
import type { EmbeddingModelInfo } from "../../interfaces/embedding_interfaces.js";
|
import type { EmbeddingModelInfo } from "../../interfaces/embedding_interfaces.js";
|
||||||
|
import OpenAI from "openai";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* OpenAI embedding provider implementation
|
* OpenAI embedding provider implementation using the official SDK
|
||||||
*/
|
*/
|
||||||
export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider {
|
export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider {
|
||||||
name = "openai";
|
name = "openai";
|
||||||
|
private client: OpenAI | null = null;
|
||||||
|
|
||||||
constructor(config: EmbeddingConfig) {
|
constructor(config: EmbeddingConfig) {
|
||||||
super(config);
|
super(config);
|
||||||
|
this.initClient();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialize the OpenAI client
|
||||||
|
*/
|
||||||
|
private initClient() {
|
||||||
|
if (this.apiKey) {
|
||||||
|
this.client = new OpenAI({
|
||||||
|
apiKey: this.apiKey,
|
||||||
|
baseURL: this.baseUrl
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -21,6 +36,11 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider {
|
|||||||
async initialize(): Promise<void> {
|
async initialize(): Promise<void> {
|
||||||
const modelName = this.config.model || "text-embedding-3-small";
|
const modelName = this.config.model || "text-embedding-3-small";
|
||||||
try {
|
try {
|
||||||
|
// Initialize client if needed
|
||||||
|
if (!this.client && this.apiKey) {
|
||||||
|
this.initClient();
|
||||||
|
}
|
||||||
|
|
||||||
// Detect model capabilities
|
// Detect model capabilities
|
||||||
const modelInfo = await this.getModelInfo(modelName);
|
const modelInfo = await this.getModelInfo(modelName);
|
||||||
|
|
||||||
@ -37,46 +57,35 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider {
|
|||||||
* Fetch model information from the OpenAI API
|
* Fetch model information from the OpenAI API
|
||||||
*/
|
*/
|
||||||
private async fetchModelCapabilities(modelName: string): Promise<EmbeddingModelInfo | null> {
|
private async fetchModelCapabilities(modelName: string): Promise<EmbeddingModelInfo | null> {
|
||||||
if (!this.apiKey) {
|
if (!this.client) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// First try to get model details from the models API
|
// Get model details using the SDK
|
||||||
const response = await fetch(`${this.baseUrl}/models/${modelName}`, {
|
const model = await this.client.models.retrieve(modelName);
|
||||||
method: 'GET',
|
|
||||||
headers: {
|
|
||||||
"Authorization": `Bearer ${this.apiKey}`,
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
},
|
|
||||||
signal: AbortSignal.timeout(10000)
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!response.ok) {
|
|
||||||
throw new Error(`HTTP error! status: ${response.status}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
const data = await response.json();
|
if (model) {
|
||||||
|
|
||||||
if (data) {
|
|
||||||
// Different model families may have different ways of exposing context window
|
// Different model families may have different ways of exposing context window
|
||||||
let contextWindow = 0;
|
let contextWindow = 0;
|
||||||
let dimension = 0;
|
let dimension = 0;
|
||||||
|
|
||||||
// Extract context window if available
|
// Extract context window if available from the response
|
||||||
if (data.context_window) {
|
const modelData = model as any;
|
||||||
contextWindow = data.context_window;
|
|
||||||
} else if (data.limits && data.limits.context_window) {
|
if (modelData.context_window) {
|
||||||
contextWindow = data.limits.context_window;
|
contextWindow = modelData.context_window;
|
||||||
} else if (data.limits && data.limits.context_length) {
|
} else if (modelData.limits && modelData.limits.context_window) {
|
||||||
contextWindow = data.limits.context_length;
|
contextWindow = modelData.limits.context_window;
|
||||||
|
} else if (modelData.limits && modelData.limits.context_length) {
|
||||||
|
contextWindow = modelData.limits.context_length;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract embedding dimensions if available
|
// Extract embedding dimensions if available
|
||||||
if (data.dimensions) {
|
if (modelData.dimensions) {
|
||||||
dimension = data.dimensions;
|
dimension = modelData.dimensions;
|
||||||
} else if (data.embedding_dimension) {
|
} else if (modelData.embedding_dimension) {
|
||||||
dimension = data.embedding_dimension;
|
dimension = modelData.embedding_dimension;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we didn't get all the info, use defaults for missing values
|
// If we didn't get all the info, use defaults for missing values
|
||||||
@ -188,27 +197,21 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider {
|
|||||||
return new Float32Array(this.config.dimension);
|
return new Float32Array(this.config.dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = await fetch(`${this.baseUrl}/embeddings`, {
|
if (!this.client) {
|
||||||
method: 'POST',
|
this.initClient();
|
||||||
headers: {
|
if (!this.client) {
|
||||||
"Content-Type": "application/json",
|
throw new Error("OpenAI client initialization failed");
|
||||||
"Authorization": `Bearer ${this.apiKey}`
|
}
|
||||||
},
|
|
||||||
body: JSON.stringify({
|
|
||||||
input: text,
|
|
||||||
model: this.config.model || "text-embedding-3-small",
|
|
||||||
encoding_format: "float"
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!response.ok) {
|
|
||||||
throw new Error(`HTTP error! status: ${response.status}`);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const response = await this.client.embeddings.create({
|
||||||
|
model: this.config.model || "text-embedding-3-small",
|
||||||
|
input: text,
|
||||||
|
encoding_format: "float"
|
||||||
|
});
|
||||||
|
|
||||||
const data = await response.json();
|
if (response && response.data && response.data[0] && response.data[0].embedding) {
|
||||||
|
return new Float32Array(response.data[0].embedding);
|
||||||
if (data && data.data && data.data[0] && data.data[0].embedding) {
|
|
||||||
return new Float32Array(data.data[0].embedding);
|
|
||||||
} else {
|
} else {
|
||||||
throw new Error("Unexpected response structure from OpenAI API");
|
throw new Error("Unexpected response structure from OpenAI API");
|
||||||
}
|
}
|
||||||
@ -243,30 +246,24 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider {
|
|||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = await fetch(`${this.baseUrl}/embeddings`, {
|
if (!this.client) {
|
||||||
method: 'POST',
|
this.initClient();
|
||||||
headers: {
|
if (!this.client) {
|
||||||
"Content-Type": "application/json",
|
throw new Error("OpenAI client initialization failed");
|
||||||
"Authorization": `Bearer ${this.apiKey}`
|
}
|
||||||
},
|
|
||||||
body: JSON.stringify({
|
|
||||||
input: texts,
|
|
||||||
model: this.config.model || "text-embedding-3-small",
|
|
||||||
encoding_format: "float"
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!response.ok) {
|
|
||||||
throw new Error(`HTTP error! status: ${response.status}`);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const response = await this.client.embeddings.create({
|
||||||
|
model: this.config.model || "text-embedding-3-small",
|
||||||
|
input: texts,
|
||||||
|
encoding_format: "float"
|
||||||
|
});
|
||||||
|
|
||||||
const data = await response.json();
|
if (response && response.data) {
|
||||||
|
|
||||||
if (data && data.data) {
|
|
||||||
// Sort the embeddings by index to ensure they match the input order
|
// Sort the embeddings by index to ensure they match the input order
|
||||||
const sortedEmbeddings = data.data
|
const sortedEmbeddings = response.data
|
||||||
.sort((a: any, b: any) => a.index - b.index)
|
.sort((a, b) => a.index - b.index)
|
||||||
.map((item: any) => new Float32Array(item.embedding));
|
.map(item => new Float32Array(item.embedding));
|
||||||
|
|
||||||
return sortedEmbeddings;
|
return sortedEmbeddings;
|
||||||
} else {
|
} else {
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import type { Message } from "../ai_interface.js";
|
import type { Message } from "../ai_interface.js";
|
||||||
// These imports need to be added for the factory to work
|
// These imports need to be added for the factory to work
|
||||||
import { OpenAIMessageFormatter } from "../formatters/openai_formatter.js";
|
import { OpenAIMessageFormatter } from "../formatters/openai_formatter.js";
|
||||||
import { AnthropicMessageFormatter } from "../formatters/anthropic_formatter.js";
|
|
||||||
import { OllamaMessageFormatter } from "../formatters/ollama_formatter.js";
|
import { OllamaMessageFormatter } from "../formatters/ollama_formatter.js";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -76,7 +75,8 @@ export class MessageFormatterFactory {
|
|||||||
this.formatters[providerKey] = new OpenAIMessageFormatter();
|
this.formatters[providerKey] = new OpenAIMessageFormatter();
|
||||||
break;
|
break;
|
||||||
case 'anthropic':
|
case 'anthropic':
|
||||||
this.formatters[providerKey] = new AnthropicMessageFormatter();
|
console.warn('Anthropic formatter not available, using OpenAI formatter as fallback');
|
||||||
|
this.formatters[providerKey] = new OpenAIMessageFormatter();
|
||||||
break;
|
break;
|
||||||
case 'ollama':
|
case 'ollama':
|
||||||
this.formatters[providerKey] = new OllamaMessageFormatter();
|
this.formatters[providerKey] = new OllamaMessageFormatter();
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
import options from '../../options.js';
|
import options from '../../options.js';
|
||||||
import { BaseAIService } from '../base_ai_service.js';
|
import { BaseAIService } from '../base_ai_service.js';
|
||||||
import type { ChatCompletionOptions, ChatResponse, Message } from '../ai_interface.js';
|
import type { ChatCompletionOptions, ChatResponse, Message } from '../ai_interface.js';
|
||||||
import { PROVIDER_CONSTANTS } from '../constants/provider_constants.js';
|
|
||||||
import type { OpenAIOptions } from './provider_options.js';
|
|
||||||
import { getOpenAIOptions } from './providers.js';
|
import { getOpenAIOptions } from './providers.js';
|
||||||
|
import OpenAI from 'openai';
|
||||||
|
|
||||||
export class OpenAIService extends BaseAIService {
|
export class OpenAIService extends BaseAIService {
|
||||||
|
private openai: OpenAI | null = null;
|
||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
super('OpenAI');
|
super('OpenAI');
|
||||||
}
|
}
|
||||||
@ -14,6 +15,16 @@ export class OpenAIService extends BaseAIService {
|
|||||||
return super.isAvailable() && !!options.getOption('openaiApiKey');
|
return super.isAvailable() && !!options.getOption('openaiApiKey');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private getClient(apiKey: string, baseUrl?: string): OpenAI {
|
||||||
|
if (!this.openai) {
|
||||||
|
this.openai = new OpenAI({
|
||||||
|
apiKey,
|
||||||
|
baseURL: baseUrl
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return this.openai;
|
||||||
|
}
|
||||||
|
|
||||||
async generateChatCompletion(messages: Message[], opts: ChatCompletionOptions = {}): Promise<ChatResponse> {
|
async generateChatCompletion(messages: Message[], opts: ChatCompletionOptions = {}): Promise<ChatResponse> {
|
||||||
if (!this.isAvailable()) {
|
if (!this.isAvailable()) {
|
||||||
throw new Error('OpenAI service is not available. Check API key and AI settings.');
|
throw new Error('OpenAI service is not available. Check API key and AI settings.');
|
||||||
@ -21,6 +32,9 @@ export class OpenAIService extends BaseAIService {
|
|||||||
|
|
||||||
// Get provider-specific options from the central provider manager
|
// Get provider-specific options from the central provider manager
|
||||||
const providerOptions = getOpenAIOptions(opts);
|
const providerOptions = getOpenAIOptions(opts);
|
||||||
|
|
||||||
|
// Initialize the OpenAI client
|
||||||
|
const client = this.getClient(providerOptions.apiKey, providerOptions.baseUrl);
|
||||||
|
|
||||||
const systemPrompt = this.getSystemPrompt(providerOptions.systemPrompt || options.getOption('aiSystemPrompt'));
|
const systemPrompt = this.getSystemPrompt(providerOptions.systemPrompt || options.getOption('aiSystemPrompt'));
|
||||||
|
|
||||||
@ -31,20 +45,10 @@ export class OpenAIService extends BaseAIService {
|
|||||||
: [{ role: 'system', content: systemPrompt }, ...messages];
|
: [{ role: 'system', content: systemPrompt }, ...messages];
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Fix endpoint construction - ensure we don't double up on /v1
|
// Create params object for the OpenAI SDK
|
||||||
const normalizedBaseUrl = providerOptions.baseUrl.replace(/\/+$/, '');
|
const params: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||||
const endpoint = normalizedBaseUrl.includes('/v1')
|
|
||||||
? `${normalizedBaseUrl}/chat/completions`
|
|
||||||
: `${normalizedBaseUrl}/v1/chat/completions`;
|
|
||||||
|
|
||||||
// Create request body directly from provider options
|
|
||||||
const requestBody: any = {
|
|
||||||
model: providerOptions.model,
|
model: providerOptions.model,
|
||||||
messages: messagesWithSystem,
|
messages: messagesWithSystem as OpenAI.Chat.ChatCompletionMessageParam[],
|
||||||
};
|
|
||||||
|
|
||||||
// Extract API parameters from provider options
|
|
||||||
const apiParams = {
|
|
||||||
temperature: providerOptions.temperature,
|
temperature: providerOptions.temperature,
|
||||||
max_tokens: providerOptions.max_tokens,
|
max_tokens: providerOptions.max_tokens,
|
||||||
stream: providerOptions.stream,
|
stream: providerOptions.stream,
|
||||||
@ -53,51 +57,138 @@ export class OpenAIService extends BaseAIService {
|
|||||||
presence_penalty: providerOptions.presence_penalty
|
presence_penalty: providerOptions.presence_penalty
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// Merge API parameters, filtering out undefined values
|
|
||||||
Object.entries(apiParams).forEach(([key, value]) => {
|
|
||||||
if (value !== undefined) {
|
|
||||||
requestBody[key] = value;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Add tools if enabled
|
// Add tools if enabled
|
||||||
if (providerOptions.enableTools && providerOptions.tools && providerOptions.tools.length > 0) {
|
if (providerOptions.enableTools && providerOptions.tools && providerOptions.tools.length > 0) {
|
||||||
requestBody.tools = providerOptions.tools;
|
params.tools = providerOptions.tools as OpenAI.Chat.ChatCompletionTool[];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (providerOptions.tool_choice) {
|
if (providerOptions.tool_choice) {
|
||||||
requestBody.tool_choice = providerOptions.tool_choice;
|
params.tool_choice = providerOptions.tool_choice as OpenAI.Chat.ChatCompletionToolChoiceOption;
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = await fetch(endpoint, {
|
// If streaming is requested
|
||||||
method: 'POST',
|
if (providerOptions.stream) {
|
||||||
headers: {
|
params.stream = true;
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Authorization': `Bearer ${providerOptions.apiKey}`
|
const stream = await client.chat.completions.create(params);
|
||||||
},
|
let fullText = '';
|
||||||
body: JSON.stringify(requestBody)
|
|
||||||
});
|
// If a direct callback is provided, use it
|
||||||
|
if (providerOptions.streamCallback) {
|
||||||
|
// Process the stream with the callback
|
||||||
|
try {
|
||||||
|
// The stream is an AsyncIterable
|
||||||
|
if (Symbol.asyncIterator in stream) {
|
||||||
|
for await (const chunk of stream as AsyncIterable<OpenAI.Chat.ChatCompletionChunk>) {
|
||||||
|
const content = chunk.choices[0]?.delta?.content || '';
|
||||||
|
if (content) {
|
||||||
|
fullText += content;
|
||||||
|
await providerOptions.streamCallback(content, false, chunk);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this is the last chunk
|
||||||
|
if (chunk.choices[0]?.finish_reason) {
|
||||||
|
await providerOptions.streamCallback('', true, chunk);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
console.error('Stream is not iterable, falling back to non-streaming response');
|
||||||
|
|
||||||
|
// If we get a non-streaming response somehow
|
||||||
|
if ('choices' in stream) {
|
||||||
|
const content = stream.choices[0]?.message?.content || '';
|
||||||
|
fullText = content;
|
||||||
|
if (providerOptions.streamCallback) {
|
||||||
|
await providerOptions.streamCallback(content, true, stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error processing stream:', error);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
text: fullText,
|
||||||
|
model: params.model,
|
||||||
|
provider: this.getName(),
|
||||||
|
usage: {} // Usage stats aren't available with streaming
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
// Use the more flexible stream interface
|
||||||
|
return {
|
||||||
|
text: '', // Initial empty text, will be filled by stream processing
|
||||||
|
model: params.model,
|
||||||
|
provider: this.getName(),
|
||||||
|
usage: {}, // Usage stats aren't available with streaming
|
||||||
|
stream: async (callback) => {
|
||||||
|
let completeText = '';
|
||||||
|
|
||||||
|
try {
|
||||||
|
// The stream is an AsyncIterable
|
||||||
|
if (Symbol.asyncIterator in stream) {
|
||||||
|
for await (const chunk of stream as AsyncIterable<OpenAI.Chat.ChatCompletionChunk>) {
|
||||||
|
const content = chunk.choices[0]?.delta?.content || '';
|
||||||
|
const isDone = !!chunk.choices[0]?.finish_reason;
|
||||||
|
|
||||||
|
if (content) {
|
||||||
|
completeText += content;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the provided callback with the StreamChunk interface
|
||||||
|
await callback({
|
||||||
|
text: content,
|
||||||
|
done: isDone
|
||||||
|
});
|
||||||
|
|
||||||
|
if (isDone) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
console.warn('Stream is not iterable, falling back to non-streaming response');
|
||||||
|
|
||||||
|
// If we get a non-streaming response somehow
|
||||||
|
if ('choices' in stream) {
|
||||||
|
const content = stream.choices[0]?.message?.content || '';
|
||||||
|
completeText = content;
|
||||||
|
await callback({
|
||||||
|
text: content,
|
||||||
|
done: true
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error processing stream:', error);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
|
return completeText;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Non-streaming response
|
||||||
|
params.stream = false;
|
||||||
|
|
||||||
|
const completion = await client.chat.completions.create(params);
|
||||||
|
|
||||||
|
if (!('choices' in completion)) {
|
||||||
|
throw new Error('Unexpected response format from OpenAI API');
|
||||||
|
}
|
||||||
|
|
||||||
if (!response.ok) {
|
return {
|
||||||
const errorBody = await response.text();
|
text: completion.choices[0].message.content || '',
|
||||||
throw new Error(`OpenAI API error: ${response.status} ${response.statusText} - ${errorBody}`);
|
model: completion.model,
|
||||||
|
provider: this.getName(),
|
||||||
|
usage: {
|
||||||
|
promptTokens: completion.usage?.prompt_tokens,
|
||||||
|
completionTokens: completion.usage?.completion_tokens,
|
||||||
|
totalTokens: completion.usage?.total_tokens
|
||||||
|
},
|
||||||
|
tool_calls: completion.choices[0].message.tool_calls
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
const data = await response.json();
|
|
||||||
|
|
||||||
return {
|
|
||||||
text: data.choices[0].message.content,
|
|
||||||
model: data.model,
|
|
||||||
provider: this.getName(),
|
|
||||||
usage: {
|
|
||||||
promptTokens: data.usage?.prompt_tokens,
|
|
||||||
completionTokens: data.usage?.completion_tokens,
|
|
||||||
totalTokens: data.usage?.total_tokens
|
|
||||||
},
|
|
||||||
tool_calls: data.choices[0].message.tool_calls
|
|
||||||
};
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('OpenAI service error:', error);
|
console.error('OpenAI service error:', error);
|
||||||
throw error;
|
throw error;
|
||||||
|
@ -53,6 +53,8 @@ export interface OpenAIOptions extends ProviderConfig {
|
|||||||
|
|
||||||
// Internal control flags (not sent directly to API)
|
// Internal control flags (not sent directly to API)
|
||||||
enableTools?: boolean;
|
enableTools?: boolean;
|
||||||
|
// Streaming callback handler
|
||||||
|
streamCallback?: (text: string, isDone: boolean, originalChunk?: any) => Promise<void> | void;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -76,6 +78,8 @@ export interface AnthropicOptions extends ProviderConfig {
|
|||||||
|
|
||||||
// Internal parameters (not sent directly to API)
|
// Internal parameters (not sent directly to API)
|
||||||
formattedMessages?: { messages: any[], system: string };
|
formattedMessages?: { messages: any[], system: string };
|
||||||
|
// Streaming callback handler
|
||||||
|
streamCallback?: (text: string, isDone: boolean, originalChunk?: any) => Promise<void> | void;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -105,6 +109,8 @@ export interface OllamaOptions extends ProviderConfig {
|
|||||||
preserveSystemPrompt?: boolean;
|
preserveSystemPrompt?: boolean;
|
||||||
expectsJsonResponse?: boolean;
|
expectsJsonResponse?: boolean;
|
||||||
toolExecutionStatus?: any[];
|
toolExecutionStatus?: any[];
|
||||||
|
// Streaming callback handler
|
||||||
|
streamCallback?: (text: string, isDone: boolean, originalChunk?: any) => Promise<void> | void;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -134,6 +140,10 @@ export function createOpenAIOptions(
|
|||||||
// Internal configuration
|
// Internal configuration
|
||||||
systemPrompt: opts.systemPrompt,
|
systemPrompt: opts.systemPrompt,
|
||||||
enableTools: opts.enableTools,
|
enableTools: opts.enableTools,
|
||||||
|
// Pass through streaming callback
|
||||||
|
streamCallback: opts.streamCallback,
|
||||||
|
// Include provider metadata
|
||||||
|
providerMetadata: opts.providerMetadata,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -164,6 +174,10 @@ export function createAnthropicOptions(
|
|||||||
|
|
||||||
// Internal configuration
|
// Internal configuration
|
||||||
systemPrompt: opts.systemPrompt,
|
systemPrompt: opts.systemPrompt,
|
||||||
|
// Pass through streaming callback
|
||||||
|
streamCallback: opts.streamCallback,
|
||||||
|
// Include provider metadata
|
||||||
|
providerMetadata: opts.providerMetadata,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -198,5 +212,9 @@ export function createOllamaOptions(
|
|||||||
preserveSystemPrompt: opts.preserveSystemPrompt,
|
preserveSystemPrompt: opts.preserveSystemPrompt,
|
||||||
expectsJsonResponse: opts.expectsJsonResponse,
|
expectsJsonResponse: opts.expectsJsonResponse,
|
||||||
toolExecutionStatus: opts.toolExecutionStatus,
|
toolExecutionStatus: opts.toolExecutionStatus,
|
||||||
|
// Pass through streaming callback
|
||||||
|
streamCallback: opts.streamCallback,
|
||||||
|
// Include provider metadata
|
||||||
|
providerMetadata: opts.providerMetadata,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user