From 9214075cad3e44fe9bc1a7d331c898b66034023b Mon Sep 17 00:00:00 2001 From: dal Date: Tue, 5 Aug 2025 00:54:23 -0600 Subject: [PATCH] analyst agent is clean --- .../analyst-agent/analyst-agent.int.test.ts | 156 +++--------------- .../src/agents/analyst-agent/analyst-agent.ts | 118 +++++++------ .../get-analyst-agent-system-prompt.test.ts | 126 +++++++------- 3 files changed, 155 insertions(+), 245 deletions(-) diff --git a/packages/ai/src/agents/analyst-agent/analyst-agent.int.test.ts b/packages/ai/src/agents/analyst-agent/analyst-agent.int.test.ts index f30c85719..508408ea6 100644 --- a/packages/ai/src/agents/analyst-agent/analyst-agent.int.test.ts +++ b/packages/ai/src/agents/analyst-agent/analyst-agent.int.test.ts @@ -1,136 +1,32 @@ -import { RuntimeContext } from '@mastra/core/runtime-context'; -import type { CoreMessage } from 'ai'; -import { initLogger, wrapTraced } from 'braintrust'; -import { afterAll, beforeAll, describe, expect, test } from 'vitest'; -import type { AnalystRuntimeContext } from '../../workflows/analyst-workflow'; -import { analystAgent } from './analyst-agent'; +import type { ModelMessage } from "ai"; +import { describe, expect, test } from "vitest"; +import { createAnalystAgent } from "./analyst-agent"; -describe('Analyst Agent Integration Tests', () => { - beforeAll(async () => { - initLogger({ - apiKey: process.env.BRAINTRUST_KEY, - projectName: 'ANALYST-AGENT', - }); - }); +describe("Analyst Agent Integration Tests", () => { + test("should generate response for data analysis query with conversation history", async () => { + const messages: ModelMessage[] = []; - afterAll(async () => { - // Cleanup if needed - // Wait 500ms before finishing - await new Promise((resolve) => setTimeout(resolve, 500)); - }); + try { + const analystAgent = createAnalystAgent({ + sql_dialect_guidance: "postgresql", + }); - test('should generate response for data analysis query with conversation history', async () => { - // Stubbed conversation history - to be filled in later - const conversationHistory: CoreMessage[] = [ - // TODO: Add stubbed conversation history here - ]; + const streamResult = await analystAgent.stream({ + messages, + }); - const tracedAgentWorkflow = wrapTraced( - async (messages: CoreMessage[]) => { - // Step 1: Generate response with analyst agent using conversation history - try { - const chatId = 'da05b6fb-01b2-4c1c-bc7f-7e55029a5c75'; - const resourceId = 'c2dd64cd-f7f3-4884-bc91-d46ae431901e'; + let response = ""; + for await (const chunk of streamResult.fullStream) { + if (chunk.type === "text-delta") { + response += chunk.text; + } + } - // Create runtime context with required properties - const runtimeContext = new RuntimeContext([ - ['userId', resourceId], - ['chatId', chatId], - ['dataSourceId', 'cc3ef3bc-44ec-4a43-8dc4-681cae5c996a'], - ['dataSourceSyntax', 'postgres'], - ['organizationId', 'bf58d19a-8bb9-4f1d-a257-2d2105e7f1ce'], - // Note: No messageId provided for testing scenario - ]); - - // Use stream with conversation history instead of single prompt - const stream = await analystAgent.stream(messages, { - maxSteps: 15, - runtimeContext, - onStepFinish: async (step) => { - console.log('\n=== onStepFinish callback (with history) ==='); - console.log('Step structure:', JSON.stringify(step, null, 2)); - console.log('Tool calls:', step.toolCalls); - console.log('Response messages:', step.response.messages); - // Response text is not directly available on step.response - console.log('===========================\n'); - }, - }); - - let response = ''; - for await (const chunk of stream.fullStream) { - if (chunk.type === 'text-delta') { - response += chunk.textDelta; - } - } - - return response; - } catch (error) { - console.error('Error during agent execution:', error); - throw error; - } - }, - { name: 'analyst-agent-with-history' } - ); - - // Test with conversation history (stubbed for now) - const result = await tracedAgentWorkflow( - conversationHistory.length > 0 - ? (conversationHistory as CoreMessage[]) - : [{ role: 'user', content: 'What are the top 5 customers by revenue?' }] - ); - - expect(result).toBeDefined(); - expect(typeof result).toBe('string'); - expect(result.length).toBeGreaterThan(0); - // Should have generated some analysis response - expect(result).not.toBe(''); - console.log('Final result:', result); - }, 300000); - - test('should generate response for analysis query', async () => { - const tracedAgentWorkflow = wrapTraced( - async (input: string) => { - // Step 1: Generate response with analyst agent - try { - const chatId = 'da05b6fb-01b2-4c1c-bc7f-7e55029a5c75'; - const resourceId = 'c2dd64cd-f7f3-4884-bc91-d46ae431901e'; - - // Create runtime context with required properties - const runtimeContext = new RuntimeContext([ - ['userId', resourceId], - ['chatId', chatId], - ['dataSourceId', 'cc3ef3bc-44ec-4a43-8dc4-681cae5c996a'], - ['dataSourceSyntax', 'postgresql'], - ['organizationId', 'bf58d19a-8bb9-4f1d-a257-2d2105e7f1ce'], - ]); - - const stream = await analystAgent.stream(input, { - maxSteps: 15, - runtimeContext, - }); - - let responseText = ''; - for await (const chunk of stream.fullStream) { - if (chunk.type === 'text-delta') { - responseText += chunk.textDelta; - } - } - - return responseText; - } catch (error) { - console.error(error); - throw error; - } - }, - { name: 'AnalystAgentWorkflow' } - ); - - // Execute the workflow - const response = await tracedAgentWorkflow('please continue with the analysis'); - - // Verify response structure - expect(response).toBeDefined(); - expect(typeof response).toBe('string'); - expect(response.length).toBeGreaterThan(0); - }, 300000); + expect(response).toBeDefined(); + expect(typeof response).toBe("string"); + } catch (error) { + console.error("Error during agent execution:", error); + throw error; + } + }, 300000); }); diff --git a/packages/ai/src/agents/analyst-agent/analyst-agent.ts b/packages/ai/src/agents/analyst-agent/analyst-agent.ts index 53aa31be4..e48641d08 100644 --- a/packages/ai/src/agents/analyst-agent/analyst-agent.ts +++ b/packages/ai/src/agents/analyst-agent/analyst-agent.ts @@ -1,68 +1,82 @@ -import { type ModelMessage, hasToolCall, stepCountIs, streamText } from 'ai'; -import { wrapTraced } from 'braintrust'; -import z from 'zod'; +import { hasToolCall, type ModelMessage, stepCountIs, streamText } from "ai"; +import { wrapTraced } from "braintrust"; +import z from "zod"; import { - createDashboards, - createMetrics, - doneTool, - modifyDashboards, - modifyMetrics, -} from '../../tools'; -import { Sonnet4 } from '../../utils/models/sonnet-4'; -import { getAnalystAgentSystemPrompt } from './get-analyst-agent-system-prompt'; + createDashboards, + createMetrics, + doneTool, + modifyDashboards, + modifyMetrics, +} from "../../tools"; +import { Sonnet4 } from "../../utils/models/sonnet-4"; +import { getAnalystAgentSystemPrompt } from "./get-analyst-agent-system-prompt"; const DEFAULT_CACHE_OPTIONS = { - anthropic: { cacheControl: { type: 'ephemeral', ttl: '1h' } }, + anthropic: { cacheControl: { type: "ephemeral", ttl: "1h" } }, }; -const STOP_CONDITIONS = [stepCountIs(18), hasToolCall('doneTool')]; +const STOP_CONDITIONS = [stepCountIs(18), hasToolCall("doneTool")]; -const AnalystAgentSchema = z.object({ - sql_dialect_guidance: z.string().describe('The SQL dialect guidance for the analyst agent.'), +const AnalystAgentOptionsSchema = z.object({ + sql_dialect_guidance: z + .string() + .describe("The SQL dialect guidance for the analyst agent."), }); -const AnalystStreamSchema = z.object({ - messages: z - .array(z.custom()) - .describe('The messages to send to the analyst agent.'), +const AnalystStreamOptionsSchema = z.object({ + messages: z + .array(z.custom()) + .describe("The messages to send to the analyst agent."), }); -export type AnalystAgentSchema = z.infer; -export type AnalystStreamSchema = z.infer; +export type AnalystAgentOptionsSchema = z.infer< + typeof AnalystAgentOptionsSchema +>; +export type AnalystStreamOptions = z.infer; -export function createAnalystAgent(analystAgentSchema: AnalystAgentSchema) { - const steps: never[] = []; +export function createAnalystAgent( + analystAgentSchema: AnalystAgentOptionsSchema, +) { + const steps: never[] = []; - const systemMessage = { - role: 'system', - content: getAnalystAgentSystemPrompt(analystAgentSchema.sql_dialect_guidance), - providerOptions: DEFAULT_CACHE_OPTIONS, - } as ModelMessage; + const systemMessage = { + role: "system", + content: getAnalystAgentSystemPrompt( + analystAgentSchema.sql_dialect_guidance, + ), + providerOptions: DEFAULT_CACHE_OPTIONS, + } as ModelMessage; - async function stream({ messages }: AnalystStreamSchema) { - return wrapTraced( - async () => - streamText({ - model: Sonnet4, - tools: { createMetrics, modifyMetrics, createDashboards, modifyDashboards, doneTool }, - messages: [systemMessage, ...messages], - stopWhen: STOP_CONDITIONS, - toolChoice: 'required', - maxOutputTokens: 10000, - temperature: 0, - }), - { - name: 'Analyst Agent', - } - ); - } + async function stream({ messages }: AnalystStreamOptions) { + return wrapTraced( + () => + streamText({ + model: Sonnet4, + tools: { + createMetrics, + modifyMetrics, + createDashboards, + modifyDashboards, + doneTool, + }, + messages: [systemMessage, ...messages], + stopWhen: STOP_CONDITIONS, + toolChoice: "required", + maxOutputTokens: 10000, + temperature: 0, + }), + { + name: "Analyst Agent", + }, + )(); + } - async function getSteps() { - return steps; - } + async function getSteps() { + return steps; + } - return { - stream, - getSteps, - }; + return { + stream, + getSteps, + }; } diff --git a/packages/ai/src/agents/analyst-agent/get-analyst-agent-system-prompt.test.ts b/packages/ai/src/agents/analyst-agent/get-analyst-agent-system-prompt.test.ts index 257ba2510..dc849289a 100644 --- a/packages/ai/src/agents/analyst-agent/get-analyst-agent-system-prompt.test.ts +++ b/packages/ai/src/agents/analyst-agent/get-analyst-agent-system-prompt.test.ts @@ -1,81 +1,81 @@ -import * as fs from 'node:fs'; -import * as path from 'node:path'; -import { describe, expect, it } from 'vitest'; -import { getAnalystAgentSystemPrompt } from './get-analyst-agent-system-prompt'; +import * as fs from "node:fs"; +import * as path from "node:path"; +import { describe, expect, it } from "vitest"; +import { getAnalystAgentSystemPrompt } from "./get-analyst-agent-system-prompt"; -describe('Analyst Agent Instructions', () => { - it('should validate template file contains expected variables', () => { - const promptPath = path.join(__dirname, 'analyst-agent-prompt.txt'); - const content = fs.readFileSync(promptPath, 'utf-8'); +describe("Analyst Agent Instructions", () => { + it("should validate template file contains expected variables", () => { + const promptPath = path.join(__dirname, "analyst-agent-prompt.txt"); + const content = fs.readFileSync(promptPath, "utf-8"); - // Expected template variables - const expectedVariables = ['sql_dialect_guidance', 'date']; + // Expected template variables + const expectedVariables = ["sql_dialect_guidance", "date"]; - // Find all template variables in the file - const templateVariablePattern = /\{\{([^}]+)\}\}/g; - const foundVariables = new Set(); + // Find all template variables in the file + const templateVariablePattern = /\{\{([^}]+)\}\}/g; + const foundVariables = new Set(); - const matches = Array.from(content.matchAll(templateVariablePattern)); - for (const match of matches) { - if (match[1] && match[1] !== 'variable') { - foundVariables.add(match[1]); - } - } + const matches = Array.from(content.matchAll(templateVariablePattern)); + for (const match of matches) { + if (match[1] && match[1] !== "variable") { + foundVariables.add(match[1]); + } + } - // Convert to arrays for easier comparison - const foundVariablesArray = Array.from(foundVariables).sort(); - const expectedVariablesArray = expectedVariables.sort(); + // Convert to arrays for easier comparison + const foundVariablesArray = Array.from(foundVariables).sort(); + const expectedVariablesArray = expectedVariables.sort(); - // Check that we have exactly the expected variables - expect(foundVariablesArray).toEqual(expectedVariablesArray); + // Check that we have exactly the expected variables + expect(foundVariablesArray).toEqual(expectedVariablesArray); - // Also verify each expected variable exists - for (const variable of expectedVariables) { - expect(content).toMatch(new RegExp(`\\{\\{${variable}\\}\\}`)); - } + // Also verify each expected variable exists + for (const variable of expectedVariables) { + expect(content).toMatch(new RegExp(`\\{\\{${variable}\\}\\}`)); + } - // Ensure no unexpected variables exist - expect(foundVariables.size).toBe(expectedVariables.length); - }); + // Ensure no unexpected variables exist + expect(foundVariables.size).toBe(expectedVariables.length); + }); - it('should load and process the prompt template correctly', () => { - const sqlDialectGuidance = 'Test SQL guidance for PostgreSQL'; - const result = getAnalystInstructions(sqlDialectGuidance); + it("should load and process the prompt template correctly", () => { + const sqlDialectGuidance = "Test SQL guidance for PostgreSQL"; + const result = getAnalystAgentSystemPrompt(sqlDialectGuidance); - expect(result).toBeDefined(); - expect(typeof result).toBe('string'); - expect(result.length).toBeGreaterThan(0); + expect(result).toBeDefined(); + expect(typeof result).toBe("string"); + expect(result.length).toBeGreaterThan(0); - // Should contain the SQL guidance we provided - expect(result).toContain(sqlDialectGuidance); + // Should contain the SQL guidance we provided + expect(result).toContain(sqlDialectGuidance); - // Should not contain any unreplaced template variables - expect(result).not.toMatch(/\{\{sql_dialect_guidance\}\}/); - expect(result).not.toMatch(/\{\{date\}\}/); + // Should not contain any unreplaced template variables + expect(result).not.toMatch(/\{\{sql_dialect_guidance\}\}/); + expect(result).not.toMatch(/\{\{date\}\}/); - // Should contain the current date in YYYY-MM-DD format - const currentDate = new Date().toISOString().split('T')[0]; - expect(result).toContain(currentDate); - }); + // Should contain the current date in YYYY-MM-DD format + const currentDate = new Date().toISOString().split("T")[0]; + expect(result).toContain(currentDate); + }); - it('should contain expected sections from the prompt template', () => { - const result = getAnalystInstructions('Test guidance'); + it("should contain expected sections from the prompt template", () => { + const result = getAnalystAgentSystemPrompt("Test guidance"); - // Check for key sections that should be in the prompt - expect(result).toContain(''); - expect(result).toContain(''); - expect(result).toContain(''); - expect(result).toContain(''); - expect(result).toContain('You are a Buster'); - }); + // Check for key sections that should be in the prompt + expect(result).toContain(""); + expect(result).toContain(""); + expect(result).toContain(""); + expect(result).toContain(""); + expect(result).toContain("You are a Buster"); + }); - it('should throw an error for empty SQL dialect guidance', () => { - expect(() => { - getAnalystInstructions(''); - }).toThrow('SQL dialect guidance is required'); + it("should throw an error for empty SQL dialect guidance", () => { + expect(() => { + getAnalystAgentSystemPrompt(""); + }).toThrow("SQL dialect guidance is required"); - expect(() => { - getAnalystInstructions(' '); // whitespace only - }).toThrow('SQL dialect guidance is required'); - }); + expect(() => { + getAnalystAgentSystemPrompt(" "); // whitespace only + }).toThrow("SQL dialect guidance is required"); + }); });