optimize dataset caching

This commit is contained in:
dal 2025-07-21 10:55:50 -06:00
parent 4dd59a299d
commit 8bb4e4ad8e
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
6 changed files with 119 additions and 6 deletions

View File

@ -314,3 +314,13 @@ export const getAnalystInstructions = async ({
sqlDialectGuidance, sqlDialectGuidance,
}); });
}; };
// Export the template function without dataset context for use in step files
export const createAnalystInstructionsWithoutDatasets = (sqlDialectGuidance: string): string => {
return createAnalystInstructions({
databaseContext: '',
sqlDialectGuidance,
})
.replace(/<database_context>[\s\S]*?<\/database_context>/, '')
.trim();
};

View File

@ -23,7 +23,7 @@ const DEFAULT_OPTIONS = {
export const analystAgent = new Agent({ export const analystAgent = new Agent({
name: 'Analyst Agent', name: 'Analyst Agent',
instructions: getAnalystInstructions, instructions: '', // We control the system messages in the step at stream instantiation
model: anthropicCachedModel('claude-sonnet-4-20250514'), model: anthropicCachedModel('claude-sonnet-4-20250514'),
tools: { tools: {
createMetrics, createMetrics,

View File

@ -22,7 +22,7 @@ const DEFAULT_OPTIONS = {
export const thinkAndPrepAgent = new Agent({ export const thinkAndPrepAgent = new Agent({
name: 'Think and Prep Agent', name: 'Think and Prep Agent',
instructions: getThinkAndPrepInstructions, instructions: '', // We control the system messages in the step at stream instantiation
model: anthropicCachedModel('claude-sonnet-4-20250514'), model: anthropicCachedModel('claude-sonnet-4-20250514'),
tools: { tools: {
sequentialThinking, sequentialThinking,

View File

@ -567,3 +567,15 @@ export const getThinkAndPrepInstructions = async ({
sqlDialectGuidance, sqlDialectGuidance,
}); });
}; };
// Export the template function without dataset context for use in step files
export const createThinkAndPrepInstructionsWithoutDatasets = (
sqlDialectGuidance: string
): string => {
return createThinkAndPrepInstructions({
databaseContext: '',
sqlDialectGuidance,
})
.replace(/<database_context>[\s\S]*?<\/database_context>/, '')
.trim();
};

View File

@ -4,11 +4,14 @@ import type { CoreMessage } from 'ai';
import { wrapTraced } from 'braintrust'; import { wrapTraced } from 'braintrust';
import { z } from 'zod'; import { z } from 'zod';
import { getPermissionedDatasets } from '@buster/access-controls';
import type { import type {
ChatMessageReasoningMessage, ChatMessageReasoningMessage,
ChatMessageResponseMessage, ChatMessageResponseMessage,
} from '@buster/server-shared/chats'; } from '@buster/server-shared/chats';
import { analystAgent } from '../agents/analyst-agent/analyst-agent'; import { analystAgent } from '../agents/analyst-agent/analyst-agent';
import { createAnalystInstructionsWithoutDatasets } from '../agents/analyst-agent/analyst-agent-instructions';
import { getSqlDialectGuidance } from '../agents/shared/sql-dialect-guidance';
import { ChunkProcessor } from '../utils/database/chunk-processor'; import { ChunkProcessor } from '../utils/database/chunk-processor';
import { import {
MessageHistorySchema, MessageHistorySchema,
@ -56,6 +59,10 @@ const outputSchema = z.object({
finalReasoningMessage: z.string().optional(), finalReasoningMessage: z.string().optional(),
}); });
const DEFAULT_CACHE_OPTIONS = {
anthropic: { cacheControl: { type: 'ephemeral' } },
};
/** /**
* Transform reasoning/response history to match ChunkProcessor expected types * Transform reasoning/response history to match ChunkProcessor expected types
*/ */
@ -253,6 +260,28 @@ const analystExecution = async ({
let retryCount = 0; let retryCount = 0;
const maxRetries = 5; const maxRetries = 5;
// Get database context and SQL dialect guidance
const userId = runtimeContext.get('userId');
const dataSourceSyntax = runtimeContext.get('dataSourceSyntax');
const datasets = await getPermissionedDatasets(userId, 0, 1000);
// Extract yml_content from each dataset and join with separators
const assembledYmlContent = datasets
.map((dataset: { ymlFile: string | null | undefined }) => dataset.ymlFile)
.filter((content: string | null | undefined) => content !== null && content !== undefined)
.join('\n---\n');
// Get dialect-specific guidance
const sqlDialectGuidance = getSqlDialectGuidance(dataSourceSyntax);
// Create dataset system message
const createDatasetSystemMessage = (databaseContext: string): string => {
return `<database_context>
${databaseContext}
</database_context>`;
};
// Initialize chunk processor with histories from previous step // Initialize chunk processor with histories from previous step
// IMPORTANT: Pass histories from think-and-prep to accumulate across steps // IMPORTANT: Pass histories from think-and-prep to accumulate across steps
const { reasoningHistory: transformedReasoning, responseHistory: transformedResponse } = const { reasoningHistory: transformedReasoning, responseHistory: transformedResponse } =
@ -359,8 +388,25 @@ const analystExecution = async ({
const wrappedStream = wrapTraced( const wrappedStream = wrapTraced(
async () => { async () => {
// Create system messages with dataset context and instructions
const systemMessages: CoreMessage[] = [
{
role: 'system',
content: createDatasetSystemMessage(assembledYmlContent),
providerOptions: DEFAULT_CACHE_OPTIONS,
},
{
role: 'system',
content: createAnalystInstructionsWithoutDatasets(sqlDialectGuidance),
providerOptions: DEFAULT_CACHE_OPTIONS,
},
];
// Combine system messages with conversation messages
const messagesWithSystem = [...systemMessages, ...messages];
// Create stream directly without retryableAgentStreamWithHealing // Create stream directly without retryableAgentStreamWithHealing
const stream = await analystAgent.stream(messages, { const stream = await analystAgent.stream(messagesWithSystem, {
toolCallStreaming: true, toolCallStreaming: true,
runtimeContext, runtimeContext,
maxRetries: 5, maxRetries: 5,
@ -442,7 +488,7 @@ const analystExecution = async ({
continue; continue;
} }
// Update messages for the retry // Update messages for the retry (without system messages)
messages = healedMessages; messages = healedMessages;
// Update chunk processor with the healed messages // Update chunk processor with the healed messages

View File

@ -1,10 +1,13 @@
import { getPermissionedDatasets } from '@buster/access-controls';
import type { ChatMessageReasoningMessage } from '@buster/server-shared/chats'; import type { ChatMessageReasoningMessage } from '@buster/server-shared/chats';
import { createStep } from '@mastra/core'; import { createStep } from '@mastra/core';
import type { RuntimeContext } from '@mastra/core/runtime-context'; import type { RuntimeContext } from '@mastra/core/runtime-context';
import type { CoreMessage } from 'ai'; import type { CoreMessage } from 'ai';
import { wrapTraced } from 'braintrust'; import { wrapTraced } from 'braintrust';
import { z } from 'zod'; import { z } from 'zod';
import { getSqlDialectGuidance } from '../agents/shared/sql-dialect-guidance';
import { thinkAndPrepAgent } from '../agents/think-and-prep-agent/think-and-prep-agent'; import { thinkAndPrepAgent } from '../agents/think-and-prep-agent/think-and-prep-agent';
import { createThinkAndPrepInstructionsWithoutDatasets } from '../agents/think-and-prep-agent/think-and-prep-instructions';
import type { thinkAndPrepWorkflowInputSchema } from '../schemas/workflow-schemas'; import type { thinkAndPrepWorkflowInputSchema } from '../schemas/workflow-schemas';
import { ChunkProcessor } from '../utils/database/chunk-processor'; import { ChunkProcessor } from '../utils/database/chunk-processor';
import { import {
@ -60,6 +63,10 @@ type BusterChatMessageResponse = z.infer<typeof BusterChatMessageResponseSchema>
const outputSchema = ThinkAndPrepOutputSchema; const outputSchema = ThinkAndPrepOutputSchema;
const DEFAULT_CACHE_OPTIONS = {
anthropic: { cacheControl: { type: 'ephemeral' } },
};
// Helper function to create the result object // Helper function to create the result object
const createStepResult = ( const createStepResult = (
finished: boolean, finished: boolean,
@ -150,6 +157,27 @@ const thinkAndPrepExecution = async ({
); );
try { try {
// Get database context and SQL dialect guidance
const userId = runtimeContext.get('userId');
const dataSourceSyntax = runtimeContext.get('dataSourceSyntax');
const datasets = await getPermissionedDatasets(userId, 0, 1000);
// Extract yml_content from each dataset and join with separators
const assembledYmlContent = datasets
.map((dataset: { ymlFile: string | null | undefined }) => dataset.ymlFile)
.filter((content: string | null | undefined) => content !== null && content !== undefined)
.join('\n---\n');
// Get dialect-specific guidance
const sqlDialectGuidance = getSqlDialectGuidance(dataSourceSyntax);
// Create dataset system message
const createDatasetSystemMessage = (databaseContext: string): string => {
return `<database_context>
${databaseContext}
</database_context>`;
};
const todos = inputData['create-todos'].todos; const todos = inputData['create-todos'].todos;
// Standardize messages from workflow inputs // Standardize messages from workflow inputs
@ -223,8 +251,25 @@ const thinkAndPrepExecution = async ({
const wrappedStream = wrapTraced( const wrappedStream = wrapTraced(
async () => { async () => {
// Create system messages with dataset context and instructions
const systemMessages: CoreMessage[] = [
{
role: 'system',
content: createDatasetSystemMessage(assembledYmlContent),
providerOptions: DEFAULT_CACHE_OPTIONS,
},
{
role: 'system',
content: createThinkAndPrepInstructionsWithoutDatasets(sqlDialectGuidance),
providerOptions: DEFAULT_CACHE_OPTIONS,
},
];
// Combine system messages with conversation messages
const messagesWithSystem = [...systemMessages, ...messages];
// Create stream directly without retryableAgentStreamWithHealing // Create stream directly without retryableAgentStreamWithHealing
const stream = await thinkAndPrepAgent.stream(messages, { const stream = await thinkAndPrepAgent.stream(messagesWithSystem, {
toolCallStreaming: true, toolCallStreaming: true,
runtimeContext, runtimeContext,
maxRetries: 5, maxRetries: 5,
@ -316,7 +361,7 @@ const thinkAndPrepExecution = async ({
continue; continue;
} }
// Update messages for the retry // Update messages for the retry (without system messages)
messages = healedMessages; messages = healedMessages;
// Update chunk processor with the healed messages // Update chunk processor with the healed messages