mirror of https://github.com/buster-so/buster.git
bugfixes
This commit is contained in:
parent
550f8f2257
commit
2461dc0a77
|
@ -7,7 +7,6 @@ import { AnalystAgentTaskInputSchema, type AnalystAgentTaskOutput } from './type
|
||||||
import {
|
import {
|
||||||
getBraintrustMetadata,
|
getBraintrustMetadata,
|
||||||
getChatConversationHistory,
|
getChatConversationHistory,
|
||||||
getChatDashboardFiles,
|
|
||||||
getMessageContext,
|
getMessageContext,
|
||||||
getOrganizationDataSource,
|
getOrganizationDataSource,
|
||||||
} from '@buster/database';
|
} from '@buster/database';
|
||||||
|
@ -278,15 +277,11 @@ export const analystAgentTask: ReturnType<
|
||||||
messageId: payload.message_id,
|
messageId: payload.message_id,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Start loading data source and dashboard files as soon as we have the required IDs
|
// Start loading data source as soon as we have the required IDs
|
||||||
const dataSourcePromise = messageContextPromise.then((context) =>
|
const dataSourcePromise = messageContextPromise.then((context) =>
|
||||||
getOrganizationDataSource({ organizationId: context.organizationId })
|
getOrganizationDataSource({ organizationId: context.organizationId })
|
||||||
);
|
);
|
||||||
|
|
||||||
const dashboardFilesPromise = messageContextPromise.then((context) =>
|
|
||||||
getChatDashboardFiles({ chatId: context.chatId })
|
|
||||||
);
|
|
||||||
|
|
||||||
// Fetch user's datasets as soon as we have the userId
|
// Fetch user's datasets as soon as we have the userId
|
||||||
const datasetsPromise = messageContextPromise.then(async (context) => {
|
const datasetsPromise = messageContextPromise.then(async (context) => {
|
||||||
try {
|
try {
|
||||||
|
@ -316,14 +311,12 @@ export const analystAgentTask: ReturnType<
|
||||||
messageContext,
|
messageContext,
|
||||||
conversationHistory,
|
conversationHistory,
|
||||||
dataSource,
|
dataSource,
|
||||||
dashboardFiles,
|
|
||||||
datasets,
|
datasets,
|
||||||
braintrustMetadata,
|
braintrustMetadata,
|
||||||
] = await Promise.all([
|
] = await Promise.all([
|
||||||
messageContextPromise,
|
messageContextPromise,
|
||||||
conversationHistoryPromise,
|
conversationHistoryPromise,
|
||||||
dataSourcePromise,
|
dataSourcePromise,
|
||||||
dashboardFilesPromise,
|
|
||||||
datasetsPromise,
|
datasetsPromise,
|
||||||
braintrustMetadataPromise,
|
braintrustMetadataPromise,
|
||||||
]);
|
]);
|
||||||
|
@ -338,14 +331,6 @@ export const analystAgentTask: ReturnType<
|
||||||
organizationId: messageContext.organizationId,
|
organizationId: messageContext.organizationId,
|
||||||
dataSourceId: dataSource.dataSourceId,
|
dataSourceId: dataSource.dataSourceId,
|
||||||
dataSourceSyntax: dataSource.dataSourceSyntax,
|
dataSourceSyntax: dataSource.dataSourceSyntax,
|
||||||
dashboardFilesCount: dashboardFiles.length,
|
|
||||||
dashboardFiles: dashboardFiles.map((d) => ({
|
|
||||||
id: d.id,
|
|
||||||
name: d.name,
|
|
||||||
versionNumber: d.versionNumber,
|
|
||||||
metricIdsCount: d.metricIds.length,
|
|
||||||
metricIds: d.metricIds,
|
|
||||||
})),
|
|
||||||
datasetsCount: datasets.length,
|
datasetsCount: datasets.length,
|
||||||
datasets: datasets.map((d) => ({
|
datasets: datasets.map((d) => ({
|
||||||
id: d.id,
|
id: d.id,
|
||||||
|
@ -358,7 +343,7 @@ export const analystAgentTask: ReturnType<
|
||||||
// Log performance after data loading
|
// Log performance after data loading
|
||||||
logPerformanceMetrics('post-data-load', payload.message_id, taskStartTime, resourceTracker);
|
logPerformanceMetrics('post-data-load', payload.message_id, taskStartTime, resourceTracker);
|
||||||
|
|
||||||
// Task 4: Prepare workflow input with conversation history and dashboard files
|
// Task 4: Prepare workflow input with conversation history
|
||||||
// Convert conversation history to messages format expected by the workflow
|
// Convert conversation history to messages format expected by the workflow
|
||||||
const messages =
|
const messages =
|
||||||
conversationHistory.length > 0
|
conversationHistory.length > 0
|
||||||
|
@ -384,8 +369,6 @@ export const analystAgentTask: ReturnType<
|
||||||
logger.log('Workflow input prepared', {
|
logger.log('Workflow input prepared', {
|
||||||
messageId: payload.message_id,
|
messageId: payload.message_id,
|
||||||
messagesCount: workflowInput.messages.length,
|
messagesCount: workflowInput.messages.length,
|
||||||
hasDashboardFiles: dashboardFiles.length > 0,
|
|
||||||
dashboardFilesCount: dashboardFiles.length,
|
|
||||||
totalPrepTimeMs: Date.now() - dataLoadStart,
|
totalPrepTimeMs: Date.now() - dataLoadStart,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -88,16 +88,16 @@ describe('ReAskStrategy', () => {
|
||||||
toolCallType: 'function',
|
toolCallType: 'function',
|
||||||
toolCallId: 'call123',
|
toolCallId: 'call123',
|
||||||
toolName: 'correctTool',
|
toolName: 'correctTool',
|
||||||
input: JSON.stringify(correctedToolCall.input),
|
input: correctedToolCall.input,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Verify the tool input is properly formatted as JSON in the messages
|
// Verify the tool input is properly formatted as an object in the messages
|
||||||
const calls = mockGenerateText.mock.calls[0];
|
const calls = mockGenerateText.mock.calls[0];
|
||||||
const messages = calls?.[0]?.messages;
|
const messages = calls?.[0]?.messages;
|
||||||
const assistantMessage = messages?.find((m: any) => m.role === 'assistant');
|
const assistantMessage = messages?.find((m: any) => m.role === 'assistant');
|
||||||
const content = assistantMessage?.content?.[0];
|
const content = assistantMessage?.content?.[0];
|
||||||
if (content && typeof content === 'object' && 'input' in content) {
|
if (content && typeof content === 'object' && 'input' in content) {
|
||||||
expect(content.input).toBe('{"param":"value"}');
|
expect(content.input).toEqual({ param: 'value' });
|
||||||
}
|
}
|
||||||
|
|
||||||
expect(mockGenerateText).toHaveBeenCalledWith(
|
expect(mockGenerateText).toHaveBeenCalledWith(
|
||||||
|
@ -303,7 +303,7 @@ describe('ReAskStrategy', () => {
|
||||||
const assistantMessage = messages?.find((m: any) => m.role === 'assistant');
|
const assistantMessage = messages?.find((m: any) => m.role === 'assistant');
|
||||||
const content = assistantMessage?.content?.[0];
|
const content = assistantMessage?.content?.[0];
|
||||||
if (content && typeof content === 'object' && 'input' in content) {
|
if (content && typeof content === 'object' && 'input' in content) {
|
||||||
expect(content.input).toBe('{"value":"plain text input"}');
|
expect(content.input).toEqual({ value: 'plain text input' });
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -335,13 +335,13 @@ describe('ReAskStrategy', () => {
|
||||||
|
|
||||||
await strategy.repair(context);
|
await strategy.repair(context);
|
||||||
|
|
||||||
// Verify the valid JSON string was left as-is
|
// Verify the valid JSON string was parsed to an object
|
||||||
const calls = mockGenerateText.mock.calls[0];
|
const calls = mockGenerateText.mock.calls[0];
|
||||||
const messages = calls?.[0]?.messages;
|
const messages = calls?.[0]?.messages;
|
||||||
const assistantMessage = messages?.find((m: any) => m.role === 'assistant');
|
const assistantMessage = messages?.find((m: any) => m.role === 'assistant');
|
||||||
const content = assistantMessage?.content?.[0];
|
const content = assistantMessage?.content?.[0];
|
||||||
if (content && typeof content === 'object' && 'input' in content) {
|
if (content && typeof content === 'object' && 'input' in content) {
|
||||||
expect(content.input).toBe('{"already":"valid"}');
|
expect(content.input).toEqual({ already: 'valid' });
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -17,22 +17,22 @@ export class ReAskStrategy implements RepairStrategy {
|
||||||
const errorMessage = this.buildErrorMessage(context);
|
const errorMessage = this.buildErrorMessage(context);
|
||||||
|
|
||||||
// Create the tool-result message with the error
|
// Create the tool-result message with the error
|
||||||
// Ensure input is properly formatted
|
// Ensure input is properly formatted as an object
|
||||||
let toolInput = context.toolCall.input;
|
let toolInput: unknown;
|
||||||
if (typeof toolInput === 'string') {
|
if (typeof context.toolCall.input === 'string') {
|
||||||
try {
|
try {
|
||||||
// Try to parse it if it's a JSON string
|
// Try to parse it if it's a JSON string
|
||||||
JSON.parse(toolInput);
|
toolInput = JSON.parse(context.toolCall.input);
|
||||||
} catch {
|
} catch {
|
||||||
// If it's not valid JSON, wrap it in an object
|
// If it's not valid JSON, wrap it in an object
|
||||||
toolInput = JSON.stringify({ value: toolInput });
|
toolInput = { value: context.toolCall.input };
|
||||||
}
|
}
|
||||||
} else if (toolInput && typeof toolInput === 'object') {
|
} else if (context.toolCall.input && typeof context.toolCall.input === 'object') {
|
||||||
// If it's already an object, stringify it
|
// If it's already an object, use it as-is
|
||||||
toolInput = JSON.stringify(toolInput);
|
toolInput = context.toolCall.input;
|
||||||
} else {
|
} else {
|
||||||
// Default to empty object
|
// Default to empty object
|
||||||
toolInput = '{}';
|
toolInput = {};
|
||||||
}
|
}
|
||||||
|
|
||||||
const healingMessages: ModelMessage[] = [
|
const healingMessages: ModelMessage[] = [
|
||||||
|
@ -96,8 +96,8 @@ export class ReAskStrategy implements RepairStrategy {
|
||||||
toolCallType: 'function' as const,
|
toolCallType: 'function' as const,
|
||||||
toolCallId: context.toolCall.toolCallId,
|
toolCallId: context.toolCall.toolCallId,
|
||||||
toolName: newToolCall.toolName,
|
toolName: newToolCall.toolName,
|
||||||
input: JSON.stringify(newToolCall.input),
|
input: newToolCall.input,
|
||||||
} as unknown as LanguageModelV2ToolCall;
|
} as LanguageModelV2ToolCall;
|
||||||
}
|
}
|
||||||
|
|
||||||
console.warn('Re-ask strategy did not produce a valid tool call', {
|
console.warn('Re-ask strategy did not produce a valid tool call', {
|
||||||
|
|
|
@ -44,9 +44,9 @@ describe('StructuredOutputStrategy', () => {
|
||||||
const { generateObject } = await import('ai');
|
const { generateObject } = await import('ai');
|
||||||
const mockGenerateObject = vi.mocked(generateObject);
|
const mockGenerateObject = vi.mocked(generateObject);
|
||||||
|
|
||||||
const repairedArgs = { field1: 'value1', field2: 123 };
|
const repairedInput = { field1: 'value1', field2: 123 };
|
||||||
mockGenerateObject.mockResolvedValueOnce({
|
mockGenerateObject.mockResolvedValueOnce({
|
||||||
object: repairedArgs,
|
object: repairedInput,
|
||||||
warnings: [],
|
warnings: [],
|
||||||
usage: {},
|
usage: {},
|
||||||
} as any);
|
} as any);
|
||||||
|
@ -56,7 +56,7 @@ describe('StructuredOutputStrategy', () => {
|
||||||
toolCallType: 'function',
|
toolCallType: 'function',
|
||||||
toolCallId: 'call123',
|
toolCallId: 'call123',
|
||||||
toolName: 'testTool',
|
toolName: 'testTool',
|
||||||
args: { field1: 'invalid', field2: 'not-a-number' },
|
input: { field1: 'invalid', field2: 'not-a-number' },
|
||||||
} as any,
|
} as any,
|
||||||
tools: {
|
tools: {
|
||||||
testTool: {
|
testTool: {
|
||||||
|
@ -84,7 +84,7 @@ describe('StructuredOutputStrategy', () => {
|
||||||
toolCallType: 'function',
|
toolCallType: 'function',
|
||||||
toolCallId: 'call123',
|
toolCallId: 'call123',
|
||||||
toolName: 'testTool',
|
toolName: 'testTool',
|
||||||
args: repairedArgs,
|
input: repairedInput,
|
||||||
});
|
});
|
||||||
|
|
||||||
const tool = context.tools.testTool as any;
|
const tool = context.tools.testTool as any;
|
||||||
|
@ -101,7 +101,7 @@ describe('StructuredOutputStrategy', () => {
|
||||||
toolCallType: 'function',
|
toolCallType: 'function',
|
||||||
toolCallId: 'call123',
|
toolCallId: 'call123',
|
||||||
toolName: 'nonExistentTool',
|
toolName: 'nonExistentTool',
|
||||||
args: {},
|
input: {},
|
||||||
} as any,
|
} as any,
|
||||||
tools: {} as any,
|
tools: {} as any,
|
||||||
error: new InvalidToolInputError({
|
error: new InvalidToolInputError({
|
||||||
|
@ -123,7 +123,7 @@ describe('StructuredOutputStrategy', () => {
|
||||||
toolCallType: 'function',
|
toolCallType: 'function',
|
||||||
toolCallId: 'call123',
|
toolCallId: 'call123',
|
||||||
toolName: 'testTool',
|
toolName: 'testTool',
|
||||||
args: {},
|
input: {},
|
||||||
} as any,
|
} as any,
|
||||||
tools: {
|
tools: {
|
||||||
testTool: {},
|
testTool: {},
|
||||||
|
@ -152,7 +152,7 @@ describe('StructuredOutputStrategy', () => {
|
||||||
toolCallType: 'function',
|
toolCallType: 'function',
|
||||||
toolCallId: 'call123',
|
toolCallId: 'call123',
|
||||||
toolName: 'testTool',
|
toolName: 'testTool',
|
||||||
args: {},
|
input: {},
|
||||||
} as any,
|
} as any,
|
||||||
tools: {
|
tools: {
|
||||||
testTool: {
|
testTool: {
|
||||||
|
@ -172,5 +172,107 @@ describe('StructuredOutputStrategy', () => {
|
||||||
'Failed to repair tool call "testTool": Generation failed'
|
'Failed to repair tool call "testTool": Generation failed'
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should handle string input that is valid JSON', async () => {
|
||||||
|
const { generateObject } = await import('ai');
|
||||||
|
const mockGenerateObject = vi.mocked(generateObject);
|
||||||
|
|
||||||
|
const repairedInput = { field1: 'value1', field2: 123 };
|
||||||
|
mockGenerateObject.mockResolvedValueOnce({
|
||||||
|
object: repairedInput,
|
||||||
|
warnings: [],
|
||||||
|
usage: {},
|
||||||
|
} as any);
|
||||||
|
|
||||||
|
const context: RepairContext = {
|
||||||
|
toolCall: {
|
||||||
|
toolCallType: 'function',
|
||||||
|
toolCallId: 'call123',
|
||||||
|
toolName: 'testTool',
|
||||||
|
input: '{"field1": "invalid", "field2": "not-a-number"}',
|
||||||
|
} as any,
|
||||||
|
tools: {
|
||||||
|
testTool: {
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {
|
||||||
|
field1: { type: 'string' },
|
||||||
|
field2: { type: 'number' },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
} as any,
|
||||||
|
error: new InvalidToolInputError({
|
||||||
|
toolName: 'testTool',
|
||||||
|
toolInput: 'invalid',
|
||||||
|
cause: new Error('validation failed'),
|
||||||
|
}),
|
||||||
|
messages: [],
|
||||||
|
system: '',
|
||||||
|
};
|
||||||
|
|
||||||
|
const result = await strategy.repair(context);
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
toolCallType: 'function',
|
||||||
|
toolCallId: 'call123',
|
||||||
|
toolName: 'testTool',
|
||||||
|
input: repairedInput,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle string input that is not valid JSON', async () => {
|
||||||
|
const { generateObject } = await import('ai');
|
||||||
|
const mockGenerateObject = vi.mocked(generateObject);
|
||||||
|
|
||||||
|
const repairedInput = { value: 'parsed correctly' };
|
||||||
|
mockGenerateObject.mockResolvedValueOnce({
|
||||||
|
object: repairedInput,
|
||||||
|
warnings: [],
|
||||||
|
usage: {},
|
||||||
|
} as any);
|
||||||
|
|
||||||
|
const context: RepairContext = {
|
||||||
|
toolCall: {
|
||||||
|
toolCallType: 'function',
|
||||||
|
toolCallId: 'call123',
|
||||||
|
toolName: 'testTool',
|
||||||
|
input: 'plain text input',
|
||||||
|
} as any,
|
||||||
|
tools: {
|
||||||
|
testTool: {
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {
|
||||||
|
value: { type: 'string' },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
} as any,
|
||||||
|
error: new InvalidToolInputError({
|
||||||
|
toolName: 'testTool',
|
||||||
|
toolInput: 'invalid',
|
||||||
|
cause: new Error('validation failed'),
|
||||||
|
}),
|
||||||
|
messages: [],
|
||||||
|
system: '',
|
||||||
|
};
|
||||||
|
|
||||||
|
const result = await strategy.repair(context);
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
toolCallType: 'function',
|
||||||
|
toolCallId: 'call123',
|
||||||
|
toolName: 'testTool',
|
||||||
|
input: repairedInput,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Verify the prompt contains the wrapped input
|
||||||
|
expect(mockGenerateObject).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
prompt: expect.stringContaining('{"value":"plain text input"}'),
|
||||||
|
})
|
||||||
|
);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -4,10 +4,6 @@ import { wrapTraced } from 'braintrust';
|
||||||
import { Sonnet4 } from '../../../llm';
|
import { Sonnet4 } from '../../../llm';
|
||||||
import type { RepairContext, RepairStrategy } from '../types';
|
import type { RepairContext, RepairStrategy } from '../types';
|
||||||
|
|
||||||
interface ToolCallWithArgs extends LanguageModelV2ToolCall {
|
|
||||||
args?: unknown;
|
|
||||||
}
|
|
||||||
|
|
||||||
export class StructuredOutputStrategy implements RepairStrategy {
|
export class StructuredOutputStrategy implements RepairStrategy {
|
||||||
canHandle(error: Error): boolean {
|
canHandle(error: Error): boolean {
|
||||||
return error instanceof InvalidToolInputError;
|
return error instanceof InvalidToolInputError;
|
||||||
|
@ -28,17 +24,29 @@ export class StructuredOutputStrategy implements RepairStrategy {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Type assertion to access args property
|
// Parse input if it's a string, otherwise use as-is
|
||||||
const toolCallWithArgs = context.toolCall as ToolCallWithArgs;
|
const toolCallInput = context.toolCall.input;
|
||||||
|
let parsedInput: unknown;
|
||||||
|
|
||||||
|
if (typeof toolCallInput === 'string') {
|
||||||
|
try {
|
||||||
|
parsedInput = JSON.parse(toolCallInput);
|
||||||
|
} catch {
|
||||||
|
// If it's not valid JSON, wrap it in an object
|
||||||
|
parsedInput = { value: toolCallInput };
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
parsedInput = toolCallInput || {};
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const { object: repairedArgs } = await generateObject({
|
const { object: repairedInput } = await generateObject({
|
||||||
model: Sonnet4,
|
model: Sonnet4,
|
||||||
schema: tool.inputSchema,
|
schema: tool.inputSchema,
|
||||||
prompt: [
|
prompt: [
|
||||||
`The model tried to call the tool "${context.toolCall.toolName}"`,
|
`The model tried to call the tool "${context.toolCall.toolName}"`,
|
||||||
`with the following arguments:`,
|
`with the following arguments:`,
|
||||||
JSON.stringify(toolCallWithArgs.args),
|
JSON.stringify(parsedInput),
|
||||||
`The tool accepts the following schema:`,
|
`The tool accepts the following schema:`,
|
||||||
JSON.stringify(tool.inputSchema),
|
JSON.stringify(tool.inputSchema),
|
||||||
'Please fix the arguments.',
|
'Please fix the arguments.',
|
||||||
|
@ -47,11 +55,11 @@ export class StructuredOutputStrategy implements RepairStrategy {
|
||||||
|
|
||||||
console.info('Successfully repaired tool arguments', {
|
console.info('Successfully repaired tool arguments', {
|
||||||
toolName: context.toolCall.toolName,
|
toolName: context.toolCall.toolName,
|
||||||
originalArgs: toolCallWithArgs.args,
|
originalInput: parsedInput,
|
||||||
repairedArgs,
|
repairedInput,
|
||||||
});
|
});
|
||||||
|
|
||||||
return { ...context.toolCall, args: repairedArgs } as LanguageModelV2ToolCall;
|
return { ...context.toolCall, input: repairedInput } as LanguageModelV2ToolCall;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Failed to repair tool arguments with structured output:', error);
|
console.error('Failed to repair tool arguments with structured output:', error);
|
||||||
console.error('Tool call that failed:', context.toolCall);
|
console.error('Tool call that failed:', context.toolCall);
|
||||||
|
|
Loading…
Reference in New Issue