mirror of https://github.com/buster-so/buster.git
Integrate dataset handling into Analyst and Think and Prep agents
- Added support for fetching and managing user-specific datasets in the `analystAgentTask`, enhancing the context available for processing. - Updated the `AnalystAgent` and `ThinkAndPrepAgent` schemas to include datasets, ensuring they are passed correctly in system messages. - Implemented error handling for dataset retrieval to prevent workflow interruptions. - Adjusted integration tests to accommodate the new datasets structure, ensuring comprehensive coverage. These changes improve the agents' ability to utilize relevant datasets, enhancing their functionality and user experience.
This commit is contained in:
parent
cc6c407023
commit
855e7b1a55
|
@ -14,6 +14,9 @@ import {
|
||||||
getOrganizationDataSource,
|
getOrganizationDataSource,
|
||||||
} from '@buster/database';
|
} from '@buster/database';
|
||||||
|
|
||||||
|
// Access control imports
|
||||||
|
import { type PermissionedDataset, getPermissionedDatasets } from '@buster/access-controls';
|
||||||
|
|
||||||
// AI package imports
|
// AI package imports
|
||||||
import { type AnalystWorkflowInput, runAnalystWorkflow } from '@buster/ai';
|
import { type AnalystWorkflowInput, runAnalystWorkflow } from '@buster/ai';
|
||||||
|
|
||||||
|
@ -286,16 +289,40 @@ export const analystAgentTask: ReturnType<
|
||||||
getChatDashboardFiles({ chatId: context.chatId })
|
getChatDashboardFiles({ chatId: context.chatId })
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Fetch user's datasets as soon as we have the userId
|
||||||
|
const datasetsPromise = messageContextPromise.then(async (context) => {
|
||||||
|
try {
|
||||||
|
// Using the existing access control function
|
||||||
|
const datasets = await getPermissionedDatasets(context.userId, 0, 1000);
|
||||||
|
return datasets;
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to fetch datasets for user', {
|
||||||
|
userId: context.userId,
|
||||||
|
messageId: payload.message_id,
|
||||||
|
error: error instanceof Error ? error.message : 'Unknown error',
|
||||||
|
});
|
||||||
|
// Return empty array on error to not block the workflow
|
||||||
|
return [] as PermissionedDataset[];
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
// Fetch Braintrust metadata in parallel
|
// Fetch Braintrust metadata in parallel
|
||||||
const braintrustMetadataPromise = getBraintrustMetadata({ messageId: payload.message_id });
|
const braintrustMetadataPromise = getBraintrustMetadata({ messageId: payload.message_id });
|
||||||
|
|
||||||
// Wait for all operations to complete
|
// Wait for all operations to complete
|
||||||
const [messageContext, conversationHistory, dataSource, dashboardFiles, braintrustMetadata] =
|
const [
|
||||||
await Promise.all([
|
messageContext,
|
||||||
|
conversationHistory,
|
||||||
|
dataSource,
|
||||||
|
dashboardFiles,
|
||||||
|
datasets,
|
||||||
|
braintrustMetadata,
|
||||||
|
] = await Promise.all([
|
||||||
messageContextPromise,
|
messageContextPromise,
|
||||||
conversationHistoryPromise,
|
conversationHistoryPromise,
|
||||||
dataSourcePromise,
|
dataSourcePromise,
|
||||||
dashboardFilesPromise,
|
dashboardFilesPromise,
|
||||||
|
datasetsPromise,
|
||||||
braintrustMetadataPromise,
|
braintrustMetadataPromise,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
@ -317,6 +344,11 @@ export const analystAgentTask: ReturnType<
|
||||||
metricIdsCount: d.metricIds.length,
|
metricIdsCount: d.metricIds.length,
|
||||||
metricIds: d.metricIds,
|
metricIds: d.metricIds,
|
||||||
})),
|
})),
|
||||||
|
datasetsCount: datasets.length,
|
||||||
|
datasets: datasets.map((d) => ({
|
||||||
|
id: d.id,
|
||||||
|
name: d.name,
|
||||||
|
})),
|
||||||
dataLoadTimeMs: dataLoadTime,
|
dataLoadTimeMs: dataLoadTime,
|
||||||
braintrustMetadata, // Log the metadata to verify it's working
|
braintrustMetadata, // Log the metadata to verify it's working
|
||||||
});
|
});
|
||||||
|
@ -344,6 +376,7 @@ export const analystAgentTask: ReturnType<
|
||||||
organizationId: messageContext.organizationId,
|
organizationId: messageContext.organizationId,
|
||||||
dataSourceId: dataSource.dataSourceId,
|
dataSourceId: dataSource.dataSourceId,
|
||||||
dataSourceSyntax: dataSource.dataSourceSyntax,
|
dataSourceSyntax: dataSource.dataSourceSyntax,
|
||||||
|
datasets,
|
||||||
};
|
};
|
||||||
|
|
||||||
logger.log('Workflow input prepared', {
|
logger.log('Workflow input prepared', {
|
||||||
|
|
|
@ -14,6 +14,7 @@ describe('Analyst Agent Integration Tests', () => {
|
||||||
dataSourceSyntax: 'postgresql',
|
dataSourceSyntax: 'postgresql',
|
||||||
organizationId: '123',
|
organizationId: '123',
|
||||||
messageId: '123',
|
messageId: '123',
|
||||||
|
datasets: [],
|
||||||
});
|
});
|
||||||
|
|
||||||
const streamResult = await analystAgent.stream({
|
const streamResult = await analystAgent.stream({
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import type { PermissionedDataset } from '@buster/access-controls';
|
||||||
import { type ModelMessage, NoSuchToolError, hasToolCall, stepCountIs, streamText } from 'ai';
|
import { type ModelMessage, NoSuchToolError, hasToolCall, stepCountIs, streamText } from 'ai';
|
||||||
import { wrapTraced } from 'braintrust';
|
import { wrapTraced } from 'braintrust';
|
||||||
import z from 'zod';
|
import z from 'zod';
|
||||||
|
@ -31,6 +32,7 @@ export const AnalystAgentOptionsSchema = z.object({
|
||||||
dataSourceSyntax: z.string(),
|
dataSourceSyntax: z.string(),
|
||||||
organizationId: z.string(),
|
organizationId: z.string(),
|
||||||
messageId: z.string(),
|
messageId: z.string(),
|
||||||
|
datasets: z.array(z.custom<PermissionedDataset>()),
|
||||||
});
|
});
|
||||||
|
|
||||||
export const AnalystStreamOptionsSchema = z.object({
|
export const AnalystStreamOptionsSchema = z.object({
|
||||||
|
@ -43,12 +45,28 @@ export type AnalystAgentOptions = z.infer<typeof AnalystAgentOptionsSchema>;
|
||||||
export type AnalystStreamOptions = z.infer<typeof AnalystStreamOptionsSchema>;
|
export type AnalystStreamOptions = z.infer<typeof AnalystStreamOptionsSchema>;
|
||||||
|
|
||||||
export function createAnalystAgent(analystAgentOptions: AnalystAgentOptions) {
|
export function createAnalystAgent(analystAgentOptions: AnalystAgentOptions) {
|
||||||
|
const { datasets } = analystAgentOptions;
|
||||||
|
|
||||||
const systemMessage = {
|
const systemMessage = {
|
||||||
role: 'system',
|
role: 'system',
|
||||||
content: getAnalystAgentSystemPrompt(analystAgentOptions.dataSourceSyntax),
|
content: getAnalystAgentSystemPrompt(analystAgentOptions.dataSourceSyntax),
|
||||||
providerOptions: DEFAULT_CACHE_OPTIONS,
|
providerOptions: DEFAULT_CACHE_OPTIONS,
|
||||||
} as ModelMessage;
|
} as ModelMessage;
|
||||||
|
|
||||||
|
// Create second system message with datasets information
|
||||||
|
const datasetsContent = datasets
|
||||||
|
.filter((d) => d.ymlFile)
|
||||||
|
.map((d) => d.ymlFile)
|
||||||
|
.join('\n\n');
|
||||||
|
|
||||||
|
const datasetsSystemMessage = {
|
||||||
|
role: 'system',
|
||||||
|
content: datasetsContent
|
||||||
|
? `<database_context>\n${datasetsContent}\n</database_context>`
|
||||||
|
: '<database_context>\nNo datasets available\n</database_context>',
|
||||||
|
providerOptions: DEFAULT_CACHE_OPTIONS,
|
||||||
|
} as ModelMessage;
|
||||||
|
|
||||||
async function stream({ messages }: AnalystStreamOptions) {
|
async function stream({ messages }: AnalystStreamOptions) {
|
||||||
const maxRetries = 2;
|
const maxRetries = 2;
|
||||||
let attempt = 0;
|
let attempt = 0;
|
||||||
|
@ -79,7 +97,7 @@ export function createAnalystAgent(analystAgentOptions: AnalystAgentOptions) {
|
||||||
modifyReports,
|
modifyReports,
|
||||||
doneTool,
|
doneTool,
|
||||||
},
|
},
|
||||||
messages: [systemMessage, ...currentMessages],
|
messages: [systemMessage, datasetsSystemMessage, ...currentMessages],
|
||||||
stopWhen: STOP_CONDITIONS,
|
stopWhen: STOP_CONDITIONS,
|
||||||
toolChoice: 'required',
|
toolChoice: 'required',
|
||||||
maxOutputTokens: 10000,
|
maxOutputTokens: 10000,
|
||||||
|
|
|
@ -17,6 +17,7 @@ describe('Think and Prep Agent Integration Tests', () => {
|
||||||
userId: 'test-user-123',
|
userId: 'test-user-123',
|
||||||
dataSourceId: 'test-data-source-123',
|
dataSourceId: 'test-data-source-123',
|
||||||
dataSourceSyntax: 'postgresql',
|
dataSourceSyntax: 'postgresql',
|
||||||
|
datasets: [],
|
||||||
});
|
});
|
||||||
|
|
||||||
const streamResult = await thinkAndPrepAgent.stream({
|
const streamResult = await thinkAndPrepAgent.stream({
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import type { PermissionedDataset } from '@buster/access-controls';
|
||||||
import { type ModelMessage, NoSuchToolError, hasToolCall, stepCountIs, streamText } from 'ai';
|
import { type ModelMessage, NoSuchToolError, hasToolCall, stepCountIs, streamText } from 'ai';
|
||||||
import { wrapTraced } from 'braintrust';
|
import { wrapTraced } from 'braintrust';
|
||||||
import z from 'zod';
|
import z from 'zod';
|
||||||
|
@ -37,6 +38,9 @@ export const ThinkAndPrepAgentOptionsSchema = z.object({
|
||||||
dataSourceId: z.string().describe('The data source ID for tracking tool execution.'),
|
dataSourceId: z.string().describe('The data source ID for tracking tool execution.'),
|
||||||
dataSourceSyntax: z.string().describe('The data source syntax for tracking tool execution.'),
|
dataSourceSyntax: z.string().describe('The data source syntax for tracking tool execution.'),
|
||||||
userId: z.string().describe('The user ID for tracking tool execution.'),
|
userId: z.string().describe('The user ID for tracking tool execution.'),
|
||||||
|
datasets: z
|
||||||
|
.array(z.custom<PermissionedDataset>())
|
||||||
|
.describe('The datasets available to the user.'),
|
||||||
analysisMode: z
|
analysisMode: z
|
||||||
.enum(['standard', 'investigation'])
|
.enum(['standard', 'investigation'])
|
||||||
.default('standard')
|
.default('standard')
|
||||||
|
@ -54,7 +58,7 @@ export type ThinkAndPrepAgentOptions = z.infer<typeof ThinkAndPrepAgentOptionsSc
|
||||||
export type ThinkAndPrepStreamOptions = z.infer<typeof ThinkAndPrepStreamOptionsSchema>;
|
export type ThinkAndPrepStreamOptions = z.infer<typeof ThinkAndPrepStreamOptionsSchema>;
|
||||||
|
|
||||||
export function createThinkAndPrepAgent(thinkAndPrepAgentSchema: ThinkAndPrepAgentOptions) {
|
export function createThinkAndPrepAgent(thinkAndPrepAgentSchema: ThinkAndPrepAgentOptions) {
|
||||||
const { messageId } = thinkAndPrepAgentSchema;
|
const { messageId, datasets } = thinkAndPrepAgentSchema;
|
||||||
|
|
||||||
const systemMessage = {
|
const systemMessage = {
|
||||||
role: 'system',
|
role: 'system',
|
||||||
|
@ -65,6 +69,20 @@ export function createThinkAndPrepAgent(thinkAndPrepAgentSchema: ThinkAndPrepAge
|
||||||
providerOptions: DEFAULT_CACHE_OPTIONS,
|
providerOptions: DEFAULT_CACHE_OPTIONS,
|
||||||
} as ModelMessage;
|
} as ModelMessage;
|
||||||
|
|
||||||
|
// Create second system message with datasets information
|
||||||
|
const datasetsContent = datasets
|
||||||
|
.filter((d) => d.ymlFile)
|
||||||
|
.map((d) => d.ymlFile)
|
||||||
|
.join('\n\n');
|
||||||
|
|
||||||
|
const datasetsSystemMessage = {
|
||||||
|
role: 'system',
|
||||||
|
content: datasetsContent
|
||||||
|
? `<database_context>\n${datasetsContent}\n</database_context>`
|
||||||
|
: '<database_context>\nNo datasets available\n</database_context>',
|
||||||
|
providerOptions: DEFAULT_CACHE_OPTIONS,
|
||||||
|
} as ModelMessage;
|
||||||
|
|
||||||
async function stream({ messages }: ThinkAndPrepStreamOptions) {
|
async function stream({ messages }: ThinkAndPrepStreamOptions) {
|
||||||
const maxRetries = 2;
|
const maxRetries = 2;
|
||||||
let attempt = 0;
|
let attempt = 0;
|
||||||
|
@ -94,7 +112,7 @@ export function createThinkAndPrepAgent(thinkAndPrepAgentSchema: ThinkAndPrepAge
|
||||||
submitThoughts,
|
submitThoughts,
|
||||||
messageUserClarifyingQuestion,
|
messageUserClarifyingQuestion,
|
||||||
},
|
},
|
||||||
messages: [systemMessage, ...currentMessages],
|
messages: [systemMessage, datasetsSystemMessage, ...currentMessages],
|
||||||
stopWhen: STOP_CONDITIONS,
|
stopWhen: STOP_CONDITIONS,
|
||||||
toolChoice: 'required',
|
toolChoice: 'required',
|
||||||
maxOutputTokens: 10000,
|
maxOutputTokens: 10000,
|
||||||
|
|
|
@ -41,6 +41,7 @@ describe('runAnalystAgentStep', () => {
|
||||||
userId: 'test-user-id',
|
userId: 'test-user-id',
|
||||||
dataSourceId: 'test-ds-id',
|
dataSourceId: 'test-ds-id',
|
||||||
dataSourceSyntax: 'postgres',
|
dataSourceSyntax: 'postgres',
|
||||||
|
datasets: [],
|
||||||
},
|
},
|
||||||
streamOptions: {
|
streamOptions: {
|
||||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||||
|
@ -75,6 +76,7 @@ describe('runAnalystAgentStep', () => {
|
||||||
userId: 'test-user-id',
|
userId: 'test-user-id',
|
||||||
dataSourceId: 'test-ds-id',
|
dataSourceId: 'test-ds-id',
|
||||||
dataSourceSyntax: 'postgres',
|
dataSourceSyntax: 'postgres',
|
||||||
|
datasets: [],
|
||||||
},
|
},
|
||||||
streamOptions: {
|
streamOptions: {
|
||||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||||
|
@ -104,6 +106,7 @@ describe('runAnalystAgentStep', () => {
|
||||||
userId: 'test-user-id',
|
userId: 'test-user-id',
|
||||||
dataSourceId: 'test-ds-id',
|
dataSourceId: 'test-ds-id',
|
||||||
dataSourceSyntax: 'postgres',
|
dataSourceSyntax: 'postgres',
|
||||||
|
datasets: [],
|
||||||
},
|
},
|
||||||
streamOptions: {
|
streamOptions: {
|
||||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||||
|
@ -131,6 +134,7 @@ describe('runAnalystAgentStep', () => {
|
||||||
userId: 'test-user-id',
|
userId: 'test-user-id',
|
||||||
dataSourceId: 'test-ds-id',
|
dataSourceId: 'test-ds-id',
|
||||||
dataSourceSyntax: 'postgres',
|
dataSourceSyntax: 'postgres',
|
||||||
|
datasets: [],
|
||||||
},
|
},
|
||||||
streamOptions: {
|
streamOptions: {
|
||||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||||
|
|
|
@ -84,7 +84,7 @@ async function generateTodosWithLLM(
|
||||||
// Start streaming
|
// Start streaming
|
||||||
await onStreamStart();
|
await onStreamStart();
|
||||||
|
|
||||||
const { object, textStream } = await streamObject({
|
const { object, textStream } = streamObject({
|
||||||
model: Sonnet4,
|
model: Sonnet4,
|
||||||
schema: llmOutputSchema,
|
schema: llmOutputSchema,
|
||||||
messages: todosMessages,
|
messages: todosMessages,
|
||||||
|
|
|
@ -33,7 +33,7 @@ export function createTodosReasoningMessage(
|
||||||
return {
|
return {
|
||||||
id,
|
id,
|
||||||
type: 'files',
|
type: 'files',
|
||||||
title: 'Analysis Plan',
|
title: todosState.is_complete ? 'Broke down your request' : 'Breaking down your request...',
|
||||||
status: todosState.is_complete ? 'completed' : 'loading',
|
status: todosState.is_complete ? 'completed' : 'loading',
|
||||||
secondary_title: undefined,
|
secondary_title: undefined,
|
||||||
file_ids: [id],
|
file_ids: [id],
|
||||||
|
|
|
@ -44,6 +44,7 @@ describe('runThinkAndPrepAgentStep', () => {
|
||||||
organizationId: 'test-organization-id',
|
organizationId: 'test-organization-id',
|
||||||
dataSourceId: 'test-data-source-id',
|
dataSourceId: 'test-data-source-id',
|
||||||
dataSourceSyntax: 'test-data-source-syntax',
|
dataSourceSyntax: 'test-data-source-syntax',
|
||||||
|
datasets: [],
|
||||||
},
|
},
|
||||||
streamOptions: {
|
streamOptions: {
|
||||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||||
|
@ -81,6 +82,7 @@ describe('runThinkAndPrepAgentStep', () => {
|
||||||
organizationId: 'test-organization-id',
|
organizationId: 'test-organization-id',
|
||||||
dataSourceId: 'test-data-source-id',
|
dataSourceId: 'test-data-source-id',
|
||||||
dataSourceSyntax: 'test-data-source-syntax',
|
dataSourceSyntax: 'test-data-source-syntax',
|
||||||
|
datasets: [],
|
||||||
},
|
},
|
||||||
streamOptions: {
|
streamOptions: {
|
||||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||||
|
@ -115,6 +117,7 @@ describe('runThinkAndPrepAgentStep', () => {
|
||||||
organizationId: 'test-organization-id',
|
organizationId: 'test-organization-id',
|
||||||
dataSourceId: 'test-data-source-id',
|
dataSourceId: 'test-data-source-id',
|
||||||
dataSourceSyntax: 'test-data-source-syntax',
|
dataSourceSyntax: 'test-data-source-syntax',
|
||||||
|
datasets: [],
|
||||||
},
|
},
|
||||||
streamOptions: {
|
streamOptions: {
|
||||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||||
|
@ -147,6 +150,7 @@ describe('runThinkAndPrepAgentStep', () => {
|
||||||
organizationId: 'test-organization-id',
|
organizationId: 'test-organization-id',
|
||||||
dataSourceId: 'test-data-source-id',
|
dataSourceId: 'test-data-source-id',
|
||||||
dataSourceSyntax: 'test-data-source-syntax',
|
dataSourceSyntax: 'test-data-source-syntax',
|
||||||
|
datasets: [],
|
||||||
},
|
},
|
||||||
streamOptions: {
|
streamOptions: {
|
||||||
messages: [{ role: 'user', content: 'Test prompt' }],
|
messages: [{ role: 'user', content: 'Test prompt' }],
|
||||||
|
|
|
@ -57,7 +57,7 @@ export async function runGetRepositoryTreeStep(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the tree structure with gitignore option enabled
|
// Get the tree structure with gitignore option enabled
|
||||||
let treeResult: unknown;
|
let treeResult: { success: boolean; output?: string; error?: unknown; command?: string };
|
||||||
try {
|
try {
|
||||||
treeResult = await getRepositoryTree(sandbox, '.', {
|
treeResult = await getRepositoryTree(sandbox, '.', {
|
||||||
gitignore: true,
|
gitignore: true,
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
// input for the workflow
|
// input for the workflow
|
||||||
|
|
||||||
|
import type { PermissionedDataset } from '@buster/access-controls';
|
||||||
import type { ModelMessage } from 'ai';
|
import type { ModelMessage } from 'ai';
|
||||||
import { z } from 'zod';
|
import { z } from 'zod';
|
||||||
import {
|
import {
|
||||||
|
@ -20,6 +21,7 @@ const AnalystWorkflowInputSchema = z.object({
|
||||||
organizationId: z.string().uuid(),
|
organizationId: z.string().uuid(),
|
||||||
dataSourceId: z.string().uuid(),
|
dataSourceId: z.string().uuid(),
|
||||||
dataSourceSyntax: z.string(),
|
dataSourceSyntax: z.string(),
|
||||||
|
datasets: z.array(z.custom<PermissionedDataset>()),
|
||||||
});
|
});
|
||||||
|
|
||||||
export type AnalystWorkflowInput = z.infer<typeof AnalystWorkflowInputSchema>;
|
export type AnalystWorkflowInput = z.infer<typeof AnalystWorkflowInputSchema>;
|
||||||
|
@ -44,6 +46,7 @@ export async function runAnalystWorkflow(input: AnalystWorkflowInput) {
|
||||||
dataSourceSyntax: input.dataSourceSyntax,
|
dataSourceSyntax: input.dataSourceSyntax,
|
||||||
userId: input.userId,
|
userId: input.userId,
|
||||||
sql_dialect_guidance: input.dataSourceSyntax,
|
sql_dialect_guidance: input.dataSourceSyntax,
|
||||||
|
datasets: input.datasets,
|
||||||
},
|
},
|
||||||
streamOptions: {
|
streamOptions: {
|
||||||
messages,
|
messages,
|
||||||
|
@ -60,6 +63,7 @@ export async function runAnalystWorkflow(input: AnalystWorkflowInput) {
|
||||||
dataSourceId: input.dataSourceId,
|
dataSourceId: input.dataSourceId,
|
||||||
dataSourceSyntax: input.dataSourceSyntax,
|
dataSourceSyntax: input.dataSourceSyntax,
|
||||||
userId: input.userId,
|
userId: input.userId,
|
||||||
|
datasets: input.datasets,
|
||||||
},
|
},
|
||||||
streamOptions: {
|
streamOptions: {
|
||||||
messages,
|
messages,
|
||||||
|
|
Loading…
Reference in New Issue