Add '@buster/ai' dependency and enhance chat cancellation logic

- Added '@buster/ai' as a workspace dependency in pnpm-lock.yaml and package.json.
- Updated database-migrations.yml to trigger on changes in the database package.
- Refined the cancelChatHandler function to include detailed message cleanup and trigger cancellation logic.
- Improved response handling in the chat cancellation endpoint to return a success message.
- Enhanced updateMessageFields to support marking messages as completed.
This commit is contained in:
dal 2025-07-11 09:51:51 -06:00
parent 87f3853ce8
commit a36bba3b40
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
6 changed files with 317 additions and 28 deletions

View File

@ -3,6 +3,10 @@ name: Database Migrations
on:
push:
branches: [main, staging]
paths:
- 'packages/database/drizzle/**'
- 'packages/database/drizzle.config.ts'
- '.github/workflows/database-migrations.yml'
jobs:
migrate:

View File

@ -20,6 +20,7 @@
},
"dependencies": {
"@buster/access-controls": "workspace:*",
"@buster/ai": "workspace:*",
"@buster/database": "workspace:*",
"@buster/server-shared": "workspace:*",
"@buster/slack": "workspace:*",
@ -30,6 +31,7 @@
"@supabase/supabase-js": "catalog:",
"@trigger.dev/sdk": "catalog:",
"drizzle-orm": "catalog:",
"ai": "catalog:",
"hono": "catalog:",
"hono-pino": "^0.9.1",
"pino": "^9.7.0",

View File

@ -1,16 +1,45 @@
import type { ChatWithMessages } from '@buster/server-shared/chats';
import { canUserAccessChatCached } from '@buster/access-controls';
import {
type ToolCallContent,
type ToolResultContent,
isToolCallContent,
isToolResultContent,
} from '@buster/ai/utils/database/types';
import type { User } from '@buster/database';
import { eq, and, isNull, isNotNull } from '@buster/database';
import { and, eq, isNotNull, updateMessageFields } from '@buster/database';
import { db, messages } from '@buster/database';
import { HTTPException } from 'hono/http-exception';
import type {
ChatMessageReasoningMessage,
ChatMessageResponseMessage,
} from '@buster/server-shared/chats';
import { runs } from '@trigger.dev/sdk';
import type { CoreMessage } from 'ai';
import { errorResponse } from '../../../utils/response';
export async function cancelChatHandler(
chatId: string,
user: User
): Promise<ChatWithMessages> {
// Query for messages with the given chat_id where is_completed: false and trigger_run_id is not null
const incompleteTriggerMessages = await db
.select()
/**
* Cancel a chat and clean up any incomplete messages
*
* Strategy:
* 1. Cancel the trigger runs
* 2. Fetch fresh data and clean up messages
* 3. Mark messages as completed with proper cleanup
*/
export async function cancelChatHandler(chatId: string, user: User): Promise<void> {
const userHasAccessToChat = await canUserAccessChatCached({
userId: user.id,
chatId,
});
if (!userHasAccessToChat) {
throw errorResponse('You do not have access to this chat', 403);
}
// First, query just for IDs and trigger run IDs
const messagesToCancel = await db
.select({
id: messages.id,
triggerRunId: messages.triggerRunId,
})
.from(messages)
.where(
and(
@ -20,21 +49,263 @@ export async function cancelChatHandler(
)
);
// TODO: Implement trigger cancellation logic here
// For each message with a trigger_run_id, cancel the corresponding trigger run
// Example (to be implemented):
// for (const message of incompleteTriggerMessages) {
// if (message.triggerRunId) {
// await cancelTriggerRun(message.triggerRunId);
// }
// }
// Type narrow to ensure triggerRunId is not null
const incompleteTriggerMessages = messagesToCancel.filter(
(result): result is { id: string; triggerRunId: string } => result.triggerRunId !== null
);
// After cancellation, return the chat object with messages
// This should match the format returned by get chat and post chat endpoints
// TODO: Fetch the full chat object with all messages in the same format as get_chat_handler
// For now, this is a stub that needs to be implemented
throw new HTTPException(501, {
message: 'Cancel chat endpoint not fully implemented - needs to return ChatWithMessages format',
// Cancel all trigger runs first
const cancellationPromises = incompleteTriggerMessages.map(async (message) => {
try {
await runs.cancel(message.triggerRunId);
console.info(`Cancelled trigger run ${message.triggerRunId} for message ${message.id}`);
} catch (error) {
console.error(`Failed to cancel trigger run ${message.triggerRunId}:`, error);
// Continue with cleanup even if cancellation fails
}
});
}
// Wait for all cancellations to complete
await Promise.allSettled(cancellationPromises);
await new Promise((resolve) => setTimeout(resolve, 500));
// Now fetch the latest message data and clean up each message
const cleanupPromises = incompleteTriggerMessages.map(async (message) => {
// Fetch the latest message data
const [latestMessageData] = await db
.select({
rawLlmMessages: messages.rawLlmMessages,
reasoning: messages.reasoning,
responseMessages: messages.responseMessages,
})
.from(messages)
.where(eq(messages.id, message.id));
if (latestMessageData) {
await cleanUpMessage(
message.id,
latestMessageData.rawLlmMessages,
latestMessageData.reasoning,
latestMessageData.responseMessages
);
}
});
// Wait for all cleanups to complete
await Promise.allSettled(cleanupPromises);
}
/**
* Find tool calls without corresponding tool results
*/
function findIncompleteToolCalls(messages: CoreMessage[]): ToolCallContent[] {
const toolCalls = new Map<string, ToolCallContent>();
const toolResults = new Set<string>();
// First pass: collect all tool calls and tool results
for (const message of messages) {
if (message.role === 'assistant' && Array.isArray(message.content)) {
for (const content of message.content) {
if (isToolCallContent(content)) {
toolCalls.set(content.toolCallId, content);
}
}
} else if (message.role === 'tool' && Array.isArray(message.content)) {
for (const content of message.content) {
if (isToolResultContent(content)) {
toolResults.add(content.toolCallId);
}
}
}
}
// Second pass: find tool calls without results
const incompleteToolCalls: ToolCallContent[] = [];
for (const [toolCallId, toolCall] of toolCalls) {
if (!toolResults.has(toolCallId)) {
incompleteToolCalls.push(toolCall);
}
}
return incompleteToolCalls;
}
/**
* Create tool result messages for incomplete tool calls
*/
function createCancellationToolResults(incompleteToolCalls: ToolCallContent[]): CoreMessage[] {
if (incompleteToolCalls.length === 0) {
return [];
}
const toolResultMessages: CoreMessage[] = [];
for (const toolCall of incompleteToolCalls) {
const toolResult: ToolResultContent = {
type: 'tool-result',
toolCallId: toolCall.toolCallId,
toolName: toolCall.toolName,
result: {
error: true,
message: 'The user ended the chat',
},
};
toolResultMessages.push({
role: 'tool',
content: [toolResult],
});
}
return toolResultMessages;
}
/**
* Clean up messages by adding tool results for incomplete tool calls
*/
function cleanUpRawLlmMessages(messages: CoreMessage[]): CoreMessage[] {
const incompleteToolCalls = findIncompleteToolCalls(messages);
if (incompleteToolCalls.length === 0) {
return messages;
}
// Create tool result messages for incomplete tool calls
const toolResultMessages = createCancellationToolResults(incompleteToolCalls);
// Append tool results to the messages
return [...messages, ...toolResultMessages];
}
/**
* Ensure reasoning messages are marked as completed
*/
function ensureReasoningMessagesCompleted(
reasoning: ChatMessageReasoningMessage[]
): ChatMessageReasoningMessage[] {
console.info('Ensuring reasoning messages are completed:', {
totalMessages: reasoning.length,
loadingMessages: reasoning.filter(
(msg) => msg && typeof msg === 'object' && 'status' in msg && msg.status === 'loading'
).length,
});
return reasoning.map((msg, index) => {
if (msg && typeof msg === 'object' && 'status' in msg && msg.status === 'loading') {
console.info(`Marking reasoning message ${index} as completed:`, {
id: 'id' in msg ? msg.id : 'unknown',
title: 'title' in msg ? msg.title : 'unknown',
previousStatus: msg.status,
});
return {
...msg,
status: 'completed' as const,
};
}
return msg;
});
}
/**
* Clean up and finalize all message fields for a cancelled chat
*/
interface CleanedMessageFields {
rawLlmMessages: CoreMessage[];
reasoning: ChatMessageReasoningMessage[];
responseMessages: ChatMessageResponseMessage[];
}
function cleanUpMessageFields(
rawLlmMessages: CoreMessage[],
reasoning: ChatMessageReasoningMessage[],
responseMessages: ChatMessageResponseMessage[]
): CleanedMessageFields {
// Clean up raw LLM messages by adding tool results for incomplete tool calls
const cleanedRawMessages = cleanUpRawLlmMessages(rawLlmMessages);
// Ensure all reasoning messages are marked as completed
const completedReasoning = ensureReasoningMessagesCompleted(reasoning);
return {
rawLlmMessages: cleanedRawMessages,
reasoning: completedReasoning,
responseMessages: responseMessages,
};
}
async function cleanUpMessage(
messageId: string,
rawLlmMessages: unknown,
reasoning: unknown,
responseMessages: unknown
): Promise<void> {
try {
// Parse and validate the message fields
const currentRawMessages = Array.isArray(rawLlmMessages)
? (rawLlmMessages as CoreMessage[])
: [];
const currentReasoning = Array.isArray(reasoning)
? (reasoning as ChatMessageReasoningMessage[])
: [];
// Handle responseMessages which could be an array or object
let currentResponseMessages: ChatMessageResponseMessage[] = [];
if (Array.isArray(responseMessages)) {
currentResponseMessages = responseMessages as ChatMessageResponseMessage[];
} else if (responseMessages && typeof responseMessages === 'object') {
// Convert object to array if it has values
const values = Object.values(responseMessages);
if (values.length > 0 && values.every((v) => v && typeof v === 'object')) {
currentResponseMessages = values as ChatMessageResponseMessage[];
}
}
console.info(`Cleaning up message ${messageId}:`, {
rawMessagesCount: currentRawMessages.length,
reasoningCount: currentReasoning.length,
responseMessagesCount: currentResponseMessages.length,
responseMessagesType: Array.isArray(responseMessages) ? 'array' : typeof responseMessages,
});
// Clean up all message fields
const cleanedFields = cleanUpMessageFields(
currentRawMessages,
currentReasoning,
currentResponseMessages
);
// Determine the final reasoning message based on whether we were in response phase
const hasResponseMessages = currentResponseMessages.length > 0;
const finalReasoningMessage = hasResponseMessages
? 'Stopped during final response'
: 'Stopped reasoning';
// Log the cleaned reasoning to debug
console.info('Cleaned reasoning before save:', {
reasoningCount: cleanedFields.reasoning.length,
loadingCount: cleanedFields.reasoning.filter(
(r) => r && typeof r === 'object' && 'status' in r && r.status === 'loading'
).length,
lastReasoningMessage: cleanedFields.reasoning[cleanedFields.reasoning.length - 1],
});
// Ensure the reasoning array is properly serializable
const serializableReasoning = JSON.parse(JSON.stringify(cleanedFields.reasoning));
// Update the message in the database
await updateMessageFields(messageId, {
rawLlmMessages: cleanedFields.rawLlmMessages,
reasoning: serializableReasoning,
responseMessages: cleanedFields.responseMessages,
finalReasoningMessage: finalReasoningMessage,
isCompleted: true,
});
console.info(
`Successfully cleaned up message ${messageId} with finalReasoningMessage: ${finalReasoningMessage}`
);
} catch (error) {
console.error(`Failed to clean up message ${messageId}:`, error);
// Don't throw - we want to continue processing other messages
}
}

View File

@ -55,8 +55,8 @@ const app = new Hono()
const params = c.req.valid('param');
const user = c.get('busterUser');
const response = await cancelChatHandler(params.chat_id, user);
return c.json(response);
await cancelChatHandler(params.chat_id, user);
return c.json({ success: true, message: 'Chat cancelled successfully' });
})
.onError((e, c) => {
if (e instanceof ChatError) {

View File

@ -110,6 +110,7 @@ export async function updateMessageFields(
reasoning?: unknown;
rawLlmMessages?: unknown;
finalReasoningMessage?: string;
isCompleted?: boolean;
}
): Promise<{ success: boolean }> {
try {
@ -125,6 +126,7 @@ export async function updateMessageFields(
reasoning?: unknown;
rawLlmMessages?: unknown;
finalReasoningMessage?: string;
isCompleted?: boolean;
} = {
updatedAt: new Date().toISOString(),
};
@ -143,6 +145,10 @@ export async function updateMessageFields(
updateData.finalReasoningMessage = fields.finalReasoningMessage;
}
if ('isCompleted' in fields) {
updateData.isCompleted = fields.isCompleted;
}
await db
.update(messages)
.set(updateData)

View File

@ -105,6 +105,9 @@ importers:
'@buster/access-controls':
specifier: workspace:*
version: link:../../packages/access-controls
'@buster/ai':
specifier: workspace:*
version: link:../../packages/ai
'@buster/database':
specifier: workspace:*
version: link:../../packages/database
@ -132,6 +135,9 @@ importers:
'@trigger.dev/sdk':
specifier: 'catalog:'
version: 4.0.0-v4-beta.22(ai@4.3.16(react@18.3.1)(zod@3.25.75))(zod@3.25.75)
ai:
specifier: 'catalog:'
version: 4.3.16(react@18.3.1)(zod@3.25.75)
drizzle-orm:
specifier: 'catalog:'
version: 0.44.2(@opentelemetry/api@1.9.0)(@types/pg@8.15.4)(mysql2@3.14.1)(pg@8.16.3)(postgres@3.4.7)