migrating over to sdk v5

This commit is contained in:
dal 2025-08-05 09:40:05 -06:00
parent 5883fc8762
commit fcbe1838a1
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
5 changed files with 147 additions and 408 deletions

View File

@ -0,0 +1,30 @@
import { describe, it, expect } from 'vitest';
import { getThinkAndPrepAgentSystemPrompt } from './get-think-and-prep-agent-system-prompt';
describe('getThinkAndPrepAgentSystemPrompt', () => {
it('should return system prompt with SQL dialect guidance', () => {
const sqlDialectGuidance = 'PostgreSQL specific guidance';
const result = getThinkAndPrepAgentSystemPrompt(sqlDialectGuidance);
expect(result).toContain('You are Buster, a specialized AI agent');
expect(result).toContain('PostgreSQL specific guidance');
expect(result).toContain("Today's date is");
});
it('should include all necessary sections', () => {
const sqlDialectGuidance = 'MySQL specific guidance';
const result = getThinkAndPrepAgentSystemPrompt(sqlDialectGuidance);
// Check for key sections
expect(result).toContain('<intro>');
expect(result).toContain('<prep_mode_capability>');
expect(result).toContain('<event_stream>');
expect(result).toContain('<agent_loop>');
expect(result).toContain('<todo_list>');
expect(result).toContain('<todo_rules>');
expect(result).toContain('<tool_use_rules>');
expect(result).toContain('<sequential_thinking_rules>');
expect(result).toContain('<execute_sql_rules>');
expect(result).toContain('<sql_best_practices>');
});
});

View File

