This commit is contained in:
dal 2025-08-20 11:26:37 -06:00
parent 550f8f2257
commit 2461dc0a77
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
5 changed files with 147 additions and 54 deletions

View File

@ -7,7 +7,6 @@ import { AnalystAgentTaskInputSchema, type AnalystAgentTaskOutput } from './type
import {
getBraintrustMetadata,
getChatConversationHistory,
getChatDashboardFiles,
getMessageContext,
getOrganizationDataSource,
} from '@buster/database';
@ -278,15 +277,11 @@ export const analystAgentTask: ReturnType<
messageId: payload.message_id,
});
// Start loading data source and dashboard files as soon as we have the required IDs
// Start loading data source as soon as we have the required IDs
const dataSourcePromise = messageContextPromise.then((context) =>
getOrganizationDataSource({ organizationId: context.organizationId })
);
const dashboardFilesPromise = messageContextPromise.then((context) =>
getChatDashboardFiles({ chatId: context.chatId })
);
// Fetch user's datasets as soon as we have the userId
const datasetsPromise = messageContextPromise.then(async (context) => {
try {
@ -316,14 +311,12 @@ export const analystAgentTask: ReturnType<
messageContext,
conversationHistory,
dataSource,
dashboardFiles,
datasets,
braintrustMetadata,
] = await Promise.all([
messageContextPromise,
conversationHistoryPromise,
dataSourcePromise,
dashboardFilesPromise,
datasetsPromise,
braintrustMetadataPromise,
]);
@ -338,14 +331,6 @@ export const analystAgentTask: ReturnType<
organizationId: messageContext.organizationId,
dataSourceId: dataSource.dataSourceId,
dataSourceSyntax: dataSource.dataSourceSyntax,
dashboardFilesCount: dashboardFiles.length,
dashboardFiles: dashboardFiles.map((d) => ({
id: d.id,
name: d.name,
versionNumber: d.versionNumber,
metricIdsCount: d.metricIds.length,
metricIds: d.metricIds,
})),
datasetsCount: datasets.length,
datasets: datasets.map((d) => ({
id: d.id,
@ -358,7 +343,7 @@ export const analystAgentTask: ReturnType<
// Log performance after data loading
logPerformanceMetrics('post-data-load', payload.message_id, taskStartTime, resourceTracker);
// Task 4: Prepare workflow input with conversation history and dashboard files
// Task 4: Prepare workflow input with conversation history
// Convert conversation history to messages format expected by the workflow
const messages =
conversationHistory.length > 0
@ -384,8 +369,6 @@ export const analystAgentTask: ReturnType<
logger.log('Workflow input prepared', {
messageId: payload.message_id,
messagesCount: workflowInput.messages.length,
hasDashboardFiles: dashboardFiles.length > 0,
dashboardFilesCount: dashboardFiles.length,
totalPrepTimeMs: Date.now() - dataLoadStart,
});

View File

@ -88,16 +88,16 @@ describe('ReAskStrategy', () => {
toolCallType: 'function',
toolCallId: 'call123',
toolName: 'correctTool',
input: JSON.stringify(correctedToolCall.input),
input: correctedToolCall.input,
});
// Verify the tool input is properly formatted as JSON in the messages
// Verify the tool input is properly formatted as an object in the messages
const calls = mockGenerateText.mock.calls[0];
const messages = calls?.[0]?.messages;
const assistantMessage = messages?.find((m: any) => m.role === 'assistant');
const content = assistantMessage?.content?.[0];
if (content && typeof content === 'object' && 'input' in content) {
expect(content.input).toBe('{"param":"value"}');
expect(content.input).toEqual({ param: 'value' });
}
expect(mockGenerateText).toHaveBeenCalledWith(
@ -303,7 +303,7 @@ describe('ReAskStrategy', () => {
const assistantMessage = messages?.find((m: any) => m.role === 'assistant');
const content = assistantMessage?.content?.[0];
if (content && typeof content === 'object' && 'input' in content) {
expect(content.input).toBe('{"value":"plain text input"}');
expect(content.input).toEqual({ value: 'plain text input' });
}
});
@ -335,13 +335,13 @@ describe('ReAskStrategy', () => {
await strategy.repair(context);
// Verify the valid JSON string was left as-is
// Verify the valid JSON string was parsed to an object
const calls = mockGenerateText.mock.calls[0];
const messages = calls?.[0]?.messages;
const assistantMessage = messages?.find((m: any) => m.role === 'assistant');
const content = assistantMessage?.content?.[0];
if (content && typeof content === 'object' && 'input' in content) {
expect(content.input).toBe('{"already":"valid"}');
expect(content.input).toEqual({ already: 'valid' });
}
});
});

View File

@ -17,22 +17,22 @@ export class ReAskStrategy implements RepairStrategy {
const errorMessage = this.buildErrorMessage(context);
// Create the tool-result message with the error
// Ensure input is properly formatted
let toolInput = context.toolCall.input;
if (typeof toolInput === 'string') {
// Ensure input is properly formatted as an object
let toolInput: unknown;
if (typeof context.toolCall.input === 'string') {
try {
// Try to parse it if it's a JSON string
JSON.parse(toolInput);
toolInput = JSON.parse(context.toolCall.input);
} catch {
// If it's not valid JSON, wrap it in an object
toolInput = JSON.stringify({ value: toolInput });
toolInput = { value: context.toolCall.input };
}
} else if (toolInput && typeof toolInput === 'object') {
// If it's already an object, stringify it
toolInput = JSON.stringify(toolInput);
} else if (context.toolCall.input && typeof context.toolCall.input === 'object') {
// If it's already an object, use it as-is
toolInput = context.toolCall.input;
} else {
// Default to empty object
toolInput = '{}';
toolInput = {};
}
const healingMessages: ModelMessage[] = [
@ -96,8 +96,8 @@ export class ReAskStrategy implements RepairStrategy {
toolCallType: 'function' as const,
toolCallId: context.toolCall.toolCallId,
toolName: newToolCall.toolName,
input: JSON.stringify(newToolCall.input),
} as unknown as LanguageModelV2ToolCall;
input: newToolCall.input,
} as LanguageModelV2ToolCall;
}
console.warn('Re-ask strategy did not produce a valid tool call', {

View File

@ -44,9 +44,9 @@ describe('StructuredOutputStrategy', () => {
const { generateObject } = await import('ai');
const mockGenerateObject = vi.mocked(generateObject);
const repairedArgs = { field1: 'value1', field2: 123 };
const repairedInput = { field1: 'value1', field2: 123 };
mockGenerateObject.mockResolvedValueOnce({
object: repairedArgs,
object: repairedInput,
warnings: [],
usage: {},
} as any);
@ -56,7 +56,7 @@ describe('StructuredOutputStrategy', () => {
toolCallType: 'function',
toolCallId: 'call123',
toolName: 'testTool',
args: { field1: 'invalid', field2: 'not-a-number' },
input: { field1: 'invalid', field2: 'not-a-number' },
} as any,
tools: {
testTool: {
@ -84,7 +84,7 @@ describe('StructuredOutputStrategy', () => {
toolCallType: 'function',
toolCallId: 'call123',
toolName: 'testTool',
args: repairedArgs,
input: repairedInput,
});
const tool = context.tools.testTool as any;
@ -101,7 +101,7 @@ describe('StructuredOutputStrategy', () => {
toolCallType: 'function',
toolCallId: 'call123',
toolName: 'nonExistentTool',
args: {},
input: {},
} as any,
tools: {} as any,
error: new InvalidToolInputError({
@ -123,7 +123,7 @@ describe('StructuredOutputStrategy', () => {
toolCallType: 'function',
toolCallId: 'call123',
toolName: 'testTool',
args: {},
input: {},
} as any,
tools: {
testTool: {},
@ -152,7 +152,7 @@ describe('StructuredOutputStrategy', () => {
toolCallType: 'function',
toolCallId: 'call123',
toolName: 'testTool',
args: {},
input: {},
} as any,
tools: {
testTool: {
@ -172,5 +172,107 @@ describe('StructuredOutputStrategy', () => {
'Failed to repair tool call "testTool": Generation failed'
);
});
it('should handle string input that is valid JSON', async () => {
const { generateObject } = await import('ai');
const mockGenerateObject = vi.mocked(generateObject);
const repairedInput = { field1: 'value1', field2: 123 };
mockGenerateObject.mockResolvedValueOnce({
object: repairedInput,
warnings: [],
usage: {},
} as any);
const context: RepairContext = {
toolCall: {
toolCallType: 'function',
toolCallId: 'call123',
toolName: 'testTool',
input: '{"field1": "invalid", "field2": "not-a-number"}',
} as any,
tools: {
testTool: {
inputSchema: {
type: 'object',
properties: {
field1: { type: 'string' },
field2: { type: 'number' },
},
},
},
} as any,
error: new InvalidToolInputError({
toolName: 'testTool',
toolInput: 'invalid',
cause: new Error('validation failed'),
}),
messages: [],
system: '',
};
const result = await strategy.repair(context);
expect(result).toEqual({
toolCallType: 'function',
toolCallId: 'call123',
toolName: 'testTool',
input: repairedInput,
});
});
it('should handle string input that is not valid JSON', async () => {
const { generateObject } = await import('ai');
const mockGenerateObject = vi.mocked(generateObject);
const repairedInput = { value: 'parsed correctly' };
mockGenerateObject.mockResolvedValueOnce({
object: repairedInput,
warnings: [],
usage: {},
} as any);
const context: RepairContext = {
toolCall: {
toolCallType: 'function',
toolCallId: 'call123',
toolName: 'testTool',
input: 'plain text input',
} as any,
tools: {
testTool: {
inputSchema: {
type: 'object',
properties: {
value: { type: 'string' },
},
},
},
} as any,
error: new InvalidToolInputError({
toolName: 'testTool',
toolInput: 'invalid',
cause: new Error('validation failed'),
}),
messages: [],
system: '',
};
const result = await strategy.repair(context);
expect(result).toEqual({
toolCallType: 'function',
toolCallId: 'call123',
toolName: 'testTool',
input: repairedInput,
});
// Verify the prompt contains the wrapped input
expect(mockGenerateObject).toHaveBeenCalledWith(
expect.objectContaining({
prompt: expect.stringContaining('{"value":"plain text input"}'),
})
);
});
});
});

View File

@ -4,10 +4,6 @@ import { wrapTraced } from 'braintrust';
import { Sonnet4 } from '../../../llm';
import type { RepairContext, RepairStrategy } from '../types';
interface ToolCallWithArgs extends LanguageModelV2ToolCall {
args?: unknown;
}
export class StructuredOutputStrategy implements RepairStrategy {
canHandle(error: Error): boolean {
return error instanceof InvalidToolInputError;
@ -28,17 +24,29 @@ export class StructuredOutputStrategy implements RepairStrategy {
return null;
}
// Type assertion to access args property
const toolCallWithArgs = context.toolCall as ToolCallWithArgs;
// Parse input if it's a string, otherwise use as-is
const toolCallInput = context.toolCall.input;
let parsedInput: unknown;
if (typeof toolCallInput === 'string') {
try {
parsedInput = JSON.parse(toolCallInput);
} catch {
// If it's not valid JSON, wrap it in an object
parsedInput = { value: toolCallInput };
}
} else {
parsedInput = toolCallInput || {};
}
try {
const { object: repairedArgs } = await generateObject({
const { object: repairedInput } = await generateObject({
model: Sonnet4,
schema: tool.inputSchema,
prompt: [
`The model tried to call the tool "${context.toolCall.toolName}"`,
`with the following arguments:`,
JSON.stringify(toolCallWithArgs.args),
JSON.stringify(parsedInput),
`The tool accepts the following schema:`,
JSON.stringify(tool.inputSchema),
'Please fix the arguments.',
@ -47,11 +55,11 @@ export class StructuredOutputStrategy implements RepairStrategy {
console.info('Successfully repaired tool arguments', {
toolName: context.toolCall.toolName,
originalArgs: toolCallWithArgs.args,
repairedArgs,
originalInput: parsedInput,
repairedInput,
});
return { ...context.toolCall, args: repairedArgs } as LanguageModelV2ToolCall;
return { ...context.toolCall, input: repairedInput } as LanguageModelV2ToolCall;
} catch (error) {
console.error('Failed to repair tool arguments with structured output:', error);
console.error('Tool call that failed:', context.toolCall);