feat(llm): add tests for streaming

This commit is contained in:
perf3ct 2025-06-08 20:30:33 +00:00
parent c1bcb73337
commit c6f2124e9d
No known key found for this signature in database
GPG Key ID: 569C4EEC436F5232
7 changed files with 2586 additions and 9 deletions

View File

@ -0,0 +1,498 @@
import { test, expect, type Page } from '@playwright/test';
import type { WebSocket } from 'ws';
interface StreamMessage {
type: string;
chatNoteId?: string;
content?: string;
thinking?: string;
toolExecution?: any;
done?: boolean;
error?: string;
}
interface ChatSession {
id: string;
title: string;
messages: Array<{ role: string; content: string }>;
createdAt: string;
}
test.describe('LLM Streaming E2E Tests', () => {
let chatSessionId: string;
test.beforeEach(async ({ page }) => {
// Navigate to the application
await page.goto('/');
// Wait for the application to load
await page.waitForSelector('[data-testid="app-loaded"]', { timeout: 10000 });
// Create a new chat session for testing
const response = await page.request.post('/api/llm/chat', {
data: {
title: 'E2E Streaming Test Chat'
}
});
expect(response.ok()).toBeTruthy();
const chatData: ChatSession = await response.json();
chatSessionId = chatData.id;
});
test.afterEach(async ({ page }) => {
// Clean up the chat session
if (chatSessionId) {
await page.request.delete(`/api/llm/chat/${chatSessionId}`);
}
});
test('should establish WebSocket connection and receive streaming messages', async ({ page }) => {
// Set up WebSocket message collection
const streamMessages: StreamMessage[] = [];
// Monitor WebSocket messages
await page.addInitScript(() => {
window.llmStreamMessages = [];
// Mock WebSocket to capture messages
const originalWebSocket = window.WebSocket;
window.WebSocket = class extends originalWebSocket {
constructor(url: string | URL, protocols?: string | string[]) {
super(url, protocols);
this.addEventListener('message', (event) => {
try {
const data = JSON.parse(event.data);
if (data.type === 'llm-stream') {
(window as any).llmStreamMessages.push(data);
}
} catch (e) {
// Ignore invalid JSON
}
});
}
};
});
// Navigate to chat interface
await page.goto(`/chat/${chatSessionId}`);
// Wait for chat interface to load
await page.waitForSelector('[data-testid="chat-interface"]');
// Type a message
const messageInput = page.locator('[data-testid="message-input"]');
await messageInput.fill('Tell me a short story about a robot');
// Click send with streaming enabled
await page.locator('[data-testid="send-stream-button"]').click();
// Wait for streaming to start
await page.waitForFunction(() => {
return (window as any).llmStreamMessages && (window as any).llmStreamMessages.length > 0;
}, { timeout: 10000 });
// Wait for streaming to complete (done: true message)
await page.waitForFunction(() => {
const messages = (window as any).llmStreamMessages || [];
return messages.some((msg: StreamMessage) => msg.done === true);
}, { timeout: 30000 });
// Get all collected stream messages
const collectedMessages = await page.evaluate(() => (window as any).llmStreamMessages);
// Verify we received streaming messages
expect(collectedMessages.length).toBeGreaterThan(0);
// Verify message structure
const firstMessage = collectedMessages[0];
expect(firstMessage.type).toBe('llm-stream');
expect(firstMessage.chatNoteId).toBe(chatSessionId);
// Verify we received a completion message
const completionMessage = collectedMessages.find((msg: StreamMessage) => msg.done === true);
expect(completionMessage).toBeDefined();
// Verify content was streamed
const contentMessages = collectedMessages.filter((msg: StreamMessage) => msg.content);
expect(contentMessages.length).toBeGreaterThan(0);
});
test('should handle streaming with thinking states visible', async ({ page }) => {
const streamMessages: StreamMessage[] = [];
await page.addInitScript(() => {
window.llmStreamMessages = [];
const originalWebSocket = window.WebSocket;
window.WebSocket = class extends originalWebSocket {
constructor(url: string | URL, protocols?: string | string[]) {
super(url, protocols);
this.addEventListener('message', (event) => {
try {
const data = JSON.parse(event.data);
if (data.type === 'llm-stream') {
(window as any).llmStreamMessages.push(data);
}
} catch (e) {}
});
}
};
});
await page.goto(`/chat/${chatSessionId}`);
await page.waitForSelector('[data-testid="chat-interface"]');
// Enable thinking display
await page.locator('[data-testid="show-thinking-toggle"]').check();
// Send a complex message that would trigger thinking
await page.locator('[data-testid="message-input"]').fill('Explain quantum computing and then write a haiku about it');
await page.locator('[data-testid="send-stream-button"]').click();
// Wait for thinking messages
await page.waitForFunction(() => {
const messages = (window as any).llmStreamMessages || [];
return messages.some((msg: StreamMessage) => msg.thinking);
}, { timeout: 15000 });
const collectedMessages = await page.evaluate(() => (window as any).llmStreamMessages);
// Verify thinking messages were received
const thinkingMessages = collectedMessages.filter((msg: StreamMessage) => msg.thinking);
expect(thinkingMessages.length).toBeGreaterThan(0);
// Verify thinking content is displayed in UI
await expect(page.locator('[data-testid="thinking-display"]')).toBeVisible();
const thinkingText = await page.locator('[data-testid="thinking-display"]').textContent();
expect(thinkingText).toBeTruthy();
});
test('should handle tool execution during streaming', async ({ page }) => {
await page.addInitScript(() => {
window.llmStreamMessages = [];
const originalWebSocket = window.WebSocket;
window.WebSocket = class extends originalWebSocket {
constructor(url: string | URL, protocols?: string | string[]) {
super(url, protocols);
this.addEventListener('message', (event) => {
try {
const data = JSON.parse(event.data);
if (data.type === 'llm-stream') {
(window as any).llmStreamMessages.push(data);
}
} catch (e) {}
});
}
};
});
await page.goto(`/chat/${chatSessionId}`);
await page.waitForSelector('[data-testid="chat-interface"]');
// Send a message that would trigger tool usage
await page.locator('[data-testid="message-input"]').fill('What is 15 * 37? Use a calculator tool.');
await page.locator('[data-testid="send-stream-button"]').click();
// Wait for tool execution messages
await page.waitForFunction(() => {
const messages = (window as any).llmStreamMessages || [];
return messages.some((msg: StreamMessage) => msg.toolExecution);
}, { timeout: 20000 });
const collectedMessages = await page.evaluate(() => (window as any).llmStreamMessages);
// Verify tool execution messages
const toolMessages = collectedMessages.filter((msg: StreamMessage) => msg.toolExecution);
expect(toolMessages.length).toBeGreaterThan(0);
const toolMessage = toolMessages[0];
expect(toolMessage.toolExecution.tool).toBeTruthy();
expect(toolMessage.toolExecution.args).toBeTruthy();
// Verify tool execution is displayed in UI
await expect(page.locator('[data-testid="tool-execution-display"]')).toBeVisible();
});
test('should handle streaming errors gracefully', async ({ page }) => {
await page.addInitScript(() => {
window.llmStreamMessages = [];
const originalWebSocket = window.WebSocket;
window.WebSocket = class extends originalWebSocket {
constructor(url: string | URL, protocols?: string | string[]) {
super(url, protocols);
this.addEventListener('message', (event) => {
try {
const data = JSON.parse(event.data);
if (data.type === 'llm-stream') {
(window as any).llmStreamMessages.push(data);
}
} catch (e) {}
});
}
};
});
await page.goto(`/chat/${chatSessionId}`);
await page.waitForSelector('[data-testid="chat-interface"]');
// Trigger an error by sending an invalid request or when AI is disabled
await page.locator('[data-testid="message-input"]').fill('This should trigger an error');
// Mock AI service to be unavailable
await page.route('/api/llm/**', route => {
route.fulfill({
status: 500,
body: JSON.stringify({ error: 'AI service unavailable' })
});
});
await page.locator('[data-testid="send-stream-button"]').click();
// Wait for error message
await page.waitForFunction(() => {
const messages = (window as any).llmStreamMessages || [];
return messages.some((msg: StreamMessage) => msg.error);
}, { timeout: 10000 });
const collectedMessages = await page.evaluate(() => (window as any).llmStreamMessages);
// Verify error message was received
const errorMessages = collectedMessages.filter((msg: StreamMessage) => msg.error);
expect(errorMessages.length).toBeGreaterThan(0);
const errorMessage = errorMessages[0];
expect(errorMessage.error).toBeTruthy();
expect(errorMessage.done).toBe(true);
// Verify error is displayed in UI
await expect(page.locator('[data-testid="error-display"]')).toBeVisible();
const errorText = await page.locator('[data-testid="error-display"]').textContent();
expect(errorText).toContain('error');
});
test('should handle rapid consecutive streaming requests', async ({ page }) => {
await page.addInitScript(() => {
window.llmStreamMessages = [];
const originalWebSocket = window.WebSocket;
window.WebSocket = class extends originalWebSocket {
constructor(url: string | URL, protocols?: string | string[]) {
super(url, protocols);
this.addEventListener('message', (event) => {
try {
const data = JSON.parse(event.data);
if (data.type === 'llm-stream') {
(window as any).llmStreamMessages.push(data);
}
} catch (e) {}
});
}
};
});
await page.goto(`/chat/${chatSessionId}`);
await page.waitForSelector('[data-testid="chat-interface"]');
// Send multiple messages rapidly
for (let i = 0; i < 3; i++) {
await page.locator('[data-testid="message-input"]').fill(`Rapid message ${i + 1}`);
await page.locator('[data-testid="send-stream-button"]').click();
// Small delay between requests
await page.waitForTimeout(100);
}
// Wait for all responses to complete
await page.waitForFunction(() => {
const messages = (window as any).llmStreamMessages || [];
const doneMessages = messages.filter((msg: StreamMessage) => msg.done === true);
return doneMessages.length >= 3;
}, { timeout: 30000 });
const collectedMessages = await page.evaluate(() => (window as any).llmStreamMessages);
// Verify all requests were processed
const uniqueChatIds = new Set(collectedMessages.map((msg: StreamMessage) => msg.chatNoteId));
expect(uniqueChatIds.size).toBe(1); // All from same chat
const doneMessages = collectedMessages.filter((msg: StreamMessage) => msg.done === true);
expect(doneMessages.length).toBeGreaterThanOrEqual(3);
});
test('should preserve message order during streaming', async ({ page }) => {
await page.addInitScript(() => {
window.llmStreamMessages = [];
window.messageOrder = [];
const originalWebSocket = window.WebSocket;
window.WebSocket = class extends originalWebSocket {
constructor(url: string | URL, protocols?: string | string[]) {
super(url, protocols);
this.addEventListener('message', (event) => {
try {
const data = JSON.parse(event.data);
if (data.type === 'llm-stream') {
(window as any).llmStreamMessages.push(data);
if (data.content) {
(window as any).messageOrder.push(data.content);
}
}
} catch (e) {}
});
}
};
});
await page.goto(`/chat/${chatSessionId}`);
await page.waitForSelector('[data-testid="chat-interface"]');
await page.locator('[data-testid="message-input"]').fill('Count from 1 to 10 with each number in a separate chunk');
await page.locator('[data-testid="send-stream-button"]').click();
// Wait for streaming to complete
await page.waitForFunction(() => {
const messages = (window as any).llmStreamMessages || [];
return messages.some((msg: StreamMessage) => msg.done === true);
}, { timeout: 20000 });
const messageOrder = await page.evaluate(() => (window as any).messageOrder);
// Verify messages arrived in order
expect(messageOrder.length).toBeGreaterThan(0);
// Verify content appears in UI in correct order
const chatContent = await page.locator('[data-testid="chat-messages"]').textContent();
expect(chatContent).toBeTruthy();
});
test('should handle WebSocket disconnection and reconnection', async ({ page }) => {
await page.addInitScript(() => {
window.llmStreamMessages = [];
window.connectionEvents = [];
const originalWebSocket = window.WebSocket;
window.WebSocket = class extends originalWebSocket {
constructor(url: string | URL, protocols?: string | string[]) {
super(url, protocols);
this.addEventListener('open', () => {
(window as any).connectionEvents.push('open');
});
this.addEventListener('close', () => {
(window as any).connectionEvents.push('close');
});
this.addEventListener('message', (event) => {
try {
const data = JSON.parse(event.data);
if (data.type === 'llm-stream') {
(window as any).llmStreamMessages.push(data);
}
} catch (e) {}
});
}
};
});
await page.goto(`/chat/${chatSessionId}`);
await page.waitForSelector('[data-testid="chat-interface"]');
// Start a streaming request
await page.locator('[data-testid="message-input"]').fill('Tell me a long story');
await page.locator('[data-testid="send-stream-button"]').click();
// Wait for streaming to start
await page.waitForFunction(() => {
const messages = (window as any).llmStreamMessages || [];
return messages.length > 0;
}, { timeout: 10000 });
// Simulate network disconnection by going offline
await page.context().setOffline(true);
await page.waitForTimeout(2000);
// Reconnect
await page.context().setOffline(false);
// Verify connection events
const connectionEvents = await page.evaluate(() => (window as any).connectionEvents);
expect(connectionEvents).toContain('open');
// UI should show reconnection status
await expect(page.locator('[data-testid="connection-status"]')).toBeVisible();
});
test('should display streaming progress indicators', async ({ page }) => {
await page.goto(`/chat/${chatSessionId}`);
await page.waitForSelector('[data-testid="chat-interface"]');
await page.locator('[data-testid="message-input"]').fill('Generate a detailed response');
await page.locator('[data-testid="send-stream-button"]').click();
// Verify typing indicator appears
await expect(page.locator('[data-testid="typing-indicator"]')).toBeVisible();
// Verify progress indicators during streaming
await expect(page.locator('[data-testid="streaming-progress"]')).toBeVisible();
// Wait for streaming to complete
await page.waitForFunction(() => {
const isStreamingDone = page.locator('[data-testid="streaming-complete"]').isVisible();
return isStreamingDone;
}, { timeout: 30000 });
// Verify indicators are hidden when done
await expect(page.locator('[data-testid="typing-indicator"]')).not.toBeVisible();
await expect(page.locator('[data-testid="streaming-progress"]')).not.toBeVisible();
});
test('should handle large streaming responses', async ({ page }) => {
await page.addInitScript(() => {
window.llmStreamMessages = [];
window.totalContentLength = 0;
const originalWebSocket = window.WebSocket;
window.WebSocket = class extends originalWebSocket {
constructor(url: string | URL, protocols?: string | string[]) {
super(url, protocols);
this.addEventListener('message', (event) => {
try {
const data = JSON.parse(event.data);
if (data.type === 'llm-stream') {
(window as any).llmStreamMessages.push(data);
if (data.content) {
(window as any).totalContentLength += data.content.length;
}
}
} catch (e) {}
});
}
};
});
await page.goto(`/chat/${chatSessionId}`);
await page.waitForSelector('[data-testid="chat-interface"]');
// Request a large response
await page.locator('[data-testid="message-input"]').fill('Write a very detailed, long response about the history of computers, at least 2000 words');
await page.locator('[data-testid="send-stream-button"]').click();
// Wait for large response to complete
await page.waitForFunction(() => {
const messages = (window as any).llmStreamMessages || [];
return messages.some((msg: StreamMessage) => msg.done === true);
}, { timeout: 60000 });
const totalLength = await page.evaluate(() => (window as any).totalContentLength);
const messages = await page.evaluate(() => (window as any).llmStreamMessages);
// Verify large content was received
expect(totalLength).toBeGreaterThan(1000); // At least 1KB
expect(messages.length).toBeGreaterThan(10); // Multiple chunks
// Verify UI can handle large content
const chatMessages = await page.locator('[data-testid="chat-messages"]').textContent();
expect(chatMessages!.length).toBeGreaterThan(1000);
});
});

