analysis type router

This commit is contained in:
dal 2025-08-12 09:36:40 -06:00
parent 77bc071f4b
commit 9d9690bb35
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
4 changed files with 156 additions and 82 deletions

View File

@ -0,0 +1,80 @@
import { describe, expect, it, vi } from 'vitest';
import type { ModelMessage } from 'ai';
import { runAnalysisTypeRouterStep } from './analysis-type-router-step';
// Mock the GPT5Mini model
vi.mock('../../../llm/gpt-5-mini', () => ({
GPT5Mini: 'mock-gpt5-mini',
}));
// Mock generateObject
vi.mock('ai', async () => {
const actual = await vi.importActual('ai');
return {
...actual,
generateObject: vi.fn(),
};
});
// Mock wrapTraced
vi.mock('braintrust', () => ({
wrapTraced: (fn: Function) => fn,
}));
describe('runAnalysisTypeRouterStep', () => {
it('should route to standard mode for simple queries', async () => {
const messages: ModelMessage[] = [
{
role: 'user',
content: 'Show me total sales for last month',
},
];
const result = await runAnalysisTypeRouterStep({ messages });
// Since we're mocking, it will default to standard
expect(result.analysisType).toBe('standard');
expect(result.reasoning).toBeDefined();
});
it('should handle empty messages array', async () => {
const messages: ModelMessage[] = [];
const result = await runAnalysisTypeRouterStep({ messages });
expect(result.analysisType).toBe('standard');
expect(result.reasoning).toContain('Defaulting to standard');
});
it('should handle multiple messages in conversation history', async () => {
const messages: ModelMessage[] = [
{
role: 'user',
content: 'What is causing the revenue drop?',
},
{
role: 'assistant',
content: 'Let me analyze the revenue trends...',
},
{
role: 'user',
content: 'Can you dig deeper into the Q3 anomaly?',
},
];
const result = await runAnalysisTypeRouterStep({ messages });
expect(result.analysisType).toBeDefined();
expect(result.reasoning).toBeDefined();
});
it('should handle errors gracefully', async () => {
// Force an error by passing invalid data
const messages = null as unknown as ModelMessage[];
const result = await runAnalysisTypeRouterStep({ messages });
expect(result.analysisType).toBe('standard');
expect(result.reasoning).toContain('Defaulting to standard');
});
});

View File

