mirror of
https://github.com/TriliumNext/Notes.git
synced 2025-08-12 20:02:28 +08:00
I can create embeddings now?
This commit is contained in:
parent
6ace4d5692
commit
0daa9e717f
@ -17,62 +17,55 @@ async function findSimilarNotes(req: Request, res: Response) {
|
|||||||
const threshold = parseFloat(req.query.threshold as string || '0.7');
|
const threshold = parseFloat(req.query.threshold as string || '0.7');
|
||||||
|
|
||||||
if (!noteId) {
|
if (!noteId) {
|
||||||
return res.status(400).send({
|
return [400, {
|
||||||
success: false,
|
success: false,
|
||||||
message: "Note ID is required"
|
message: "Note ID is required"
|
||||||
});
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
const embedding = await vectorStore.getEmbeddingForNote(noteId, providerId, modelId);
|
||||||
const embedding = await vectorStore.getEmbeddingForNote(noteId, providerId, modelId);
|
|
||||||
|
|
||||||
if (!embedding) {
|
if (!embedding) {
|
||||||
// If no embedding exists for this note yet, generate one
|
// If no embedding exists for this note yet, generate one
|
||||||
const note = becca.getNote(noteId);
|
const note = becca.getNote(noteId);
|
||||||
if (!note) {
|
if (!note) {
|
||||||
return res.status(404).send({
|
return [404, {
|
||||||
success: false,
|
success: false,
|
||||||
message: "Note not found"
|
message: "Note not found"
|
||||||
});
|
}];
|
||||||
}
|
|
||||||
|
|
||||||
const context = await vectorStore.getNoteEmbeddingContext(noteId);
|
|
||||||
const provider = providerManager.getEmbeddingProvider(providerId);
|
|
||||||
|
|
||||||
if (!provider) {
|
|
||||||
return res.status(400).send({
|
|
||||||
success: false,
|
|
||||||
message: `Embedding provider '${providerId}' not found`
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
const newEmbedding = await provider.generateNoteEmbeddings(context);
|
|
||||||
await vectorStore.storeNoteEmbedding(noteId, providerId, modelId, newEmbedding);
|
|
||||||
|
|
||||||
const similarNotes = await vectorStore.findSimilarNotes(
|
|
||||||
newEmbedding, providerId, modelId, limit, threshold
|
|
||||||
);
|
|
||||||
|
|
||||||
return res.send({
|
|
||||||
success: true,
|
|
||||||
similarNotes
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const context = await vectorStore.getNoteEmbeddingContext(noteId);
|
||||||
|
const provider = providerManager.getEmbeddingProvider(providerId);
|
||||||
|
|
||||||
|
if (!provider) {
|
||||||
|
return [400, {
|
||||||
|
success: false,
|
||||||
|
message: `Embedding provider '${providerId}' not found`
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
const newEmbedding = await provider.generateNoteEmbeddings(context);
|
||||||
|
await vectorStore.storeNoteEmbedding(noteId, providerId, modelId, newEmbedding);
|
||||||
|
|
||||||
const similarNotes = await vectorStore.findSimilarNotes(
|
const similarNotes = await vectorStore.findSimilarNotes(
|
||||||
embedding.embedding, providerId, modelId, limit, threshold
|
newEmbedding, providerId, modelId, limit, threshold
|
||||||
);
|
);
|
||||||
|
|
||||||
return res.send({
|
return {
|
||||||
success: true,
|
success: true,
|
||||||
similarNotes
|
similarNotes
|
||||||
});
|
};
|
||||||
} catch (error: any) {
|
|
||||||
return res.status(500).send({
|
|
||||||
success: false,
|
|
||||||
message: error.message || "Unknown error"
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const similarNotes = await vectorStore.findSimilarNotes(
|
||||||
|
embedding.embedding, providerId, modelId, limit, threshold
|
||||||
|
);
|
||||||
|
|
||||||
|
return {
|
||||||
|
success: true,
|
||||||
|
similarNotes
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -86,58 +79,45 @@ async function searchByText(req: Request, res: Response) {
|
|||||||
const threshold = parseFloat(req.query.threshold as string || '0.7');
|
const threshold = parseFloat(req.query.threshold as string || '0.7');
|
||||||
|
|
||||||
if (!text) {
|
if (!text) {
|
||||||
return res.status(400).send({
|
return [400, {
|
||||||
success: false,
|
success: false,
|
||||||
message: "Search text is required"
|
message: "Search text is required"
|
||||||
});
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
const provider = providerManager.getEmbeddingProvider(providerId);
|
||||||
const provider = providerManager.getEmbeddingProvider(providerId);
|
|
||||||
|
|
||||||
if (!provider) {
|
if (!provider) {
|
||||||
return res.status(400).send({
|
return [400, {
|
||||||
success: false,
|
|
||||||
message: `Embedding provider '${providerId}' not found`
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate embedding for the search text
|
|
||||||
const embedding = await provider.generateEmbeddings(text);
|
|
||||||
|
|
||||||
// Find similar notes
|
|
||||||
const similarNotes = await vectorStore.findSimilarNotes(
|
|
||||||
embedding, providerId, modelId, limit, threshold
|
|
||||||
);
|
|
||||||
|
|
||||||
return res.send({
|
|
||||||
success: true,
|
|
||||||
similarNotes
|
|
||||||
});
|
|
||||||
} catch (error: any) {
|
|
||||||
return res.status(500).send({
|
|
||||||
success: false,
|
success: false,
|
||||||
message: error.message || "Unknown error"
|
message: `Embedding provider '${providerId}' not found`
|
||||||
});
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Generate embedding for the search text
|
||||||
|
const embedding = await provider.generateEmbeddings(text);
|
||||||
|
|
||||||
|
// Find similar notes
|
||||||
|
const similarNotes = await vectorStore.findSimilarNotes(
|
||||||
|
embedding, providerId, modelId, limit, threshold
|
||||||
|
);
|
||||||
|
|
||||||
|
return {
|
||||||
|
success: true,
|
||||||
|
similarNotes
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get embedding providers
|
* Get embedding providers
|
||||||
*/
|
*/
|
||||||
async function getProviders(req: Request, res: Response) {
|
async function getProviders(req: Request, res: Response) {
|
||||||
try {
|
const providerConfigs = await providerManager.getEmbeddingProviderConfigs();
|
||||||
const providerConfigs = await providerManager.getEmbeddingProviderConfigs();
|
|
||||||
return res.send({
|
return {
|
||||||
success: true,
|
success: true,
|
||||||
providers: providerConfigs
|
providers: providerConfigs
|
||||||
});
|
};
|
||||||
} catch (error: any) {
|
|
||||||
return res.status(500).send({
|
|
||||||
success: false,
|
|
||||||
message: error.message || "Unknown error"
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -147,96 +127,68 @@ async function updateProvider(req: Request, res: Response) {
|
|||||||
const { providerId } = req.params;
|
const { providerId } = req.params;
|
||||||
const { isEnabled, priority, config } = req.body;
|
const { isEnabled, priority, config } = req.body;
|
||||||
|
|
||||||
try {
|
const success = await providerManager.updateEmbeddingProviderConfig(
|
||||||
const success = await providerManager.updateEmbeddingProviderConfig(
|
providerId, isEnabled, priority, config
|
||||||
providerId, isEnabled, priority, config
|
);
|
||||||
);
|
|
||||||
|
|
||||||
if (!success) {
|
if (!success) {
|
||||||
return res.status(404).send({
|
return [404, {
|
||||||
success: false,
|
|
||||||
message: "Provider not found"
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
return res.send({
|
|
||||||
success: true
|
|
||||||
});
|
|
||||||
} catch (error: any) {
|
|
||||||
return res.status(500).send({
|
|
||||||
success: false,
|
success: false,
|
||||||
message: error.message || "Unknown error"
|
message: "Provider not found"
|
||||||
});
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
success: true
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Manually trigger a reprocessing of all notes
|
* Manually trigger a reprocessing of all notes
|
||||||
*/
|
*/
|
||||||
async function reprocessAllNotes(req: Request, res: Response) {
|
async function reprocessAllNotes(req: Request, res: Response) {
|
||||||
// Wrap in a try-catch to handle errors
|
// Start the reprocessing operation in the background
|
||||||
try {
|
setTimeout(async () => {
|
||||||
// Start the reprocessing operation in the background
|
try {
|
||||||
// and immediately respond to the client
|
await vectorStore.reprocessAllNotes();
|
||||||
res.send({
|
log.info("Embedding reprocessing completed successfully");
|
||||||
success: true,
|
} catch (error: any) {
|
||||||
message: "Embedding reprocessing started in the background"
|
log.error(`Error during background embedding reprocessing: ${error.message || "Unknown error"}`);
|
||||||
});
|
|
||||||
|
|
||||||
// Continue processing in the background after sending the response
|
|
||||||
setTimeout(async () => {
|
|
||||||
try {
|
|
||||||
await vectorStore.reprocessAllNotes();
|
|
||||||
log.info("Embedding reprocessing completed successfully");
|
|
||||||
} catch (error: any) {
|
|
||||||
log.error(`Error during background embedding reprocessing: ${error.message || "Unknown error"}`);
|
|
||||||
}
|
|
||||||
}, 0);
|
|
||||||
} catch (error: any) {
|
|
||||||
// Only catch errors that happen before we send the response
|
|
||||||
log.error(`Error initiating embedding reprocessing: ${error.message || "Unknown error"}`);
|
|
||||||
|
|
||||||
if (!res.headersSent) {
|
|
||||||
res.status(500).send({
|
|
||||||
success: false,
|
|
||||||
message: error.message || "Unknown error"
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
}
|
}, 0);
|
||||||
|
|
||||||
|
// Return the response data
|
||||||
|
return {
|
||||||
|
success: true,
|
||||||
|
message: "Embedding reprocessing started in the background"
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get embedding queue status
|
* Get embedding queue status
|
||||||
*/
|
*/
|
||||||
async function getQueueStatus(req: Request, res: Response) {
|
async function getQueueStatus(req: Request, res: Response) {
|
||||||
try {
|
// Use the imported sql instead of requiring it
|
||||||
// Use the imported sql instead of requiring it
|
const queueCount = await sql.getValue(
|
||||||
const queueCount = await sql.getValue(
|
"SELECT COUNT(*) FROM embedding_queue"
|
||||||
"SELECT COUNT(*) FROM embedding_queue"
|
);
|
||||||
);
|
|
||||||
|
|
||||||
const failedCount = await sql.getValue(
|
const failedCount = await sql.getValue(
|
||||||
"SELECT COUNT(*) FROM embedding_queue WHERE attempts > 0"
|
"SELECT COUNT(*) FROM embedding_queue WHERE attempts > 0"
|
||||||
);
|
);
|
||||||
|
|
||||||
const totalEmbeddingsCount = await sql.getValue(
|
const totalEmbeddingsCount = await sql.getValue(
|
||||||
"SELECT COUNT(*) FROM note_embeddings"
|
"SELECT COUNT(*) FROM note_embeddings"
|
||||||
);
|
);
|
||||||
|
|
||||||
return res.send({
|
return {
|
||||||
success: true,
|
success: true,
|
||||||
status: {
|
status: {
|
||||||
queueCount,
|
queueCount,
|
||||||
failedCount,
|
failedCount,
|
||||||
totalEmbeddingsCount
|
totalEmbeddingsCount
|
||||||
}
|
}
|
||||||
});
|
};
|
||||||
} catch (error: any) {
|
|
||||||
return res.status(500).send({
|
|
||||||
success: false,
|
|
||||||
message: error.message || "Unknown error"
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export default {
|
export default {
|
||||||
|
@ -372,12 +372,12 @@ function register(app: express.Application) {
|
|||||||
etapiBackupRoute.register(router);
|
etapiBackupRoute.register(router);
|
||||||
|
|
||||||
// Embeddings API endpoints
|
// Embeddings API endpoints
|
||||||
route(GET, "/api/embeddings/similar/:noteId", [auth.checkApiAuth], embeddingsRoute.findSimilarNotes, apiResultHandler);
|
apiRoute(GET, "/api/embeddings/similar/:noteId", embeddingsRoute.findSimilarNotes);
|
||||||
route(PST, "/api/embeddings/search", [auth.checkApiAuth, csrfMiddleware], embeddingsRoute.searchByText, apiResultHandler);
|
apiRoute(PST, "/api/embeddings/search", embeddingsRoute.searchByText);
|
||||||
route(GET, "/api/embeddings/providers", [auth.checkApiAuth], embeddingsRoute.getProviders, apiResultHandler);
|
apiRoute(GET, "/api/embeddings/providers", embeddingsRoute.getProviders);
|
||||||
route(PATCH, "/api/embeddings/providers/:providerId", [auth.checkApiAuth, csrfMiddleware], embeddingsRoute.updateProvider, apiResultHandler);
|
apiRoute(PATCH, "/api/embeddings/providers/:providerId", embeddingsRoute.updateProvider);
|
||||||
route(PST, "/api/embeddings/reprocess", [auth.checkApiAuth, csrfMiddleware], embeddingsRoute.reprocessAllNotes, apiResultHandler);
|
apiRoute(PST, "/api/embeddings/reprocess", embeddingsRoute.reprocessAllNotes);
|
||||||
route(GET, "/api/embeddings/queue-status", [auth.checkApiAuth], embeddingsRoute.getQueueStatus, apiResultHandler);
|
apiRoute(GET, "/api/embeddings/queue-status", embeddingsRoute.getQueueStatus);
|
||||||
|
|
||||||
// Ollama API endpoints
|
// Ollama API endpoints
|
||||||
route(PST, "/api/ollama/list-models", [auth.checkApiAuth, csrfMiddleware], ollamaRoute.listModels, apiResultHandler);
|
route(PST, "/api/ollama/list-models", [auth.checkApiAuth, csrfMiddleware], ollamaRoute.listModels, apiResultHandler);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user