@ -1,16 +1,4 @@
import { getPermissionedDatasets } from '@buster/access-controls'; export const getThinkAndPrepAgentSystemPrompt = (sqlDialectGuidance: string): string => {
import type { RuntimeContext } from '@mastra/core/runtime-context';
import type { AnalystRuntimeContext } from '../../workflows/analyst-workflow';
import { getSqlDialectGuidance } from '../shared/sql-dialect-guidance';
// Define the required template parameters
interface ThinkAndPrepTemplateParams {
databaseContext: string;
sqlDialectGuidance: string;
}
// Template string as a function that requires parameters
const createThinkAndPrepInstructions = (params: ThinkAndPrepTemplateParams): string => {
return ` return `
You are Buster, a specialized AI agent within an AI-powered data analyst system. You are Buster, a specialized AI agent within an AI-powered data analyst system.
@ -398,7 +386,7 @@ Once all TODO list items are addressed and submitted for review, the system will
<sql_best_practices> <sql_best_practices>
- Current SQL Dialect Guidance: - Current SQL Dialect Guidance:
${params.sqlDialectGuidance} ${sqlDialectGuidance}
- Keep Queries Simple: Strive for simplicity and clarity in your SQL. Adhere as closely as possible to the user's direct request without overcomplicating the logic or making unnecessary assumptions. - Keep Queries Simple: Strive for simplicity and clarity in your SQL. Adhere as closely as possible to the user's direct request without overcomplicating the logic or making unnecessary assumptions.
- Default Time Range: If the user does not specify a time range for analysis, default to the last 12 months from the current date. Clearly state this assumption if making it. - Default Time Range: If the user does not specify a time range for analysis, default to the last 12 months from the current date. Clearly state this assumption if making it.
- Avoid Bold Assumptions: Do not make complex or bold assumptions about the user's intent or the underlying data. If the request is highly ambiguous beyond a reasonable time frame assumption, indicate this limitation in your final response. - Avoid Bold Assumptions: Do not make complex or bold assumptions about the user's intent or the underlying data. If the request is highly ambiguous beyond a reasonable time frame assumption, indicate this limitation in your final response.
@ -583,46 +571,6 @@ ${params.sqlDialectGuidance}
Start by using the \`sequentialThinking\` to immediately start checking off items on your TODO list Start by using the \`sequentialThinking\` to immediately start checking off items on your TODO list
Today's date is ${new Date().toLocaleDateString()}. Today's date is ${new Date().toLocaleDateString()}.
---
<database_context>
${params.databaseContext}
</database_context>
`; `;
}; };
export const getThinkAndPrepInstructions = async ({
runtimeContext,
}: { runtimeContext: RuntimeContext<AnalystRuntimeContext> }): Promise<string> => {
const userId = runtimeContext.get('userId');
const dataSourceSyntax = runtimeContext.get('dataSourceSyntax');
const datasets = await getPermissionedDatasets(userId, 0, 1000);
// Extract yml_content from each dataset and join with separators
const assembledYmlContent = datasets
.map((dataset: { ymlFile: string | null | undefined }) => dataset.ymlFile)
.filter((content: string | null | undefined) => content !== null && content !== undefined)
.join('\n---\n');
// Get dialect-specific guidance
const sqlDialectGuidance = getSqlDialectGuidance(dataSourceSyntax);
return createThinkAndPrepInstructions({
databaseContext: assembledYmlContent,
sqlDialectGuidance,
});
};
// Export the template function without dataset context for use in step files
export const createThinkAndPrepInstructionsWithoutDatasets = (
sqlDialectGuidance: string
): string => {
return createThinkAndPrepInstructions({
databaseContext: '',
sqlDialectGuidance,
})
.replace(/<database_context>[\s\S]*?<\/database_context>/, '')
.trim();
};

View File

@ -1,4 +1,6 @@
import { Agent } from '@mastra/core'; import { hasToolCall, type ModelMessage, stepCountIs, streamText } from "ai";
import { wrapTraced } from "braintrust";
import z from "zod";
import { import {
executeSql, executeSql,
messageUserClarifyingQuestion, messageUserClarifyingQuestion,
@ -7,29 +9,79 @@ import {
submitThoughts, submitThoughts,
} from '../../tools'; } from '../../tools';
import { Sonnet4 } from '../../utils/models/sonnet-4'; import { Sonnet4 } from '../../utils/models/sonnet-4';
import { getThinkAndPrepAgentSystemPrompt } from './get-think-and-prep-agent-system-prompt';
const DEFAULT_OPTIONS = { const DEFAULT_CACHE_OPTIONS = {
maxSteps: 18, anthropic: { cacheControl: { type: "ephemeral", ttl: "1h" } },
temperature: 0,
maxTokens: 10000,
providerOptions: {
anthropic: {
disableParallelToolCalls: true,
},
},
}; };
export const thinkAndPrepAgent = new Agent({ const STOP_CONDITIONS = [
name: 'Think and Prep Agent', stepCountIs(18),
instructions: '', // We control the system messages in the step at stream instantiation hasToolCall("submitThoughts"),
model: Sonnet4, hasToolCall("respondWithoutAssetCreation"),
tools: { hasToolCall("messageUserClarifyingQuestion")
sequentialThinking, ];
executeSql,
respondWithoutAssetCreation, const ThinkAndPrepAgentOptionsSchema = z.object({
submitThoughts, sql_dialect_guidance: z
messageUserClarifyingQuestion, .string()
}, .describe("The SQL dialect guidance for the think and prep agent."),
defaultGenerateOptions: DEFAULT_OPTIONS,
defaultStreamOptions: DEFAULT_OPTIONS,
}); });
const ThinkAndPrepStreamOptionsSchema = z.object({
messages: z
.array(z.custom<ModelMessage>())
.describe("The messages to send to the think and prep agent."),
});
export type ThinkAndPrepAgentOptionsSchema = z.infer<
typeof ThinkAndPrepAgentOptionsSchema
>;
export type ThinkAndPrepStreamOptions = z.infer<typeof ThinkAndPrepStreamOptionsSchema>;
export function createThinkAndPrepAgent(
thinkAndPrepAgentSchema: ThinkAndPrepAgentOptionsSchema,
) {
const steps: never[] = [];
const systemMessage = {
role: "system",
content: getThinkAndPrepAgentSystemPrompt(
thinkAndPrepAgentSchema.sql_dialect_guidance,
),
providerOptions: DEFAULT_CACHE_OPTIONS,
} as ModelMessage;
async function stream({ messages }: ThinkAndPrepStreamOptions) {
return wrapTraced(
() =>
streamText({
model: Sonnet4,
tools: {
sequentialThinking,
executeSql,
respondWithoutAssetCreation,
submitThoughts,
messageUserClarifyingQuestion,
},
messages: [systemMessage, ...messages],
stopWhen: STOP_CONDITIONS,
toolChoice: "required",
maxOutputTokens: 10000,
temperature: 0,
}),
{
name: "Think and Prep Agent",
},
)();
}
async function getSteps() {
return steps;
}
return {
stream,
getSteps,
};
}

View File

@ -6,8 +6,8 @@ import type { CoreMessage } from 'ai';
import { wrapTraced } from 'braintrust'; import { wrapTraced } from 'braintrust';
import { z } from 'zod'; import { z } from 'zod';
import { getSqlDialectGuidance } from '../agents/shared/sql-dialect-guidance'; import { getSqlDialectGuidance } from '../agents/shared/sql-dialect-guidance';
import { thinkAndPrepAgent } from '../agents/think-and-prep-agent/think-and-prep-agent'; import { createThinkAndPrepAgent } from '../agents/think-and-prep-agent/think-and-prep-agent';
import { createThinkAndPrepInstructionsWithoutDatasets } from '../agents/think-and-prep-agent/think-and-prep-instructions'; import { getThinkAndPrepAgentSystemPrompt } from '../agents/think-and-prep-agent/get-think-and-prep-agent-system-prompt';
import type { thinkAndPrepWorkflowInputSchema } from '../schemas/workflow-schemas'; import type { thinkAndPrepWorkflowInputSchema } from '../schemas/workflow-schemas';
import { ChunkProcessor } from '../utils/database/chunk-processor'; import { ChunkProcessor } from '../utils/database/chunk-processor';
import { import {
@ -248,60 +248,48 @@ ${databaseContext}
), ),
}); });
// Create the agent instance
const thinkAndPrepAgent = createThinkAndPrepAgent({
sql_dialect_guidance: sqlDialectGuidance,
});
const wrappedStream = wrapTraced( const wrappedStream = wrapTraced(
async () => { async () => {
// Create system messages with dataset context and instructions // Create dataset system message
const systemMessages: CoreMessage[] = [ const datasetSystemMessage: CoreMessage = {
{ role: 'system',
role: 'system', content: createDatasetSystemMessage(assembledYmlContent),
content: createThinkAndPrepInstructionsWithoutDatasets(sqlDialectGuidance), providerOptions: DEFAULT_CACHE_OPTIONS,
providerOptions: DEFAULT_CACHE_OPTIONS, };
},
{
role: 'system',
content: createDatasetSystemMessage(assembledYmlContent),
providerOptions: DEFAULT_CACHE_OPTIONS,
},
];
// Combine system messages with conversation messages // Combine dataset system message with conversation messages
const messagesWithSystem = [...systemMessages, ...messages]; const messagesWithDataset = [datasetSystemMessage, ...messages];
// Create stream directly without retryableAgentStreamWithHealing // Create stream using the new agent pattern
const stream = await thinkAndPrepAgent.stream(messagesWithSystem, { const stream = await thinkAndPrepAgent.stream({
toolCallStreaming: true, messages: messagesWithDataset,
runtimeContext, });
maxRetries: 5,
abortSignal: abortController.signal, // Handle streaming with chunk processor
toolChoice: 'required', stream.onChunk = createOnChunkHandler({
onChunk: createOnChunkHandler({ chunkProcessor,
chunkProcessor, abortController,
abortController, finishingToolNames: [
finishingToolNames: [ 'submitThoughts',
'submitThoughts', 'respondWithoutAssetCreation',
'respondWithoutAssetCreation', 'messageUserClarifyingQuestion',
'messageUserClarifyingQuestion', ],
], onFinishingTool: () => {
onFinishingTool: () => { // Set finished = true for respondWithoutAssetCreation and messageUserClarifyingQuestion
// Set finished = true for respondWithoutAssetCreation and messageUserClarifyingQuestion // submitThoughts should abort but not finish so workflow can continue
// submitThoughts should abort but not finish so workflow can continue const finishingToolName = chunkProcessor.getFinishingToolName();
const finishingToolName = chunkProcessor.getFinishingToolName(); if (
if ( finishingToolName === 'respondWithoutAssetCreation' ||
finishingToolName === 'respondWithoutAssetCreation' || finishingToolName === 'messageUserClarifyingQuestion'
finishingToolName === 'messageUserClarifyingQuestion' ) {
) { finished = true;
finished = true; }
} },
},
}),
onError: createRetryOnErrorHandler({
retryCount,
maxRetries,
workflowContext: {
currentStep: 'think-and-prep',
availableTools,
},
}),
}); });
return stream; return stream;

View File

@ -1,279 +0,0 @@
import type { CoreMessage } from 'ai';
import { describe, expect, test } from 'vitest';
import { validateArrayAccess } from '../validation-helpers';
import { extractMessageHistory } from './message-history';
describe('AI SDK Message Bundling Issues', () => {
test('identify when AI SDK returns bundled messages', () => {
// The AI SDK tends to bundle multiple tool calls in a single assistant message
// when parallel tool calls are made, even with disableParallelToolCalls
const aiSdkResponse: CoreMessage[] = [
{
role: 'user',
content: 'Analyze our customer data',
},
{
role: 'assistant',
content: [
{
type: 'tool-call',
toolCallId: 'call_ABC123',
toolName: 'sequentialThinking',
args: { thought: 'First, I need to understand the data structure' },
},
{
type: 'tool-call',
toolCallId: 'call_DEF456',
toolName: 'executeSql',
args: { statements: ['SELECT COUNT(*) FROM customers'] },
},
{
type: 'tool-call',
toolCallId: 'call_GHI789',
toolName: 'submitThoughts',
args: {},
},
],
},
{
role: 'tool',
content: [
{
type: 'tool-result',
toolCallId: 'call_ABC123',
toolName: 'sequentialThinking',
result: { success: true },
},
],
},
{
role: 'tool',
content: [
{
type: 'tool-result',
toolCallId: 'call_DEF456',
toolName: 'executeSql',
result: { results: [{ count: 100 }] },
},
],
},
{
role: 'tool',
content: [
{
type: 'tool-result',
toolCallId: 'call_GHI789',
toolName: 'submitThoughts',
result: {},
},
],
},
];
// Our extraction should fix this
const fixed = extractMessageHistory(aiSdkResponse);
// Should be properly interleaved now
expect(fixed).toHaveLength(7); // user + 3*(assistant + tool)
// Check the pattern
const msg0 = validateArrayAccess(fixed, 0, 'fixed messages');
const msg1 = validateArrayAccess(fixed, 1, 'fixed messages');
const msg2 = validateArrayAccess(fixed, 2, 'fixed messages');
const msg3 = validateArrayAccess(fixed, 3, 'fixed messages');
const msg4 = validateArrayAccess(fixed, 4, 'fixed messages');
const msg5 = validateArrayAccess(fixed, 5, 'fixed messages');
const msg6 = validateArrayAccess(fixed, 6, 'fixed messages');
expect(msg0.role).toBe('user');
expect(msg1.role).toBe('assistant');
if (msg1.role === 'assistant' && Array.isArray(msg1.content)) {
const content = validateArrayAccess(msg1.content, 0, 'assistant content');
if ('toolCallId' in content) {
expect(content.toolCallId).toBe('call_ABC123');
}
}
expect(msg2.role).toBe('tool');
if (msg2.role === 'tool' && Array.isArray(msg2.content)) {
const content = validateArrayAccess(msg2.content, 0, 'tool content');
if ('toolCallId' in content) {
expect(content.toolCallId).toBe('call_ABC123');
}
}
expect(msg3.role).toBe('assistant');
if (msg3.role === 'assistant' && Array.isArray(msg3.content)) {
const content = validateArrayAccess(msg3.content, 0, 'assistant content');
if ('toolCallId' in content) {
expect(content.toolCallId).toBe('call_DEF456');
}
}
expect(msg4.role).toBe('tool');
if (msg4.role === 'tool' && Array.isArray(msg4.content)) {
const content = validateArrayAccess(msg4.content, 0, 'tool content');
if ('toolCallId' in content) {
expect(content.toolCallId).toBe('call_DEF456');
}
}
expect(msg5.role).toBe('assistant');
if (msg5.role === 'assistant' && Array.isArray(msg5.content)) {
const content = validateArrayAccess(msg5.content, 0, 'assistant content');
if ('toolCallId' in content) {
expect(content.toolCallId).toBe('call_GHI789');
}
}
expect(msg6.role).toBe('tool');
if (msg6.role === 'tool' && Array.isArray(msg6.content)) {
const content = validateArrayAccess(msg6.content, 0, 'tool content');
if ('toolCallId' in content) {
expect(content.toolCallId).toBe('call_GHI789');
}
}
});
test('handle case where AI SDK partially bundles messages', () => {
// Sometimes the AI SDK might bundle some calls but not others
const partiallyBundled: CoreMessage[] = [
{
role: 'user',
content: 'Test',
},
{
role: 'assistant',
content: [
{
type: 'tool-call',
toolCallId: 'id1',
toolName: 'tool1',
args: {},
},
],
},
{
role: 'tool',
content: [
{
type: 'tool-result',
toolCallId: 'id1',
toolName: 'tool1',
result: {},
},
],
},
{
role: 'assistant',
content: [
{
type: 'tool-call',
toolCallId: 'id2',
toolName: 'tool2',
args: {},
},
{
type: 'tool-call',
toolCallId: 'id3',
toolName: 'tool3',
args: {},
},
],
},
{
role: 'tool',
content: [
{
type: 'tool-result',
toolCallId: 'id2',
toolName: 'tool2',
result: {},
},
],
},
{
role: 'tool',
content: [
{
type: 'tool-result',
toolCallId: 'id3',
toolName: 'tool3',
result: {},
},
],
},
];
const fixed = extractMessageHistory(partiallyBundled);
// Should fix only the bundled part
expect(fixed).toHaveLength(7);
// First part should remain unchanged
const fixedMsg0 = validateArrayAccess(fixed, 0, 'fixed messages');
const fixedMsg1 = validateArrayAccess(fixed, 1, 'fixed messages');
const fixedMsg2 = validateArrayAccess(fixed, 2, 'fixed messages');
const fixedMsg3 = validateArrayAccess(fixed, 3, 'fixed messages');
const fixedMsg4 = validateArrayAccess(fixed, 4, 'fixed messages');
const fixedMsg5 = validateArrayAccess(fixed, 5, 'fixed messages');
const fixedMsg6 = validateArrayAccess(fixed, 6, 'fixed messages');
const partialMsg0 = validateArrayAccess(partiallyBundled, 0, 'partially bundled messages');
const partialMsg1 = validateArrayAccess(partiallyBundled, 1, 'partially bundled messages');
const partialMsg2 = validateArrayAccess(partiallyBundled, 2, 'partially bundled messages');
expect(fixedMsg0).toEqual(partialMsg0);
expect(fixedMsg1).toEqual(partialMsg1);
expect(fixedMsg2).toEqual(partialMsg2);
// Second part should be unbundled
if (fixedMsg3.role === 'assistant' && Array.isArray(fixedMsg3.content)) {
const content = validateArrayAccess(fixedMsg3.content, 0, 'assistant content');
if ('toolCallId' in content) {
expect(content.toolCallId).toBe('id2');
}
}
if (fixedMsg4.role === 'tool' && Array.isArray(fixedMsg4.content)) {
const content = validateArrayAccess(fixedMsg4.content, 0, 'tool content');
if ('toolCallId' in content) {
expect(content.toolCallId).toBe('id2');
}
}
if (fixedMsg5.role === 'assistant' && Array.isArray(fixedMsg5.content)) {
const content = validateArrayAccess(fixedMsg5.content, 0, 'assistant content');
if ('toolCallId' in content) {
expect(content.toolCallId).toBe('id3');
}
}
if (fixedMsg6.role === 'tool' && Array.isArray(fixedMsg6.content)) {
const content = validateArrayAccess(fixedMsg6.content, 0, 'tool content');
if ('toolCallId' in content) {
expect(content.toolCallId).toBe('id3');
}
}
});
test('verify already correct messages pass through unchanged', () => {
const correctlyFormatted: CoreMessage[] = [
{ role: 'user', content: 'Test' },
{
role: 'assistant',
content: [{ type: 'tool-call', toolCallId: 'id1', toolName: 'tool1', args: {} }],
},
{
role: 'tool',
content: [{ type: 'tool-result', toolCallId: 'id1', toolName: 'tool1', result: {} }],
},
{
role: 'assistant',
content: [{ type: 'tool-call', toolCallId: 'id2', toolName: 'tool2', args: {} }],
},
{
role: 'tool',
content: [{ type: 'tool-result', toolCallId: 'id2', toolName: 'tool2', result: {} }],
},
];
const result = extractMessageHistory(correctlyFormatted);
// Should be unchanged
expect(result).toEqual(correctlyFormatted);
expect(result).toHaveLength(5);
});
});