mirror of https://github.com/buster-so/buster.git
fix on tool use ids and repair calls
This commit is contained in:
parent
39eaac2fdb
commit
4338cfd5d3
|
@ -0,0 +1,337 @@
|
||||||
|
import type { ModelMessage } from 'ai';
|
||||||
|
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||||
|
import { getChatConversationHistory } from './chatConversationHistory';
|
||||||
|
|
||||||
|
// Mock the database connection and queries
|
||||||
|
vi.mock('../../connection', () => ({
|
||||||
|
db: {
|
||||||
|
select: vi.fn().mockReturnThis(),
|
||||||
|
from: vi.fn().mockReturnThis(),
|
||||||
|
where: vi.fn().mockReturnThis(),
|
||||||
|
limit: vi.fn().mockReturnThis(),
|
||||||
|
orderBy: vi.fn(),
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
|
describe('getChatConversationHistory - Orphaned Tool Call Cleanup', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should remove orphaned tool calls (tool calls without matching results)', async () => {
|
||||||
|
// Mock database to return messages with orphaned tool calls
|
||||||
|
const mockMessages: ModelMessage[] = [
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
content: 'test question',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'assistant',
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'tool-call',
|
||||||
|
toolCallId: 'orphaned-call-123',
|
||||||
|
toolName: 'sequentialThinking',
|
||||||
|
input: { thought: 'test thought', nextThoughtNeeded: false },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
// No tool result for orphaned-call-123
|
||||||
|
{
|
||||||
|
role: 'assistant',
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'tool-call',
|
||||||
|
toolCallId: 'valid-call-456',
|
||||||
|
toolName: 'executeSql',
|
||||||
|
input: { statements: ['SELECT 1'] },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'tool',
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'tool-result',
|
||||||
|
toolCallId: 'valid-call-456',
|
||||||
|
toolName: 'executeSql',
|
||||||
|
output: { type: 'json', value: '{"results":[]}' },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
// Mock the database query functions
|
||||||
|
const { db } = await import('../../connection');
|
||||||
|
vi.mocked(db.orderBy).mockResolvedValue([
|
||||||
|
{
|
||||||
|
id: 'msg-1',
|
||||||
|
rawLlmMessages: mockMessages,
|
||||||
|
createdAt: '2025-01-01T00:00:00Z',
|
||||||
|
isCompleted: true,
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
vi.mocked(db.limit).mockResolvedValue([
|
||||||
|
{
|
||||||
|
chatId: 'chat-123',
|
||||||
|
createdAt: '2025-01-01T00:00:00Z',
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
const result = await getChatConversationHistory({
|
||||||
|
messageId: 'test-message-id',
|
||||||
|
});
|
||||||
|
|
||||||
|
// Should have removed the orphaned tool call but kept the valid one
|
||||||
|
expect(result).toHaveLength(3); // user, assistant (with valid tool call), tool result
|
||||||
|
|
||||||
|
// Find the assistant message
|
||||||
|
const assistantMessages = result.filter((m) => m.role === 'assistant');
|
||||||
|
expect(assistantMessages).toHaveLength(1);
|
||||||
|
|
||||||
|
// The remaining assistant message should only have the valid tool call
|
||||||
|
const assistantContent = assistantMessages[0]?.content;
|
||||||
|
expect(Array.isArray(assistantContent)).toBe(true);
|
||||||
|
if (Array.isArray(assistantContent)) {
|
||||||
|
expect(assistantContent).toHaveLength(1);
|
||||||
|
expect(assistantContent[0]).toMatchObject({
|
||||||
|
type: 'tool-call',
|
||||||
|
toolCallId: 'valid-call-456',
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should keep assistant messages with valid tool calls', async () => {
|
||||||
|
const mockMessages: ModelMessage[] = [
|
||||||
|
{
|
||||||
|
role: 'assistant',
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'tool-call',
|
||||||
|
toolCallId: 'call-123',
|
||||||
|
toolName: 'testTool',
|
||||||
|
input: {},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'tool',
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'tool-result',
|
||||||
|
toolCallId: 'call-123',
|
||||||
|
toolName: 'testTool',
|
||||||
|
output: { type: 'text', value: 'result' },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
const { db } = await import('../../connection');
|
||||||
|
vi.mocked(db.orderBy).mockResolvedValue([
|
||||||
|
{
|
||||||
|
id: 'msg-1',
|
||||||
|
rawLlmMessages: mockMessages,
|
||||||
|
createdAt: '2025-01-01T00:00:00Z',
|
||||||
|
isCompleted: true,
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
vi.mocked(db.limit).mockResolvedValue([
|
||||||
|
{
|
||||||
|
chatId: 'chat-123',
|
||||||
|
createdAt: '2025-01-01T00:00:00Z',
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
const result = await getChatConversationHistory({
|
||||||
|
messageId: 'test-message-id',
|
||||||
|
});
|
||||||
|
|
||||||
|
// Should keep both messages (tool call and result)
|
||||||
|
expect(result).toHaveLength(2);
|
||||||
|
expect(result[0]?.role).toBe('assistant');
|
||||||
|
expect(result[1]?.role).toBe('tool');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should keep assistant messages that have at least one valid tool call (even if some are orphaned)', async () => {
|
||||||
|
const mockMessages: ModelMessage[] = [
|
||||||
|
{
|
||||||
|
role: 'assistant',
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
text: 'Let me analyze this',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
type: 'tool-call',
|
||||||
|
toolCallId: 'orphaned-123',
|
||||||
|
toolName: 'orphanedTool',
|
||||||
|
input: {},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
type: 'tool-call',
|
||||||
|
toolCallId: 'valid-456',
|
||||||
|
toolName: 'validTool',
|
||||||
|
input: {},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'tool',
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'tool-result',
|
||||||
|
toolCallId: 'valid-456',
|
||||||
|
toolName: 'validTool',
|
||||||
|
output: { type: 'text', value: 'success' },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
const { db } = await import('../../connection');
|
||||||
|
vi.mocked(db.orderBy).mockResolvedValue([
|
||||||
|
{
|
||||||
|
id: 'msg-1',
|
||||||
|
rawLlmMessages: mockMessages,
|
||||||
|
createdAt: '2025-01-01T00:00:00Z',
|
||||||
|
isCompleted: true,
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
vi.mocked(db.limit).mockResolvedValue([
|
||||||
|
{
|
||||||
|
chatId: 'chat-123',
|
||||||
|
createdAt: '2025-01-01T00:00:00Z',
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
const result = await getChatConversationHistory({
|
||||||
|
messageId: 'test-message-id',
|
||||||
|
});
|
||||||
|
|
||||||
|
// Should keep the assistant message because it has at least one valid tool call
|
||||||
|
// Note: We keep the entire message including the orphaned tool call
|
||||||
|
expect(result).toHaveLength(2);
|
||||||
|
|
||||||
|
const assistantMessage = result.find((m) => m.role === 'assistant');
|
||||||
|
expect(assistantMessage).toBeDefined();
|
||||||
|
|
||||||
|
const content = assistantMessage?.content;
|
||||||
|
expect(Array.isArray(content)).toBe(true);
|
||||||
|
if (Array.isArray(content)) {
|
||||||
|
// Should have all content including the orphaned tool call (we don't modify the message)
|
||||||
|
expect(content).toHaveLength(3);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should remove assistant messages that only contain orphaned tool calls', async () => {
|
||||||
|
const mockMessages: ModelMessage[] = [
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
content: 'test',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'assistant',
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'tool-call',
|
||||||
|
toolCallId: 'orphaned-only',
|
||||||
|
toolName: 'orphanedTool',
|
||||||
|
input: {},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
// No tool result for orphaned-only
|
||||||
|
{
|
||||||
|
role: 'assistant',
|
||||||
|
content: 'This is a text response',
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
const { db } = await import('../../connection');
|
||||||
|
vi.mocked(db.orderBy).mockResolvedValue([
|
||||||
|
{
|
||||||
|
id: 'msg-1',
|
||||||
|
rawLlmMessages: mockMessages,
|
||||||
|
createdAt: '2025-01-01T00:00:00Z',
|
||||||
|
isCompleted: true,
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
vi.mocked(db.limit).mockResolvedValue([
|
||||||
|
{
|
||||||
|
chatId: 'chat-123',
|
||||||
|
createdAt: '2025-01-01T00:00:00Z',
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
const result = await getChatConversationHistory({
|
||||||
|
messageId: 'test-message-id',
|
||||||
|
});
|
||||||
|
|
||||||
|
// Should have removed the assistant message with only orphaned tool call
|
||||||
|
expect(result).toHaveLength(2); // user + assistant text response
|
||||||
|
expect(result[0]?.role).toBe('user');
|
||||||
|
expect(result[1]?.role).toBe('assistant');
|
||||||
|
expect(result[1]?.content).toBe('This is a text response');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle empty message arrays', async () => {
|
||||||
|
const { db } = await import('../../connection');
|
||||||
|
vi.mocked(db.orderBy).mockResolvedValue([
|
||||||
|
{
|
||||||
|
id: 'msg-1',
|
||||||
|
rawLlmMessages: [],
|
||||||
|
createdAt: '2025-01-01T00:00:00Z',
|
||||||
|
isCompleted: true,
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
vi.mocked(db.limit).mockResolvedValue([
|
||||||
|
{
|
||||||
|
chatId: 'chat-123',
|
||||||
|
createdAt: '2025-01-01T00:00:00Z',
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
const result = await getChatConversationHistory({
|
||||||
|
messageId: 'test-message-id',
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toEqual([]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle messages with no tool calls', async () => {
|
||||||
|
const mockMessages: ModelMessage[] = [
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
content: 'Hello',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'assistant',
|
||||||
|
content: 'Hi there!',
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
const { db } = await import('../../connection');
|
||||||
|
vi.mocked(db.orderBy).mockResolvedValue([
|
||||||
|
{
|
||||||
|
id: 'msg-1',
|
||||||
|
rawLlmMessages: mockMessages,
|
||||||
|
createdAt: '2025-01-01T00:00:00Z',
|
||||||
|
isCompleted: true,
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
vi.mocked(db.limit).mockResolvedValue([
|
||||||
|
{
|
||||||
|
chatId: 'chat-123',
|
||||||
|
createdAt: '2025-01-01T00:00:00Z',
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
const result = await getChatConversationHistory({
|
||||||
|
messageId: 'test-message-id',
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toHaveLength(2);
|
||||||
|
expect(result).toEqual(mockMessages);
|
||||||
|
});
|
||||||
|
});
|
|
@ -324,6 +324,80 @@ export const ChatConversationHistoryOutputSchema = z.array(z.custom<ModelMessage
|
||||||
export type ChatConversationHistoryInput = z.infer<typeof ChatConversationHistoryInputSchema>;
|
export type ChatConversationHistoryInput = z.infer<typeof ChatConversationHistoryInputSchema>;
|
||||||
export type ChatConversationHistoryOutput = z.infer<typeof ChatConversationHistoryOutputSchema>;
|
export type ChatConversationHistoryOutput = z.infer<typeof ChatConversationHistoryOutputSchema>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Removes orphaned assistant messages with tool calls that have no matching tool results
|
||||||
|
* An orphaned message is an assistant message where ALL its tool calls lack corresponding tool results
|
||||||
|
*/
|
||||||
|
function removeOrphanedToolCalls(messages: ModelMessage[]): ModelMessage[] {
|
||||||
|
// Build a Set of all tool call IDs that have results (from tool role messages)
|
||||||
|
const toolCallIdsWithResults = new Set<string>();
|
||||||
|
|
||||||
|
for (const message of messages) {
|
||||||
|
if (message.role === 'tool' && Array.isArray(message.content)) {
|
||||||
|
for (const part of message.content) {
|
||||||
|
if (
|
||||||
|
typeof part === 'object' &&
|
||||||
|
part !== null &&
|
||||||
|
'type' in part &&
|
||||||
|
part.type === 'tool-result' &&
|
||||||
|
'toolCallId' in part &&
|
||||||
|
typeof part.toolCallId === 'string'
|
||||||
|
) {
|
||||||
|
toolCallIdsWithResults.add(part.toolCallId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter out assistant messages where ALL tool calls are orphaned
|
||||||
|
const filteredMessages: ModelMessage[] = [];
|
||||||
|
|
||||||
|
for (const message of messages) {
|
||||||
|
// Only check assistant messages with array content
|
||||||
|
if (message.role === 'assistant' && Array.isArray(message.content)) {
|
||||||
|
// Extract tool call IDs from this message
|
||||||
|
const toolCallIds: string[] = [];
|
||||||
|
|
||||||
|
for (const part of message.content) {
|
||||||
|
if (
|
||||||
|
typeof part === 'object' &&
|
||||||
|
part !== null &&
|
||||||
|
'type' in part &&
|
||||||
|
part.type === 'tool-call' &&
|
||||||
|
'toolCallId' in part &&
|
||||||
|
typeof part.toolCallId === 'string'
|
||||||
|
) {
|
||||||
|
toolCallIds.push(part.toolCallId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this message has no tool calls, keep it
|
||||||
|
if (toolCallIds.length === 0) {
|
||||||
|
filteredMessages.push(message);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if ANY of the tool calls have results
|
||||||
|
const hasAnyResults = toolCallIds.some((id) => toolCallIdsWithResults.has(id));
|
||||||
|
|
||||||
|
if (hasAnyResults) {
|
||||||
|
// At least one tool call has a result, keep the message
|
||||||
|
filteredMessages.push(message);
|
||||||
|
} else {
|
||||||
|
// ALL tool calls are orphaned, skip this message entirely
|
||||||
|
console.warn('[chatConversationHistory] Removing orphaned assistant message:', {
|
||||||
|
orphanedToolCallIds: toolCallIds,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Keep all non-assistant messages (including tool role messages)
|
||||||
|
filteredMessages.push(message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return filteredMessages;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get conversation history for a chat up to and including a specific message
|
* Get conversation history for a chat up to and including a specific message
|
||||||
* Finds the chat from the given messageId, then merges and deduplicates all rawLlmMessages
|
* Finds the chat from the given messageId, then merges and deduplicates all rawLlmMessages
|
||||||
|
@ -389,9 +463,12 @@ export async function getChatConversationHistory(
|
||||||
// Since we're merging from multiple messages, we should preserve the order they appear
|
// Since we're merging from multiple messages, we should preserve the order they appear
|
||||||
const deduplicatedMessages = Array.from(uniqueMessagesMap.values());
|
const deduplicatedMessages = Array.from(uniqueMessagesMap.values());
|
||||||
|
|
||||||
|
// Remove orphaned tool calls (tool calls without matching tool results)
|
||||||
|
const cleanedMessages = removeOrphanedToolCalls(deduplicatedMessages);
|
||||||
|
|
||||||
// Validate output
|
// Validate output
|
||||||
try {
|
try {
|
||||||
return ChatConversationHistoryOutputSchema.parse(deduplicatedMessages);
|
return ChatConversationHistoryOutputSchema.parse(cleanedMessages);
|
||||||
} catch (validationError) {
|
} catch (validationError) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
`Output validation failed: ${validationError instanceof Error ? validationError.message : 'Invalid output format'}`
|
`Output validation failed: ${validationError instanceof Error ? validationError.message : 'Invalid output format'}`
|
||||||
|
|
Loading…
Reference in New Issue