analyst agent is clean

This commit is contained in:
dal 2025-08-05 00:54:23 -06:00
parent 0450b8c350
commit 9214075cad
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
3 changed files with 155 additions and 245 deletions

View File

@ -1,136 +1,32 @@
import { RuntimeContext } from '@mastra/core/runtime-context'; import type { ModelMessage } from "ai";
import type { CoreMessage } from 'ai'; import { describe, expect, test } from "vitest";
import { initLogger, wrapTraced } from 'braintrust'; import { createAnalystAgent } from "./analyst-agent";
import { afterAll, beforeAll, describe, expect, test } from 'vitest';
import type { AnalystRuntimeContext } from '../../workflows/analyst-workflow';
import { analystAgent } from './analyst-agent';
describe('Analyst Agent Integration Tests', () => { describe("Analyst Agent Integration Tests", () => {
beforeAll(async () => { test("should generate response for data analysis query with conversation history", async () => {
initLogger({ const messages: ModelMessage[] = [];
apiKey: process.env.BRAINTRUST_KEY,
projectName: 'ANALYST-AGENT',
});
});
afterAll(async () => { try {
// Cleanup if needed const analystAgent = createAnalystAgent({
// Wait 500ms before finishing sql_dialect_guidance: "postgresql",
await new Promise((resolve) => setTimeout(resolve, 500)); });
});
test('should generate response for data analysis query with conversation history', async () => { const streamResult = await analystAgent.stream({
// Stubbed conversation history - to be filled in later messages,
const conversationHistory: CoreMessage[] = [ });
// TODO: Add stubbed conversation history here
];
const tracedAgentWorkflow = wrapTraced( let response = "";
async (messages: CoreMessage[]) => { for await (const chunk of streamResult.fullStream) {
// Step 1: Generate response with analyst agent using conversation history if (chunk.type === "text-delta") {
try { response += chunk.text;
const chatId = 'da05b6fb-01b2-4c1c-bc7f-7e55029a5c75'; }
const resourceId = 'c2dd64cd-f7f3-4884-bc91-d46ae431901e'; }
// Create runtime context with required properties expect(response).toBeDefined();
const runtimeContext = new RuntimeContext<AnalystRuntimeContext>([ expect(typeof response).toBe("string");
['userId', resourceId], } catch (error) {
['chatId', chatId], console.error("Error during agent execution:", error);
['dataSourceId', 'cc3ef3bc-44ec-4a43-8dc4-681cae5c996a'], throw error;
['dataSourceSyntax', 'postgres'], }
['organizationId', 'bf58d19a-8bb9-4f1d-a257-2d2105e7f1ce'], }, 300000);
// Note: No messageId provided for testing scenario
]);
// Use stream with conversation history instead of single prompt
const stream = await analystAgent.stream(messages, {
maxSteps: 15,
runtimeContext,
onStepFinish: async (step) => {
console.log('\n=== onStepFinish callback (with history) ===');
console.log('Step structure:', JSON.stringify(step, null, 2));
console.log('Tool calls:', step.toolCalls);
console.log('Response messages:', step.response.messages);
// Response text is not directly available on step.response
console.log('===========================\n');
},
});
let response = '';
for await (const chunk of stream.fullStream) {
if (chunk.type === 'text-delta') {
response += chunk.textDelta;
}
}
return response;
} catch (error) {
console.error('Error during agent execution:', error);
throw error;
}
},
{ name: 'analyst-agent-with-history' }
);
// Test with conversation history (stubbed for now)
const result = await tracedAgentWorkflow(
conversationHistory.length > 0
? (conversationHistory as CoreMessage[])
: [{ role: 'user', content: 'What are the top 5 customers by revenue?' }]
);
expect(result).toBeDefined();
expect(typeof result).toBe('string');
expect(result.length).toBeGreaterThan(0);
// Should have generated some analysis response
expect(result).not.toBe('');
console.log('Final result:', result);
}, 300000);
test('should generate response for analysis query', async () => {
const tracedAgentWorkflow = wrapTraced(
async (input: string) => {
// Step 1: Generate response with analyst agent
try {
const chatId = 'da05b6fb-01b2-4c1c-bc7f-7e55029a5c75';
const resourceId = 'c2dd64cd-f7f3-4884-bc91-d46ae431901e';
// Create runtime context with required properties
const runtimeContext = new RuntimeContext<AnalystRuntimeContext>([
['userId', resourceId],
['chatId', chatId],
['dataSourceId', 'cc3ef3bc-44ec-4a43-8dc4-681cae5c996a'],
['dataSourceSyntax', 'postgresql'],
['organizationId', 'bf58d19a-8bb9-4f1d-a257-2d2105e7f1ce'],
]);
const stream = await analystAgent.stream(input, {
maxSteps: 15,
runtimeContext,
});
let responseText = '';
for await (const chunk of stream.fullStream) {
if (chunk.type === 'text-delta') {
responseText += chunk.textDelta;
}
}
return responseText;
} catch (error) {
console.error(error);
throw error;
}
},
{ name: 'AnalystAgentWorkflow' }
);
// Execute the workflow
const response = await tracedAgentWorkflow('please continue with the analysis');
// Verify response structure
expect(response).toBeDefined();
expect(typeof response).toBe('string');
expect(response.length).toBeGreaterThan(0);
}, 300000);
}); });