@ -1,82 +1,65 @@
import { createStep } from '@mastra/core';
import type { RuntimeContext } from '@mastra/core/runtime-context';
import type { CoreMessage } from 'ai';
import { generateObject } from 'ai'; import { generateObject } from 'ai';
import type { ModelMessage } from 'ai';
import { wrapTraced } from 'braintrust'; import { wrapTraced } from 'braintrust';
import { z } from 'zod'; import { z } from 'zod';
import { GPT5Mini } from '../../../llm/gpt-5-mini'; import { GPT5Mini } from '../../../llm/gpt-5-mini';
import { thinkAndPrepWorkflowInputSchema } from '../../../schemas/workflow-schemas';
import { appendToConversation, standardizeMessages } from '../../../utils/standardizeMessages';
import type { AnalystRuntimeContext } from '../../../workflows/analyst-workflow';
import { formatAnalysisTypeRouterPrompt } from './format-analysis-type-router-prompt'; import { formatAnalysisTypeRouterPrompt } from './format-analysis-type-router-prompt';
const inputSchema = thinkAndPrepWorkflowInputSchema; // Zod schemas first - following Zod-first approach
export const analysisTypeRouterParamsSchema = z.object({
messages: z.array(z.custom<ModelMessage>()).describe('The conversation history'),
});
// Define the analysis type choices export const analysisTypeRouterResultSchema = z.object({
const AnalysisTypeEnum = z.enum(['standard', 'investigation']); analysisType: z.enum(['standard', 'investigation']).describe('The chosen analysis type'),
// Define the structure for the AI response
const analysisTypeSchema = z.object({
choice: AnalysisTypeEnum.describe('The type of analysis to perform'),
reasoning: z.string().describe('Explanation for why this analysis type was chosen'), reasoning: z.string().describe('Explanation for why this analysis type was chosen'),
}); });
const outputSchema = z.object({ // Export types from schemas
analysisType: analysisTypeSchema, export type AnalysisTypeRouterParams = z.infer<typeof analysisTypeRouterParamsSchema>;
conversationHistory: z.array(z.any()), export type AnalysisTypeRouterResult = z.infer<typeof analysisTypeRouterResultSchema>;
// Pass through dashboard context
dashboardFiles: z // Schema for what the LLM returns
.array( const llmOutputSchema = z.object({
z.object({ choice: z.enum(['standard', 'investigation']).describe('The type of analysis to perform'),
id: z.string(), reasoning: z.string().describe('Explanation for why this analysis type was chosen'),
name: z.string(),
versionNumber: z.number(),
metricIds: z.array(z.string()),
})
)
.optional(),
}); });
const execution = async ({ /**
inputData, * Generates the analysis type decision using the LLM
}: { */
inputData: z.infer<typeof inputSchema>; async function generateAnalysisTypeWithLLM(messages: ModelMessage[]): Promise<{
runtimeContext: RuntimeContext<AnalystRuntimeContext>; choice: 'standard' | 'investigation';
}): Promise<z.infer<typeof outputSchema>> => { reasoning: string;
}> {
try { try {
// Use the input data directly // Get the last user message for context
const prompt = inputData.prompt; const lastUserMessage = messages
const conversationHistory = inputData.conversationHistory; .slice()
.reverse()
.find((m) => m.role === 'user');
const userPrompt = lastUserMessage?.content?.toString() || '';
// Prepare messages for the agent // Format the system prompt
let messages: CoreMessage[];
if (conversationHistory && conversationHistory.length > 0) {
// Use conversation history as context + append new user message
messages = appendToConversation(conversationHistory as CoreMessage[], prompt);
} else {
// Otherwise, use just the prompt
messages = standardizeMessages(prompt);
}
// Format the prompt using the helper function
const systemPrompt = formatAnalysisTypeRouterPrompt({ const systemPrompt = formatAnalysisTypeRouterPrompt({
userPrompt: prompt, userPrompt,
...(conversationHistory && { conversationHistory: conversationHistory as CoreMessage[] }), conversationHistory: messages.length > 1 ? messages : [],
}); });
// Generate the analysis type decision // Prepare messages for the LLM
const systemMessage: ModelMessage = {
role: 'system',
content: systemPrompt,
};
const llmMessages = [systemMessage, ...messages];
const tracedAnalysisType = wrapTraced( const tracedAnalysisType = wrapTraced(
async () => { async () => {
const { object } = await generateObject({ const { object } = await generateObject({
model: GPT5Mini, model: GPT5Mini,
schema: analysisTypeSchema, schema: llmOutputSchema,
messages: [ messages: llmMessages,
{
role: 'system',
content: systemPrompt,
},
...messages,
],
temperature: 1, temperature: 1,
providerOptions: { providerOptions: {
openai: { openai: {
@ -97,36 +80,46 @@ const execution = async ({
} }
); );
const analysisType = await tracedAnalysisType(); const result = await tracedAnalysisType();
return {
choice: result.choice,
reasoning: result.reasoning,
};
} catch (llmError) {
console.warn('[AnalysisTypeRouter] LLM failed to generate valid response:', {
error: llmError instanceof Error ? llmError.message : 'Unknown error',
errorType: llmError instanceof Error ? llmError.name : 'Unknown',
});
// Default to standard analysis on error
return {
choice: 'standard',
reasoning: 'Defaulting to standard analysis due to routing error',
};
}
}
export async function runAnalysisTypeRouterStep(
params: AnalysisTypeRouterParams
): Promise<AnalysisTypeRouterResult> {
try {
const result = await generateAnalysisTypeWithLLM(params.messages);
console.info('[Analysis Type Router] Decision:', { console.info('[Analysis Type Router] Decision:', {
choice: analysisType.choice, choice: result.choice,
reasoning: analysisType.reasoning, reasoning: result.reasoning,
}); });
return { return {
analysisType, analysisType: result.choice,
conversationHistory: messages, reasoning: result.reasoning,
dashboardFiles: inputData.dashboardFiles, // Pass through dashboard context
}; };
} catch (error) { } catch (error) {
console.error('[Analysis Type Router] Error:', error); console.error('[analysis-type-router-step] Unexpected error:', error);
// Default to standard analysis on error // Default to standard analysis on error
return { return {
analysisType: { analysisType: 'standard',
choice: 'standard', reasoning: 'Defaulting to standard analysis due to routing error',
reasoning: 'Defaulting to standard analysis due to routing error',
},
conversationHistory: inputData.conversationHistory || [],
dashboardFiles: inputData.dashboardFiles, // Pass through dashboard context
}; };
} }
}; }
export const analysisTypeRouterStep = createStep({
id: 'analysis-type-router',
description: 'Determines whether to use standard or investigation analysis based on the query',
inputSchema,
outputSchema,
execute: execution,
});

View File

@ -1,9 +1,9 @@
import type { CoreMessage } from 'ai'; import type { ModelMessage } from 'ai';
// Define the required template parameters // Define the required template parameters
interface AnalysisTypeRouterTemplateParams { interface AnalysisTypeRouterTemplateParams {
userPrompt: string; userPrompt: string;
conversationHistory?: CoreMessage[]; conversationHistory?: ModelMessage[];
} }
/** /**

View File

@ -5,6 +5,7 @@ export * from './analyst-agent-steps/mark-message-complete-step/mark-message-com
export * from './analyst-agent-steps/generate-chat-title-step/generate-chat-title-step'; export * from './analyst-agent-steps/generate-chat-title-step/generate-chat-title-step';
export * from './analyst-agent-steps/extract-values-step/extract-values-search-step'; export * from './analyst-agent-steps/extract-values-step/extract-values-search-step';
export * from './analyst-agent-steps/create-todos-step/create-todos-step'; export * from './analyst-agent-steps/create-todos-step/create-todos-step';
export * from './analyst-agent-steps/analysis-type-router-step/analysis-type-router-step';
// Docs agent steps // Docs agent steps
export * from './docs-agent-steps/create-docs-todos-step'; export * from './docs-agent-steps/create-docs-todos-step';