mirror of
https://github.com/TriliumNext/Notes.git
synced 2025-09-02 21:42:15 +08:00
add additional options for ollama embeddings
This commit is contained in:
parent
ea6f9c8e18
commit
d3013c925e
@ -4,6 +4,19 @@ import type { FilterOptionsByType, OptionMap } from "../../../../../services/opt
|
||||
import server from "../../../services/server.js";
|
||||
import toastService from "../../../services/toast.js";
|
||||
|
||||
// Interface for the Ollama model response
|
||||
interface OllamaModelResponse {
|
||||
success: boolean;
|
||||
models: Array<{
|
||||
name: string;
|
||||
model: string;
|
||||
details?: {
|
||||
family?: string;
|
||||
parameter_size?: string;
|
||||
}
|
||||
}>;
|
||||
}
|
||||
|
||||
export default class AiSettingsWidget extends OptionsWidget {
|
||||
doRender() {
|
||||
this.$widget = $(`
|
||||
@ -102,16 +115,27 @@ export default class AiSettingsWidget extends OptionsWidget {
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label>${t("ai_llm.base_url")}</label>
|
||||
<label>${t("ai_llm.ollama_url")}</label>
|
||||
<input class="ollama-base-url form-control" type="text">
|
||||
<div class="help-text">${t("ai_llm.ollama_url_description")}</div>
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label>${t("ai_llm.default_model")}</label>
|
||||
<label>${t("ai_llm.ollama_model")}</label>
|
||||
<input class="ollama-default-model form-control" type="text">
|
||||
<div class="help-text">${t("ai_llm.ollama_model_description")}</div>
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label>${t("ai_llm.ollama_embedding_model")}</label>
|
||||
<select class="ollama-embedding-model form-control">
|
||||
<option value="nomic-embed-text">nomic-embed-text (recommended)</option>
|
||||
<option value="mxbai-embed-large">mxbai-embed-large</option>
|
||||
<option value="llama3">llama3</option>
|
||||
</select>
|
||||
<div class="help-text">${t("ai_llm.ollama_embedding_model_description")}</div>
|
||||
<button class="btn btn-sm btn-outline-secondary refresh-models">${t("ai_llm.refresh_models")}</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<hr />
|
||||
@ -220,6 +244,59 @@ export default class AiSettingsWidget extends OptionsWidget {
|
||||
await this.updateOption('ollamaDefaultModel', $ollamaDefaultModel.val() as string);
|
||||
});
|
||||
|
||||
const $ollamaEmbeddingModel = this.$widget.find('.ollama-embedding-model');
|
||||
$ollamaEmbeddingModel.on('change', async () => {
|
||||
await this.updateOption('ollamaEmbeddingModel', $ollamaEmbeddingModel.val() as string);
|
||||
});
|
||||
|
||||
const $refreshModels = this.$widget.find('.refresh-models');
|
||||
$refreshModels.on('click', async () => {
|
||||
$refreshModels.prop('disabled', true);
|
||||
$refreshModels.text(t("ai_llm.refresh_models"));
|
||||
|
||||
try {
|
||||
const ollamaBaseUrl = this.$widget.find('.ollama-base-url').val() as string;
|
||||
const response = await server.post<OllamaModelResponse>('ollama/list-models', { baseUrl: ollamaBaseUrl });
|
||||
|
||||
if (response && response.models) {
|
||||
const $embedModelSelect = this.$widget.find('.ollama-embedding-model');
|
||||
const currentValue = $embedModelSelect.val();
|
||||
|
||||
// Clear existing options
|
||||
$embedModelSelect.empty();
|
||||
|
||||
// Add embedding-specific models first
|
||||
const embeddingModels = response.models.filter(model =>
|
||||
model.name.includes('embed') || model.name.includes('bert'));
|
||||
|
||||
embeddingModels.forEach(model => {
|
||||
$embedModelSelect.append(`<option value="${model.name}">${model.name}</option>`);
|
||||
});
|
||||
|
||||
// Add separator
|
||||
$embedModelSelect.append(`<option disabled>───────────</option>`);
|
||||
|
||||
// Add other models (LLMs can also generate embeddings)
|
||||
const otherModels = response.models.filter(model =>
|
||||
!model.name.includes('embed') && !model.name.includes('bert'));
|
||||
|
||||
otherModels.forEach(model => {
|
||||
$embedModelSelect.append(`<option value="${model.name}">${model.name}</option>`);
|
||||
});
|
||||
|
||||
// Restore previous selection if possible
|
||||
if (currentValue) {
|
||||
$embedModelSelect.val(currentValue);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error refreshing Ollama models:", error);
|
||||
} finally {
|
||||
$refreshModels.prop('disabled', false);
|
||||
$refreshModels.text(t("ai_llm.refresh_models"));
|
||||
}
|
||||
});
|
||||
|
||||
// Embedding options event handlers
|
||||
const $embeddingAutoUpdateEnabled = this.$widget.find('.embedding-auto-update-enabled');
|
||||
$embeddingAutoUpdateEnabled.on('change', async () => {
|
||||
@ -290,6 +367,7 @@ export default class AiSettingsWidget extends OptionsWidget {
|
||||
|
||||
this.$widget.find('.ollama-base-url').val(options.ollamaBaseUrl);
|
||||
this.$widget.find('.ollama-default-model').val(options.ollamaDefaultModel);
|
||||
this.$widget.find('.ollama-embedding-model').val(options.ollamaEmbeddingModel || 'nomic-embed-text');
|
||||
|
||||
// Load embedding options
|
||||
this.setCheckboxState(this.$widget.find('.embedding-auto-update-enabled'), options.embeddingAutoUpdateEnabled);
|
||||
|
@ -1144,8 +1144,14 @@
|
||||
"ollama_configuration": "Ollama Configuration",
|
||||
"enable_ollama": "Enable Ollama",
|
||||
"enable_ollama_description": "Enable Ollama for local AI model usage",
|
||||
"ollama_url": "Ollama URL",
|
||||
"ollama_url_description": "Default: http://localhost:11434",
|
||||
"ollama_model": "Ollama Model",
|
||||
"ollama_model_description": "Examples: llama3, mistral, phi3",
|
||||
"ollama_embedding_model": "Embedding Model",
|
||||
"ollama_embedding_model_description": "Specialized model for generating embeddings (vector representations)",
|
||||
"refresh_models": "Refresh Models",
|
||||
"refreshing_models": "Refreshing...",
|
||||
"embedding_configuration": "Embeddings Configuration",
|
||||
"enable_auto_update_embeddings": "Auto-update Embeddings",
|
||||
"enable_auto_update_embeddings_description": "Automatically update embeddings when notes are modified",
|
||||
|
40
src/routes/api/ollama.ts
Normal file
40
src/routes/api/ollama.ts
Normal file
@ -0,0 +1,40 @@
|
||||
import axios from 'axios';
|
||||
import options from "../../services/options.js";
|
||||
import log from "../../services/log.js";
|
||||
import type { Request, Response } from "express";
|
||||
|
||||
/**
|
||||
* List available models from Ollama
|
||||
*/
|
||||
async function listModels(req: Request, res: Response) {
|
||||
try {
|
||||
const { baseUrl } = req.body;
|
||||
|
||||
// Use provided base URL or default from options
|
||||
const ollamaBaseUrl = baseUrl || await options.getOption('ollamaBaseUrl') || 'http://localhost:11434';
|
||||
|
||||
// Call Ollama API to get models
|
||||
const response = await axios.get(`${ollamaBaseUrl}/api/tags`, {
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
timeout: 10000
|
||||
});
|
||||
|
||||
// Return the models list
|
||||
return res.send({
|
||||
success: true,
|
||||
models: response.data.models || []
|
||||
});
|
||||
} catch (error: any) {
|
||||
log.error(`Error listing Ollama models: ${error.message || 'Unknown error'}`);
|
||||
|
||||
return res.status(500).send({
|
||||
success: false,
|
||||
message: error.message || 'Failed to list Ollama models',
|
||||
error: error.toString()
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
export default {
|
||||
listModels
|
||||
};
|
@ -61,6 +61,7 @@ import relationMapApiRoute from "./api/relation-map.js";
|
||||
import otherRoute from "./api/other.js";
|
||||
import shareRoutes from "../share/routes.js";
|
||||
import embeddingsRoute from "./api/embeddings.js";
|
||||
import ollamaRoute from "./api/ollama.js";
|
||||
|
||||
import etapiAuthRoutes from "../etapi/auth.js";
|
||||
import etapiAppInfoRoutes from "../etapi/app_info.js";
|
||||
@ -378,6 +379,9 @@ function register(app: express.Application) {
|
||||
route(PST, "/api/embeddings/reprocess", [auth.checkApiAuth, csrfMiddleware], embeddingsRoute.reprocessAllNotes, apiResultHandler);
|
||||
route(GET, "/api/embeddings/queue-status", [auth.checkApiAuth], embeddingsRoute.getQueueStatus, apiResultHandler);
|
||||
|
||||
// Ollama API endpoints
|
||||
route(PST, "/api/ollama/list-models", [auth.checkApiAuth, csrfMiddleware], ollamaRoute.listModels, apiResultHandler);
|
||||
|
||||
// API Documentation
|
||||
apiDocsRoute.register(app);
|
||||
|
||||
|
@ -7,7 +7,52 @@ import type { EmbeddingProvider, EmbeddingConfig } from "./embeddings_interface.
|
||||
import { OpenAIEmbeddingProvider } from "./providers/openai.js";
|
||||
import { OllamaEmbeddingProvider } from "./providers/ollama.js";
|
||||
import { AnthropicEmbeddingProvider } from "./providers/anthropic.js";
|
||||
import { LocalEmbeddingProvider } from "./providers/local.js";
|
||||
|
||||
/**
|
||||
* Simple local embedding provider implementation
|
||||
* This avoids the need to import a separate file which might not exist
|
||||
*/
|
||||
class SimpleLocalEmbeddingProvider implements EmbeddingProvider {
|
||||
name = "local";
|
||||
config: EmbeddingConfig;
|
||||
|
||||
constructor(config: EmbeddingConfig) {
|
||||
this.config = config;
|
||||
}
|
||||
|
||||
getConfig(): EmbeddingConfig {
|
||||
return this.config;
|
||||
}
|
||||
|
||||
async generateEmbeddings(text: string): Promise<Float32Array> {
|
||||
// Create deterministic embeddings based on text content
|
||||
const result = new Float32Array(this.config.dimension || 384);
|
||||
|
||||
// Simple hash-based approach
|
||||
for (let i = 0; i < result.length; i++) {
|
||||
// Use character codes and position to generate values between -1 and 1
|
||||
const charSum = Array.from(text).reduce((sum, char, idx) =>
|
||||
sum + char.charCodeAt(0) * Math.sin(idx * 0.1), 0);
|
||||
result[i] = Math.sin(i * 0.1 + charSum * 0.01);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
async generateBatchEmbeddings(texts: string[]): Promise<Float32Array[]> {
|
||||
return Promise.all(texts.map(text => this.generateEmbeddings(text)));
|
||||
}
|
||||
|
||||
async generateNoteEmbeddings(context: any): Promise<Float32Array> {
|
||||
// Combine text from context
|
||||
const text = (context.title || "") + " " + (context.content || "");
|
||||
return this.generateEmbeddings(text);
|
||||
}
|
||||
|
||||
async generateBatchNoteEmbeddings(contexts: any[]): Promise<Float32Array[]> {
|
||||
return Promise.all(contexts.map(context => this.generateNoteEmbeddings(context)));
|
||||
}
|
||||
}
|
||||
|
||||
const providers = new Map<string, EmbeddingProvider>();
|
||||
|
||||
@ -236,15 +281,25 @@ export async function initializeDefaultProviders() {
|
||||
|
||||
// Register Ollama provider if enabled
|
||||
if (await options.getOptionBool('ollamaEnabled')) {
|
||||
const ollamaModel = await options.getOption('ollamaDefaultModel') || 'llama3';
|
||||
const ollamaBaseUrl = await options.getOption('ollamaBaseUrl') || 'http://localhost:11434';
|
||||
|
||||
registerEmbeddingProvider(new OllamaEmbeddingProvider({
|
||||
model: ollamaModel,
|
||||
dimension: 4096, // Typical for Ollama models
|
||||
// Use specific embedding models if available
|
||||
const embeddingModel = await options.getOption('ollamaEmbeddingModel') || 'nomic-embed-text';
|
||||
|
||||
try {
|
||||
// Create provider with initial dimension to be updated during initialization
|
||||
const ollamaProvider = new OllamaEmbeddingProvider({
|
||||
model: embeddingModel,
|
||||
dimension: 768, // Initial value, will be updated during initialization
|
||||
type: 'float32',
|
||||
baseUrl: ollamaBaseUrl
|
||||
}));
|
||||
});
|
||||
|
||||
// Register the provider
|
||||
registerEmbeddingProvider(ollamaProvider);
|
||||
|
||||
// Initialize the provider to detect model capabilities
|
||||
await ollamaProvider.initialize();
|
||||
|
||||
// Create Ollama provider config if it doesn't exist
|
||||
const existingOllama = await sql.getRow(
|
||||
@ -254,15 +309,18 @@ export async function initializeDefaultProviders() {
|
||||
|
||||
if (!existingOllama) {
|
||||
await createEmbeddingProviderConfig('ollama', {
|
||||
model: ollamaModel,
|
||||
dimension: 4096,
|
||||
model: embeddingModel,
|
||||
dimension: ollamaProvider.getDimension(),
|
||||
type: 'float32'
|
||||
}, true, 50);
|
||||
}
|
||||
} catch (error: any) {
|
||||
log.error(`Error initializing Ollama embedding provider: ${error.message || 'Unknown error'}`);
|
||||
}
|
||||
}
|
||||
|
||||
// Always register local provider as fallback
|
||||
registerEmbeddingProvider(new LocalEmbeddingProvider({
|
||||
registerEmbeddingProvider(new SimpleLocalEmbeddingProvider({
|
||||
model: 'local',
|
||||
dimension: 384,
|
||||
type: 'float32'
|
||||
|
@ -7,33 +7,156 @@ interface OllamaEmbeddingConfig extends EmbeddingConfig {
|
||||
baseUrl: string;
|
||||
}
|
||||
|
||||
// Model-specific embedding dimensions
|
||||
interface EmbeddingModelInfo {
|
||||
dimension: number;
|
||||
contextWindow: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Ollama embedding provider implementation
|
||||
*/
|
||||
export class OllamaEmbeddingProvider extends BaseEmbeddingProvider {
|
||||
name = "ollama";
|
||||
private baseUrl: string;
|
||||
// Cache for model dimensions to avoid repeated API calls
|
||||
private modelInfoCache = new Map<string, EmbeddingModelInfo>();
|
||||
|
||||
constructor(config: OllamaEmbeddingConfig) {
|
||||
super(config);
|
||||
this.baseUrl = config.baseUrl;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the provider by detecting model capabilities
|
||||
*/
|
||||
async initialize(): Promise<void> {
|
||||
const modelName = this.config.model || "llama3";
|
||||
try {
|
||||
await this.getModelInfo(modelName);
|
||||
log.info(`Ollama embedding provider initialized with model ${modelName}`);
|
||||
} catch (error: any) {
|
||||
log.error(`Failed to initialize Ollama embedding provider: ${error.message}`);
|
||||
// Still continue with default dimensions
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get model information including embedding dimensions
|
||||
*/
|
||||
async getModelInfo(modelName: string): Promise<EmbeddingModelInfo> {
|
||||
// Check cache first
|
||||
if (this.modelInfoCache.has(modelName)) {
|
||||
return this.modelInfoCache.get(modelName)!;
|
||||
}
|
||||
|
||||
// Default dimensions for common embedding models
|
||||
const defaultDimensions: Record<string, number> = {
|
||||
"nomic-embed-text": 768,
|
||||
"mxbai-embed-large": 1024,
|
||||
"llama3": 4096,
|
||||
"all-minilm": 384,
|
||||
"default": 4096
|
||||
};
|
||||
|
||||
// Default context windows
|
||||
const defaultContextWindows: Record<string, number> = {
|
||||
"nomic-embed-text": 8192,
|
||||
"mxbai-embed-large": 8192,
|
||||
"llama3": 8192,
|
||||
"all-minilm": 4096,
|
||||
"default": 4096
|
||||
};
|
||||
|
||||
try {
|
||||
// Try to detect if this is an embedding model
|
||||
const testResponse = await axios.post(
|
||||
`${this.baseUrl}/api/embeddings`,
|
||||
{
|
||||
model: modelName,
|
||||
prompt: "Test"
|
||||
},
|
||||
{
|
||||
headers: { "Content-Type": "application/json" },
|
||||
timeout: 10000
|
||||
}
|
||||
);
|
||||
|
||||
let dimension = 0;
|
||||
let contextWindow = 0;
|
||||
|
||||
if (testResponse.data && Array.isArray(testResponse.data.embedding)) {
|
||||
dimension = testResponse.data.embedding.length;
|
||||
|
||||
// Set context window based on model name if we have it
|
||||
const baseModelName = modelName.split(':')[0];
|
||||
contextWindow = defaultContextWindows[baseModelName] || defaultContextWindows.default;
|
||||
|
||||
log.info(`Detected Ollama model ${modelName} with dimension ${dimension}`);
|
||||
} else {
|
||||
throw new Error("Could not detect embedding dimensions");
|
||||
}
|
||||
|
||||
const modelInfo: EmbeddingModelInfo = { dimension, contextWindow };
|
||||
this.modelInfoCache.set(modelName, modelInfo);
|
||||
|
||||
// Update the provider config dimension
|
||||
this.config.dimension = dimension;
|
||||
|
||||
return modelInfo;
|
||||
} catch (error: any) {
|
||||
log.error(`Error detecting Ollama model capabilities: ${error.message}`);
|
||||
|
||||
// If detection fails, use defaults based on model name
|
||||
const baseModelName = modelName.split(':')[0];
|
||||
const dimension = defaultDimensions[baseModelName] || defaultDimensions.default;
|
||||
const contextWindow = defaultContextWindows[baseModelName] || defaultContextWindows.default;
|
||||
|
||||
log.info(`Using default dimension ${dimension} for model ${modelName}`);
|
||||
|
||||
const modelInfo: EmbeddingModelInfo = { dimension, contextWindow };
|
||||
this.modelInfoCache.set(modelName, modelInfo);
|
||||
|
||||
// Update the provider config dimension
|
||||
this.config.dimension = dimension;
|
||||
|
||||
return modelInfo;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current embedding dimension
|
||||
*/
|
||||
getDimension(): number {
|
||||
return this.config.dimension;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate embeddings for a single text
|
||||
*/
|
||||
async generateEmbeddings(text: string): Promise<Float32Array> {
|
||||
try {
|
||||
const modelName = this.config.model || "llama3";
|
||||
|
||||
// Ensure we have model info
|
||||
const modelInfo = await this.getModelInfo(modelName);
|
||||
|
||||
// Trim text if it might exceed context window (rough character estimate)
|
||||
// This is a simplistic approach - ideally we'd count tokens properly
|
||||
const charLimit = modelInfo.contextWindow * 4; // Rough estimate: avg 4 chars per token
|
||||
const trimmedText = text.length > charLimit ? text.substring(0, charLimit) : text;
|
||||
|
||||
const response = await axios.post(
|
||||
`${this.baseUrl}/api/embeddings`,
|
||||
{
|
||||
model: this.config.model || "llama3",
|
||||
prompt: text
|
||||
model: modelName,
|
||||
prompt: trimmedText
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
},
|
||||
timeout: 30000 // Longer timeout for larger texts
|
||||
}
|
||||
);
|
||||
|
||||
|
@ -260,7 +260,21 @@ const defaultOptions: DefaultOption[] = [
|
||||
|
||||
// Share settings
|
||||
{ name: "redirectBareDomain", value: "false", isSynced: true },
|
||||
{ name: "showLoginInShareTheme", value: "false", isSynced: true }
|
||||
{ name: "showLoginInShareTheme", value: "false", isSynced: true },
|
||||
|
||||
// AI Options
|
||||
{ name: "aiEnabled", value: "false", isSynced: true },
|
||||
{ name: "openaiApiKey", value: "", isSynced: false },
|
||||
{ name: "openaiDefaultModel", value: "gpt-3.5-turbo", isSynced: true },
|
||||
{ name: "openaiBaseUrl", value: "https://api.openai.com/v1", isSynced: true },
|
||||
{ name: "anthropicApiKey", value: "", isSynced: false },
|
||||
{ name: "anthropicDefaultModel", value: "claude-3-haiku-20240307", isSynced: true },
|
||||
{ name: "anthropicBaseUrl", value: "https://api.anthropic.com/v1", isSynced: true },
|
||||
{ name: "ollamaEnabled", value: "false", isSynced: true },
|
||||
{ name: "ollamaDefaultModel", value: "llama3", isSynced: true },
|
||||
{ name: "ollamaBaseUrl", value: "http://localhost:11434", isSynced: true },
|
||||
{ name: "ollamaEmbeddingModel", value: "nomic-embed-text", isSynced: true },
|
||||
{ name: "embeddingAutoUpdate", value: "true", isSynced: true },
|
||||
];
|
||||
|
||||
/**
|
||||
|
@ -57,6 +57,7 @@ export interface OptionDefinitions extends KeyboardShortcutsOptions<KeyboardActi
|
||||
ollamaEnabled: boolean;
|
||||
ollamaBaseUrl: string;
|
||||
ollamaDefaultModel: string;
|
||||
ollamaEmbeddingModel: string;
|
||||
aiProviderPrecedence: string;
|
||||
aiTemperature: string;
|
||||
aiSystemPrompt: string;
|
||||
@ -66,6 +67,7 @@ export interface OptionDefinitions extends KeyboardShortcutsOptions<KeyboardActi
|
||||
embeddingUpdateInterval: number;
|
||||
embeddingBatchSize: number;
|
||||
embeddingDefaultDimension: number;
|
||||
embeddingAutoUpdate: boolean;
|
||||
|
||||
lastSyncedPull: number;
|
||||
lastSyncedPush: number;
|
||||
|
Loading…
x
Reference in New Issue
Block a user