analysis type router

This commit is contained in:
dal 2025-08-12 09:36:40 -06:00
parent 77bc071f4b
commit 9d9690bb35
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
4 changed files with 156 additions and 82 deletions

View File

@ -0,0 +1,80 @@
import { describe, expect, it, vi } from 'vitest';
import type { ModelMessage } from 'ai';
import { runAnalysisTypeRouterStep } from './analysis-type-router-step';
// Mock the GPT5Mini model
vi.mock('../../../llm/gpt-5-mini', () => ({
GPT5Mini: 'mock-gpt5-mini',
}));
// Mock generateObject
vi.mock('ai', async () => {
const actual = await vi.importActual('ai');
return {
...actual,
generateObject: vi.fn(),
};
});
// Mock wrapTraced
vi.mock('braintrust', () => ({
wrapTraced: (fn: Function) => fn,
}));
describe('runAnalysisTypeRouterStep', () => {
it('should route to standard mode for simple queries', async () => {
const messages: ModelMessage[] = [
{
role: 'user',
content: 'Show me total sales for last month',
},
];
const result = await runAnalysisTypeRouterStep({ messages });
// Since we're mocking, it will default to standard
expect(result.analysisType).toBe('standard');
expect(result.reasoning).toBeDefined();
});
it('should handle empty messages array', async () => {
const messages: ModelMessage[] = [];
const result = await runAnalysisTypeRouterStep({ messages });
expect(result.analysisType).toBe('standard');
expect(result.reasoning).toContain('Defaulting to standard');
});
it('should handle multiple messages in conversation history', async () => {
const messages: ModelMessage[] = [
{
role: 'user',
content: 'What is causing the revenue drop?',
},
{
role: 'assistant',
content: 'Let me analyze the revenue trends...',
},
{
role: 'user',
content: 'Can you dig deeper into the Q3 anomaly?',
},
];
const result = await runAnalysisTypeRouterStep({ messages });
expect(result.analysisType).toBeDefined();
expect(result.reasoning).toBeDefined();
});
it('should handle errors gracefully', async () => {
// Force an error by passing invalid data
const messages = null as unknown as ModelMessage[];
const result = await runAnalysisTypeRouterStep({ messages });
expect(result.analysisType).toBe('standard');
expect(result.reasoning).toContain('Defaulting to standard');
});
});

View File

