From a36bba3b40036b13ea561ade116d931b3c784ea2 Mon Sep 17 00:00:00 2001 From: dal Date: Fri, 11 Jul 2025 09:51:51 -0600 Subject: [PATCH] Add '@buster/ai' dependency and enhance chat cancellation logic - Added '@buster/ai' as a workspace dependency in pnpm-lock.yaml and package.json. - Updated database-migrations.yml to trigger on changes in the database package. - Refined the cancelChatHandler function to include detailed message cleanup and trigger cancellation logic. - Improved response handling in the chat cancellation endpoint to return a success message. - Enhanced updateMessageFields to support marking messages as completed. --- .github/workflows/database-migrations.yml | 4 + apps/server/package.json | 2 + apps/server/src/api/v2/chats/cancel-chat.ts | 323 ++++++++++++++++++-- apps/server/src/api/v2/chats/index.ts | 4 +- packages/database/src/helpers/messages.ts | 6 + pnpm-lock.yaml | 6 + 6 files changed, 317 insertions(+), 28 deletions(-) diff --git a/.github/workflows/database-migrations.yml b/.github/workflows/database-migrations.yml index 0d0a126be..2d5a29858 100644 --- a/.github/workflows/database-migrations.yml +++ b/.github/workflows/database-migrations.yml @@ -3,6 +3,10 @@ name: Database Migrations on: push: branches: [main, staging] + paths: + - 'packages/database/drizzle/**' + - 'packages/database/drizzle.config.ts' + - '.github/workflows/database-migrations.yml' jobs: migrate: diff --git a/apps/server/package.json b/apps/server/package.json index e70300f2b..3a373c5c4 100644 --- a/apps/server/package.json +++ b/apps/server/package.json @@ -20,6 +20,7 @@ }, "dependencies": { "@buster/access-controls": "workspace:*", + "@buster/ai": "workspace:*", "@buster/database": "workspace:*", "@buster/server-shared": "workspace:*", "@buster/slack": "workspace:*", @@ -30,6 +31,7 @@ "@supabase/supabase-js": "catalog:", "@trigger.dev/sdk": "catalog:", "drizzle-orm": "catalog:", + "ai": "catalog:", "hono": "catalog:", "hono-pino": "^0.9.1", "pino": "^9.7.0", diff --git a/apps/server/src/api/v2/chats/cancel-chat.ts b/apps/server/src/api/v2/chats/cancel-chat.ts index 6b1a9f3a5..566942744 100644 --- a/apps/server/src/api/v2/chats/cancel-chat.ts +++ b/apps/server/src/api/v2/chats/cancel-chat.ts @@ -1,16 +1,45 @@ -import type { ChatWithMessages } from '@buster/server-shared/chats'; +import { canUserAccessChatCached } from '@buster/access-controls'; +import { + type ToolCallContent, + type ToolResultContent, + isToolCallContent, + isToolResultContent, +} from '@buster/ai/utils/database/types'; import type { User } from '@buster/database'; -import { eq, and, isNull, isNotNull } from '@buster/database'; +import { and, eq, isNotNull, updateMessageFields } from '@buster/database'; import { db, messages } from '@buster/database'; -import { HTTPException } from 'hono/http-exception'; +import type { + ChatMessageReasoningMessage, + ChatMessageResponseMessage, +} from '@buster/server-shared/chats'; +import { runs } from '@trigger.dev/sdk'; +import type { CoreMessage } from 'ai'; +import { errorResponse } from '../../../utils/response'; -export async function cancelChatHandler( - chatId: string, - user: User -): Promise { - // Query for messages with the given chat_id where is_completed: false and trigger_run_id is not null - const incompleteTriggerMessages = await db - .select() +/** + * Cancel a chat and clean up any incomplete messages + * + * Strategy: + * 1. Cancel the trigger runs + * 2. Fetch fresh data and clean up messages + * 3. Mark messages as completed with proper cleanup + */ +export async function cancelChatHandler(chatId: string, user: User): Promise { + const userHasAccessToChat = await canUserAccessChatCached({ + userId: user.id, + chatId, + }); + + if (!userHasAccessToChat) { + throw errorResponse('You do not have access to this chat', 403); + } + + // First, query just for IDs and trigger run IDs + const messagesToCancel = await db + .select({ + id: messages.id, + triggerRunId: messages.triggerRunId, + }) .from(messages) .where( and( @@ -20,21 +49,263 @@ export async function cancelChatHandler( ) ); - // TODO: Implement trigger cancellation logic here - // For each message with a trigger_run_id, cancel the corresponding trigger run - // Example (to be implemented): - // for (const message of incompleteTriggerMessages) { - // if (message.triggerRunId) { - // await cancelTriggerRun(message.triggerRunId); - // } - // } + // Type narrow to ensure triggerRunId is not null + const incompleteTriggerMessages = messagesToCancel.filter( + (result): result is { id: string; triggerRunId: string } => result.triggerRunId !== null + ); - // After cancellation, return the chat object with messages - // This should match the format returned by get chat and post chat endpoints - // TODO: Fetch the full chat object with all messages in the same format as get_chat_handler - // For now, this is a stub that needs to be implemented - - throw new HTTPException(501, { - message: 'Cancel chat endpoint not fully implemented - needs to return ChatWithMessages format', + // Cancel all trigger runs first + const cancellationPromises = incompleteTriggerMessages.map(async (message) => { + try { + await runs.cancel(message.triggerRunId); + console.info(`Cancelled trigger run ${message.triggerRunId} for message ${message.id}`); + } catch (error) { + console.error(`Failed to cancel trigger run ${message.triggerRunId}:`, error); + // Continue with cleanup even if cancellation fails + } }); -} \ No newline at end of file + + // Wait for all cancellations to complete + await Promise.allSettled(cancellationPromises); + + await new Promise((resolve) => setTimeout(resolve, 500)); + + // Now fetch the latest message data and clean up each message + const cleanupPromises = incompleteTriggerMessages.map(async (message) => { + // Fetch the latest message data + const [latestMessageData] = await db + .select({ + rawLlmMessages: messages.rawLlmMessages, + reasoning: messages.reasoning, + responseMessages: messages.responseMessages, + }) + .from(messages) + .where(eq(messages.id, message.id)); + + if (latestMessageData) { + await cleanUpMessage( + message.id, + latestMessageData.rawLlmMessages, + latestMessageData.reasoning, + latestMessageData.responseMessages + ); + } + }); + + // Wait for all cleanups to complete + await Promise.allSettled(cleanupPromises); +} +/** + * Find tool calls without corresponding tool results + */ +function findIncompleteToolCalls(messages: CoreMessage[]): ToolCallContent[] { + const toolCalls = new Map(); + const toolResults = new Set(); + + // First pass: collect all tool calls and tool results + for (const message of messages) { + if (message.role === 'assistant' && Array.isArray(message.content)) { + for (const content of message.content) { + if (isToolCallContent(content)) { + toolCalls.set(content.toolCallId, content); + } + } + } else if (message.role === 'tool' && Array.isArray(message.content)) { + for (const content of message.content) { + if (isToolResultContent(content)) { + toolResults.add(content.toolCallId); + } + } + } + } + + // Second pass: find tool calls without results + const incompleteToolCalls: ToolCallContent[] = []; + for (const [toolCallId, toolCall] of toolCalls) { + if (!toolResults.has(toolCallId)) { + incompleteToolCalls.push(toolCall); + } + } + + return incompleteToolCalls; +} + +/** + * Create tool result messages for incomplete tool calls + */ +function createCancellationToolResults(incompleteToolCalls: ToolCallContent[]): CoreMessage[] { + if (incompleteToolCalls.length === 0) { + return []; + } + + const toolResultMessages: CoreMessage[] = []; + + for (const toolCall of incompleteToolCalls) { + const toolResult: ToolResultContent = { + type: 'tool-result', + toolCallId: toolCall.toolCallId, + toolName: toolCall.toolName, + result: { + error: true, + message: 'The user ended the chat', + }, + }; + + toolResultMessages.push({ + role: 'tool', + content: [toolResult], + }); + } + + return toolResultMessages; +} + +/** + * Clean up messages by adding tool results for incomplete tool calls + */ +function cleanUpRawLlmMessages(messages: CoreMessage[]): CoreMessage[] { + const incompleteToolCalls = findIncompleteToolCalls(messages); + + if (incompleteToolCalls.length === 0) { + return messages; + } + + // Create tool result messages for incomplete tool calls + const toolResultMessages = createCancellationToolResults(incompleteToolCalls); + + // Append tool results to the messages + return [...messages, ...toolResultMessages]; +} + +/** + * Ensure reasoning messages are marked as completed + */ +function ensureReasoningMessagesCompleted( + reasoning: ChatMessageReasoningMessage[] +): ChatMessageReasoningMessage[] { + console.info('Ensuring reasoning messages are completed:', { + totalMessages: reasoning.length, + loadingMessages: reasoning.filter( + (msg) => msg && typeof msg === 'object' && 'status' in msg && msg.status === 'loading' + ).length, + }); + + return reasoning.map((msg, index) => { + if (msg && typeof msg === 'object' && 'status' in msg && msg.status === 'loading') { + console.info(`Marking reasoning message ${index} as completed:`, { + id: 'id' in msg ? msg.id : 'unknown', + title: 'title' in msg ? msg.title : 'unknown', + previousStatus: msg.status, + }); + return { + ...msg, + status: 'completed' as const, + }; + } + return msg; + }); +} + +/** + * Clean up and finalize all message fields for a cancelled chat + */ +interface CleanedMessageFields { + rawLlmMessages: CoreMessage[]; + reasoning: ChatMessageReasoningMessage[]; + responseMessages: ChatMessageResponseMessage[]; +} + +function cleanUpMessageFields( + rawLlmMessages: CoreMessage[], + reasoning: ChatMessageReasoningMessage[], + responseMessages: ChatMessageResponseMessage[] +): CleanedMessageFields { + // Clean up raw LLM messages by adding tool results for incomplete tool calls + const cleanedRawMessages = cleanUpRawLlmMessages(rawLlmMessages); + + // Ensure all reasoning messages are marked as completed + const completedReasoning = ensureReasoningMessagesCompleted(reasoning); + + return { + rawLlmMessages: cleanedRawMessages, + reasoning: completedReasoning, + responseMessages: responseMessages, + }; +} + +async function cleanUpMessage( + messageId: string, + rawLlmMessages: unknown, + reasoning: unknown, + responseMessages: unknown +): Promise { + try { + // Parse and validate the message fields + const currentRawMessages = Array.isArray(rawLlmMessages) + ? (rawLlmMessages as CoreMessage[]) + : []; + const currentReasoning = Array.isArray(reasoning) + ? (reasoning as ChatMessageReasoningMessage[]) + : []; + + // Handle responseMessages which could be an array or object + let currentResponseMessages: ChatMessageResponseMessage[] = []; + if (Array.isArray(responseMessages)) { + currentResponseMessages = responseMessages as ChatMessageResponseMessage[]; + } else if (responseMessages && typeof responseMessages === 'object') { + // Convert object to array if it has values + const values = Object.values(responseMessages); + if (values.length > 0 && values.every((v) => v && typeof v === 'object')) { + currentResponseMessages = values as ChatMessageResponseMessage[]; + } + } + + console.info(`Cleaning up message ${messageId}:`, { + rawMessagesCount: currentRawMessages.length, + reasoningCount: currentReasoning.length, + responseMessagesCount: currentResponseMessages.length, + responseMessagesType: Array.isArray(responseMessages) ? 'array' : typeof responseMessages, + }); + + // Clean up all message fields + const cleanedFields = cleanUpMessageFields( + currentRawMessages, + currentReasoning, + currentResponseMessages + ); + + // Determine the final reasoning message based on whether we were in response phase + const hasResponseMessages = currentResponseMessages.length > 0; + const finalReasoningMessage = hasResponseMessages + ? 'Stopped during final response' + : 'Stopped reasoning'; + + // Log the cleaned reasoning to debug + console.info('Cleaned reasoning before save:', { + reasoningCount: cleanedFields.reasoning.length, + loadingCount: cleanedFields.reasoning.filter( + (r) => r && typeof r === 'object' && 'status' in r && r.status === 'loading' + ).length, + lastReasoningMessage: cleanedFields.reasoning[cleanedFields.reasoning.length - 1], + }); + + // Ensure the reasoning array is properly serializable + const serializableReasoning = JSON.parse(JSON.stringify(cleanedFields.reasoning)); + + // Update the message in the database + await updateMessageFields(messageId, { + rawLlmMessages: cleanedFields.rawLlmMessages, + reasoning: serializableReasoning, + responseMessages: cleanedFields.responseMessages, + finalReasoningMessage: finalReasoningMessage, + isCompleted: true, + }); + + console.info( + `Successfully cleaned up message ${messageId} with finalReasoningMessage: ${finalReasoningMessage}` + ); + } catch (error) { + console.error(`Failed to clean up message ${messageId}:`, error); + // Don't throw - we want to continue processing other messages + } +} diff --git a/apps/server/src/api/v2/chats/index.ts b/apps/server/src/api/v2/chats/index.ts index a69b2b804..3a9b40ef2 100644 --- a/apps/server/src/api/v2/chats/index.ts +++ b/apps/server/src/api/v2/chats/index.ts @@ -55,8 +55,8 @@ const app = new Hono() const params = c.req.valid('param'); const user = c.get('busterUser'); - const response = await cancelChatHandler(params.chat_id, user); - return c.json(response); + await cancelChatHandler(params.chat_id, user); + return c.json({ success: true, message: 'Chat cancelled successfully' }); }) .onError((e, c) => { if (e instanceof ChatError) { diff --git a/packages/database/src/helpers/messages.ts b/packages/database/src/helpers/messages.ts index 4cf9b73cc..5cdc0e396 100644 --- a/packages/database/src/helpers/messages.ts +++ b/packages/database/src/helpers/messages.ts @@ -110,6 +110,7 @@ export async function updateMessageFields( reasoning?: unknown; rawLlmMessages?: unknown; finalReasoningMessage?: string; + isCompleted?: boolean; } ): Promise<{ success: boolean }> { try { @@ -125,6 +126,7 @@ export async function updateMessageFields( reasoning?: unknown; rawLlmMessages?: unknown; finalReasoningMessage?: string; + isCompleted?: boolean; } = { updatedAt: new Date().toISOString(), }; @@ -143,6 +145,10 @@ export async function updateMessageFields( updateData.finalReasoningMessage = fields.finalReasoningMessage; } + if ('isCompleted' in fields) { + updateData.isCompleted = fields.isCompleted; + } + await db .update(messages) .set(updateData) diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 03438ac52..033b63d91 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -105,6 +105,9 @@ importers: '@buster/access-controls': specifier: workspace:* version: link:../../packages/access-controls + '@buster/ai': + specifier: workspace:* + version: link:../../packages/ai '@buster/database': specifier: workspace:* version: link:../../packages/database @@ -132,6 +135,9 @@ importers: '@trigger.dev/sdk': specifier: 'catalog:' version: 4.0.0-v4-beta.22(ai@4.3.16(react@18.3.1)(zod@3.25.75))(zod@3.25.75) + ai: + specifier: 'catalog:' + version: 4.3.16(react@18.3.1)(zod@3.25.75) drizzle-orm: specifier: 'catalog:' version: 0.44.2(@opentelemetry/api@1.9.0)(@types/pg@8.15.4)(mysql2@3.14.1)(pg@8.16.3)(postgres@3.4.7)