View File

@ -1,68 +1,82 @@
import { type ModelMessage, hasToolCall, stepCountIs, streamText } from 'ai'; import { hasToolCall, type ModelMessage, stepCountIs, streamText } from "ai";
import { wrapTraced } from 'braintrust'; import { wrapTraced } from "braintrust";
import z from 'zod'; import z from "zod";
import { import {
createDashboards, createDashboards,
createMetrics, createMetrics,
doneTool, doneTool,
modifyDashboards, modifyDashboards,
modifyMetrics, modifyMetrics,
} from '../../tools'; } from "../../tools";
import { Sonnet4 } from '../../utils/models/sonnet-4'; import { Sonnet4 } from "../../utils/models/sonnet-4";
import { getAnalystAgentSystemPrompt } from './get-analyst-agent-system-prompt'; import { getAnalystAgentSystemPrompt } from "./get-analyst-agent-system-prompt";
const DEFAULT_CACHE_OPTIONS = { const DEFAULT_CACHE_OPTIONS = {
anthropic: { cacheControl: { type: 'ephemeral', ttl: '1h' } }, anthropic: { cacheControl: { type: "ephemeral", ttl: "1h" } },
}; };
const STOP_CONDITIONS = [stepCountIs(18), hasToolCall('doneTool')]; const STOP_CONDITIONS = [stepCountIs(18), hasToolCall("doneTool")];
const AnalystAgentSchema = z.object({ const AnalystAgentOptionsSchema = z.object({
sql_dialect_guidance: z.string().describe('The SQL dialect guidance for the analyst agent.'), sql_dialect_guidance: z
.string()
.describe("The SQL dialect guidance for the analyst agent."),
}); });
const AnalystStreamSchema = z.object({ const AnalystStreamOptionsSchema = z.object({
messages: z messages: z
.array(z.custom<ModelMessage>()) .array(z.custom<ModelMessage>())
.describe('The messages to send to the analyst agent.'), .describe("The messages to send to the analyst agent."),
}); });
export type AnalystAgentSchema = z.infer<typeof AnalystAgentSchema>; export type AnalystAgentOptionsSchema = z.infer<
export type AnalystStreamSchema = z.infer<typeof AnalystStreamSchema>; typeof AnalystAgentOptionsSchema
>;
export type AnalystStreamOptions = z.infer<typeof AnalystStreamOptionsSchema>;
export function createAnalystAgent(analystAgentSchema: AnalystAgentSchema) { export function createAnalystAgent(
const steps: never[] = []; analystAgentSchema: AnalystAgentOptionsSchema,
) {
const steps: never[] = [];
const systemMessage = { const systemMessage = {
role: 'system', role: "system",
content: getAnalystAgentSystemPrompt(analystAgentSchema.sql_dialect_guidance), content: getAnalystAgentSystemPrompt(
providerOptions: DEFAULT_CACHE_OPTIONS, analystAgentSchema.sql_dialect_guidance,
} as ModelMessage; ),
providerOptions: DEFAULT_CACHE_OPTIONS,
} as ModelMessage;
async function stream({ messages }: AnalystStreamSchema) { async function stream({ messages }: AnalystStreamOptions) {
return wrapTraced( return wrapTraced(
async () => () =>
streamText({ streamText({
model: Sonnet4, model: Sonnet4,
tools: { createMetrics, modifyMetrics, createDashboards, modifyDashboards, doneTool }, tools: {
messages: [systemMessage, ...messages], createMetrics,
stopWhen: STOP_CONDITIONS, modifyMetrics,
toolChoice: 'required', createDashboards,
maxOutputTokens: 10000, modifyDashboards,
temperature: 0, doneTool,
}), },
{ messages: [systemMessage, ...messages],
name: 'Analyst Agent', stopWhen: STOP_CONDITIONS,
} toolChoice: "required",
); maxOutputTokens: 10000,
} temperature: 0,
}),
{
name: "Analyst Agent",
},
)();
}
async function getSteps() { async function getSteps() {
return steps; return steps;
} }
return { return {
stream, stream,
getSteps, getSteps,
}; };
} }

View File