@ -1,82 +1,65 @@
import { createStep } from '@mastra/core';
import type { RuntimeContext } from '@mastra/core/runtime-context';
import type { CoreMessage } from 'ai';
import { generateObject } from 'ai';
import type { ModelMessage } from 'ai';
import { wrapTraced } from 'braintrust';
import { z } from 'zod';
import { GPT5Mini } from '../../../llm/gpt-5-mini';
import { thinkAndPrepWorkflowInputSchema } from '../../../schemas/workflow-schemas';
import { appendToConversation, standardizeMessages } from '../../../utils/standardizeMessages';
import type { AnalystRuntimeContext } from '../../../workflows/analyst-workflow';
import { formatAnalysisTypeRouterPrompt } from './format-analysis-type-router-prompt';
const inputSchema = thinkAndPrepWorkflowInputSchema;
// Zod schemas first - following Zod-first approach
export const analysisTypeRouterParamsSchema = z.object({
messages: z.array(z.custom<ModelMessage>()).describe('The conversation history'),
});
// Define the analysis type choices
const AnalysisTypeEnum = z.enum(['standard', 'investigation']);
// Define the structure for the AI response
const analysisTypeSchema = z.object({
choice: AnalysisTypeEnum.describe('The type of analysis to perform'),
export const analysisTypeRouterResultSchema = z.object({
analysisType: z.enum(['standard', 'investigation']).describe('The chosen analysis type'),
reasoning: z.string().describe('Explanation for why this analysis type was chosen'),
});
const outputSchema = z.object({
analysisType: analysisTypeSchema,
conversationHistory: z.array(z.any()),
// Pass through dashboard context
dashboardFiles: z
.array(
z.object({
id: z.string(),
name: z.string(),
versionNumber: z.number(),
metricIds: z.array(z.string()),
})
)
.optional(),
// Export types from schemas
export type AnalysisTypeRouterParams = z.infer<typeof analysisTypeRouterParamsSchema>;
export type AnalysisTypeRouterResult = z.infer<typeof analysisTypeRouterResultSchema>;
// Schema for what the LLM returns
const llmOutputSchema = z.object({
choice: z.enum(['standard', 'investigation']).describe('The type of analysis to perform'),
reasoning: z.string().describe('Explanation for why this analysis type was chosen'),
});
const execution = async ({
inputData,
}: {
inputData: z.infer<typeof inputSchema>;
runtimeContext: RuntimeContext<AnalystRuntimeContext>;
}): Promise<z.infer<typeof outputSchema>> => {
/**
* Generates the analysis type decision using the LLM
*/
async function generateAnalysisTypeWithLLM(messages: ModelMessage[]): Promise<{
choice: 'standard' | 'investigation';
reasoning: string;
}> {
try {
// Use the input data directly
const prompt = inputData.prompt;
const conversationHistory = inputData.conversationHistory;
// Get the last user message for context
const lastUserMessage = messages
.slice()
.reverse()
.find((m) => m.role === 'user');
const userPrompt = lastUserMessage?.content?.toString() || '';
// Prepare messages for the agent
let messages: CoreMessage[];
if (conversationHistory && conversationHistory.length > 0) {
// Use conversation history as context + append new user message
messages = appendToConversation(conversationHistory as CoreMessage[], prompt);
} else {
// Otherwise, use just the prompt
messages = standardizeMessages(prompt);
}
// Format the prompt using the helper function
// Format the system prompt
const systemPrompt = formatAnalysisTypeRouterPrompt({
userPrompt: prompt,
...(conversationHistory && { conversationHistory: conversationHistory as CoreMessage[] }),
userPrompt,
conversationHistory: messages.length > 1 ? messages : [],
});
// Generate the analysis type decision
// Prepare messages for the LLM
const systemMessage: ModelMessage = {
role: 'system',
content: systemPrompt,
};
const llmMessages = [systemMessage, ...messages];
const tracedAnalysisType = wrapTraced(
async () => {
const { object } = await generateObject({
model: GPT5Mini,
schema: analysisTypeSchema,
messages: [
{
role: 'system',
content: systemPrompt,
},
...messages,
],
schema: llmOutputSchema,
messages: llmMessages,
temperature: 1,
providerOptions: {
openai: {
@ -97,36 +80,46 @@ const execution = async ({
}
);
const analysisType = await tracedAnalysisType();
const result = await tracedAnalysisType();
return {
choice: result.choice,
reasoning: result.reasoning,
};
} catch (llmError) {
console.warn('[AnalysisTypeRouter] LLM failed to generate valid response:', {
error: llmError instanceof Error ? llmError.message : 'Unknown error',
errorType: llmError instanceof Error ? llmError.name : 'Unknown',
});
// Default to standard analysis on error
return {
choice: 'standard',
reasoning: 'Defaulting to standard analysis due to routing error',
};
}
}
export async function runAnalysisTypeRouterStep(
params: AnalysisTypeRouterParams
): Promise<AnalysisTypeRouterResult> {
try {
const result = await generateAnalysisTypeWithLLM(params.messages);
console.info('[Analysis Type Router] Decision:', {
choice: analysisType.choice,
reasoning: analysisType.reasoning,
choice: result.choice,
reasoning: result.reasoning,
});
return {
analysisType,
conversationHistory: messages,
dashboardFiles: inputData.dashboardFiles, // Pass through dashboard context
analysisType: result.choice,
reasoning: result.reasoning,
};
} catch (error) {
console.error('[Analysis Type Router] Error:', error);
console.error('[analysis-type-router-step] Unexpected error:', error);
// Default to standard analysis on error
return {
analysisType: {
choice: 'standard',
reasoning: 'Defaulting to standard analysis due to routing error',
},
conversationHistory: inputData.conversationHistory || [],
dashboardFiles: inputData.dashboardFiles, // Pass through dashboard context
analysisType: 'standard',
reasoning: 'Defaulting to standard analysis due to routing error',
};
}
};
export const analysisTypeRouterStep = createStep({
id: 'analysis-type-router',
description: 'Determines whether to use standard or investigation analysis based on the query',
inputSchema,
outputSchema,
execute: execution,
});
}

View File

@ -1,9 +1,9 @@
import type { CoreMessage } from 'ai';
import type { ModelMessage } from 'ai';
// Define the required template parameters
interface AnalysisTypeRouterTemplateParams {
userPrompt: string;
conversationHistory?: CoreMessage[];
conversationHistory?: ModelMessage[];
}
/**

View File

@ -5,6 +5,7 @@ export * from './analyst-agent-steps/mark-message-complete-step/mark-message-com
export * from './analyst-agent-steps/generate-chat-title-step/generate-chat-title-step';
export * from './analyst-agent-steps/extract-values-step/extract-values-search-step';
export * from './analyst-agent-steps/create-todos-step/create-todos-step';
export * from './analyst-agent-steps/analysis-type-router-step/analysis-type-router-step';
// Docs agent steps
export * from './docs-agent-steps/create-docs-todos-step';