Integrate dataset handling into Analyst and Think and Prep agents

- Added support for fetching and managing user-specific datasets in the `analystAgentTask`, enhancing the context available for processing.
- Updated the `AnalystAgent` and `ThinkAndPrepAgent` schemas to include datasets, ensuring they are passed correctly in system messages.
- Implemented error handling for dataset retrieval to prevent workflow interruptions.
- Adjusted integration tests to accommodate the new datasets structure, ensuring comprehensive coverage.

These changes improve the agents' ability to utilize relevant datasets, enhancing their functionality and user experience.
This commit is contained in:
dal 2025-08-12 19:31:25 -06:00
parent cc6c407023
commit 855e7b1a55
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
11 changed files with 97 additions and 14 deletions

View File

@ -14,6 +14,9 @@ import {
getOrganizationDataSource, getOrganizationDataSource,
} from '@buster/database'; } from '@buster/database';
// Access control imports
import { type PermissionedDataset, getPermissionedDatasets } from '@buster/access-controls';
// AI package imports // AI package imports
import { type AnalystWorkflowInput, runAnalystWorkflow } from '@buster/ai'; import { type AnalystWorkflowInput, runAnalystWorkflow } from '@buster/ai';
@ -286,16 +289,40 @@ export const analystAgentTask: ReturnType<
getChatDashboardFiles({ chatId: context.chatId }) getChatDashboardFiles({ chatId: context.chatId })
); );
// Fetch user's datasets as soon as we have the userId
const datasetsPromise = messageContextPromise.then(async (context) => {
try {
// Using the existing access control function
const datasets = await getPermissionedDatasets(context.userId, 0, 1000);
return datasets;
} catch (error) {
logger.error('Failed to fetch datasets for user', {
userId: context.userId,
messageId: payload.message_id,
error: error instanceof Error ? error.message : 'Unknown error',
});
// Return empty array on error to not block the workflow
return [] as PermissionedDataset[];
}
});
// Fetch Braintrust metadata in parallel // Fetch Braintrust metadata in parallel
const braintrustMetadataPromise = getBraintrustMetadata({ messageId: payload.message_id }); const braintrustMetadataPromise = getBraintrustMetadata({ messageId: payload.message_id });
// Wait for all operations to complete // Wait for all operations to complete
const [messageContext, conversationHistory, dataSource, dashboardFiles, braintrustMetadata] = const [
await Promise.all([ messageContext,
conversationHistory,
dataSource,
dashboardFiles,
datasets,
braintrustMetadata,
] = await Promise.all([
messageContextPromise, messageContextPromise,
conversationHistoryPromise, conversationHistoryPromise,
dataSourcePromise, dataSourcePromise,
dashboardFilesPromise, dashboardFilesPromise,
datasetsPromise,
braintrustMetadataPromise, braintrustMetadataPromise,
]); ]);
@ -317,6 +344,11 @@ export const analystAgentTask: ReturnType<
metricIdsCount: d.metricIds.length, metricIdsCount: d.metricIds.length,
metricIds: d.metricIds, metricIds: d.metricIds,
})), })),
datasetsCount: datasets.length,
datasets: datasets.map((d) => ({
id: d.id,
name: d.name,
})),
dataLoadTimeMs: dataLoadTime, dataLoadTimeMs: dataLoadTime,
braintrustMetadata, // Log the metadata to verify it's working braintrustMetadata, // Log the metadata to verify it's working
}); });
@ -344,6 +376,7 @@ export const analystAgentTask: ReturnType<
organizationId: messageContext.organizationId, organizationId: messageContext.organizationId,
dataSourceId: dataSource.dataSourceId, dataSourceId: dataSource.dataSourceId,
dataSourceSyntax: dataSource.dataSourceSyntax, dataSourceSyntax: dataSource.dataSourceSyntax,
datasets,
}; };
logger.log('Workflow input prepared', { logger.log('Workflow input prepared', {

View File

@ -14,6 +14,7 @@ describe('Analyst Agent Integration Tests', () => {
dataSourceSyntax: 'postgresql', dataSourceSyntax: 'postgresql',
organizationId: '123', organizationId: '123',
messageId: '123', messageId: '123',
datasets: [],
}); });
const streamResult = await analystAgent.stream({ const streamResult = await analystAgent.stream({

View File

@ -1,3 +1,4 @@
import type { PermissionedDataset } from '@buster/access-controls';
import { type ModelMessage, NoSuchToolError, hasToolCall, stepCountIs, streamText } from 'ai'; import { type ModelMessage, NoSuchToolError, hasToolCall, stepCountIs, streamText } from 'ai';
import { wrapTraced } from 'braintrust'; import { wrapTraced } from 'braintrust';
import z from 'zod'; import z from 'zod';
@ -31,6 +32,7 @@ export const AnalystAgentOptionsSchema = z.object({
dataSourceSyntax: z.string(), dataSourceSyntax: z.string(),
organizationId: z.string(), organizationId: z.string(),
messageId: z.string(), messageId: z.string(),
datasets: z.array(z.custom<PermissionedDataset>()),
}); });
export const AnalystStreamOptionsSchema = z.object({ export const AnalystStreamOptionsSchema = z.object({
@ -43,12 +45,28 @@ export type AnalystAgentOptions = z.infer<typeof AnalystAgentOptionsSchema>;
export type AnalystStreamOptions = z.infer<typeof AnalystStreamOptionsSchema>; export type AnalystStreamOptions = z.infer<typeof AnalystStreamOptionsSchema>;
export function createAnalystAgent(analystAgentOptions: AnalystAgentOptions) { export function createAnalystAgent(analystAgentOptions: AnalystAgentOptions) {
const { datasets } = analystAgentOptions;
const systemMessage = { const systemMessage = {
role: 'system', role: 'system',
content: getAnalystAgentSystemPrompt(analystAgentOptions.dataSourceSyntax), content: getAnalystAgentSystemPrompt(analystAgentOptions.dataSourceSyntax),
providerOptions: DEFAULT_CACHE_OPTIONS, providerOptions: DEFAULT_CACHE_OPTIONS,
} as ModelMessage; } as ModelMessage;
// Create second system message with datasets information
const datasetsContent = datasets
.filter((d) => d.ymlFile)
.map((d) => d.ymlFile)
.join('\n\n');
const datasetsSystemMessage = {
role: 'system',
content: datasetsContent
? `<database_context>\n${datasetsContent}\n</database_context>`
: '<database_context>\nNo datasets available\n</database_context>',
providerOptions: DEFAULT_CACHE_OPTIONS,
} as ModelMessage;
async function stream({ messages }: AnalystStreamOptions) { async function stream({ messages }: AnalystStreamOptions) {
const maxRetries = 2; const maxRetries = 2;
let attempt = 0; let attempt = 0;
@ -79,7 +97,7 @@ export function createAnalystAgent(analystAgentOptions: AnalystAgentOptions) {
modifyReports, modifyReports,
doneTool, doneTool,
}, },
messages: [systemMessage, ...currentMessages], messages: [systemMessage, datasetsSystemMessage, ...currentMessages],
stopWhen: STOP_CONDITIONS, stopWhen: STOP_CONDITIONS,
toolChoice: 'required', toolChoice: 'required',
maxOutputTokens: 10000, maxOutputTokens: 10000,

View File

@ -17,6 +17,7 @@ describe('Think and Prep Agent Integration Tests', () => {
userId: 'test-user-123', userId: 'test-user-123',
dataSourceId: 'test-data-source-123', dataSourceId: 'test-data-source-123',
dataSourceSyntax: 'postgresql', dataSourceSyntax: 'postgresql',
datasets: [],
}); });
const streamResult = await thinkAndPrepAgent.stream({ const streamResult = await thinkAndPrepAgent.stream({

View File

@ -1,3 +1,4 @@
import type { PermissionedDataset } from '@buster/access-controls';
import { type ModelMessage, NoSuchToolError, hasToolCall, stepCountIs, streamText } from 'ai'; import { type ModelMessage, NoSuchToolError, hasToolCall, stepCountIs, streamText } from 'ai';
import { wrapTraced } from 'braintrust'; import { wrapTraced } from 'braintrust';
import z from 'zod'; import z from 'zod';
@ -37,6 +38,9 @@ export const ThinkAndPrepAgentOptionsSchema = z.object({
dataSourceId: z.string().describe('The data source ID for tracking tool execution.'), dataSourceId: z.string().describe('The data source ID for tracking tool execution.'),
dataSourceSyntax: z.string().describe('The data source syntax for tracking tool execution.'), dataSourceSyntax: z.string().describe('The data source syntax for tracking tool execution.'),
userId: z.string().describe('The user ID for tracking tool execution.'), userId: z.string().describe('The user ID for tracking tool execution.'),
datasets: z
.array(z.custom<PermissionedDataset>())
.describe('The datasets available to the user.'),
analysisMode: z analysisMode: z
.enum(['standard', 'investigation']) .enum(['standard', 'investigation'])
.default('standard') .default('standard')
@ -54,7 +58,7 @@ export type ThinkAndPrepAgentOptions = z.infer<typeof ThinkAndPrepAgentOptionsSc
export type ThinkAndPrepStreamOptions = z.infer<typeof ThinkAndPrepStreamOptionsSchema>; export type ThinkAndPrepStreamOptions = z.infer<typeof ThinkAndPrepStreamOptionsSchema>;
export function createThinkAndPrepAgent(thinkAndPrepAgentSchema: ThinkAndPrepAgentOptions) { export function createThinkAndPrepAgent(thinkAndPrepAgentSchema: ThinkAndPrepAgentOptions) {
const { messageId } = thinkAndPrepAgentSchema; const { messageId, datasets } = thinkAndPrepAgentSchema;
const systemMessage = { const systemMessage = {
role: 'system', role: 'system',
@ -65,6 +69,20 @@ export function createThinkAndPrepAgent(thinkAndPrepAgentSchema: ThinkAndPrepAge
providerOptions: DEFAULT_CACHE_OPTIONS, providerOptions: DEFAULT_CACHE_OPTIONS,
} as ModelMessage; } as ModelMessage;
// Create second system message with datasets information
const datasetsContent = datasets
.filter((d) => d.ymlFile)
.map((d) => d.ymlFile)
.join('\n\n');
const datasetsSystemMessage = {
role: 'system',
content: datasetsContent
? `<database_context>\n${datasetsContent}\n</database_context>`
: '<database_context>\nNo datasets available\n</database_context>',
providerOptions: DEFAULT_CACHE_OPTIONS,
} as ModelMessage;
async function stream({ messages }: ThinkAndPrepStreamOptions) { async function stream({ messages }: ThinkAndPrepStreamOptions) {
const maxRetries = 2; const maxRetries = 2;
let attempt = 0; let attempt = 0;
@ -94,7 +112,7 @@ export function createThinkAndPrepAgent(thinkAndPrepAgentSchema: ThinkAndPrepAge
submitThoughts, submitThoughts,
messageUserClarifyingQuestion, messageUserClarifyingQuestion,
}, },
messages: [systemMessage, ...currentMessages], messages: [systemMessage, datasetsSystemMessage, ...currentMessages],
stopWhen: STOP_CONDITIONS, stopWhen: STOP_CONDITIONS,
toolChoice: 'required', toolChoice: 'required',
maxOutputTokens: 10000, maxOutputTokens: 10000,

View File

@ -41,6 +41,7 @@ describe('runAnalystAgentStep', () => {
userId: 'test-user-id', userId: 'test-user-id',
dataSourceId: 'test-ds-id', dataSourceId: 'test-ds-id',
dataSourceSyntax: 'postgres', dataSourceSyntax: 'postgres',
datasets: [],
}, },
streamOptions: { streamOptions: {
messages: [{ role: 'user', content: 'Test prompt' }], messages: [{ role: 'user', content: 'Test prompt' }],
@ -75,6 +76,7 @@ describe('runAnalystAgentStep', () => {
userId: 'test-user-id', userId: 'test-user-id',
dataSourceId: 'test-ds-id', dataSourceId: 'test-ds-id',
dataSourceSyntax: 'postgres', dataSourceSyntax: 'postgres',
datasets: [],
}, },
streamOptions: { streamOptions: {
messages: [{ role: 'user', content: 'Test prompt' }], messages: [{ role: 'user', content: 'Test prompt' }],
@ -104,6 +106,7 @@ describe('runAnalystAgentStep', () => {
userId: 'test-user-id', userId: 'test-user-id',
dataSourceId: 'test-ds-id', dataSourceId: 'test-ds-id',
dataSourceSyntax: 'postgres', dataSourceSyntax: 'postgres',
datasets: [],
}, },
streamOptions: { streamOptions: {
messages: [{ role: 'user', content: 'Test prompt' }], messages: [{ role: 'user', content: 'Test prompt' }],
@ -131,6 +134,7 @@ describe('runAnalystAgentStep', () => {
userId: 'test-user-id', userId: 'test-user-id',
dataSourceId: 'test-ds-id', dataSourceId: 'test-ds-id',
dataSourceSyntax: 'postgres', dataSourceSyntax: 'postgres',
datasets: [],
}, },
streamOptions: { streamOptions: {
messages: [{ role: 'user', content: 'Test prompt' }], messages: [{ role: 'user', content: 'Test prompt' }],

View File

@ -84,7 +84,7 @@ async function generateTodosWithLLM(
// Start streaming // Start streaming
await onStreamStart(); await onStreamStart();
const { object, textStream } = await streamObject({ const { object, textStream } = streamObject({
model: Sonnet4, model: Sonnet4,
schema: llmOutputSchema, schema: llmOutputSchema,
messages: todosMessages, messages: todosMessages,

View File

@ -33,7 +33,7 @@ export function createTodosReasoningMessage(
return { return {
id, id,
type: 'files', type: 'files',
title: 'Analysis Plan', title: todosState.is_complete ? 'Broke down your request' : 'Breaking down your request...',
status: todosState.is_complete ? 'completed' : 'loading', status: todosState.is_complete ? 'completed' : 'loading',
secondary_title: undefined, secondary_title: undefined,
file_ids: [id], file_ids: [id],

View File

@ -44,6 +44,7 @@ describe('runThinkAndPrepAgentStep', () => {
organizationId: 'test-organization-id', organizationId: 'test-organization-id',
dataSourceId: 'test-data-source-id', dataSourceId: 'test-data-source-id',
dataSourceSyntax: 'test-data-source-syntax', dataSourceSyntax: 'test-data-source-syntax',
datasets: [],
}, },
streamOptions: { streamOptions: {
messages: [{ role: 'user', content: 'Test prompt' }], messages: [{ role: 'user', content: 'Test prompt' }],
@ -81,6 +82,7 @@ describe('runThinkAndPrepAgentStep', () => {
organizationId: 'test-organization-id', organizationId: 'test-organization-id',
dataSourceId: 'test-data-source-id', dataSourceId: 'test-data-source-id',
dataSourceSyntax: 'test-data-source-syntax', dataSourceSyntax: 'test-data-source-syntax',
datasets: [],
}, },
streamOptions: { streamOptions: {
messages: [{ role: 'user', content: 'Test prompt' }], messages: [{ role: 'user', content: 'Test prompt' }],
@ -115,6 +117,7 @@ describe('runThinkAndPrepAgentStep', () => {
organizationId: 'test-organization-id', organizationId: 'test-organization-id',
dataSourceId: 'test-data-source-id', dataSourceId: 'test-data-source-id',
dataSourceSyntax: 'test-data-source-syntax', dataSourceSyntax: 'test-data-source-syntax',
datasets: [],
}, },
streamOptions: { streamOptions: {
messages: [{ role: 'user', content: 'Test prompt' }], messages: [{ role: 'user', content: 'Test prompt' }],
@ -147,6 +150,7 @@ describe('runThinkAndPrepAgentStep', () => {
organizationId: 'test-organization-id', organizationId: 'test-organization-id',
dataSourceId: 'test-data-source-id', dataSourceId: 'test-data-source-id',
dataSourceSyntax: 'test-data-source-syntax', dataSourceSyntax: 'test-data-source-syntax',
datasets: [],
}, },
streamOptions: { streamOptions: {
messages: [{ role: 'user', content: 'Test prompt' }], messages: [{ role: 'user', content: 'Test prompt' }],

View File

@ -57,7 +57,7 @@ export async function runGetRepositoryTreeStep(
} }
// Get the tree structure with gitignore option enabled // Get the tree structure with gitignore option enabled
let treeResult: unknown; let treeResult: { success: boolean; output?: string; error?: unknown; command?: string };
try { try {
treeResult = await getRepositoryTree(sandbox, '.', { treeResult = await getRepositoryTree(sandbox, '.', {
gitignore: true, gitignore: true,

View File

@ -1,5 +1,6 @@
// input for the workflow // input for the workflow
import type { PermissionedDataset } from '@buster/access-controls';
import type { ModelMessage } from 'ai'; import type { ModelMessage } from 'ai';
import { z } from 'zod'; import { z } from 'zod';
import { import {
@ -20,6 +21,7 @@ const AnalystWorkflowInputSchema = z.object({
organizationId: z.string().uuid(), organizationId: z.string().uuid(),
dataSourceId: z.string().uuid(), dataSourceId: z.string().uuid(),
dataSourceSyntax: z.string(), dataSourceSyntax: z.string(),
datasets: z.array(z.custom<PermissionedDataset>()),
}); });
export type AnalystWorkflowInput = z.infer<typeof AnalystWorkflowInputSchema>; export type AnalystWorkflowInput = z.infer<typeof AnalystWorkflowInputSchema>;
@ -44,6 +46,7 @@ export async function runAnalystWorkflow(input: AnalystWorkflowInput) {
dataSourceSyntax: input.dataSourceSyntax, dataSourceSyntax: input.dataSourceSyntax,
userId: input.userId, userId: input.userId,
sql_dialect_guidance: input.dataSourceSyntax, sql_dialect_guidance: input.dataSourceSyntax,
datasets: input.datasets,
}, },
streamOptions: { streamOptions: {
messages, messages,
@ -60,6 +63,7 @@ export async function runAnalystWorkflow(input: AnalystWorkflowInput) {
dataSourceId: input.dataSourceId, dataSourceId: input.dataSourceId,
dataSourceSyntax: input.dataSourceSyntax, dataSourceSyntax: input.dataSourceSyntax,
userId: input.userId, userId: input.userId,
datasets: input.datasets,
}, },
streamOptions: { streamOptions: {
messages, messages,