mirror of https://github.com/buster-so/buster.git
Integrate dataset handling into Analyst and Think and Prep agents
- Added support for fetching and managing user-specific datasets in the `analystAgentTask`, enhancing the context available for processing. - Updated the `AnalystAgent` and `ThinkAndPrepAgent` schemas to include datasets, ensuring they are passed correctly in system messages. - Implemented error handling for dataset retrieval to prevent workflow interruptions. - Adjusted integration tests to accommodate the new datasets structure, ensuring comprehensive coverage. These changes improve the agents' ability to utilize relevant datasets, enhancing their functionality and user experience.
This commit is contained in:
parent
cc6c407023
commit
855e7b1a55
|
@ -14,6 +14,9 @@ import {
|
|||
getOrganizationDataSource,
|
||||
} from '@buster/database';
|
||||
|
||||
// Access control imports
|
||||
import { type PermissionedDataset, getPermissionedDatasets } from '@buster/access-controls';
|
||||
|
||||
// AI package imports
|
||||
import { type AnalystWorkflowInput, runAnalystWorkflow } from '@buster/ai';
|
||||
|
||||
|
@ -286,16 +289,40 @@ export const analystAgentTask: ReturnType<
|
|||
getChatDashboardFiles({ chatId: context.chatId })
|
||||
);
|
||||
|
||||
// Fetch user's datasets as soon as we have the userId
|
||||
const datasetsPromise = messageContextPromise.then(async (context) => {
|
||||
try {
|
||||
// Using the existing access control function
|
||||
const datasets = await getPermissionedDatasets(context.userId, 0, 1000);
|
||||
return datasets;
|
||||
} catch (error) {
|
||||
logger.error('Failed to fetch datasets for user', {
|
||||
userId: context.userId,
|
||||
messageId: payload.message_id,
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
});
|
||||
// Return empty array on error to not block the workflow
|
||||
return [] as PermissionedDataset[];
|
||||
}
|
||||
});
|
||||
|
||||
// Fetch Braintrust metadata in parallel
|
||||
const braintrustMetadataPromise = getBraintrustMetadata({ messageId: payload.message_id });
|
||||
|
||||
// Wait for all operations to complete
|
||||
const [messageContext, conversationHistory, dataSource, dashboardFiles, braintrustMetadata] =
|
||||
await Promise.all([
|
||||
const [
|
||||
messageContext,
|
||||
conversationHistory,
|
||||
dataSource,
|
||||
dashboardFiles,
|
||||
datasets,
|
||||
braintrustMetadata,
|
||||
] = await Promise.all([
|
||||
messageContextPromise,
|
||||
conversationHistoryPromise,
|
||||
dataSourcePromise,
|
||||
dashboardFilesPromise,
|
||||
datasetsPromise,
|
||||
braintrustMetadataPromise,
|
||||
]);
|
||||
|
||||
|
@ -317,6 +344,11 @@ export const analystAgentTask: ReturnType<
|
|||
metricIdsCount: d.metricIds.length,
|
||||
metricIds: d.metricIds,
|
||||
})),
|
||||
datasetsCount: datasets.length,
|
||||
datasets: datasets.map((d) => ({
|
||||
id: d.id,
|
||||
name: d.name,
|
||||
})),
|
||||
dataLoadTimeMs: dataLoadTime,
|
||||
braintrustMetadata, // Log the metadata to verify it's working
|
||||
});
|
||||
|
@ -344,6 +376,7 @@ export const analystAgentTask: ReturnType<
|
|||
organizationId: messageContext.organizationId,
|
||||
dataSourceId: dataSource.dataSourceId,
|
||||
dataSourceSyntax: dataSource.dataSourceSyntax,
|
||||
datasets,
|
||||
};
|
||||
|
||||
logger.log('Workflow input prepared', {
|
||||
|
|
|
@ -14,6 +14,7 @@ describe('Analyst Agent Integration Tests', () => {
|
|||
dataSourceSyntax: 'postgresql',
|
||||
organizationId: '123',
|
||||
messageId: '123',
|
||||
datasets: [],
|
||||
});
|
||||
|
||||
const streamResult = await analystAgent.stream({
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import type { PermissionedDataset } from '@buster/access-controls';
|
||||
import { type ModelMessage, NoSuchToolError, hasToolCall, stepCountIs, streamText } from 'ai';
|
||||
import { wrapTraced } from 'braintrust';
|
||||
import z from 'zod';
|
||||
|
@ -31,6 +32,7 @@ export const AnalystAgentOptionsSchema = z.object({
|
|||
dataSourceSyntax: z.string(),
|
||||
organizationId: z.string(),
|
||||
messageId: z.string(),
|
||||
datasets: z.array(z.custom<PermissionedDataset>()),
|
||||
});
|
||||
|
||||
export const AnalystStreamOptionsSchema = z.object({
|
||||
|
@ -43,12 +45,28 @@ export type AnalystAgentOptions = z.infer<typeof AnalystAgentOptionsSchema>;
|
|||
export type AnalystStreamOptions = z.infer<typeof AnalystStreamOptionsSchema>;
|
||||
|
||||
export function createAnalystAgent(analystAgentOptions: AnalystAgentOptions) {
|
||||
const { datasets } = analystAgentOptions;
|
||||
|
||||
const systemMessage = {
|
||||
role: 'system',
|
||||
content: getAnalystAgentSystemPrompt(analystAgentOptions.dataSourceSyntax),
|
||||
providerOptions: DEFAULT_CACHE_OPTIONS,
|
||||
} as ModelMessage;
|
||||
|
||||
// Create second system message with datasets information
|
||||
const datasetsContent = datasets
|
||||
.filter((d) => d.ymlFile)
|
||||
.map((d) => d.ymlFile)
|
||||
.join('\n\n');
|
||||
|
||||
const datasetsSystemMessage = {
|
||||
role: 'system',
|
||||
content: datasetsContent
|
||||
? `<database_context>\n${datasetsContent}\n</database_context>`
|
||||
: '<database_context>\nNo datasets available\n</database_context>',
|
||||
providerOptions: DEFAULT_CACHE_OPTIONS,
|
||||
} as ModelMessage;
|
||||
|
||||
async function stream({ messages }: AnalystStreamOptions) {
|
||||
const maxRetries = 2;
|
||||
let attempt = 0;
|
||||
|
@ -79,7 +97,7 @@ export function createAnalystAgent(analystAgentOptions: AnalystAgentOptions) {
|
|||
modifyReports,
|
||||
doneTool,
|
||||
},
|
||||
messages: [systemMessage, ...currentMessages],
|
||||
messages: [systemMessage, datasetsSystemMessage, ...currentMessages],
|
||||
stopWhen: STOP_CONDITIONS,
|
||||
toolChoice: 'required',
|
||||
maxOutputTokens: 10000,
|
||||
|
|
|
@ -17,6 +17,7 @@ describe('Think and Prep Agent Integration Tests', () => {
|
|||
userId: 'test-user-123',
|
||||
dataSourceId: 'test-data-source-123',
|
||||
dataSourceSyntax: 'postgresql',
|
||||
datasets: [],
|
||||
});
|
||||
|
||||
const streamResult = await thinkAndPrepAgent.stream({
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import type { PermissionedDataset } from '@buster/access-controls';
|
||||
import { type ModelMessage, NoSuchToolError, hasToolCall, stepCountIs, streamText } from 'ai';
|
||||
import { wrapTraced } from 'braintrust';
|
||||
import z from 'zod';
|
||||
|
@ -37,6 +38,9 @@ export const ThinkAndPrepAgentOptionsSchema = z.object({
|
|||
dataSourceId: z.string().describe('The data source ID for tracking tool execution.'),
|
||||
dataSourceSyntax: z.string().describe('The data source syntax for tracking tool execution.'),
|
||||
userId: z.string().describe('The user ID for tracking tool execution.'),
|
||||
datasets: z
|
||||
.array(z.custom<PermissionedDataset>())
|
||||
.describe('The datasets available to the user.'),
|
||||
analysisMode: z
|
||||
.enum(['standard', 'investigation'])
|
||||
.default('standard')
|
||||
|
@ -54,7 +58,7 @@ export type ThinkAndPrepAgentOptions = z.infer<typeof ThinkAndPrepAgentOptionsSc
|
|||
export type ThinkAndPrepStreamOptions = z.infer<typeof ThinkAndPrepStreamOptionsSchema>;
|
||||
|
||||
export function createThinkAndPrepAgent(thinkAndPrepAgentSchema: ThinkAndPrepAgentOptions) {
|
||||
const { messageId } = thinkAndPrepAgentSchema;
|
||||
const { messageId, datasets } = thinkAndPrepAgentSchema;
|
||||
|
||||
const systemMessage = {
|
||||
role: 'system',
|
||||
|
@ -65,6 +69,20 @@ export function createThinkAndPrepAgent(thinkAndPrepAgentSchema: ThinkAndPrepAge
|
|||
providerOptions: DEFAULT_CACHE_OPTIONS,
|
||||
} as ModelMessage;
|
||||
|
||||
// Create second system message with datasets information
|
||||
const datasetsContent = datasets
|
||||
.filter((d) => d.ymlFile)
|
||||
.map((d) => d.ymlFile)
|
||||
.join('\n\n');
|
||||
|
||||
const datasetsSystemMessage = {
|
||||
role: 'system',
|
||||
content: datasetsContent
|
||||
? `<database_context>\n${datasetsContent}\n</database_context>`
|
||||
: '<database_context>\nNo datasets available\n</database_context>',
|
||||
providerOptions: DEFAULT_CACHE_OPTIONS,
|
||||
} as ModelMessage;
|
||||
|
||||
async function stream({ messages }: ThinkAndPrepStreamOptions) {
|
||||
const maxRetries = 2;
|
||||
let attempt = 0;
|
||||
|
@ -94,7 +112,7 @@ export function createThinkAndPrepAgent(thinkAndPrepAgentSchema: ThinkAndPrepAge
|
|||
submitThoughts,
|
||||
messageUserClarifyingQuestion,
|
||||
},
|
||||
messages: [systemMessage, ...currentMessages],
|
||||
messages: [systemMessage, datasetsSystemMessage, ...currentMessages],
|
||||
stopWhen: STOP_CONDITIONS,
|
||||
toolChoice: 'required',
|
||||
maxOutputTokens: 10000,
|
||||
|
|
|
@ -41,6 +41,7 @@ describe('runAnalystAgentStep', () => {
|
|||
userId: 'test-user-id',
|
||||
dataSourceId: 'test-ds-id',
|
||||
dataSourceSyntax: 'postgres',
|
||||
datasets: [],
|
||||
},
|
||||
streamOptions: {
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
|
@ -75,6 +76,7 @@ describe('runAnalystAgentStep', () => {
|
|||
userId: 'test-user-id',
|
||||
dataSourceId: 'test-ds-id',
|
||||
dataSourceSyntax: 'postgres',
|
||||
datasets: [],
|
||||
},
|
||||
streamOptions: {
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
|
@ -104,6 +106,7 @@ describe('runAnalystAgentStep', () => {
|
|||
userId: 'test-user-id',
|
||||
dataSourceId: 'test-ds-id',
|
||||
dataSourceSyntax: 'postgres',
|
||||
datasets: [],
|
||||
},
|
||||
streamOptions: {
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
|
@ -131,6 +134,7 @@ describe('runAnalystAgentStep', () => {
|
|||
userId: 'test-user-id',
|
||||
dataSourceId: 'test-ds-id',
|
||||
dataSourceSyntax: 'postgres',
|
||||
datasets: [],
|
||||
},
|
||||
streamOptions: {
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
|
|
|
@ -84,7 +84,7 @@ async function generateTodosWithLLM(
|
|||
// Start streaming
|
||||
await onStreamStart();
|
||||
|
||||
const { object, textStream } = await streamObject({
|
||||
const { object, textStream } = streamObject({
|
||||
model: Sonnet4,
|
||||
schema: llmOutputSchema,
|
||||
messages: todosMessages,
|
||||
|
|
|
@ -33,7 +33,7 @@ export function createTodosReasoningMessage(
|
|||
return {
|
||||
id,
|
||||
type: 'files',
|
||||
title: 'Analysis Plan',
|
||||
title: todosState.is_complete ? 'Broke down your request' : 'Breaking down your request...',
|
||||
status: todosState.is_complete ? 'completed' : 'loading',
|
||||
secondary_title: undefined,
|
||||
file_ids: [id],
|
||||
|
|
|
@ -44,6 +44,7 @@ describe('runThinkAndPrepAgentStep', () => {
|
|||
organizationId: 'test-organization-id',
|
||||
dataSourceId: 'test-data-source-id',
|
||||
dataSourceSyntax: 'test-data-source-syntax',
|
||||
datasets: [],
|
||||
},
|
||||
streamOptions: {
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
|
@ -81,6 +82,7 @@ describe('runThinkAndPrepAgentStep', () => {
|
|||
organizationId: 'test-organization-id',
|
||||
dataSourceId: 'test-data-source-id',
|
||||
dataSourceSyntax: 'test-data-source-syntax',
|
||||
datasets: [],
|
||||
},
|
||||
streamOptions: {
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
|
@ -115,6 +117,7 @@ describe('runThinkAndPrepAgentStep', () => {
|
|||
organizationId: 'test-organization-id',
|
||||
dataSourceId: 'test-data-source-id',
|
||||
dataSourceSyntax: 'test-data-source-syntax',
|
||||
datasets: [],
|
||||
},
|
||||
streamOptions: {
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
|
@ -147,6 +150,7 @@ describe('runThinkAndPrepAgentStep', () => {
|
|||
organizationId: 'test-organization-id',
|
||||
dataSourceId: 'test-data-source-id',
|
||||
dataSourceSyntax: 'test-data-source-syntax',
|
||||
datasets: [],
|
||||
},
|
||||
streamOptions: {
|
||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||
|
|
|
@ -57,7 +57,7 @@ export async function runGetRepositoryTreeStep(
|
|||
}
|
||||
|
||||
// Get the tree structure with gitignore option enabled
|
||||
let treeResult: unknown;
|
||||
let treeResult: { success: boolean; output?: string; error?: unknown; command?: string };
|
||||
try {
|
||||
treeResult = await getRepositoryTree(sandbox, '.', {
|
||||
gitignore: true,
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// input for the workflow
|
||||
|
||||
import type { PermissionedDataset } from '@buster/access-controls';
|
||||
import type { ModelMessage } from 'ai';
|
||||
import { z } from 'zod';
|
||||
import {
|
||||
|
@ -20,6 +21,7 @@ const AnalystWorkflowInputSchema = z.object({
|
|||
organizationId: z.string().uuid(),
|
||||
dataSourceId: z.string().uuid(),
|
||||
dataSourceSyntax: z.string(),
|
||||
datasets: z.array(z.custom<PermissionedDataset>()),
|
||||
});
|
||||
|
||||
export type AnalystWorkflowInput = z.infer<typeof AnalystWorkflowInputSchema>;
|
||||
|
@ -44,6 +46,7 @@ export async function runAnalystWorkflow(input: AnalystWorkflowInput) {
|
|||
dataSourceSyntax: input.dataSourceSyntax,
|
||||
userId: input.userId,
|
||||
sql_dialect_guidance: input.dataSourceSyntax,
|
||||
datasets: input.datasets,
|
||||
},
|
||||
streamOptions: {
|
||||
messages,
|
||||
|
@ -60,6 +63,7 @@ export async function runAnalystWorkflow(input: AnalystWorkflowInput) {
|
|||
dataSourceId: input.dataSourceId,
|
||||
dataSourceSyntax: input.dataSourceSyntax,
|
||||
userId: input.userId,
|
||||
datasets: input.datasets,
|
||||
},
|
||||
streamOptions: {
|
||||
messages,
|
||||
|
|
Loading…
Reference in New Issue