@ -1,81 +1,81 @@
import * as fs from 'node:fs'; import * as fs from "node:fs";
import * as path from 'node:path'; import * as path from "node:path";
import { describe, expect, it } from 'vitest'; import { describe, expect, it } from "vitest";
import { getAnalystAgentSystemPrompt } from './get-analyst-agent-system-prompt'; import { getAnalystAgentSystemPrompt } from "./get-analyst-agent-system-prompt";
describe('Analyst Agent Instructions', () => { describe("Analyst Agent Instructions", () => {
it('should validate template file contains expected variables', () => { it("should validate template file contains expected variables", () => {
const promptPath = path.join(__dirname, 'analyst-agent-prompt.txt'); const promptPath = path.join(__dirname, "analyst-agent-prompt.txt");
const content = fs.readFileSync(promptPath, 'utf-8'); const content = fs.readFileSync(promptPath, "utf-8");
// Expected template variables // Expected template variables
const expectedVariables = ['sql_dialect_guidance', 'date']; const expectedVariables = ["sql_dialect_guidance", "date"];
// Find all template variables in the file // Find all template variables in the file
const templateVariablePattern = /\{\{([^}]+)\}\}/g; const templateVariablePattern = /\{\{([^}]+)\}\}/g;
const foundVariables = new Set<string>(); const foundVariables = new Set<string>();
const matches = Array.from(content.matchAll(templateVariablePattern)); const matches = Array.from(content.matchAll(templateVariablePattern));
for (const match of matches) { for (const match of matches) {
if (match[1] && match[1] !== 'variable') { if (match[1] && match[1] !== "variable") {
foundVariables.add(match[1]); foundVariables.add(match[1]);
} }
} }
// Convert to arrays for easier comparison // Convert to arrays for easier comparison
const foundVariablesArray = Array.from(foundVariables).sort(); const foundVariablesArray = Array.from(foundVariables).sort();
const expectedVariablesArray = expectedVariables.sort(); const expectedVariablesArray = expectedVariables.sort();
// Check that we have exactly the expected variables // Check that we have exactly the expected variables
expect(foundVariablesArray).toEqual(expectedVariablesArray); expect(foundVariablesArray).toEqual(expectedVariablesArray);
// Also verify each expected variable exists // Also verify each expected variable exists
for (const variable of expectedVariables) { for (const variable of expectedVariables) {
expect(content).toMatch(new RegExp(`\\{\\{${variable}\\}\\}`)); expect(content).toMatch(new RegExp(`\\{\\{${variable}\\}\\}`));
} }
// Ensure no unexpected variables exist // Ensure no unexpected variables exist
expect(foundVariables.size).toBe(expectedVariables.length); expect(foundVariables.size).toBe(expectedVariables.length);
}); });
it('should load and process the prompt template correctly', () => { it("should load and process the prompt template correctly", () => {
const sqlDialectGuidance = 'Test SQL guidance for PostgreSQL'; const sqlDialectGuidance = "Test SQL guidance for PostgreSQL";
const result = getAnalystInstructions(sqlDialectGuidance); const result = getAnalystAgentSystemPrompt(sqlDialectGuidance);
expect(result).toBeDefined(); expect(result).toBeDefined();
expect(typeof result).toBe('string'); expect(typeof result).toBe("string");
expect(result.length).toBeGreaterThan(0); expect(result.length).toBeGreaterThan(0);
// Should contain the SQL guidance we provided // Should contain the SQL guidance we provided
expect(result).toContain(sqlDialectGuidance); expect(result).toContain(sqlDialectGuidance);
// Should not contain any unreplaced template variables // Should not contain any unreplaced template variables
expect(result).not.toMatch(/\{\{sql_dialect_guidance\}\}/); expect(result).not.toMatch(/\{\{sql_dialect_guidance\}\}/);
expect(result).not.toMatch(/\{\{date\}\}/); expect(result).not.toMatch(/\{\{date\}\}/);
// Should contain the current date in YYYY-MM-DD format // Should contain the current date in YYYY-MM-DD format
const currentDate = new Date().toISOString().split('T')[0]; const currentDate = new Date().toISOString().split("T")[0];
expect(result).toContain(currentDate); expect(result).toContain(currentDate);
}); });
it('should contain expected sections from the prompt template', () => { it("should contain expected sections from the prompt template", () => {
const result = getAnalystInstructions('Test guidance'); const result = getAnalystAgentSystemPrompt("Test guidance");
// Check for key sections that should be in the prompt // Check for key sections that should be in the prompt
expect(result).toContain('<intro>'); expect(result).toContain("<intro>");
expect(result).toContain('<analysis_mode_capability>'); expect(result).toContain("<analysis_mode_capability>");
expect(result).toContain('<sql_best_practices>'); expect(result).toContain("<sql_best_practices>");
expect(result).toContain('<visualization_and_charting_guidelines>'); expect(result).toContain("<visualization_and_charting_guidelines>");
expect(result).toContain('You are a Buster'); expect(result).toContain("You are a Buster");
}); });
it('should throw an error for empty SQL dialect guidance', () => { it("should throw an error for empty SQL dialect guidance", () => {
expect(() => { expect(() => {
getAnalystInstructions(''); getAnalystAgentSystemPrompt("");
}).toThrow('SQL dialect guidance is required'); }).toThrow("SQL dialect guidance is required");
expect(() => { expect(() => {
getAnalystInstructions(' '); // whitespace only getAnalystAgentSystemPrompt(" "); // whitespace only
}).toThrow('SQL dialect guidance is required'); }).toThrow("SQL dialect guidance is required");
}); });
}); });