View File

@ -1,8 +1,9 @@
import { Application } from "express";
import { beforeAll, describe, expect, it, vi, beforeEach } from "vitest";
import { beforeAll, describe, expect, it, vi, beforeEach, afterEach } from "vitest";
import supertest from "supertest";
import config from "../../services/config.js";
import { refreshAuth } from "../../services/auth.js";
import type { WebSocket } from 'ws';
// Mock the CSRF protection middleware to allow tests to pass
vi.mock("../csrf_protection.js", () => ({
@ -10,6 +11,64 @@ vi.mock("../csrf_protection.js", () => ({
generateToken: () => "mock-csrf-token"
}));
// Mock WebSocket service
vi.mock("../../services/ws.js", () => ({
default: {
sendMessageToAllClients: vi.fn()
}
}));
// Mock log service
vi.mock("../../services/log.js", () => ({
default: {
info: vi.fn(),
error: vi.fn(),
warn: vi.fn()
}
}));
// Mock chat storage service
const mockChatStorage = {
createChat: vi.fn(),
getChat: vi.fn(),
updateChat: vi.fn(),
getAllChats: vi.fn(),
deleteChat: vi.fn()
};
vi.mock("../../services/llm/storage/chat_storage_service.js", () => ({
default: mockChatStorage
}));
// Mock AI service manager
const mockAiServiceManager = {
getOrCreateAnyService: vi.fn()
};
vi.mock("../../services/llm/ai_service_manager.js", () => ({
default: mockAiServiceManager
}));
// Mock chat pipeline
const mockChatPipelineExecute = vi.fn();
const MockChatPipeline = vi.fn().mockImplementation(() => ({
execute: mockChatPipelineExecute
}));
vi.mock("../../services/llm/pipeline/chat_pipeline.js", () => ({
ChatPipeline: MockChatPipeline
}));
// Mock configuration helpers
const mockGetSelectedModelConfig = vi.fn();
vi.mock("../../services/llm/config/configuration_helpers.js", () => ({
getSelectedModelConfig: mockGetSelectedModelConfig
}));
// Mock options service
vi.mock("../../services/options.js", () => ({
default: {
getOptionBool: vi.fn()
}
}));
// Session-based login that properly establishes req.session.loggedIn
async function loginWithSession(app: Application) {
const response = await supertest(app)
@ -257,7 +316,30 @@ describe("LLM API Tests", () => {
let testChatId: string;
beforeEach(async () => {
// Reset all mocks
vi.clearAllMocks();
// Import options service to access mock
const options = (await import("../../services/options.js")).default;
// Setup default mock behaviors
options.getOptionBool.mockReturnValue(true); // AI enabled
mockAiServiceManager.getOrCreateAnyService.mockResolvedValue({});
mockGetSelectedModelConfig.mockResolvedValue({
model: 'test-model',
provider: 'test-provider'
});
// Create a fresh chat for each test
const mockChat = {
id: 'streaming-test-chat',
title: 'Streaming Test Chat',
messages: [],
createdAt: new Date().toISOString()
};
mockChatStorage.createChat.mockResolvedValue(mockChat);
mockChatStorage.getChat.mockResolvedValue(mockChat);
const createResponse = await supertest(app)
.post("/api/llm/chat")
.set("Cookie", sessionCookie)
@ -269,7 +351,19 @@ describe("LLM API Tests", () => {
testChatId = createResponse.body.id;
});
afterEach(() => {
vi.clearAllMocks();
});
it("should initiate streaming for a chat message", async () => {
// Setup streaming simulation
mockChatPipelineExecute.mockImplementation(async (input) => {
const callback = input.streamCallback;
// Simulate streaming chunks
await callback('Hello', false, {});
await callback(' world!', true, {});
});
const response = await supertest(app)
.post(`/api/llm/chat/${testChatId}/messages/stream`)
.set("Cookie", sessionCookie)
@ -286,6 +380,31 @@ describe("LLM API Tests", () => {
success: true,
message: "Streaming initiated successfully"
});
// Import ws service to access mock
const ws = (await import("../../services/ws.js")).default;
// Verify WebSocket messages were sent
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
type: 'llm-stream',
chatNoteId: testChatId,
thinking: undefined
});
// Verify streaming chunks were sent
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
type: 'llm-stream',
chatNoteId: testChatId,
content: 'Hello',
done: false
});
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
type: 'llm-stream',
chatNoteId: testChatId,
content: ' world!',
done: true
});
});
it("should handle empty content for streaming", async () => {
@ -338,6 +457,29 @@ describe("LLM API Tests", () => {
});
it("should handle streaming with note mentions", async () => {
// Mock becca for note content retrieval
const mockBecca = {
getNote: vi.fn().mockReturnValue({
noteId: 'root',
title: 'Root Note',
getBlob: () => ({
getContent: () => 'Root note content for testing'
})
})
};
vi.mocked(await import('../../becca/becca.js')).default = mockBecca;
// Setup streaming with mention context
mockChatPipelineExecute.mockImplementation(async (input) => {
// Verify mention content is included
expect(input.query).toContain('Tell me about this note');
expect(input.query).toContain('Root note content for testing');
const callback = input.streamCallback;
await callback('The root note contains', false, {});
await callback(' important information.', true, {});
});
const response = await supertest(app)
.post(`/api/llm/chat/${testChatId}/messages/stream`)
.set("Cookie", sessionCookie)
@ -358,6 +500,250 @@ describe("LLM API Tests", () => {
success: true,
message: "Streaming initiated successfully"
});
// Import ws service to access mock
const ws = (await import("../../services/ws.js")).default;
// Verify thinking message was sent
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
type: 'llm-stream',
chatNoteId: testChatId,
thinking: 'Initializing streaming LLM response...'
});
});
it("should handle streaming with thinking states", async () => {
mockChatPipelineExecute.mockImplementation(async (input) => {
const callback = input.streamCallback;
// Simulate thinking states
await callback('', false, { thinking: 'Analyzing the question...' });
await callback('', false, { thinking: 'Formulating response...' });
await callback('The answer is', false, {});
await callback(' 42.', true, {});
});
const response = await supertest(app)
.post(`/api/llm/chat/${testChatId}/messages/stream`)
.set("Cookie", sessionCookie)
.send({
content: "What is the meaning of life?",
useAdvancedContext: false,
showThinking: true
});
expect(response.status).toBe(200);
// Import ws service to access mock
const ws = (await import("../../services/ws.js")).default;
// Verify thinking messages
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
type: 'llm-stream',
chatNoteId: testChatId,
thinking: 'Analyzing the question...',
done: false
});
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
type: 'llm-stream',
chatNoteId: testChatId,
thinking: 'Formulating response...',
done: false
});
});
it("should handle streaming with tool executions", async () => {
mockChatPipelineExecute.mockImplementation(async (input) => {
const callback = input.streamCallback;
// Simulate tool execution
await callback('Let me calculate that', false, {});
await callback('', false, {
toolExecution: {
tool: 'calculator',
arguments: { expression: '2 + 2' },
result: '4',
toolCallId: 'call_123',
action: 'execute'
}
});
await callback('The result is 4', true, {});
});
const response = await supertest(app)
.post(`/api/llm/chat/${testChatId}/messages/stream`)
.set("Cookie", sessionCookie)
.send({
content: "What is 2 + 2?",
useAdvancedContext: false,
showThinking: false
});
expect(response.status).toBe(200);
// Import ws service to access mock
const ws = (await import("../../services/ws.js")).default;
// Verify tool execution message
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
type: 'llm-stream',
chatNoteId: testChatId,
toolExecution: {
tool: 'calculator',
args: { expression: '2 + 2' },
result: '4',
toolCallId: 'call_123',
action: 'execute',
error: undefined
},
done: false
});
});
it("should handle streaming errors gracefully", async () => {
mockChatPipelineExecute.mockRejectedValue(new Error('Pipeline error'));
const response = await supertest(app)
.post(`/api/llm/chat/${testChatId}/messages/stream`)
.set("Cookie", sessionCookie)
.send({
content: "This will fail",
useAdvancedContext: false,
showThinking: false
});
expect(response.status).toBe(200); // Still returns 200
// Import ws service to access mock
const ws = (await import("../../services/ws.js")).default;
// Verify error message was sent via WebSocket
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
type: 'llm-stream',
chatNoteId: testChatId,
error: 'Error during streaming: Pipeline error',
done: true
});
});
it("should handle AI disabled state", async () => {
// Import options service to access mock
const options = (await import("../../services/options.js")).default;
options.getOptionBool.mockReturnValue(false); // AI disabled
const response = await supertest(app)
.post(`/api/llm/chat/${testChatId}/messages/stream`)
.set("Cookie", sessionCookie)
.send({
content: "Hello AI",
useAdvancedContext: false,
showThinking: false
});
expect(response.status).toBe(200);
// Import ws service to access mock
const ws = (await import("../../services/ws.js")).default;
// Verify error message about AI being disabled
expect(ws.sendMessageToAllClients).toHaveBeenCalledWith({
type: 'llm-stream',
chatNoteId: testChatId,
error: 'Error during streaming: AI features are disabled. Please enable them in the settings.',
done: true
});
});
it("should save chat messages after streaming completion", async () => {
const completeResponse = 'This is the complete response';
mockChatPipelineExecute.mockImplementation(async (input) => {
const callback = input.streamCallback;
await callback(completeResponse, true, {});
});
await supertest(app)
.post(`/api/llm/chat/${testChatId}/messages/stream`)
.set("Cookie", sessionCookie)
.send({
content: "Save this response",
useAdvancedContext: false,
showThinking: false
});
// Wait for async operations
await new Promise(resolve => setTimeout(resolve, 100));
// Verify chat was updated with the complete response
expect(mockChatStorage.updateChat).toHaveBeenCalledWith(
testChatId,
expect.arrayContaining([
{ role: 'assistant', content: completeResponse }
]),
'Streaming Test Chat'
);
});
it("should handle rapid consecutive streaming requests", async () => {
let callCount = 0;
mockChatPipelineExecute.mockImplementation(async (input) => {
callCount++;
const callback = input.streamCallback;
await callback(`Response ${callCount}`, true, {});
});
// Send multiple requests rapidly
const promises = Array.from({ length: 3 }, (_, i) =>
supertest(app)
.post(`/api/llm/chat/${testChatId}/messages/stream`)
.set("Cookie", sessionCookie)
.send({
content: `Request ${i + 1}`,
useAdvancedContext: false,
showThinking: false
})
);
const responses = await Promise.all(promises);
// All should succeed
responses.forEach(response => {
expect(response.status).toBe(200);
expect(response.body.success).toBe(true);
});
// Verify all were processed
expect(mockChatPipelineExecute).toHaveBeenCalledTimes(3);
});
it("should handle large streaming responses", async () => {
const largeContent = 'x'.repeat(10000); // 10KB of content
mockChatPipelineExecute.mockImplementation(async (input) => {
const callback = input.streamCallback;
// Simulate chunked delivery of large content
for (let i = 0; i < 10; i++) {
await callback(largeContent.slice(i * 1000, (i + 1) * 1000), false, {});
}
await callback('', true, {});
});
const response = await supertest(app)
.post(`/api/llm/chat/${testChatId}/messages/stream`)
.set("Cookie", sessionCookie)
.send({
content: "Generate large response",
useAdvancedContext: false,
showThinking: false
});
expect(response.status).toBe(200);
// Import ws service to access mock
const ws = (await import("../../services/ws.js")).default;
// Verify multiple chunks were sent
const streamCalls = ws.sendMessageToAllClients.mock.calls.filter(
call => call[0].type === 'llm-stream' && call[0].content
);
expect(streamCalls.length).toBeGreaterThan(5);
});
});

