mirror of https://github.com/buster-so/buster.git
analyst agent is clean
This commit is contained in:
parent
0450b8c350
commit
9214075cad
|
@ -1,136 +1,32 @@
|
||||||
import { RuntimeContext } from '@mastra/core/runtime-context';
|
import type { ModelMessage } from "ai";
|
||||||
import type { CoreMessage } from 'ai';
|
import { describe, expect, test } from "vitest";
|
||||||
import { initLogger, wrapTraced } from 'braintrust';
|
import { createAnalystAgent } from "./analyst-agent";
|
||||||
import { afterAll, beforeAll, describe, expect, test } from 'vitest';
|
|
||||||
import type { AnalystRuntimeContext } from '../../workflows/analyst-workflow';
|
|
||||||
import { analystAgent } from './analyst-agent';
|
|
||||||
|
|
||||||
describe('Analyst Agent Integration Tests', () => {
|
describe("Analyst Agent Integration Tests", () => {
|
||||||
beforeAll(async () => {
|
test("should generate response for data analysis query with conversation history", async () => {
|
||||||
initLogger({
|
const messages: ModelMessage[] = [];
|
||||||
apiKey: process.env.BRAINTRUST_KEY,
|
|
||||||
projectName: 'ANALYST-AGENT',
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
afterAll(async () => {
|
try {
|
||||||
// Cleanup if needed
|
const analystAgent = createAnalystAgent({
|
||||||
// Wait 500ms before finishing
|
sql_dialect_guidance: "postgresql",
|
||||||
await new Promise((resolve) => setTimeout(resolve, 500));
|
});
|
||||||
});
|
|
||||||
|
|
||||||
test('should generate response for data analysis query with conversation history', async () => {
|
const streamResult = await analystAgent.stream({
|
||||||
// Stubbed conversation history - to be filled in later
|
messages,
|
||||||
const conversationHistory: CoreMessage[] = [
|
});
|
||||||
// TODO: Add stubbed conversation history here
|
|
||||||
];
|
|
||||||
|
|
||||||
const tracedAgentWorkflow = wrapTraced(
|
let response = "";
|
||||||
async (messages: CoreMessage[]) => {
|
for await (const chunk of streamResult.fullStream) {
|
||||||
// Step 1: Generate response with analyst agent using conversation history
|
if (chunk.type === "text-delta") {
|
||||||
try {
|
response += chunk.text;
|
||||||
const chatId = 'da05b6fb-01b2-4c1c-bc7f-7e55029a5c75';
|
}
|
||||||
const resourceId = 'c2dd64cd-f7f3-4884-bc91-d46ae431901e';
|
}
|
||||||
|
|
||||||
// Create runtime context with required properties
|
expect(response).toBeDefined();
|
||||||
const runtimeContext = new RuntimeContext<AnalystRuntimeContext>([
|
expect(typeof response).toBe("string");
|
||||||
['userId', resourceId],
|
} catch (error) {
|
||||||
['chatId', chatId],
|
console.error("Error during agent execution:", error);
|
||||||
['dataSourceId', 'cc3ef3bc-44ec-4a43-8dc4-681cae5c996a'],
|
throw error;
|
||||||
['dataSourceSyntax', 'postgres'],
|
}
|
||||||
['organizationId', 'bf58d19a-8bb9-4f1d-a257-2d2105e7f1ce'],
|
}, 300000);
|
||||||
// Note: No messageId provided for testing scenario
|
|
||||||
]);
|
|
||||||
|
|
||||||
// Use stream with conversation history instead of single prompt
|
|
||||||
const stream = await analystAgent.stream(messages, {
|
|
||||||
maxSteps: 15,
|
|
||||||
runtimeContext,
|
|
||||||
onStepFinish: async (step) => {
|
|
||||||
console.log('\n=== onStepFinish callback (with history) ===');
|
|
||||||
console.log('Step structure:', JSON.stringify(step, null, 2));
|
|
||||||
console.log('Tool calls:', step.toolCalls);
|
|
||||||
console.log('Response messages:', step.response.messages);
|
|
||||||
// Response text is not directly available on step.response
|
|
||||||
console.log('===========================\n');
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
let response = '';
|
|
||||||
for await (const chunk of stream.fullStream) {
|
|
||||||
if (chunk.type === 'text-delta') {
|
|
||||||
response += chunk.textDelta;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return response;
|
|
||||||
} catch (error) {
|
|
||||||
console.error('Error during agent execution:', error);
|
|
||||||
throw error;
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{ name: 'analyst-agent-with-history' }
|
|
||||||
);
|
|
||||||
|
|
||||||
// Test with conversation history (stubbed for now)
|
|
||||||
const result = await tracedAgentWorkflow(
|
|
||||||
conversationHistory.length > 0
|
|
||||||
? (conversationHistory as CoreMessage[])
|
|
||||||
: [{ role: 'user', content: 'What are the top 5 customers by revenue?' }]
|
|
||||||
);
|
|
||||||
|
|
||||||
expect(result).toBeDefined();
|
|
||||||
expect(typeof result).toBe('string');
|
|
||||||
expect(result.length).toBeGreaterThan(0);
|
|
||||||
// Should have generated some analysis response
|
|
||||||
expect(result).not.toBe('');
|
|
||||||
console.log('Final result:', result);
|
|
||||||
}, 300000);
|
|
||||||
|
|
||||||
test('should generate response for analysis query', async () => {
|
|
||||||
const tracedAgentWorkflow = wrapTraced(
|
|
||||||
async (input: string) => {
|
|
||||||
// Step 1: Generate response with analyst agent
|
|
||||||
try {
|
|
||||||
const chatId = 'da05b6fb-01b2-4c1c-bc7f-7e55029a5c75';
|
|
||||||
const resourceId = 'c2dd64cd-f7f3-4884-bc91-d46ae431901e';
|
|
||||||
|
|
||||||
// Create runtime context with required properties
|
|
||||||
const runtimeContext = new RuntimeContext<AnalystRuntimeContext>([
|
|
||||||
['userId', resourceId],
|
|
||||||
['chatId', chatId],
|
|
||||||
['dataSourceId', 'cc3ef3bc-44ec-4a43-8dc4-681cae5c996a'],
|
|
||||||
['dataSourceSyntax', 'postgresql'],
|
|
||||||
['organizationId', 'bf58d19a-8bb9-4f1d-a257-2d2105e7f1ce'],
|
|
||||||
]);
|
|
||||||
|
|
||||||
const stream = await analystAgent.stream(input, {
|
|
||||||
maxSteps: 15,
|
|
||||||
runtimeContext,
|
|
||||||
});
|
|
||||||
|
|
||||||
let responseText = '';
|
|
||||||
for await (const chunk of stream.fullStream) {
|
|
||||||
if (chunk.type === 'text-delta') {
|
|
||||||
responseText += chunk.textDelta;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return responseText;
|
|
||||||
} catch (error) {
|
|
||||||
console.error(error);
|
|
||||||
throw error;
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{ name: 'AnalystAgentWorkflow' }
|
|
||||||
);
|
|
||||||
|
|
||||||
// Execute the workflow
|
|
||||||
const response = await tracedAgentWorkflow('please continue with the analysis');
|
|
||||||
|
|
||||||
// Verify response structure
|
|
||||||
expect(response).toBeDefined();
|
|
||||||
expect(typeof response).toBe('string');
|
|
||||||
expect(response.length).toBeGreaterThan(0);
|
|
||||||
}, 300000);
|
|
||||||
});
|
});
|
||||||
|
|
|
@ -1,68 +1,82 @@
|
||||||
import { type ModelMessage, hasToolCall, stepCountIs, streamText } from 'ai';
|
import { hasToolCall, type ModelMessage, stepCountIs, streamText } from "ai";
|
||||||
import { wrapTraced } from 'braintrust';
|
import { wrapTraced } from "braintrust";
|
||||||
import z from 'zod';
|
import z from "zod";
|
||||||
import {
|
import {
|
||||||
createDashboards,
|
createDashboards,
|
||||||
createMetrics,
|
createMetrics,
|
||||||
doneTool,
|
doneTool,
|
||||||
modifyDashboards,
|
modifyDashboards,
|
||||||
modifyMetrics,
|
modifyMetrics,
|
||||||
} from '../../tools';
|
} from "../../tools";
|
||||||
import { Sonnet4 } from '../../utils/models/sonnet-4';
|
import { Sonnet4 } from "../../utils/models/sonnet-4";
|
||||||
import { getAnalystAgentSystemPrompt } from './get-analyst-agent-system-prompt';
|
import { getAnalystAgentSystemPrompt } from "./get-analyst-agent-system-prompt";
|
||||||
|
|
||||||
const DEFAULT_CACHE_OPTIONS = {
|
const DEFAULT_CACHE_OPTIONS = {
|
||||||
anthropic: { cacheControl: { type: 'ephemeral', ttl: '1h' } },
|
anthropic: { cacheControl: { type: "ephemeral", ttl: "1h" } },
|
||||||
};
|
};
|
||||||
|
|
||||||
const STOP_CONDITIONS = [stepCountIs(18), hasToolCall('doneTool')];
|
const STOP_CONDITIONS = [stepCountIs(18), hasToolCall("doneTool")];
|
||||||
|
|
||||||
const AnalystAgentSchema = z.object({
|
const AnalystAgentOptionsSchema = z.object({
|
||||||
sql_dialect_guidance: z.string().describe('The SQL dialect guidance for the analyst agent.'),
|
sql_dialect_guidance: z
|
||||||
|
.string()
|
||||||
|
.describe("The SQL dialect guidance for the analyst agent."),
|
||||||
});
|
});
|
||||||
|
|
||||||
const AnalystStreamSchema = z.object({
|
const AnalystStreamOptionsSchema = z.object({
|
||||||
messages: z
|
messages: z
|
||||||
.array(z.custom<ModelMessage>())
|
.array(z.custom<ModelMessage>())
|
||||||
.describe('The messages to send to the analyst agent.'),
|
.describe("The messages to send to the analyst agent."),
|
||||||
});
|
});
|
||||||
|
|
||||||
export type AnalystAgentSchema = z.infer<typeof AnalystAgentSchema>;
|
export type AnalystAgentOptionsSchema = z.infer<
|
||||||
export type AnalystStreamSchema = z.infer<typeof AnalystStreamSchema>;
|
typeof AnalystAgentOptionsSchema
|
||||||
|
>;
|
||||||
|
export type AnalystStreamOptions = z.infer<typeof AnalystStreamOptionsSchema>;
|
||||||
|
|
||||||
export function createAnalystAgent(analystAgentSchema: AnalystAgentSchema) {
|
export function createAnalystAgent(
|
||||||
const steps: never[] = [];
|
analystAgentSchema: AnalystAgentOptionsSchema,
|
||||||
|
) {
|
||||||
|
const steps: never[] = [];
|
||||||
|
|
||||||
const systemMessage = {
|
const systemMessage = {
|
||||||
role: 'system',
|
role: "system",
|
||||||
content: getAnalystAgentSystemPrompt(analystAgentSchema.sql_dialect_guidance),
|
content: getAnalystAgentSystemPrompt(
|
||||||
providerOptions: DEFAULT_CACHE_OPTIONS,
|
analystAgentSchema.sql_dialect_guidance,
|
||||||
} as ModelMessage;
|
),
|
||||||
|
providerOptions: DEFAULT_CACHE_OPTIONS,
|
||||||
|
} as ModelMessage;
|
||||||
|
|
||||||
async function stream({ messages }: AnalystStreamSchema) {
|
async function stream({ messages }: AnalystStreamOptions) {
|
||||||
return wrapTraced(
|
return wrapTraced(
|
||||||
async () =>
|
() =>
|
||||||
streamText({
|
streamText({
|
||||||
model: Sonnet4,
|
model: Sonnet4,
|
||||||
tools: { createMetrics, modifyMetrics, createDashboards, modifyDashboards, doneTool },
|
tools: {
|
||||||
messages: [systemMessage, ...messages],
|
createMetrics,
|
||||||
stopWhen: STOP_CONDITIONS,
|
modifyMetrics,
|
||||||
toolChoice: 'required',
|
createDashboards,
|
||||||
maxOutputTokens: 10000,
|
modifyDashboards,
|
||||||
temperature: 0,
|
doneTool,
|
||||||
}),
|
},
|
||||||
{
|
messages: [systemMessage, ...messages],
|
||||||
name: 'Analyst Agent',
|
stopWhen: STOP_CONDITIONS,
|
||||||
}
|
toolChoice: "required",
|
||||||
);
|
maxOutputTokens: 10000,
|
||||||
}
|
temperature: 0,
|
||||||
|
}),
|
||||||
|
{
|
||||||
|
name: "Analyst Agent",
|
||||||
|
},
|
||||||
|
)();
|
||||||
|
}
|
||||||
|
|
||||||
async function getSteps() {
|
async function getSteps() {
|
||||||
return steps;
|
return steps;
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
stream,
|
stream,
|
||||||
getSteps,
|
getSteps,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,81 +1,81 @@
|
||||||
import * as fs from 'node:fs';
|
import * as fs from "node:fs";
|
||||||
import * as path from 'node:path';
|
import * as path from "node:path";
|
||||||
import { describe, expect, it } from 'vitest';
|
import { describe, expect, it } from "vitest";
|
||||||
import { getAnalystAgentSystemPrompt } from './get-analyst-agent-system-prompt';
|
import { getAnalystAgentSystemPrompt } from "./get-analyst-agent-system-prompt";
|
||||||
|
|
||||||
describe('Analyst Agent Instructions', () => {
|
describe("Analyst Agent Instructions", () => {
|
||||||
it('should validate template file contains expected variables', () => {
|
it("should validate template file contains expected variables", () => {
|
||||||
const promptPath = path.join(__dirname, 'analyst-agent-prompt.txt');
|
const promptPath = path.join(__dirname, "analyst-agent-prompt.txt");
|
||||||
const content = fs.readFileSync(promptPath, 'utf-8');
|
const content = fs.readFileSync(promptPath, "utf-8");
|
||||||
|
|
||||||
// Expected template variables
|
// Expected template variables
|
||||||
const expectedVariables = ['sql_dialect_guidance', 'date'];
|
const expectedVariables = ["sql_dialect_guidance", "date"];
|
||||||
|
|
||||||
// Find all template variables in the file
|
// Find all template variables in the file
|
||||||
const templateVariablePattern = /\{\{([^}]+)\}\}/g;
|
const templateVariablePattern = /\{\{([^}]+)\}\}/g;
|
||||||
const foundVariables = new Set<string>();
|
const foundVariables = new Set<string>();
|
||||||
|
|
||||||
const matches = Array.from(content.matchAll(templateVariablePattern));
|
const matches = Array.from(content.matchAll(templateVariablePattern));
|
||||||
for (const match of matches) {
|
for (const match of matches) {
|
||||||
if (match[1] && match[1] !== 'variable') {
|
if (match[1] && match[1] !== "variable") {
|
||||||
foundVariables.add(match[1]);
|
foundVariables.add(match[1]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert to arrays for easier comparison
|
// Convert to arrays for easier comparison
|
||||||
const foundVariablesArray = Array.from(foundVariables).sort();
|
const foundVariablesArray = Array.from(foundVariables).sort();
|
||||||
const expectedVariablesArray = expectedVariables.sort();
|
const expectedVariablesArray = expectedVariables.sort();
|
||||||
|
|
||||||
// Check that we have exactly the expected variables
|
// Check that we have exactly the expected variables
|
||||||
expect(foundVariablesArray).toEqual(expectedVariablesArray);
|
expect(foundVariablesArray).toEqual(expectedVariablesArray);
|
||||||
|
|
||||||
// Also verify each expected variable exists
|
// Also verify each expected variable exists
|
||||||
for (const variable of expectedVariables) {
|
for (const variable of expectedVariables) {
|
||||||
expect(content).toMatch(new RegExp(`\\{\\{${variable}\\}\\}`));
|
expect(content).toMatch(new RegExp(`\\{\\{${variable}\\}\\}`));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure no unexpected variables exist
|
// Ensure no unexpected variables exist
|
||||||
expect(foundVariables.size).toBe(expectedVariables.length);
|
expect(foundVariables.size).toBe(expectedVariables.length);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should load and process the prompt template correctly', () => {
|
it("should load and process the prompt template correctly", () => {
|
||||||
const sqlDialectGuidance = 'Test SQL guidance for PostgreSQL';
|
const sqlDialectGuidance = "Test SQL guidance for PostgreSQL";
|
||||||
const result = getAnalystInstructions(sqlDialectGuidance);
|
const result = getAnalystAgentSystemPrompt(sqlDialectGuidance);
|
||||||
|
|
||||||
expect(result).toBeDefined();
|
expect(result).toBeDefined();
|
||||||
expect(typeof result).toBe('string');
|
expect(typeof result).toBe("string");
|
||||||
expect(result.length).toBeGreaterThan(0);
|
expect(result.length).toBeGreaterThan(0);
|
||||||
|
|
||||||
// Should contain the SQL guidance we provided
|
// Should contain the SQL guidance we provided
|
||||||
expect(result).toContain(sqlDialectGuidance);
|
expect(result).toContain(sqlDialectGuidance);
|
||||||
|
|
||||||
// Should not contain any unreplaced template variables
|
// Should not contain any unreplaced template variables
|
||||||
expect(result).not.toMatch(/\{\{sql_dialect_guidance\}\}/);
|
expect(result).not.toMatch(/\{\{sql_dialect_guidance\}\}/);
|
||||||
expect(result).not.toMatch(/\{\{date\}\}/);
|
expect(result).not.toMatch(/\{\{date\}\}/);
|
||||||
|
|
||||||
// Should contain the current date in YYYY-MM-DD format
|
// Should contain the current date in YYYY-MM-DD format
|
||||||
const currentDate = new Date().toISOString().split('T')[0];
|
const currentDate = new Date().toISOString().split("T")[0];
|
||||||
expect(result).toContain(currentDate);
|
expect(result).toContain(currentDate);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should contain expected sections from the prompt template', () => {
|
it("should contain expected sections from the prompt template", () => {
|
||||||
const result = getAnalystInstructions('Test guidance');
|
const result = getAnalystAgentSystemPrompt("Test guidance");
|
||||||
|
|
||||||
// Check for key sections that should be in the prompt
|
// Check for key sections that should be in the prompt
|
||||||
expect(result).toContain('<intro>');
|
expect(result).toContain("<intro>");
|
||||||
expect(result).toContain('<analysis_mode_capability>');
|
expect(result).toContain("<analysis_mode_capability>");
|
||||||
expect(result).toContain('<sql_best_practices>');
|
expect(result).toContain("<sql_best_practices>");
|
||||||
expect(result).toContain('<visualization_and_charting_guidelines>');
|
expect(result).toContain("<visualization_and_charting_guidelines>");
|
||||||
expect(result).toContain('You are a Buster');
|
expect(result).toContain("You are a Buster");
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should throw an error for empty SQL dialect guidance', () => {
|
it("should throw an error for empty SQL dialect guidance", () => {
|
||||||
expect(() => {
|
expect(() => {
|
||||||
getAnalystInstructions('');
|
getAnalystAgentSystemPrompt("");
|
||||||
}).toThrow('SQL dialect guidance is required');
|
}).toThrow("SQL dialect guidance is required");
|
||||||
|
|
||||||
expect(() => {
|
expect(() => {
|
||||||
getAnalystInstructions(' '); // whitespace only
|
getAnalystAgentSystemPrompt(" "); // whitespace only
|
||||||
}).toThrow('SQL dialect guidance is required');
|
}).toThrow("SQL dialect guidance is required");
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
Loading…
Reference in New Issue