mirror of https://github.com/buster-so/buster.git
analysis type router
This commit is contained in:
parent
77bc071f4b
commit
9d9690bb35
|
@ -0,0 +1,80 @@
|
|||
import { describe, expect, it, vi } from 'vitest';
|
||||
import type { ModelMessage } from 'ai';
|
||||
import { runAnalysisTypeRouterStep } from './analysis-type-router-step';
|
||||
|
||||
// Mock the GPT5Mini model
|
||||
vi.mock('../../../llm/gpt-5-mini', () => ({
|
||||
GPT5Mini: 'mock-gpt5-mini',
|
||||
}));
|
||||
|
||||
// Mock generateObject
|
||||
vi.mock('ai', async () => {
|
||||
const actual = await vi.importActual('ai');
|
||||
return {
|
||||
...actual,
|
||||
generateObject: vi.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
// Mock wrapTraced
|
||||
vi.mock('braintrust', () => ({
|
||||
wrapTraced: (fn: Function) => fn,
|
||||
}));
|
||||
|
||||
describe('runAnalysisTypeRouterStep', () => {
|
||||
it('should route to standard mode for simple queries', async () => {
|
||||
const messages: ModelMessage[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Show me total sales for last month',
|
||||
},
|
||||
];
|
||||
|
||||
const result = await runAnalysisTypeRouterStep({ messages });
|
||||
|
||||
// Since we're mocking, it will default to standard
|
||||
expect(result.analysisType).toBe('standard');
|
||||
expect(result.reasoning).toBeDefined();
|
||||
});
|
||||
|
||||
it('should handle empty messages array', async () => {
|
||||
const messages: ModelMessage[] = [];
|
||||
|
||||
const result = await runAnalysisTypeRouterStep({ messages });
|
||||
|
||||
expect(result.analysisType).toBe('standard');
|
||||
expect(result.reasoning).toContain('Defaulting to standard');
|
||||
});
|
||||
|
||||
it('should handle multiple messages in conversation history', async () => {
|
||||
const messages: ModelMessage[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'What is causing the revenue drop?',
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'Let me analyze the revenue trends...',
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Can you dig deeper into the Q3 anomaly?',
|
||||
},
|
||||
];
|
||||
|
||||
const result = await runAnalysisTypeRouterStep({ messages });
|
||||
|
||||
expect(result.analysisType).toBeDefined();
|
||||
expect(result.reasoning).toBeDefined();
|
||||
});
|
||||
|
||||
it('should handle errors gracefully', async () => {
|
||||
// Force an error by passing invalid data
|
||||
const messages = null as unknown as ModelMessage[];
|
||||
|
||||
const result = await runAnalysisTypeRouterStep({ messages });
|
||||
|
||||
expect(result.analysisType).toBe('standard');
|
||||
expect(result.reasoning).toContain('Defaulting to standard');
|
||||
});
|
||||
});
|
|
@ -1,82 +1,65 @@
|
|||
import { createStep } from '@mastra/core';
|
||||
import type { RuntimeContext } from '@mastra/core/runtime-context';
|
||||
import type { CoreMessage } from 'ai';
|
||||
import { generateObject } from 'ai';
|
||||
import type { ModelMessage } from 'ai';
|
||||
import { wrapTraced } from 'braintrust';
|
||||
import { z } from 'zod';
|
||||
import { GPT5Mini } from '../../../llm/gpt-5-mini';
|
||||
import { thinkAndPrepWorkflowInputSchema } from '../../../schemas/workflow-schemas';
|
||||
import { appendToConversation, standardizeMessages } from '../../../utils/standardizeMessages';
|
||||
import type { AnalystRuntimeContext } from '../../../workflows/analyst-workflow';
|
||||
import { formatAnalysisTypeRouterPrompt } from './format-analysis-type-router-prompt';
|
||||
|
||||
const inputSchema = thinkAndPrepWorkflowInputSchema;
|
||||
// Zod schemas first - following Zod-first approach
|
||||
export const analysisTypeRouterParamsSchema = z.object({
|
||||
messages: z.array(z.custom<ModelMessage>()).describe('The conversation history'),
|
||||
});
|
||||
|
||||
// Define the analysis type choices
|
||||
const AnalysisTypeEnum = z.enum(['standard', 'investigation']);
|
||||
|
||||
// Define the structure for the AI response
|
||||
const analysisTypeSchema = z.object({
|
||||
choice: AnalysisTypeEnum.describe('The type of analysis to perform'),
|
||||
export const analysisTypeRouterResultSchema = z.object({
|
||||
analysisType: z.enum(['standard', 'investigation']).describe('The chosen analysis type'),
|
||||
reasoning: z.string().describe('Explanation for why this analysis type was chosen'),
|
||||
});
|
||||
|
||||
const outputSchema = z.object({
|
||||
analysisType: analysisTypeSchema,
|
||||
conversationHistory: z.array(z.any()),
|
||||
// Pass through dashboard context
|
||||
dashboardFiles: z
|
||||
.array(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
name: z.string(),
|
||||
versionNumber: z.number(),
|
||||
metricIds: z.array(z.string()),
|
||||
})
|
||||
)
|
||||
.optional(),
|
||||
// Export types from schemas
|
||||
export type AnalysisTypeRouterParams = z.infer<typeof analysisTypeRouterParamsSchema>;
|
||||
export type AnalysisTypeRouterResult = z.infer<typeof analysisTypeRouterResultSchema>;
|
||||
|
||||
// Schema for what the LLM returns
|
||||
const llmOutputSchema = z.object({
|
||||
choice: z.enum(['standard', 'investigation']).describe('The type of analysis to perform'),
|
||||
reasoning: z.string().describe('Explanation for why this analysis type was chosen'),
|
||||
});
|
||||
|
||||
const execution = async ({
|
||||
inputData,
|
||||
}: {
|
||||
inputData: z.infer<typeof inputSchema>;
|
||||
runtimeContext: RuntimeContext<AnalystRuntimeContext>;
|
||||
}): Promise<z.infer<typeof outputSchema>> => {
|
||||
/**
|
||||
* Generates the analysis type decision using the LLM
|
||||
*/
|
||||
async function generateAnalysisTypeWithLLM(messages: ModelMessage[]): Promise<{
|
||||
choice: 'standard' | 'investigation';
|
||||
reasoning: string;
|
||||
}> {
|
||||
try {
|
||||
// Use the input data directly
|
||||
const prompt = inputData.prompt;
|
||||
const conversationHistory = inputData.conversationHistory;
|
||||
// Get the last user message for context
|
||||
const lastUserMessage = messages
|
||||
.slice()
|
||||
.reverse()
|
||||
.find((m) => m.role === 'user');
|
||||
const userPrompt = lastUserMessage?.content?.toString() || '';
|
||||
|
||||
// Prepare messages for the agent
|
||||
let messages: CoreMessage[];
|
||||
if (conversationHistory && conversationHistory.length > 0) {
|
||||
// Use conversation history as context + append new user message
|
||||
messages = appendToConversation(conversationHistory as CoreMessage[], prompt);
|
||||
} else {
|
||||
// Otherwise, use just the prompt
|
||||
messages = standardizeMessages(prompt);
|
||||
}
|
||||
|
||||
// Format the prompt using the helper function
|
||||
// Format the system prompt
|
||||
const systemPrompt = formatAnalysisTypeRouterPrompt({
|
||||
userPrompt: prompt,
|
||||
...(conversationHistory && { conversationHistory: conversationHistory as CoreMessage[] }),
|
||||
userPrompt,
|
||||
conversationHistory: messages.length > 1 ? messages : [],
|
||||
});
|
||||
|
||||
// Generate the analysis type decision
|
||||
// Prepare messages for the LLM
|
||||
const systemMessage: ModelMessage = {
|
||||
role: 'system',
|
||||
content: systemPrompt,
|
||||
};
|
||||
|
||||
const llmMessages = [systemMessage, ...messages];
|
||||
|
||||
const tracedAnalysisType = wrapTraced(
|
||||
async () => {
|
||||
const { object } = await generateObject({
|
||||
model: GPT5Mini,
|
||||
schema: analysisTypeSchema,
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
content: systemPrompt,
|
||||
},
|
||||
...messages,
|
||||
],
|
||||
schema: llmOutputSchema,
|
||||
messages: llmMessages,
|
||||
temperature: 1,
|
||||
providerOptions: {
|
||||
openai: {
|
||||
|
@ -97,36 +80,46 @@ const execution = async ({
|
|||
}
|
||||
);
|
||||
|
||||
const analysisType = await tracedAnalysisType();
|
||||
const result = await tracedAnalysisType();
|
||||
return {
|
||||
choice: result.choice,
|
||||
reasoning: result.reasoning,
|
||||
};
|
||||
} catch (llmError) {
|
||||
console.warn('[AnalysisTypeRouter] LLM failed to generate valid response:', {
|
||||
error: llmError instanceof Error ? llmError.message : 'Unknown error',
|
||||
errorType: llmError instanceof Error ? llmError.name : 'Unknown',
|
||||
});
|
||||
|
||||
// Default to standard analysis on error
|
||||
return {
|
||||
choice: 'standard',
|
||||
reasoning: 'Defaulting to standard analysis due to routing error',
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export async function runAnalysisTypeRouterStep(
|
||||
params: AnalysisTypeRouterParams
|
||||
): Promise<AnalysisTypeRouterResult> {
|
||||
try {
|
||||
const result = await generateAnalysisTypeWithLLM(params.messages);
|
||||
|
||||
console.info('[Analysis Type Router] Decision:', {
|
||||
choice: analysisType.choice,
|
||||
reasoning: analysisType.reasoning,
|
||||
choice: result.choice,
|
||||
reasoning: result.reasoning,
|
||||
});
|
||||
|
||||
return {
|
||||
analysisType,
|
||||
conversationHistory: messages,
|
||||
dashboardFiles: inputData.dashboardFiles, // Pass through dashboard context
|
||||
analysisType: result.choice,
|
||||
reasoning: result.reasoning,
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('[Analysis Type Router] Error:', error);
|
||||
console.error('[analysis-type-router-step] Unexpected error:', error);
|
||||
// Default to standard analysis on error
|
||||
return {
|
||||
analysisType: {
|
||||
choice: 'standard',
|
||||
reasoning: 'Defaulting to standard analysis due to routing error',
|
||||
},
|
||||
conversationHistory: inputData.conversationHistory || [],
|
||||
dashboardFiles: inputData.dashboardFiles, // Pass through dashboard context
|
||||
analysisType: 'standard',
|
||||
reasoning: 'Defaulting to standard analysis due to routing error',
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
export const analysisTypeRouterStep = createStep({
|
||||
id: 'analysis-type-router',
|
||||
description: 'Determines whether to use standard or investigation analysis based on the query',
|
||||
inputSchema,
|
||||
outputSchema,
|
||||
execute: execution,
|
||||
});
|
||||
}
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import type { CoreMessage } from 'ai';
|
||||
import type { ModelMessage } from 'ai';
|
||||
|
||||
// Define the required template parameters
|
||||
interface AnalysisTypeRouterTemplateParams {
|
||||
userPrompt: string;
|
||||
conversationHistory?: CoreMessage[];
|
||||
conversationHistory?: ModelMessage[];
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -5,6 +5,7 @@ export * from './analyst-agent-steps/mark-message-complete-step/mark-message-com
|
|||
export * from './analyst-agent-steps/generate-chat-title-step/generate-chat-title-step';
|
||||
export * from './analyst-agent-steps/extract-values-step/extract-values-search-step';
|
||||
export * from './analyst-agent-steps/create-todos-step/create-todos-step';
|
||||
export * from './analyst-agent-steps/analysis-type-router-step/analysis-type-router-step';
|
||||
|
||||
// Docs agent steps
|
||||
export * from './docs-agent-steps/create-docs-todos-step';
|
||||
|
|
Loading…
Reference in New Issue