View File

@ -537,11 +537,11 @@ async function handleStreamingProcess(
}
// Get AI service
const aiServiceManager = await import('../ai_service_manager.js');
const aiServiceManager = await import('../../services/llm/ai_service_manager.js');
await aiServiceManager.default.getOrCreateAnyService();
// Use the chat pipeline directly for streaming
const { ChatPipeline } = await import('../pipeline/chat_pipeline.js');
const { ChatPipeline } = await import('../../services/llm/pipeline/chat_pipeline.js');
const pipeline = new ChatPipeline({
enableStreaming: true,
enableMetrics: true,
@ -549,7 +549,7 @@ async function handleStreamingProcess(
});
// Get selected model
const { getSelectedModelConfig } = await import('../config/configuration_helpers.js');
const { getSelectedModelConfig } = await import('../../services/llm/config/configuration_helpers.js');
const modelConfig = await getSelectedModelConfig();
if (!modelConfig) {

View File

@ -0,0 +1,573 @@
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { processProviderStream, StreamProcessor } from '../stream_handler.js';
import type { ProviderStreamOptions } from '../stream_handler.js';
// Mock log service
vi.mock('../../log.js', () => ({
default: {
info: vi.fn(),
error: vi.fn(),
warn: vi.fn()
}
}));
describe('Provider Streaming Integration Tests', () => {
let mockProviderOptions: ProviderStreamOptions;
beforeEach(() => {
vi.clearAllMocks();
mockProviderOptions = {
providerName: 'TestProvider',
modelName: 'test-model-v1'
};
});
describe('OpenAI-like Provider Integration', () => {
it('should handle OpenAI streaming format', async () => {
// Simulate OpenAI streaming chunks
const openAIChunks = [
{
choices: [{ delta: { content: 'Hello' } }],
model: 'gpt-3.5-turbo'
},
{
choices: [{ delta: { content: ' world' } }],
model: 'gpt-3.5-turbo'
},
{
choices: [{ delta: { content: '!' } }],
model: 'gpt-3.5-turbo'
},
{
choices: [{ finish_reason: 'stop' }],
model: 'gpt-3.5-turbo',
usage: {
prompt_tokens: 10,
completion_tokens: 3,
total_tokens: 13
},
done: true
}
];
const mockIterator = {
async *[Symbol.asyncIterator]() {
for (const chunk of openAIChunks) {
yield chunk;
}
}
};
const receivedChunks: any[] = [];
const result = await processProviderStream(
mockIterator,
{ ...mockProviderOptions, providerName: 'OpenAI' },
(text, done, chunk) => {
receivedChunks.push({ text, done, chunk });
}
);
expect(result.completeText).toBe('Hello world!');
expect(result.chunkCount).toBe(4);
expect(receivedChunks.length).toBeGreaterThan(0);
// Verify callback received content chunks
const contentChunks = receivedChunks.filter(c => c.text);
expect(contentChunks.length).toBe(3);
});
it('should handle OpenAI tool calls', async () => {
const openAIWithTools = [
{
choices: [{ delta: { content: 'Let me calculate that' } }]
},
{
choices: [{
delta: {
tool_calls: [{
id: 'call_123',
type: 'function',
function: {
name: 'calculator',
arguments: '{"expression": "2+2"}'
}
}]
}
}]
},
{
choices: [{ delta: { content: 'The answer is 4' } }]
},
{ done: true }
];
const mockIterator = {
async *[Symbol.asyncIterator]() {
for (const chunk of openAIWithTools) {
yield chunk;
}
}
};
const result = await processProviderStream(
mockIterator,
{ ...mockProviderOptions, providerName: 'OpenAI' }
);
expect(result.completeText).toBe('Let me calculate thatThe answer is 4');
expect(result.toolCalls.length).toBe(1);
expect(result.toolCalls[0].function.name).toBe('calculator');
});
});
describe('Ollama Provider Integration', () => {
it('should handle Ollama streaming format', async () => {
const ollamaChunks = [
{
model: 'llama2',
message: { content: 'The weather' },
done: false
},
{
model: 'llama2',
message: { content: ' today is' },
done: false
},
{
model: 'llama2',
message: { content: ' sunny.' },
done: false
},
{
model: 'llama2',
message: { content: '' },
done: true,
prompt_eval_count: 15,
eval_count: 8,
total_duration: 12345678
}
];
const mockIterator = {
async *[Symbol.asyncIterator]() {
for (const chunk of ollamaChunks) {
yield chunk;
}
}
};
const receivedChunks: any[] = [];
const result = await processProviderStream(
mockIterator,
{ ...mockProviderOptions, providerName: 'Ollama' },
(text, done, chunk) => {
receivedChunks.push({ text, done, chunk });
}
);
expect(result.completeText).toBe('The weather today is sunny.');
expect(result.chunkCount).toBe(4);
// Verify final chunk has usage stats
expect(result.finalChunk.prompt_eval_count).toBe(15);
expect(result.finalChunk.eval_count).toBe(8);
});
it('should handle Ollama empty responses', async () => {
const ollamaEmpty = [
{
model: 'llama2',
message: { content: '' },
done: true,
prompt_eval_count: 5,
eval_count: 0
}
];
const mockIterator = {
async *[Symbol.asyncIterator]() {
for (const chunk of ollamaEmpty) {
yield chunk;
}
}
};
const result = await processProviderStream(
mockIterator,
{ ...mockProviderOptions, providerName: 'Ollama' }
);
expect(result.completeText).toBe('');
expect(result.chunkCount).toBe(1);
});
});
describe('Anthropic Provider Integration', () => {
it('should handle Anthropic streaming format', async () => {
const anthropicChunks = [
{
type: 'message_start',
message: {
id: 'msg_123',
type: 'message',
role: 'assistant',
content: []
}
},
{
type: 'content_block_start',
index: 0,
content_block: { type: 'text', text: '' }
},
{
type: 'content_block_delta',
index: 0,
delta: { type: 'text_delta', text: 'Hello' }
},
{
type: 'content_block_delta',
index: 0,
delta: { type: 'text_delta', text: ' from' }
},
{
type: 'content_block_delta',
index: 0,
delta: { type: 'text_delta', text: ' Claude!' }
},
{
type: 'content_block_stop',
index: 0
},
{
type: 'message_delta',
delta: { stop_reason: 'end_turn' },
usage: { output_tokens: 3 }
},
{
type: 'message_stop',
done: true
}
];
const mockIterator = {
async *[Symbol.asyncIterator]() {
for (const chunk of anthropicChunks) {
// Anthropic format needs conversion to our standard format
if (chunk.type === 'content_block_delta') {
yield {
message: { content: chunk.delta.text },
done: false
};
} else if (chunk.type === 'message_stop') {
yield { done: true };
}
}
}
};
const result = await processProviderStream(
mockIterator,
{ ...mockProviderOptions, providerName: 'Anthropic' }
);
expect(result.completeText).toBe('Hello from Claude!');
expect(result.chunkCount).toBe(4); // 3 content chunks + 1 done
});
it('should handle Anthropic thinking blocks', async () => {
const anthropicWithThinking = [
{
message: { content: '', thinking: 'Let me think about this...' },
done: false
},
{
message: { content: '', thinking: 'I need to consider multiple factors' },
done: false
},
{
message: { content: 'Based on my analysis' },
done: false
},
{
message: { content: ', the answer is 42.' },
done: true
}
];
const mockIterator = {
async *[Symbol.asyncIterator]() {
for (const chunk of anthropicWithThinking) {
yield chunk;
}
}
};
const receivedChunks: any[] = [];
const result = await processProviderStream(
mockIterator,
{ ...mockProviderOptions, providerName: 'Anthropic' },
(text, done, chunk) => {
receivedChunks.push({ text, done, chunk });
}
);
expect(result.completeText).toBe('Based on my analysis, the answer is 42.');
// Verify thinking states were captured
const thinkingChunks = receivedChunks.filter(c => c.chunk?.thinking);
expect(thinkingChunks.length).toBe(2);
});
});
describe('Error Scenarios Integration', () => {
it('should handle provider connection timeouts', async () => {
const timeoutIterator = {
async *[Symbol.asyncIterator]() {
yield { message: { content: 'Starting...' } };
// Simulate timeout
await new Promise((_, reject) =>
setTimeout(() => reject(new Error('Request timeout')), 100)
);
}
};
await expect(processProviderStream(
timeoutIterator,
mockProviderOptions
)).rejects.toThrow('Request timeout');
});
it('should handle malformed provider responses', async () => {
const malformedIterator = {
async *[Symbol.asyncIterator]() {
yield null; // Invalid chunk
yield undefined; // Invalid chunk
yield { invalidFormat: true }; // No standard fields
yield { done: true };
}
};
const result = await processProviderStream(
malformedIterator,
mockProviderOptions
);
expect(result.completeText).toBe('');
expect(result.chunkCount).toBe(4);
});
it('should handle provider rate limiting', async () => {
const rateLimitIterator = {
async *[Symbol.asyncIterator]() {
yield { message: { content: 'Starting request' } };
throw new Error('Rate limit exceeded. Please try again later.');
}
};
await expect(processProviderStream(
rateLimitIterator,
mockProviderOptions
)).rejects.toThrow('Rate limit exceeded');
});
it('should handle network interruptions', async () => {
const networkErrorIterator = {
async *[Symbol.asyncIterator]() {
yield { message: { content: 'Partial' } };
yield { message: { content: ' response' } };
throw new Error('Network error: Connection reset');
}
};
await expect(processProviderStream(
networkErrorIterator,
mockProviderOptions
)).rejects.toThrow('Network error');
});
});
describe('Performance and Scalability', () => {
it('should handle high-frequency chunk delivery', async () => {
const highFrequencyChunks = Array.from({ length: 1000 }, (_, i) => ({
message: { content: `chunk${i}` },
done: i === 999
}));
const mockIterator = {
async *[Symbol.asyncIterator]() {
for (const chunk of highFrequencyChunks) {
yield chunk;
// No delay - rapid fire
}
}
};
const startTime = Date.now();
const result = await processProviderStream(
mockIterator,
mockProviderOptions
);
const endTime = Date.now();
expect(result.chunkCount).toBe(1000);
expect(result.completeText).toContain('chunk999');
expect(endTime - startTime).toBeLessThan(5000); // Should complete in under 5s
});
it('should handle large individual chunks', async () => {
const largeContent = 'x'.repeat(100000); // 100KB chunk
const largeChunks = [
{ message: { content: largeContent }, done: false },
{ message: { content: ' end' }, done: true }
];
const mockIterator = {
async *[Symbol.asyncIterator]() {
for (const chunk of largeChunks) {
yield chunk;
}
}
};
const result = await processProviderStream(
mockIterator,
mockProviderOptions
);
expect(result.completeText.length).toBe(100004); // 100KB + ' end'
expect(result.chunkCount).toBe(2);
});
it('should handle concurrent streaming sessions', async () => {
const createMockIterator = (sessionId: number) => ({
async *[Symbol.asyncIterator]() {
for (let i = 0; i < 10; i++) {
yield {
message: { content: `Session${sessionId}-Chunk${i}` },
done: i === 9
};
await new Promise(resolve => setTimeout(resolve, 10));
}
}
});
// Start 5 concurrent streaming sessions
const promises = Array.from({ length: 5 }, (_, i) =>
processProviderStream(
createMockIterator(i),
{ ...mockProviderOptions, providerName: `Provider${i}` }
)
);
const results = await Promise.all(promises);
// Verify all sessions completed successfully
results.forEach((result, i) => {
expect(result.chunkCount).toBe(10);
expect(result.completeText).toContain(`Session${i}`);
});
});
});
describe('Memory Management', () => {
it('should not leak memory during long streaming sessions', async () => {
const longSessionIterator = {
async *[Symbol.asyncIterator]() {
for (let i = 0; i < 10000; i++) {
yield {
message: { content: `Chunk ${i} with some additional content to increase memory usage` },
done: i === 9999
};
// Periodic yield to event loop
if (i % 100 === 0) {
await new Promise(resolve => setImmediate(resolve));
}
}
}
};
const initialMemory = process.memoryUsage();
const result = await processProviderStream(
longSessionIterator,
mockProviderOptions
);
const finalMemory = process.memoryUsage();
expect(result.chunkCount).toBe(10000);
// Memory increase should be reasonable (less than 50MB)
const memoryIncrease = finalMemory.heapUsed - initialMemory.heapUsed;
expect(memoryIncrease).toBeLessThan(50 * 1024 * 1024);
});
it('should clean up resources on stream completion', async () => {
const resourceTracker = {
resources: new Set<string>(),
allocate(id: string) { this.resources.add(id); },
cleanup(id: string) { this.resources.delete(id); }
};
const mockIterator = {
async *[Symbol.asyncIterator]() {
resourceTracker.allocate('stream-1');
try {
yield { message: { content: 'Hello' } };
yield { message: { content: 'World' } };
yield { done: true };
} finally {
resourceTracker.cleanup('stream-1');
}
}
};
await processProviderStream(
mockIterator,
mockProviderOptions
);
expect(resourceTracker.resources.size).toBe(0);
});
});
describe('Provider-Specific Configurations', () => {
it('should handle provider-specific options', async () => {
const configuredOptions: ProviderStreamOptions = {
providerName: 'CustomProvider',
modelName: 'custom-model',
apiConfig: {
temperature: 0.7,
maxTokens: 1000,
customParameter: 'test-value'
}
};
const mockIterator = {
async *[Symbol.asyncIterator]() {
yield { message: { content: 'Configured response' }, done: true };
}
};
const result = await processProviderStream(
mockIterator,
configuredOptions
);
expect(result.completeText).toBe('Configured response');
});
it('should validate provider compatibility', async () => {
const unsupportedIterator = {
// Missing Symbol.asyncIterator
next() { return { value: null, done: true }; }
};
await expect(processProviderStream(
unsupportedIterator as any,
mockProviderOptions
)).rejects.toThrow('Invalid stream iterator');
});
});
});

View File

@ -0,0 +1,602 @@
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { StreamProcessor, createStreamHandler, processProviderStream, extractStreamStats, performProviderHealthCheck } from './stream_handler.js';
import type { StreamProcessingOptions, StreamChunk, ProviderStreamOptions } from './stream_handler.js';
// Mock the log module
vi.mock('../../log.js', () => ({
default: {
info: vi.fn(),
error: vi.fn(),
warn: vi.fn()
}
}));
describe('StreamProcessor', () => {
let mockCallback: ReturnType<typeof vi.fn>;
let mockOptions: StreamProcessingOptions;
beforeEach(() => {
mockCallback = vi.fn();
mockOptions = {
streamCallback: mockCallback,
providerName: 'TestProvider',
modelName: 'test-model'
};
vi.clearAllMocks();
});
describe('processChunk', () => {
it('should process a chunk with content', async () => {
const chunk = {
message: { content: 'Hello' },
done: false
};
const result = await StreamProcessor.processChunk(chunk, '', 1, mockOptions);
expect(result.completeText).toBe('Hello');
expect(result.logged).toBe(true);
});
it('should handle chunks without content', async () => {
const chunk = { done: false };
const result = await StreamProcessor.processChunk(chunk, 'existing', 2, mockOptions);
expect(result.completeText).toBe('existing');
expect(result.logged).toBe(false);
});
it('should log every 10th chunk', async () => {
const chunk = { message: { content: 'test' } };
const result = await StreamProcessor.processChunk(chunk, '', 10, mockOptions);
expect(result.logged).toBe(true);
});
it('should log final chunks with done flag', async () => {
const chunk = { done: true };
const result = await StreamProcessor.processChunk(chunk, 'complete', 5, mockOptions);
expect(result.logged).toBe(true);
});
it('should accumulate text correctly', async () => {
const chunk1 = { message: { content: 'Hello ' } };
const chunk2 = { message: { content: 'World' } };
const result1 = await StreamProcessor.processChunk(chunk1, '', 1, mockOptions);
const result2 = await StreamProcessor.processChunk(chunk2, result1.completeText, 2, mockOptions);
expect(result2.completeText).toBe('Hello World');
});
});
describe('sendChunkToCallback', () => {
it('should call callback with content', async () => {
await StreamProcessor.sendChunkToCallback(mockCallback, 'test content', false, {}, 1);
expect(mockCallback).toHaveBeenCalledWith('test content', false, {});
});
it('should handle async callbacks', async () => {
const asyncCallback = vi.fn().mockResolvedValue(undefined);
await StreamProcessor.sendChunkToCallback(asyncCallback, 'async test', true, { done: true }, 5);
expect(asyncCallback).toHaveBeenCalledWith('async test', true, { done: true });
});
it('should handle callback errors gracefully', async () => {
const errorCallback = vi.fn().mockRejectedValue(new Error('Callback error'));
// Should not throw
await expect(StreamProcessor.sendChunkToCallback(errorCallback, 'test', false, {}, 1))
.resolves.toBeUndefined();
});
it('should handle empty content', async () => {
await StreamProcessor.sendChunkToCallback(mockCallback, '', true, { done: true }, 10);
expect(mockCallback).toHaveBeenCalledWith('', true, { done: true });
});
it('should handle null content by converting to empty string', async () => {
await StreamProcessor.sendChunkToCallback(mockCallback, null as any, false, {}, 1);
expect(mockCallback).toHaveBeenCalledWith('', false, {});
});
});
describe('sendFinalCallback', () => {
it('should send final callback with complete text', async () => {
await StreamProcessor.sendFinalCallback(mockCallback, 'Complete text');
expect(mockCallback).toHaveBeenCalledWith('Complete text', true, { done: true, complete: true });
});
it('should handle empty complete text', async () => {
await StreamProcessor.sendFinalCallback(mockCallback, '');
expect(mockCallback).toHaveBeenCalledWith('', true, { done: true, complete: true });
});
it('should handle async final callbacks', async () => {
const asyncCallback = vi.fn().mockResolvedValue(undefined);
await StreamProcessor.sendFinalCallback(asyncCallback, 'Final');
expect(asyncCallback).toHaveBeenCalledWith('Final', true, { done: true, complete: true });
});
it('should handle final callback errors gracefully', async () => {
const errorCallback = vi.fn().mockRejectedValue(new Error('Final callback error'));
await expect(StreamProcessor.sendFinalCallback(errorCallback, 'test'))
.resolves.toBeUndefined();
});
});
describe('extractToolCalls', () => {
it('should extract tool calls from chunk', () => {
const chunk = {
message: {
tool_calls: [
{ id: '1', function: { name: 'test_tool', arguments: '{}' } }
]
}
};
const toolCalls = StreamProcessor.extractToolCalls(chunk);
expect(toolCalls).toHaveLength(1);
expect(toolCalls[0].function.name).toBe('test_tool');
});
it('should return empty array when no tool calls', () => {
const chunk = { message: { content: 'Just text' } };
const toolCalls = StreamProcessor.extractToolCalls(chunk);
expect(toolCalls).toEqual([]);
});
it('should handle missing message property', () => {
const chunk = {};
const toolCalls = StreamProcessor.extractToolCalls(chunk);
expect(toolCalls).toEqual([]);
});
it('should handle non-array tool_calls', () => {
const chunk = { message: { tool_calls: 'not-an-array' } };
const toolCalls = StreamProcessor.extractToolCalls(chunk);
expect(toolCalls).toEqual([]);
});
});
describe('createFinalResponse', () => {
it('should create a complete response object', () => {
const response = StreamProcessor.createFinalResponse(
'Complete text',
'test-model',
'TestProvider',
[{ id: '1', function: { name: 'tool1' } }],
{ promptTokens: 10, completionTokens: 20 }
);
expect(response).toEqual({
text: 'Complete text',
model: 'test-model',
provider: 'TestProvider',
tool_calls: [{ id: '1', function: { name: 'tool1' } }],
usage: { promptTokens: 10, completionTokens: 20 }
});
});
it('should handle empty parameters', () => {
const response = StreamProcessor.createFinalResponse('', '', '', []);
expect(response).toEqual({
text: '',
model: '',
provider: '',
tool_calls: [],
usage: {}
});
});
});
});
describe('createStreamHandler', () => {
it('should create a working stream handler', async () => {
const mockProcessFn = vi.fn().mockImplementation(async (callback) => {
await callback({ text: 'chunk1', done: false });
await callback({ text: 'chunk2', done: true });
return 'complete';
});
const handler = createStreamHandler(
{ providerName: 'test', modelName: 'model' },
mockProcessFn
);
const chunks: StreamChunk[] = [];
const result = await handler(async (chunk) => {
chunks.push(chunk);
});
expect(result).toBe('complete');
expect(chunks).toHaveLength(3); // 2 from processFn + 1 final
expect(chunks[2]).toEqual({ text: '', done: true });
});
it('should handle errors in processor function', async () => {
const mockProcessFn = vi.fn().mockRejectedValue(new Error('Process error'));
const handler = createStreamHandler(
{ providerName: 'test', modelName: 'model' },
mockProcessFn
);
await expect(handler(vi.fn())).rejects.toThrow('Process error');
});
it('should ensure final chunk even on error after some chunks', async () => {
const chunks: StreamChunk[] = [];
const mockProcessFn = vi.fn().mockImplementation(async (callback) => {
await callback({ text: 'chunk1', done: false });
throw new Error('Mid-stream error');
});
const handler = createStreamHandler(
{ providerName: 'test', modelName: 'model' },
mockProcessFn
);
try {
await handler(async (chunk) => {
chunks.push(chunk);
});
} catch (e) {
// Expected error
}
// Should have received the chunk before error and final done chunk
expect(chunks.length).toBeGreaterThanOrEqual(2);
expect(chunks[chunks.length - 1]).toEqual({ text: '', done: true });
});
});
describe('processProviderStream', () => {
let mockStreamIterator: AsyncIterable<any>;
let mockCallback: ReturnType<typeof vi.fn>;
beforeEach(() => {
mockCallback = vi.fn();
});
it('should process a complete stream', async () => {
const chunks = [
{ message: { content: 'Hello ' } },
{ message: { content: 'World' } },
{ done: true }
];
mockStreamIterator = {
async *[Symbol.asyncIterator]() {
for (const chunk of chunks) {
yield chunk;
}
}
};
const result = await processProviderStream(
mockStreamIterator,
{ providerName: 'Test', modelName: 'test-model' },
mockCallback
);
expect(result.completeText).toBe('Hello World');
expect(result.chunkCount).toBe(3);
expect(mockCallback).toHaveBeenCalledTimes(3);
});
it('should handle tool calls in stream', async () => {
const chunks = [
{ message: { content: 'Using tool...' } },
{
message: {
tool_calls: [
{ id: 'call_1', function: { name: 'calculator', arguments: '{"x": 5}' } }
]
}
},
{ done: true }
];
mockStreamIterator = {
async *[Symbol.asyncIterator]() {
for (const chunk of chunks) {
yield chunk;
}
}
};
const result = await processProviderStream(
mockStreamIterator,
{ providerName: 'Test', modelName: 'test-model' }
);
expect(result.toolCalls).toHaveLength(1);
expect(result.toolCalls[0].function.name).toBe('calculator');
});
it('should handle empty stream', async () => {
mockStreamIterator = {
async *[Symbol.asyncIterator]() {
// Empty stream
}
};
const result = await processProviderStream(
mockStreamIterator,
{ providerName: 'Test', modelName: 'test-model' },
mockCallback
);
expect(result.completeText).toBe('');
expect(result.chunkCount).toBe(0);
// Should still send final callback
expect(mockCallback).toHaveBeenCalledWith('', true, expect.any(Object));
});
it('should handle stream errors', async () => {
mockStreamIterator = {
async *[Symbol.asyncIterator]() {
yield { message: { content: 'Start' } };
throw new Error('Stream error');
}
};
await expect(processProviderStream(
mockStreamIterator,
{ providerName: 'Test', modelName: 'test-model' }
)).rejects.toThrow('Stream error');
});
it('should handle invalid stream iterator', async () => {
const invalidIterator = {} as any;
await expect(processProviderStream(
invalidIterator,
{ providerName: 'Test', modelName: 'test-model' }
)).rejects.toThrow('Invalid stream iterator');
});
it('should handle different chunk content formats', async () => {
const chunks = [
{ content: 'Direct content' },
{ choices: [{ delta: { content: 'OpenAI format' } }] },
{ message: { content: 'Standard format' } }
];
mockStreamIterator = {
async *[Symbol.asyncIterator]() {
for (const chunk of chunks) {
yield chunk;
}
}
};
const result = await processProviderStream(
mockStreamIterator,
{ providerName: 'Test', modelName: 'test-model' },
mockCallback
);
expect(mockCallback).toHaveBeenCalledTimes(4); // 3 chunks + final
});
it('should send final callback when last chunk has no done flag', async () => {
const chunks = [
{ message: { content: 'Hello' } },
{ message: { content: 'World' } }
// No done flag
];
mockStreamIterator = {
async *[Symbol.asyncIterator]() {
for (const chunk of chunks) {
yield chunk;
}
}
};
await processProviderStream(
mockStreamIterator,
{ providerName: 'Test', modelName: 'test-model' },
mockCallback
);
// Should have explicit final callback
const lastCall = mockCallback.mock.calls[mockCallback.mock.calls.length - 1];
expect(lastCall[1]).toBe(true); // done flag
});
});
describe('extractStreamStats', () => {
it('should extract Ollama format stats', () => {
const chunk = {
prompt_eval_count: 10,
eval_count: 20
};
const stats = extractStreamStats(chunk, 'Ollama');
expect(stats).toEqual({
promptTokens: 10,
completionTokens: 20,
totalTokens: 30
});
});
it('should extract OpenAI format stats', () => {
const chunk = {
usage: {
prompt_tokens: 15,
completion_tokens: 25,
total_tokens: 40
}
};
const stats = extractStreamStats(chunk, 'OpenAI');
expect(stats).toEqual({
promptTokens: 15,
completionTokens: 25,
totalTokens: 40
});
});
it('should handle missing stats', () => {
const chunk = { message: { content: 'No stats here' } };
const stats = extractStreamStats(chunk, 'Test');
expect(stats).toEqual({
promptTokens: 0,
completionTokens: 0,
totalTokens: 0
});
});
it('should handle null chunk', () => {
const stats = extractStreamStats(null, 'Test');
expect(stats).toEqual({
promptTokens: 0,
completionTokens: 0,
totalTokens: 0
});
});
it('should handle partial Ollama stats', () => {
const chunk = {
prompt_eval_count: 10
// Missing eval_count
};
const stats = extractStreamStats(chunk, 'Ollama');
expect(stats).toEqual({
promptTokens: 10,
completionTokens: 0,
totalTokens: 10
});
});
});
describe('performProviderHealthCheck', () => {
it('should return true on successful health check', async () => {
const mockCheckFn = vi.fn().mockResolvedValue({ status: 'ok' });
const result = await performProviderHealthCheck(mockCheckFn, 'TestProvider');
expect(result).toBe(true);
expect(mockCheckFn).toHaveBeenCalled();
});
it('should throw error on failed health check', async () => {
const mockCheckFn = vi.fn().mockRejectedValue(new Error('Connection refused'));
await expect(performProviderHealthCheck(mockCheckFn, 'TestProvider'))
.rejects.toThrow('Unable to connect to TestProvider server: Connection refused');
});
it('should handle non-Error rejections', async () => {
const mockCheckFn = vi.fn().mockRejectedValue('String error');
await expect(performProviderHealthCheck(mockCheckFn, 'TestProvider'))
.rejects.toThrow('Unable to connect to TestProvider server: String error');
});
});
describe('Streaming edge cases and concurrency', () => {
it('should handle rapid chunk delivery', async () => {
const chunks = Array.from({ length: 100 }, (_, i) => ({
message: { content: `chunk${i}` }
}));
const mockStreamIterator = {
async *[Symbol.asyncIterator]() {
for (const chunk of chunks) {
yield chunk;
}
}
};
const receivedChunks: any[] = [];
const result = await processProviderStream(
mockStreamIterator,
{ providerName: 'Test', modelName: 'test-model' },
async (text, done, chunk) => {
receivedChunks.push({ text, done });
// Simulate some processing delay
await new Promise(resolve => setTimeout(resolve, 1));
}
);
expect(result.chunkCount).toBe(100);
expect(result.completeText).toContain('chunk99');
});
it('should handle callback throwing errors', async () => {
const chunks = [
{ message: { content: 'chunk1' } },
{ message: { content: 'chunk2' } },
{ done: true }
];
const mockStreamIterator = {
async *[Symbol.asyncIterator]() {
for (const chunk of chunks) {
yield chunk;
}
}
};
let callCount = 0;
const errorCallback = vi.fn().mockImplementation(() => {
callCount++;
if (callCount === 2) {
throw new Error('Callback error');
}
});
// Should not throw, errors in callbacks are caught
const result = await processProviderStream(
mockStreamIterator,
{ providerName: 'Test', modelName: 'test-model' },
errorCallback
);
expect(result.completeText).toBe('chunk1chunk2');
});
it('should handle mixed content and tool calls', async () => {
const chunks = [
{ message: { content: 'Let me calculate that...' } },
{
message: {
content: '',
tool_calls: [{ id: '1', function: { name: 'calc' } }]
}
},
{ message: { content: 'The answer is 42.' } },
{ done: true }
];
const mockStreamIterator = {
async *[Symbol.asyncIterator]() {
for (const chunk of chunks) {
yield chunk;
}
}
};
const result = await processProviderStream(
mockStreamIterator,
{ providerName: 'Test', modelName: 'test-model' }
);
expect(result.completeText).toBe('Let me calculate that...The answer is 42.');
expect(result.toolCalls).toHaveLength(1);
});
});

View File

@ -383,12 +383,14 @@ export function extractStreamStats(finalChunk: any | null, providerName: string)
return { promptTokens: 0, completionTokens: 0, totalTokens: 0 };
}
// Ollama format
if (finalChunk.prompt_eval_count !== undefined && finalChunk.eval_count !== undefined) {
// Ollama format - handle partial stats where some fields might be missing
if (finalChunk.prompt_eval_count !== undefined || finalChunk.eval_count !== undefined) {
const promptTokens = finalChunk.prompt_eval_count || 0;
const completionTokens = finalChunk.eval_count || 0;
return {
promptTokens: finalChunk.prompt_eval_count || 0,
completionTokens: finalChunk.eval_count || 0,
totalTokens: (finalChunk.prompt_eval_count || 0) + (finalChunk.eval_count || 0)
promptTokens,
completionTokens,
totalTokens: promptTokens + completionTokens
};
}

View File

@ -0,0 +1,516 @@
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { Server as WebSocketServer } from 'ws';
import type { WebSocket } from 'ws';
// Mock dependencies
vi.mock('./log.js', () => ({
default: {
info: vi.fn(),
error: vi.fn(),
warn: vi.fn()
}
}));
vi.mock('./sync_mutex.js', () => ({
default: {
doExclusively: vi.fn().mockImplementation((fn) => fn())
}
}));
vi.mock('./sql.js', () => ({
getManyRows: vi.fn(),
getValue: vi.fn(),
getRow: vi.fn()
}));
vi.mock('../becca/becca.js', () => ({
default: {
getAttribute: vi.fn(),
getBranch: vi.fn(),
getNote: vi.fn(),
getOption: vi.fn()
}
}));
vi.mock('./protected_session.js', () => ({
default: {
decryptString: vi.fn((str) => str)
}
}));
vi.mock('./cls.js', () => ({
getAndClearEntityChangeIds: vi.fn().mockReturnValue([])
}));
// Mock WebSocket server
const mockWebSocketServer = {
clients: new Set<WebSocket>(),
on: vi.fn(),
close: vi.fn()
};
vi.mock('ws', () => ({
Server: vi.fn().mockImplementation(() => mockWebSocketServer),
WebSocket: {
OPEN: 1,
CLOSED: 3,
CONNECTING: 0,
CLOSING: 2
}
}));
describe('WebSocket Service', () => {
let wsService: any;
let mockWebSocket: Partial<WebSocket>;
let log: any;
beforeEach(async () => {
vi.clearAllMocks();
// Create mock WebSocket
mockWebSocket = {
readyState: 1, // WebSocket.OPEN
send: vi.fn(),
close: vi.fn(),
on: vi.fn(),
ping: vi.fn()
};
// Clear clients set
mockWebSocketServer.clients.clear();
mockWebSocketServer.clients.add(mockWebSocket as WebSocket);
// Get mocked log
log = (await import('./log.js')).default;
// Import service after mocks are set up
wsService = (await import('./ws.js')).default;
// Initialize the WebSocket server in the service
// This simulates calling the init function with a mock HTTP server and session parser
const mockHttpServer = {} as any;
const mockSessionParser = vi.fn((req, params, cb) => cb());
wsService.init(mockHttpServer, mockSessionParser);
});
afterEach(() => {
vi.clearAllMocks();
mockWebSocketServer.clients.clear();
});
describe('LLM Stream Message Broadcasting', () => {
it('should send basic LLM stream message to all clients', () => {
const message = {
type: 'llm-stream' as const,
chatNoteId: 'test-chat-123',
content: 'Hello world',
done: false
};
wsService.sendMessageToAllClients(message);
expect(mockWebSocket.send).toHaveBeenCalledWith(JSON.stringify(message));
expect(log.info).toHaveBeenCalledWith(
expect.stringContaining('Sending LLM stream message: chatNoteId=test-chat-123')
);
});
it('should send LLM stream message with thinking state', () => {
const message = {
type: 'llm-stream' as const,
chatNoteId: 'test-chat-456',
thinking: 'Processing your request...',
done: false
};
wsService.sendMessageToAllClients(message);
expect(mockWebSocket.send).toHaveBeenCalledWith(JSON.stringify(message));
expect(log.info).toHaveBeenCalledWith(
expect.stringMatching(/thinking=true/)
);
});
it('should send LLM stream message with tool execution', () => {
const message = {
type: 'llm-stream' as const,
chatNoteId: 'test-chat-789',
toolExecution: {
tool: 'calculator',
args: { expression: '2+2' },
result: '4',
toolCallId: 'call_123'
},
done: false
};
wsService.sendMessageToAllClients(message);
expect(mockWebSocket.send).toHaveBeenCalledWith(JSON.stringify(message));
expect(log.info).toHaveBeenCalledWith(
expect.stringMatching(/toolExecution=true/)
);
});
it('should send final LLM stream message with done flag', () => {
const message = {
type: 'llm-stream' as const,
chatNoteId: 'test-chat-final',
content: 'Final response',
done: true
};
wsService.sendMessageToAllClients(message);
expect(mockWebSocket.send).toHaveBeenCalledWith(JSON.stringify(message));
expect(log.info).toHaveBeenCalledWith(
expect.stringMatching(/done=true/)
);
});
it('should handle error in LLM stream message', () => {
const message = {
type: 'llm-stream' as const,
chatNoteId: 'test-chat-error',
error: 'AI service not available',
done: true
};
wsService.sendMessageToAllClients(message);
expect(mockWebSocket.send).toHaveBeenCalledWith(JSON.stringify(message));
});
it('should log client count for LLM stream messages', () => {
// Add multiple mock clients
const mockClient2 = { readyState: 1, send: vi.fn() };
const mockClient3 = { readyState: 1, send: vi.fn() };
mockWebSocketServer.clients.add(mockClient2 as WebSocket);
mockWebSocketServer.clients.add(mockClient3 as WebSocket);
const message = {
type: 'llm-stream' as const,
chatNoteId: 'test-multi-client',
content: 'Message to all',
done: false
};
wsService.sendMessageToAllClients(message);
expect(log.info).toHaveBeenCalledWith(
expect.stringContaining('Sent LLM stream message to 3 clients')
);
});
it('should handle closed WebSocket connections gracefully', () => {
// Set WebSocket to closed state
mockWebSocket.readyState = 3; // WebSocket.CLOSED
const message = {
type: 'llm-stream' as const,
chatNoteId: 'test-closed-connection',
content: 'This should not be sent',
done: false
};
wsService.sendMessageToAllClients(message);
expect(mockWebSocket.send).not.toHaveBeenCalled();
expect(log.info).toHaveBeenCalledWith(
expect.stringContaining('Sent LLM stream message to 0 clients')
);
});
it('should handle mixed open and closed connections', () => {
// Add a closed connection
const closedSocket = { readyState: 3, send: vi.fn() };
mockWebSocketServer.clients.add(closedSocket as WebSocket);
const message = {
type: 'llm-stream' as const,
chatNoteId: 'test-mixed-connections',
content: 'Mixed connection test',
done: false
};
wsService.sendMessageToAllClients(message);
expect(mockWebSocket.send).toHaveBeenCalledWith(JSON.stringify(message));
expect(closedSocket.send).not.toHaveBeenCalled();
expect(log.info).toHaveBeenCalledWith(
expect.stringContaining('Sent LLM stream message to 1 clients')
);
});
});
describe('LLM Stream Message Content Verification', () => {
it('should handle empty content in stream message', () => {
const message = {
type: 'llm-stream' as const,
chatNoteId: 'test-empty-content',
content: '',
done: false
};
wsService.sendMessageToAllClients(message);
expect(mockWebSocket.send).toHaveBeenCalledWith(JSON.stringify(message));
expect(log.info).toHaveBeenCalledWith(
expect.stringMatching(/content=false/)
);
});
it('should handle large content in stream message', () => {
const largeContent = 'x'.repeat(10000);
const message = {
type: 'llm-stream' as const,
chatNoteId: 'test-large-content',
content: largeContent,
done: false
};
wsService.sendMessageToAllClients(message);
expect(mockWebSocket.send).toHaveBeenCalledWith(JSON.stringify(message));
expect(log.info).toHaveBeenCalledWith(
expect.stringMatching(/content=true/)
);
});
it('should handle unicode content in stream message', () => {
const unicodeContent = '你好 🌍 こんにちは مرحبا';
const message = {
type: 'llm-stream' as const,
chatNoteId: 'test-unicode-content',
content: unicodeContent,
done: false
};
wsService.sendMessageToAllClients(message);
expect(mockWebSocket.send).toHaveBeenCalledWith(JSON.stringify(message));
const sentData = JSON.parse((mockWebSocket.send as any).mock.calls[0][0]);
expect(sentData.content).toBe(unicodeContent);
});
it('should handle complex tool execution data', () => {
const complexToolExecution = {
tool: 'data_analyzer',
args: {
dataset: {
rows: 1000,
columns: ['name', 'age', 'email'],
filters: { active: true }
},
operations: ['filter', 'group', 'aggregate']
},
result: {
summary: 'Analysis complete',
data: { filtered: 850, grouped: 10 }
},
toolCallId: 'call_complex_analysis'
};
const message = {
type: 'llm-stream' as const,
chatNoteId: 'test-complex-tool',
toolExecution: complexToolExecution,
done: false
};
wsService.sendMessageToAllClients(message);
expect(mockWebSocket.send).toHaveBeenCalledWith(JSON.stringify(message));
const sentData = JSON.parse((mockWebSocket.send as any).mock.calls[0][0]);
expect(sentData.toolExecution).toEqual(complexToolExecution);
});
});
describe('Non-LLM Message Handling', () => {
it('should send regular messages without special LLM logging', () => {
const message = {
type: 'frontend-update' as const,
data: { test: 'data' }
};
wsService.sendMessageToAllClients(message);
expect(mockWebSocket.send).toHaveBeenCalledWith(JSON.stringify(message));
expect(log.info).not.toHaveBeenCalledWith(
expect.stringContaining('LLM stream message')
);
});
it('should handle sync-failed messages quietly', () => {
const message = {
type: 'sync-failed' as const,
lastSyncedPush: 123
};
wsService.sendMessageToAllClients(message);
expect(mockWebSocket.send).toHaveBeenCalledWith(JSON.stringify(message));
// sync-failed messages should not generate regular logs
});
it('should handle api-log-messages quietly', () => {
const message = {
type: 'api-log-messages' as const,
logs: ['log1', 'log2']
};
wsService.sendMessageToAllClients(message);
expect(mockWebSocket.send).toHaveBeenCalledWith(JSON.stringify(message));
// api-log-messages should not generate regular logs
});
});
describe('WebSocket Connection Management', () => {
it('should handle WebSocket send errors gracefully', () => {
// Mock send to throw an error
(mockWebSocket.send as any).mockImplementation(() => {
throw new Error('Connection closed');
});
const message = {
type: 'llm-stream' as const,
chatNoteId: 'test-send-error',
content: 'This will fail to send',
done: false
};
// Should not throw
expect(() => wsService.sendMessageToAllClients(message)).not.toThrow();
});
it('should handle multiple concurrent stream messages', async () => {
const promises = Array.from({ length: 10 }, (_, i) => {
const message = {
type: 'llm-stream' as const,
chatNoteId: `concurrent-test-${i}`,
content: `Message ${i}`,
done: false
};
return Promise.resolve(wsService.sendMessageToAllClients(message));
});
await Promise.all(promises);
expect(mockWebSocket.send).toHaveBeenCalledTimes(10);
});
it('should handle rapid message bursts', () => {
for (let i = 0; i < 100; i++) {
const message = {
type: 'llm-stream' as const,
chatNoteId: 'burst-test',
content: `Burst ${i}`,
done: i === 99
};
wsService.sendMessageToAllClients(message);
}
expect(mockWebSocket.send).toHaveBeenCalledTimes(100);
});
});
describe('Message Serialization', () => {
it('should handle circular reference objects', () => {
const circularObj: any = { name: 'test' };
circularObj.self = circularObj;
const message = {
type: 'llm-stream' as const,
chatNoteId: 'circular-test',
toolExecution: {
tool: 'test',
args: circularObj,
result: 'success'
},
done: false
};
// Should handle serialization error gracefully
expect(() => wsService.sendMessageToAllClients(message)).not.toThrow();
});
it('should handle undefined and null values in messages', () => {
const message = {
type: 'llm-stream' as const,
chatNoteId: 'null-undefined-test',
content: undefined,
thinking: null,
toolExecution: undefined,
done: false
};
wsService.sendMessageToAllClients(message);
expect(mockWebSocket.send).toHaveBeenCalled();
const sentData = JSON.parse((mockWebSocket.send as any).mock.calls[0][0]);
expect(sentData.thinking).toBeNull();
expect(sentData.content).toBeUndefined();
});
it('should preserve message structure integrity', () => {
const originalMessage = {
type: 'llm-stream' as const,
chatNoteId: 'integrity-test',
content: 'Test content',
thinking: 'Test thinking',
toolExecution: {
tool: 'test_tool',
args: { param1: 'value1' },
result: 'success'
},
done: true
};
wsService.sendMessageToAllClients(originalMessage);
const sentData = JSON.parse((mockWebSocket.send as any).mock.calls[0][0]);
expect(sentData).toEqual(originalMessage);
});
});
describe('Logging Verification', () => {
it('should log message details correctly', () => {
const message = {
type: 'llm-stream' as const,
chatNoteId: 'log-verification-test',
content: 'Test content',
thinking: 'Test thinking',
toolExecution: { tool: 'test' },
done: true
};
wsService.sendMessageToAllClients(message);
expect(log.info).toHaveBeenCalledWith(
expect.stringMatching(
/chatNoteId=log-verification-test.*content=true.*thinking=true.*toolExecution=true.*done=true/
)
);
});
it('should log boolean flags correctly for empty values', () => {
const message = {
type: 'llm-stream' as const,
chatNoteId: 'empty-values-test',
content: '',
thinking: undefined,
toolExecution: null,
done: false
};
wsService.sendMessageToAllClients(message);
expect(log.info).toHaveBeenCalledWith(
expect.stringMatching(
/content=false.*thinking=false.*toolExecution=false.*done=false/
)
);
});
});
});