diff --git a/packages/ai/src/agents/think-and-prep-agent/get-think-and-prep-agent-system-prompt.test.ts b/packages/ai/src/agents/think-and-prep-agent/get-think-and-prep-agent-system-prompt.test.ts new file mode 100644 index 000000000..ae5de9b9a --- /dev/null +++ b/packages/ai/src/agents/think-and-prep-agent/get-think-and-prep-agent-system-prompt.test.ts @@ -0,0 +1,30 @@ +import { describe, it, expect } from 'vitest'; +import { getThinkAndPrepAgentSystemPrompt } from './get-think-and-prep-agent-system-prompt'; + +describe('getThinkAndPrepAgentSystemPrompt', () => { + it('should return system prompt with SQL dialect guidance', () => { + const sqlDialectGuidance = 'PostgreSQL specific guidance'; + const result = getThinkAndPrepAgentSystemPrompt(sqlDialectGuidance); + + expect(result).toContain('You are Buster, a specialized AI agent'); + expect(result).toContain('PostgreSQL specific guidance'); + expect(result).toContain("Today's date is"); + }); + + it('should include all necessary sections', () => { + const sqlDialectGuidance = 'MySQL specific guidance'; + const result = getThinkAndPrepAgentSystemPrompt(sqlDialectGuidance); + + // Check for key sections + expect(result).toContain(''); + expect(result).toContain(''); + expect(result).toContain(''); + expect(result).toContain(''); + expect(result).toContain(''); + expect(result).toContain(''); + expect(result).toContain(''); + expect(result).toContain(''); + expect(result).toContain(''); + expect(result).toContain(''); + }); +}); \ No newline at end of file diff --git a/packages/ai/src/agents/think-and-prep-agent/think-and-prep-instructions.ts b/packages/ai/src/agents/think-and-prep-agent/get-think-and-prep-agent-system-prompt.ts similarity index 96% rename from packages/ai/src/agents/think-and-prep-agent/think-and-prep-instructions.ts rename to packages/ai/src/agents/think-and-prep-agent/get-think-and-prep-agent-system-prompt.ts index e27edee60..cd965f871 100644 --- a/packages/ai/src/agents/think-and-prep-agent/think-and-prep-instructions.ts +++ b/packages/ai/src/agents/think-and-prep-agent/get-think-and-prep-agent-system-prompt.ts @@ -1,16 +1,4 @@ -import { getPermissionedDatasets } from '@buster/access-controls'; -import type { RuntimeContext } from '@mastra/core/runtime-context'; -import type { AnalystRuntimeContext } from '../../workflows/analyst-workflow'; -import { getSqlDialectGuidance } from '../shared/sql-dialect-guidance'; - -// Define the required template parameters -interface ThinkAndPrepTemplateParams { - databaseContext: string; - sqlDialectGuidance: string; -} - -// Template string as a function that requires parameters -const createThinkAndPrepInstructions = (params: ThinkAndPrepTemplateParams): string => { +export const getThinkAndPrepAgentSystemPrompt = (sqlDialectGuidance: string): string => { return ` You are Buster, a specialized AI agent within an AI-powered data analyst system. @@ -398,7 +386,7 @@ Once all TODO list items are addressed and submitted for review, the system will - Current SQL Dialect Guidance: -${params.sqlDialectGuidance} +${sqlDialectGuidance} - Keep Queries Simple: Strive for simplicity and clarity in your SQL. Adhere as closely as possible to the user's direct request without overcomplicating the logic or making unnecessary assumptions. - Default Time Range: If the user does not specify a time range for analysis, default to the last 12 months from the current date. Clearly state this assumption if making it. - Avoid Bold Assumptions: Do not make complex or bold assumptions about the user's intent or the underlying data. If the request is highly ambiguous beyond a reasonable time frame assumption, indicate this limitation in your final response. @@ -583,46 +571,6 @@ ${params.sqlDialectGuidance} Start by using the \`sequentialThinking\` to immediately start checking off items on your TODO list Today's date is ${new Date().toLocaleDateString()}. - ---- - - -${params.databaseContext} - `; }; -export const getThinkAndPrepInstructions = async ({ - runtimeContext, -}: { runtimeContext: RuntimeContext }): Promise => { - const userId = runtimeContext.get('userId'); - const dataSourceSyntax = runtimeContext.get('dataSourceSyntax'); - - const datasets = await getPermissionedDatasets(userId, 0, 1000); - - // Extract yml_content from each dataset and join with separators - const assembledYmlContent = datasets - .map((dataset: { ymlFile: string | null | undefined }) => dataset.ymlFile) - .filter((content: string | null | undefined) => content !== null && content !== undefined) - .join('\n---\n'); - - // Get dialect-specific guidance - const sqlDialectGuidance = getSqlDialectGuidance(dataSourceSyntax); - - return createThinkAndPrepInstructions({ - databaseContext: assembledYmlContent, - sqlDialectGuidance, - }); -}; - -// Export the template function without dataset context for use in step files -export const createThinkAndPrepInstructionsWithoutDatasets = ( - sqlDialectGuidance: string -): string => { - return createThinkAndPrepInstructions({ - databaseContext: '', - sqlDialectGuidance, - }) - .replace(/[\s\S]*?<\/database_context>/, '') - .trim(); -}; diff --git a/packages/ai/src/agents/think-and-prep-agent/think-and-prep-agent.ts b/packages/ai/src/agents/think-and-prep-agent/think-and-prep-agent.ts index 8ae534050..8ca5b8539 100644 --- a/packages/ai/src/agents/think-and-prep-agent/think-and-prep-agent.ts +++ b/packages/ai/src/agents/think-and-prep-agent/think-and-prep-agent.ts @@ -1,4 +1,6 @@ -import { Agent } from '@mastra/core'; +import { hasToolCall, type ModelMessage, stepCountIs, streamText } from "ai"; +import { wrapTraced } from "braintrust"; +import z from "zod"; import { executeSql, messageUserClarifyingQuestion, @@ -7,29 +9,79 @@ import { submitThoughts, } from '../../tools'; import { Sonnet4 } from '../../utils/models/sonnet-4'; +import { getThinkAndPrepAgentSystemPrompt } from './get-think-and-prep-agent-system-prompt'; -const DEFAULT_OPTIONS = { - maxSteps: 18, - temperature: 0, - maxTokens: 10000, - providerOptions: { - anthropic: { - disableParallelToolCalls: true, - }, - }, +const DEFAULT_CACHE_OPTIONS = { + anthropic: { cacheControl: { type: "ephemeral", ttl: "1h" } }, }; -export const thinkAndPrepAgent = new Agent({ - name: 'Think and Prep Agent', - instructions: '', // We control the system messages in the step at stream instantiation - model: Sonnet4, - tools: { - sequentialThinking, - executeSql, - respondWithoutAssetCreation, - submitThoughts, - messageUserClarifyingQuestion, - }, - defaultGenerateOptions: DEFAULT_OPTIONS, - defaultStreamOptions: DEFAULT_OPTIONS, +const STOP_CONDITIONS = [ + stepCountIs(18), + hasToolCall("submitThoughts"), + hasToolCall("respondWithoutAssetCreation"), + hasToolCall("messageUserClarifyingQuestion") +]; + +const ThinkAndPrepAgentOptionsSchema = z.object({ + sql_dialect_guidance: z + .string() + .describe("The SQL dialect guidance for the think and prep agent."), }); + +const ThinkAndPrepStreamOptionsSchema = z.object({ + messages: z + .array(z.custom()) + .describe("The messages to send to the think and prep agent."), +}); + +export type ThinkAndPrepAgentOptionsSchema = z.infer< + typeof ThinkAndPrepAgentOptionsSchema +>; +export type ThinkAndPrepStreamOptions = z.infer; + +export function createThinkAndPrepAgent( + thinkAndPrepAgentSchema: ThinkAndPrepAgentOptionsSchema, +) { + const steps: never[] = []; + + const systemMessage = { + role: "system", + content: getThinkAndPrepAgentSystemPrompt( + thinkAndPrepAgentSchema.sql_dialect_guidance, + ), + providerOptions: DEFAULT_CACHE_OPTIONS, + } as ModelMessage; + + async function stream({ messages }: ThinkAndPrepStreamOptions) { + return wrapTraced( + () => + streamText({ + model: Sonnet4, + tools: { + sequentialThinking, + executeSql, + respondWithoutAssetCreation, + submitThoughts, + messageUserClarifyingQuestion, + }, + messages: [systemMessage, ...messages], + stopWhen: STOP_CONDITIONS, + toolChoice: "required", + maxOutputTokens: 10000, + temperature: 0, + }), + { + name: "Think and Prep Agent", + }, + )(); + } + + async function getSteps() { + return steps; + } + + return { + stream, + getSteps, + }; +} diff --git a/packages/ai/src/steps/think-and-prep-step.ts b/packages/ai/src/steps/think-and-prep-step.ts index 8b8d112ce..1f6595910 100644 --- a/packages/ai/src/steps/think-and-prep-step.ts +++ b/packages/ai/src/steps/think-and-prep-step.ts @@ -6,8 +6,8 @@ import type { CoreMessage } from 'ai'; import { wrapTraced } from 'braintrust'; import { z } from 'zod'; import { getSqlDialectGuidance } from '../agents/shared/sql-dialect-guidance'; -import { thinkAndPrepAgent } from '../agents/think-and-prep-agent/think-and-prep-agent'; -import { createThinkAndPrepInstructionsWithoutDatasets } from '../agents/think-and-prep-agent/think-and-prep-instructions'; +import { createThinkAndPrepAgent } from '../agents/think-and-prep-agent/think-and-prep-agent'; +import { getThinkAndPrepAgentSystemPrompt } from '../agents/think-and-prep-agent/get-think-and-prep-agent-system-prompt'; import type { thinkAndPrepWorkflowInputSchema } from '../schemas/workflow-schemas'; import { ChunkProcessor } from '../utils/database/chunk-processor'; import { @@ -248,60 +248,48 @@ ${databaseContext} ), }); + // Create the agent instance + const thinkAndPrepAgent = createThinkAndPrepAgent({ + sql_dialect_guidance: sqlDialectGuidance, + }); + const wrappedStream = wrapTraced( async () => { - // Create system messages with dataset context and instructions - const systemMessages: CoreMessage[] = [ - { - role: 'system', - content: createThinkAndPrepInstructionsWithoutDatasets(sqlDialectGuidance), - providerOptions: DEFAULT_CACHE_OPTIONS, - }, - { - role: 'system', - content: createDatasetSystemMessage(assembledYmlContent), - providerOptions: DEFAULT_CACHE_OPTIONS, - }, - ]; + // Create dataset system message + const datasetSystemMessage: CoreMessage = { + role: 'system', + content: createDatasetSystemMessage(assembledYmlContent), + providerOptions: DEFAULT_CACHE_OPTIONS, + }; - // Combine system messages with conversation messages - const messagesWithSystem = [...systemMessages, ...messages]; + // Combine dataset system message with conversation messages + const messagesWithDataset = [datasetSystemMessage, ...messages]; - // Create stream directly without retryableAgentStreamWithHealing - const stream = await thinkAndPrepAgent.stream(messagesWithSystem, { - toolCallStreaming: true, - runtimeContext, - maxRetries: 5, - abortSignal: abortController.signal, - toolChoice: 'required', - onChunk: createOnChunkHandler({ - chunkProcessor, - abortController, - finishingToolNames: [ - 'submitThoughts', - 'respondWithoutAssetCreation', - 'messageUserClarifyingQuestion', - ], - onFinishingTool: () => { - // Set finished = true for respondWithoutAssetCreation and messageUserClarifyingQuestion - // submitThoughts should abort but not finish so workflow can continue - const finishingToolName = chunkProcessor.getFinishingToolName(); - if ( - finishingToolName === 'respondWithoutAssetCreation' || - finishingToolName === 'messageUserClarifyingQuestion' - ) { - finished = true; - } - }, - }), - onError: createRetryOnErrorHandler({ - retryCount, - maxRetries, - workflowContext: { - currentStep: 'think-and-prep', - availableTools, - }, - }), + // Create stream using the new agent pattern + const stream = await thinkAndPrepAgent.stream({ + messages: messagesWithDataset, + }); + + // Handle streaming with chunk processor + stream.onChunk = createOnChunkHandler({ + chunkProcessor, + abortController, + finishingToolNames: [ + 'submitThoughts', + 'respondWithoutAssetCreation', + 'messageUserClarifyingQuestion', + ], + onFinishingTool: () => { + // Set finished = true for respondWithoutAssetCreation and messageUserClarifyingQuestion + // submitThoughts should abort but not finish so workflow can continue + const finishingToolName = chunkProcessor.getFinishingToolName(); + if ( + finishingToolName === 'respondWithoutAssetCreation' || + finishingToolName === 'messageUserClarifyingQuestion' + ) { + finished = true; + } + }, }); return stream; diff --git a/packages/ai/src/utils/memory/ai-sdk-bundling.test.ts b/packages/ai/src/utils/memory/ai-sdk-bundling.test.ts deleted file mode 100644 index 7a23254c2..000000000 --- a/packages/ai/src/utils/memory/ai-sdk-bundling.test.ts +++ /dev/null @@ -1,279 +0,0 @@ -import type { CoreMessage } from 'ai'; -import { describe, expect, test } from 'vitest'; -import { validateArrayAccess } from '../validation-helpers'; -import { extractMessageHistory } from './message-history'; - -describe('AI SDK Message Bundling Issues', () => { - test('identify when AI SDK returns bundled messages', () => { - // The AI SDK tends to bundle multiple tool calls in a single assistant message - // when parallel tool calls are made, even with disableParallelToolCalls - const aiSdkResponse: CoreMessage[] = [ - { - role: 'user', - content: 'Analyze our customer data', - }, - { - role: 'assistant', - content: [ - { - type: 'tool-call', - toolCallId: 'call_ABC123', - toolName: 'sequentialThinking', - args: { thought: 'First, I need to understand the data structure' }, - }, - { - type: 'tool-call', - toolCallId: 'call_DEF456', - toolName: 'executeSql', - args: { statements: ['SELECT COUNT(*) FROM customers'] }, - }, - { - type: 'tool-call', - toolCallId: 'call_GHI789', - toolName: 'submitThoughts', - args: {}, - }, - ], - }, - { - role: 'tool', - content: [ - { - type: 'tool-result', - toolCallId: 'call_ABC123', - toolName: 'sequentialThinking', - result: { success: true }, - }, - ], - }, - { - role: 'tool', - content: [ - { - type: 'tool-result', - toolCallId: 'call_DEF456', - toolName: 'executeSql', - result: { results: [{ count: 100 }] }, - }, - ], - }, - { - role: 'tool', - content: [ - { - type: 'tool-result', - toolCallId: 'call_GHI789', - toolName: 'submitThoughts', - result: {}, - }, - ], - }, - ]; - - // Our extraction should fix this - const fixed = extractMessageHistory(aiSdkResponse); - - // Should be properly interleaved now - expect(fixed).toHaveLength(7); // user + 3*(assistant + tool) - - // Check the pattern - const msg0 = validateArrayAccess(fixed, 0, 'fixed messages'); - const msg1 = validateArrayAccess(fixed, 1, 'fixed messages'); - const msg2 = validateArrayAccess(fixed, 2, 'fixed messages'); - const msg3 = validateArrayAccess(fixed, 3, 'fixed messages'); - const msg4 = validateArrayAccess(fixed, 4, 'fixed messages'); - const msg5 = validateArrayAccess(fixed, 5, 'fixed messages'); - const msg6 = validateArrayAccess(fixed, 6, 'fixed messages'); - - expect(msg0.role).toBe('user'); - expect(msg1.role).toBe('assistant'); - if (msg1.role === 'assistant' && Array.isArray(msg1.content)) { - const content = validateArrayAccess(msg1.content, 0, 'assistant content'); - if ('toolCallId' in content) { - expect(content.toolCallId).toBe('call_ABC123'); - } - } - expect(msg2.role).toBe('tool'); - if (msg2.role === 'tool' && Array.isArray(msg2.content)) { - const content = validateArrayAccess(msg2.content, 0, 'tool content'); - if ('toolCallId' in content) { - expect(content.toolCallId).toBe('call_ABC123'); - } - } - expect(msg3.role).toBe('assistant'); - if (msg3.role === 'assistant' && Array.isArray(msg3.content)) { - const content = validateArrayAccess(msg3.content, 0, 'assistant content'); - if ('toolCallId' in content) { - expect(content.toolCallId).toBe('call_DEF456'); - } - } - expect(msg4.role).toBe('tool'); - if (msg4.role === 'tool' && Array.isArray(msg4.content)) { - const content = validateArrayAccess(msg4.content, 0, 'tool content'); - if ('toolCallId' in content) { - expect(content.toolCallId).toBe('call_DEF456'); - } - } - expect(msg5.role).toBe('assistant'); - if (msg5.role === 'assistant' && Array.isArray(msg5.content)) { - const content = validateArrayAccess(msg5.content, 0, 'assistant content'); - if ('toolCallId' in content) { - expect(content.toolCallId).toBe('call_GHI789'); - } - } - expect(msg6.role).toBe('tool'); - if (msg6.role === 'tool' && Array.isArray(msg6.content)) { - const content = validateArrayAccess(msg6.content, 0, 'tool content'); - if ('toolCallId' in content) { - expect(content.toolCallId).toBe('call_GHI789'); - } - } - }); - - test('handle case where AI SDK partially bundles messages', () => { - // Sometimes the AI SDK might bundle some calls but not others - const partiallyBundled: CoreMessage[] = [ - { - role: 'user', - content: 'Test', - }, - { - role: 'assistant', - content: [ - { - type: 'tool-call', - toolCallId: 'id1', - toolName: 'tool1', - args: {}, - }, - ], - }, - { - role: 'tool', - content: [ - { - type: 'tool-result', - toolCallId: 'id1', - toolName: 'tool1', - result: {}, - }, - ], - }, - { - role: 'assistant', - content: [ - { - type: 'tool-call', - toolCallId: 'id2', - toolName: 'tool2', - args: {}, - }, - { - type: 'tool-call', - toolCallId: 'id3', - toolName: 'tool3', - args: {}, - }, - ], - }, - { - role: 'tool', - content: [ - { - type: 'tool-result', - toolCallId: 'id2', - toolName: 'tool2', - result: {}, - }, - ], - }, - { - role: 'tool', - content: [ - { - type: 'tool-result', - toolCallId: 'id3', - toolName: 'tool3', - result: {}, - }, - ], - }, - ]; - - const fixed = extractMessageHistory(partiallyBundled); - - // Should fix only the bundled part - expect(fixed).toHaveLength(7); - - // First part should remain unchanged - const fixedMsg0 = validateArrayAccess(fixed, 0, 'fixed messages'); - const fixedMsg1 = validateArrayAccess(fixed, 1, 'fixed messages'); - const fixedMsg2 = validateArrayAccess(fixed, 2, 'fixed messages'); - const fixedMsg3 = validateArrayAccess(fixed, 3, 'fixed messages'); - const fixedMsg4 = validateArrayAccess(fixed, 4, 'fixed messages'); - const fixedMsg5 = validateArrayAccess(fixed, 5, 'fixed messages'); - const fixedMsg6 = validateArrayAccess(fixed, 6, 'fixed messages'); - - const partialMsg0 = validateArrayAccess(partiallyBundled, 0, 'partially bundled messages'); - const partialMsg1 = validateArrayAccess(partiallyBundled, 1, 'partially bundled messages'); - const partialMsg2 = validateArrayAccess(partiallyBundled, 2, 'partially bundled messages'); - - expect(fixedMsg0).toEqual(partialMsg0); - expect(fixedMsg1).toEqual(partialMsg1); - expect(fixedMsg2).toEqual(partialMsg2); - - // Second part should be unbundled - if (fixedMsg3.role === 'assistant' && Array.isArray(fixedMsg3.content)) { - const content = validateArrayAccess(fixedMsg3.content, 0, 'assistant content'); - if ('toolCallId' in content) { - expect(content.toolCallId).toBe('id2'); - } - } - if (fixedMsg4.role === 'tool' && Array.isArray(fixedMsg4.content)) { - const content = validateArrayAccess(fixedMsg4.content, 0, 'tool content'); - if ('toolCallId' in content) { - expect(content.toolCallId).toBe('id2'); - } - } - if (fixedMsg5.role === 'assistant' && Array.isArray(fixedMsg5.content)) { - const content = validateArrayAccess(fixedMsg5.content, 0, 'assistant content'); - if ('toolCallId' in content) { - expect(content.toolCallId).toBe('id3'); - } - } - if (fixedMsg6.role === 'tool' && Array.isArray(fixedMsg6.content)) { - const content = validateArrayAccess(fixedMsg6.content, 0, 'tool content'); - if ('toolCallId' in content) { - expect(content.toolCallId).toBe('id3'); - } - } - }); - - test('verify already correct messages pass through unchanged', () => { - const correctlyFormatted: CoreMessage[] = [ - { role: 'user', content: 'Test' }, - { - role: 'assistant', - content: [{ type: 'tool-call', toolCallId: 'id1', toolName: 'tool1', args: {} }], - }, - { - role: 'tool', - content: [{ type: 'tool-result', toolCallId: 'id1', toolName: 'tool1', result: {} }], - }, - { - role: 'assistant', - content: [{ type: 'tool-call', toolCallId: 'id2', toolName: 'tool2', args: {} }], - }, - { - role: 'tool', - content: [{ type: 'tool-result', toolCallId: 'id2', toolName: 'tool2', result: {} }], - }, - ]; - - const result = extractMessageHistory(correctlyFormatted); - - // Should be unchanged - expect(result).toEqual(correctlyFormatted); - expect(result).toHaveLength(5); - }); -});