From 2461dc0a772ce4e7161b76d461c9aeb89668d487 Mon Sep 17 00:00:00 2001 From: dal Date: Wed, 20 Aug 2025 11:26:37 -0600 Subject: [PATCH] bugfixes --- .../analyst-agent-task/analyst-agent-task.ts | 21 +--- .../strategies/re-ask-strategy.test.ts | 12 +- .../strategies/re-ask-strategy.ts | 22 ++-- .../structured-output-strategy.test.ts | 116 ++++++++++++++++-- .../strategies/structured-output-strategy.ts | 30 +++-- 5 files changed, 147 insertions(+), 54 deletions(-) diff --git a/apps/trigger/src/tasks/analyst-agent-task/analyst-agent-task.ts b/apps/trigger/src/tasks/analyst-agent-task/analyst-agent-task.ts index f0fa44dc5..ac738bfa8 100644 --- a/apps/trigger/src/tasks/analyst-agent-task/analyst-agent-task.ts +++ b/apps/trigger/src/tasks/analyst-agent-task/analyst-agent-task.ts @@ -7,7 +7,6 @@ import { AnalystAgentTaskInputSchema, type AnalystAgentTaskOutput } from './type import { getBraintrustMetadata, getChatConversationHistory, - getChatDashboardFiles, getMessageContext, getOrganizationDataSource, } from '@buster/database'; @@ -278,15 +277,11 @@ export const analystAgentTask: ReturnType< messageId: payload.message_id, }); - // Start loading data source and dashboard files as soon as we have the required IDs + // Start loading data source as soon as we have the required IDs const dataSourcePromise = messageContextPromise.then((context) => getOrganizationDataSource({ organizationId: context.organizationId }) ); - const dashboardFilesPromise = messageContextPromise.then((context) => - getChatDashboardFiles({ chatId: context.chatId }) - ); - // Fetch user's datasets as soon as we have the userId const datasetsPromise = messageContextPromise.then(async (context) => { try { @@ -316,14 +311,12 @@ export const analystAgentTask: ReturnType< messageContext, conversationHistory, dataSource, - dashboardFiles, datasets, braintrustMetadata, ] = await Promise.all([ messageContextPromise, conversationHistoryPromise, dataSourcePromise, - dashboardFilesPromise, datasetsPromise, braintrustMetadataPromise, ]); @@ -338,14 +331,6 @@ export const analystAgentTask: ReturnType< organizationId: messageContext.organizationId, dataSourceId: dataSource.dataSourceId, dataSourceSyntax: dataSource.dataSourceSyntax, - dashboardFilesCount: dashboardFiles.length, - dashboardFiles: dashboardFiles.map((d) => ({ - id: d.id, - name: d.name, - versionNumber: d.versionNumber, - metricIdsCount: d.metricIds.length, - metricIds: d.metricIds, - })), datasetsCount: datasets.length, datasets: datasets.map((d) => ({ id: d.id, @@ -358,7 +343,7 @@ export const analystAgentTask: ReturnType< // Log performance after data loading logPerformanceMetrics('post-data-load', payload.message_id, taskStartTime, resourceTracker); - // Task 4: Prepare workflow input with conversation history and dashboard files + // Task 4: Prepare workflow input with conversation history // Convert conversation history to messages format expected by the workflow const messages = conversationHistory.length > 0 @@ -384,8 +369,6 @@ export const analystAgentTask: ReturnType< logger.log('Workflow input prepared', { messageId: payload.message_id, messagesCount: workflowInput.messages.length, - hasDashboardFiles: dashboardFiles.length > 0, - dashboardFilesCount: dashboardFiles.length, totalPrepTimeMs: Date.now() - dataLoadStart, }); diff --git a/packages/ai/src/utils/tool-call-repair/strategies/re-ask-strategy.test.ts b/packages/ai/src/utils/tool-call-repair/strategies/re-ask-strategy.test.ts index 2fb34b882..cab217b1c 100644 --- a/packages/ai/src/utils/tool-call-repair/strategies/re-ask-strategy.test.ts +++ b/packages/ai/src/utils/tool-call-repair/strategies/re-ask-strategy.test.ts @@ -88,16 +88,16 @@ describe('ReAskStrategy', () => { toolCallType: 'function', toolCallId: 'call123', toolName: 'correctTool', - input: JSON.stringify(correctedToolCall.input), + input: correctedToolCall.input, }); - // Verify the tool input is properly formatted as JSON in the messages + // Verify the tool input is properly formatted as an object in the messages const calls = mockGenerateText.mock.calls[0]; const messages = calls?.[0]?.messages; const assistantMessage = messages?.find((m: any) => m.role === 'assistant'); const content = assistantMessage?.content?.[0]; if (content && typeof content === 'object' && 'input' in content) { - expect(content.input).toBe('{"param":"value"}'); + expect(content.input).toEqual({ param: 'value' }); } expect(mockGenerateText).toHaveBeenCalledWith( @@ -303,7 +303,7 @@ describe('ReAskStrategy', () => { const assistantMessage = messages?.find((m: any) => m.role === 'assistant'); const content = assistantMessage?.content?.[0]; if (content && typeof content === 'object' && 'input' in content) { - expect(content.input).toBe('{"value":"plain text input"}'); + expect(content.input).toEqual({ value: 'plain text input' }); } }); @@ -335,13 +335,13 @@ describe('ReAskStrategy', () => { await strategy.repair(context); - // Verify the valid JSON string was left as-is + // Verify the valid JSON string was parsed to an object const calls = mockGenerateText.mock.calls[0]; const messages = calls?.[0]?.messages; const assistantMessage = messages?.find((m: any) => m.role === 'assistant'); const content = assistantMessage?.content?.[0]; if (content && typeof content === 'object' && 'input' in content) { - expect(content.input).toBe('{"already":"valid"}'); + expect(content.input).toEqual({ already: 'valid' }); } }); }); diff --git a/packages/ai/src/utils/tool-call-repair/strategies/re-ask-strategy.ts b/packages/ai/src/utils/tool-call-repair/strategies/re-ask-strategy.ts index ce70da1f4..750fcfe28 100644 --- a/packages/ai/src/utils/tool-call-repair/strategies/re-ask-strategy.ts +++ b/packages/ai/src/utils/tool-call-repair/strategies/re-ask-strategy.ts @@ -17,22 +17,22 @@ export class ReAskStrategy implements RepairStrategy { const errorMessage = this.buildErrorMessage(context); // Create the tool-result message with the error - // Ensure input is properly formatted - let toolInput = context.toolCall.input; - if (typeof toolInput === 'string') { + // Ensure input is properly formatted as an object + let toolInput: unknown; + if (typeof context.toolCall.input === 'string') { try { // Try to parse it if it's a JSON string - JSON.parse(toolInput); + toolInput = JSON.parse(context.toolCall.input); } catch { // If it's not valid JSON, wrap it in an object - toolInput = JSON.stringify({ value: toolInput }); + toolInput = { value: context.toolCall.input }; } - } else if (toolInput && typeof toolInput === 'object') { - // If it's already an object, stringify it - toolInput = JSON.stringify(toolInput); + } else if (context.toolCall.input && typeof context.toolCall.input === 'object') { + // If it's already an object, use it as-is + toolInput = context.toolCall.input; } else { // Default to empty object - toolInput = '{}'; + toolInput = {}; } const healingMessages: ModelMessage[] = [ @@ -96,8 +96,8 @@ export class ReAskStrategy implements RepairStrategy { toolCallType: 'function' as const, toolCallId: context.toolCall.toolCallId, toolName: newToolCall.toolName, - input: JSON.stringify(newToolCall.input), - } as unknown as LanguageModelV2ToolCall; + input: newToolCall.input, + } as LanguageModelV2ToolCall; } console.warn('Re-ask strategy did not produce a valid tool call', { diff --git a/packages/ai/src/utils/tool-call-repair/strategies/structured-output-strategy.test.ts b/packages/ai/src/utils/tool-call-repair/strategies/structured-output-strategy.test.ts index 41d248555..6a430edf2 100644 --- a/packages/ai/src/utils/tool-call-repair/strategies/structured-output-strategy.test.ts +++ b/packages/ai/src/utils/tool-call-repair/strategies/structured-output-strategy.test.ts @@ -44,9 +44,9 @@ describe('StructuredOutputStrategy', () => { const { generateObject } = await import('ai'); const mockGenerateObject = vi.mocked(generateObject); - const repairedArgs = { field1: 'value1', field2: 123 }; + const repairedInput = { field1: 'value1', field2: 123 }; mockGenerateObject.mockResolvedValueOnce({ - object: repairedArgs, + object: repairedInput, warnings: [], usage: {}, } as any); @@ -56,7 +56,7 @@ describe('StructuredOutputStrategy', () => { toolCallType: 'function', toolCallId: 'call123', toolName: 'testTool', - args: { field1: 'invalid', field2: 'not-a-number' }, + input: { field1: 'invalid', field2: 'not-a-number' }, } as any, tools: { testTool: { @@ -84,7 +84,7 @@ describe('StructuredOutputStrategy', () => { toolCallType: 'function', toolCallId: 'call123', toolName: 'testTool', - args: repairedArgs, + input: repairedInput, }); const tool = context.tools.testTool as any; @@ -101,7 +101,7 @@ describe('StructuredOutputStrategy', () => { toolCallType: 'function', toolCallId: 'call123', toolName: 'nonExistentTool', - args: {}, + input: {}, } as any, tools: {} as any, error: new InvalidToolInputError({ @@ -123,7 +123,7 @@ describe('StructuredOutputStrategy', () => { toolCallType: 'function', toolCallId: 'call123', toolName: 'testTool', - args: {}, + input: {}, } as any, tools: { testTool: {}, @@ -152,7 +152,7 @@ describe('StructuredOutputStrategy', () => { toolCallType: 'function', toolCallId: 'call123', toolName: 'testTool', - args: {}, + input: {}, } as any, tools: { testTool: { @@ -172,5 +172,107 @@ describe('StructuredOutputStrategy', () => { 'Failed to repair tool call "testTool": Generation failed' ); }); + + it('should handle string input that is valid JSON', async () => { + const { generateObject } = await import('ai'); + const mockGenerateObject = vi.mocked(generateObject); + + const repairedInput = { field1: 'value1', field2: 123 }; + mockGenerateObject.mockResolvedValueOnce({ + object: repairedInput, + warnings: [], + usage: {}, + } as any); + + const context: RepairContext = { + toolCall: { + toolCallType: 'function', + toolCallId: 'call123', + toolName: 'testTool', + input: '{"field1": "invalid", "field2": "not-a-number"}', + } as any, + tools: { + testTool: { + inputSchema: { + type: 'object', + properties: { + field1: { type: 'string' }, + field2: { type: 'number' }, + }, + }, + }, + } as any, + error: new InvalidToolInputError({ + toolName: 'testTool', + toolInput: 'invalid', + cause: new Error('validation failed'), + }), + messages: [], + system: '', + }; + + const result = await strategy.repair(context); + + expect(result).toEqual({ + toolCallType: 'function', + toolCallId: 'call123', + toolName: 'testTool', + input: repairedInput, + }); + }); + + it('should handle string input that is not valid JSON', async () => { + const { generateObject } = await import('ai'); + const mockGenerateObject = vi.mocked(generateObject); + + const repairedInput = { value: 'parsed correctly' }; + mockGenerateObject.mockResolvedValueOnce({ + object: repairedInput, + warnings: [], + usage: {}, + } as any); + + const context: RepairContext = { + toolCall: { + toolCallType: 'function', + toolCallId: 'call123', + toolName: 'testTool', + input: 'plain text input', + } as any, + tools: { + testTool: { + inputSchema: { + type: 'object', + properties: { + value: { type: 'string' }, + }, + }, + }, + } as any, + error: new InvalidToolInputError({ + toolName: 'testTool', + toolInput: 'invalid', + cause: new Error('validation failed'), + }), + messages: [], + system: '', + }; + + const result = await strategy.repair(context); + + expect(result).toEqual({ + toolCallType: 'function', + toolCallId: 'call123', + toolName: 'testTool', + input: repairedInput, + }); + + // Verify the prompt contains the wrapped input + expect(mockGenerateObject).toHaveBeenCalledWith( + expect.objectContaining({ + prompt: expect.stringContaining('{"value":"plain text input"}'), + }) + ); + }); }); }); diff --git a/packages/ai/src/utils/tool-call-repair/strategies/structured-output-strategy.ts b/packages/ai/src/utils/tool-call-repair/strategies/structured-output-strategy.ts index 4b4f6649e..8a4b1cb76 100644 --- a/packages/ai/src/utils/tool-call-repair/strategies/structured-output-strategy.ts +++ b/packages/ai/src/utils/tool-call-repair/strategies/structured-output-strategy.ts @@ -4,10 +4,6 @@ import { wrapTraced } from 'braintrust'; import { Sonnet4 } from '../../../llm'; import type { RepairContext, RepairStrategy } from '../types'; -interface ToolCallWithArgs extends LanguageModelV2ToolCall { - args?: unknown; -} - export class StructuredOutputStrategy implements RepairStrategy { canHandle(error: Error): boolean { return error instanceof InvalidToolInputError; @@ -28,17 +24,29 @@ export class StructuredOutputStrategy implements RepairStrategy { return null; } - // Type assertion to access args property - const toolCallWithArgs = context.toolCall as ToolCallWithArgs; + // Parse input if it's a string, otherwise use as-is + const toolCallInput = context.toolCall.input; + let parsedInput: unknown; + + if (typeof toolCallInput === 'string') { + try { + parsedInput = JSON.parse(toolCallInput); + } catch { + // If it's not valid JSON, wrap it in an object + parsedInput = { value: toolCallInput }; + } + } else { + parsedInput = toolCallInput || {}; + } try { - const { object: repairedArgs } = await generateObject({ + const { object: repairedInput } = await generateObject({ model: Sonnet4, schema: tool.inputSchema, prompt: [ `The model tried to call the tool "${context.toolCall.toolName}"`, `with the following arguments:`, - JSON.stringify(toolCallWithArgs.args), + JSON.stringify(parsedInput), `The tool accepts the following schema:`, JSON.stringify(tool.inputSchema), 'Please fix the arguments.', @@ -47,11 +55,11 @@ export class StructuredOutputStrategy implements RepairStrategy { console.info('Successfully repaired tool arguments', { toolName: context.toolCall.toolName, - originalArgs: toolCallWithArgs.args, - repairedArgs, + originalInput: parsedInput, + repairedInput, }); - return { ...context.toolCall, args: repairedArgs } as LanguageModelV2ToolCall; + return { ...context.toolCall, input: repairedInput } as LanguageModelV2ToolCall; } catch (error) { console.error('Failed to repair tool arguments with structured output:', error); console.error('Tool call that failed:', context.toolCall);