Refactor Docs Agent to include additional context parameters and streamline SQL execution. Introduce user, chat, data source, organization, and optional message IDs in the DocsAgentOptions schema. Update createDocsAgent function to utilize the new context structure. Remove obsolete parseStreamingArgs function from execute-sql-docs-agent as AI SDK v5 handles streaming parsing internally.

This commit is contained in:
dal 2025-08-06 17:10:25 -06:00
parent 04ae594d3a
commit 218bdf8eb3
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
6 changed files with 60 additions and 315 deletions

View File

@ -17,6 +17,7 @@ import {
webSearch,
} from '../../tools';
import { Sonnet4 } from '../../utils/models/sonnet-4';
import { healToolWithLlm } from '../../utils/tool-call-repair';
import { getDocsAgentSystemPrompt } from './get-docs-agent-system-prompt';
const DEFAULT_CACHE_OPTIONS = {
@ -27,6 +28,11 @@ const STOP_CONDITIONS = [stepCountIs(50), hasToolCall('idleTool')];
const DocsAgentOptionsSchema = z.object({
folder_structure: z.string().describe('The file structure of the dbt repository'),
userId: z.string(),
chatId: z.string(),
dataSourceId: z.string(),
organizationId: z.string(),
messageId: z.string().optional(),
});
const DocsStreamOptionsSchema = z.object({
@ -36,12 +42,12 @@ const DocsStreamOptionsSchema = z.object({
export type DocsAgentOptions = z.infer<typeof DocsAgentOptionsSchema>;
export type DocsStreamOptions = z.infer<typeof DocsStreamOptionsSchema>;
export function createDocsAgent(options: DocsAgentOptions) {
export function createDocsAgent(docsAgentOptions: DocsAgentOptions) {
const steps: never[] = [];
const systemMessage = {
role: 'system',
content: getDocsAgentSystemPrompt(options.folder_structure),
content: getDocsAgentSystemPrompt(docsAgentOptions.folder_structure),
providerOptions: DEFAULT_CACHE_OPTIONS,
} as ModelMessage;
@ -70,6 +76,8 @@ export function createDocsAgent(options: DocsAgentOptions) {
toolChoice: 'required',
maxOutputTokens: 10000,
temperature: 0,
experimental_context: docsAgentOptions,
experimental_repairToolCall: healToolWithLlm,
}),
{
name: 'Docs Agent',

View File

@ -93,9 +93,14 @@ const docsAgentExecution = async ({
}
}
// Create the docs agent with folder structure
// Create the docs agent with folder structure and context
const docsAgent = createDocsAgent({
folder_structure: inputData.repositoryTree,
userId: organizationId, // Using organizationId as userId for now
chatId: workflowStartTime?.toString() || 'unknown', // Using workflowStartTime as chatId
dataSourceId: dataSourceId || '',
organizationId: organizationId || '',
messageId: undefined, // Optional field
});
const userMessage = `${inputData.message}`;

View File

@ -1,9 +1,8 @@
import { type DataSource, withRateLimit } from '@buster/data-source';
import type { RuntimeContext } from '@mastra/core/runtime-context';
import { createTool } from '@mastra/core/tools';
import { tool } from 'ai';
import { wrapTraced } from 'braintrust';
import { z } from 'zod';
import type { DocsAgentContext } from '../../agents/docs-agent/docs-agent-context';
import type { DocsAgentOptions } from '../../agents/docs-agent/docs-agent';
import { getWorkflowDataSourceManager } from '../../utils/data-source-manager';
import { checkQueryIsReadOnly } from '../../utils/sql-permissions/sql-parser-helpers';
@ -25,6 +24,12 @@ const executeSqlDocsAgentInputSchema = z.object({
),
});
const executeSqlDocsAgentContextSchema = z.object({
dataSourceId: z.string().describe('ID of the data source to execute SQL against'),
});
type ExecuteSqlDocsAgentContext = z.infer<typeof executeSqlDocsAgentContextSchema>;
/**
* Processes a single column value for truncation
*/
@ -71,81 +76,8 @@ function truncateQueryResults(
});
}
/**
* Optimistic parsing function for streaming execute-sql-docs-agent tool arguments
* Extracts the statements array as it's being built incrementally
*/
export function parseStreamingArgs(
accumulatedText: string
): Partial<z.infer<typeof executeSqlDocsAgentInputSchema>> | null {
// Validate input type
if (typeof accumulatedText !== 'string') {
throw new Error(`parseStreamingArgs expects string input, got ${typeof accumulatedText}`);
}
try {
// First try to parse as complete JSON
const parsed = JSON.parse(accumulatedText);
// Ensure statements is an array if present
if (parsed.statements !== undefined && !Array.isArray(parsed.statements)) {
console.warn('[execute-sql-docs-agent parseStreamingArgs] statements is not an array:', {
type: typeof parsed.statements,
value: parsed.statements,
});
return null; // Return null to indicate invalid parse
}
return {
statements: parsed.statements || undefined,
};
} catch (error) {
// Only catch JSON parse errors - let other errors bubble up
if (error instanceof SyntaxError) {
// JSON parsing failed - try regex extraction for partial content
// If JSON is incomplete, try to extract and reconstruct the statements array
const statementsMatch = accumulatedText.match(/"statements"\s*:\s*\[(.*)/s);
if (statementsMatch && statementsMatch[1] !== undefined) {
const arrayContent = statementsMatch[1];
try {
// Try to parse the array content by adding closing bracket
const testArray = `[${arrayContent}]`;
const parsed = JSON.parse(testArray);
return { statements: parsed };
} catch {
// If that fails, try to extract individual statement strings that are complete
const statements: string[] = [];
// Match complete string statements within the array
const statementMatches = arrayContent.matchAll(/"((?:[^"\\]|\\.)*)"/g);
for (const match of statementMatches) {
if (match[1] !== undefined) {
const statement = match[1].replace(/\\"/g, '"').replace(/\\\\/g, '\\');
statements.push(statement);
}
}
return { statements };
}
}
// Check if we at least have the start of the statements field
const partialMatch = accumulatedText.match(/"statements"\s*:\s*\[/);
if (partialMatch) {
return { statements: [] };
}
return null;
}
// Unexpected error - re-throw with context
throw new Error(
`Unexpected error in parseStreamingArgs: ${error instanceof Error ? error.message : 'Unknown error'}`
);
}
}
// Remove parseStreamingArgs as it's no longer needed with AI SDK v5
// The SDK handles streaming parsing internally
const executeSqlDocsAgentOutputSchema = z.object({
results: z.array(
@ -167,7 +99,7 @@ const executeSqlDocsAgentOutputSchema = z.object({
const executeSqlDocsAgentStatement = wrapTraced(
async (
params: z.infer<typeof executeSqlDocsAgentInputSchema>,
runtimeContext: RuntimeContext<DocsAgentContext>
context: ExecuteSqlDocsAgentContext
): Promise<z.infer<typeof executeSqlDocsAgentOutputSchema>> => {
let { statements } = params;
@ -273,7 +205,7 @@ const executeSqlDocsAgentStatement = wrapTraced(
};
}
const dataSourceId = runtimeContext.get('dataSourceId');
const dataSourceId = context.dataSourceId;
// Get data source from workflow manager (reuses existing connections)
const manager = getWorkflowDataSourceManager(dataSourceId);
@ -455,9 +387,8 @@ async function executeSingleStatement(
};
}
// Export the tool
export const executeSqlDocsAgent = createTool({
id: 'execute-sql-docs-agent',
// Export the tool using AI SDK v5
export const executeSqlDocsAgent = tool({
description: `Use this to run lightweight validation and metadata queries for documentation purposes.
This tool is specifically for the docs agent to gather metadata, validate assumptions, and collect context.
Please limit your queries to 100 rows for performance.
@ -467,14 +398,14 @@ export const executeSqlDocsAgent = createTool({
referential integrity checks, and match percentage calculations.`,
inputSchema: executeSqlDocsAgentInputSchema,
outputSchema: executeSqlDocsAgentOutputSchema,
execute: async ({
context,
runtimeContext,
}: {
context: z.infer<typeof executeSqlDocsAgentInputSchema>;
runtimeContext: RuntimeContext<DocsAgentContext>;
}) => {
return await executeSqlDocsAgentStatement(context, runtimeContext);
execute: async (input, { experimental_context: context }) => {
const rawContext = context as DocsAgentOptions;
const executeSqlDocsAgentContext = executeSqlDocsAgentContextSchema.parse({
dataSourceId: rawContext.dataSourceId,
});
return await executeSqlDocsAgentStatement(input, executeSqlDocsAgentContext);
},
});

View File

@ -1,207 +0,0 @@
import { describe, expect, test } from 'vitest';
import { validateArrayAccess } from '../../utils/validation-helpers';
import { parseStreamingArgs } from './execute-sql';
describe('Execute SQL Tool Streaming Parser', () => {
test('should return null for empty or invalid input', () => {
expect(parseStreamingArgs('')).toBeNull();
expect(parseStreamingArgs('{')).toBeNull();
expect(parseStreamingArgs('invalid json')).toBeNull();
expect(parseStreamingArgs('{"other_field":')).toBeNull();
});
test('should parse complete JSON with statements array', () => {
const completeJson = JSON.stringify({
statements: [
'SELECT user_id, name FROM public.users LIMIT 25',
"SELECT COUNT(*) FROM public.orders WHERE created_at >= '2024-01-01'",
],
});
const result = parseStreamingArgs(completeJson);
expect(result).toEqual({
statements: [
'SELECT user_id, name FROM public.users LIMIT 25',
"SELECT COUNT(*) FROM public.orders WHERE created_at >= '2024-01-01'",
],
});
});
test('should extract partial statements array as it builds incrementally', () => {
// Simulate the streaming chunks building up a statements array
const chunks = [
'{"statements"',
'{"statements":',
'{"statements": [',
'{"statements": ["',
'{"statements": ["SELECT',
'{"statements": ["SELECT user_id',
'{"statements": ["SELECT user_id, name',
'{"statements": ["SELECT user_id, name FROM',
'{"statements": ["SELECT user_id, name FROM public.users"',
'{"statements": ["SELECT user_id, name FROM public.users",',
'{"statements": ["SELECT user_id, name FROM public.users", "',
'{"statements": ["SELECT user_id, name FROM public.users", "SELECT COUNT(*)"',
'{"statements": ["SELECT user_id, name FROM public.users", "SELECT COUNT(*)"]}',
];
// Test incremental building
expect(parseStreamingArgs(validateArrayAccess(chunks, 0, 'test chunks'))).toBeNull(); // No colon yet
expect(parseStreamingArgs(validateArrayAccess(chunks, 1, 'test chunks'))).toBeNull(); // No array start yet
expect(parseStreamingArgs(validateArrayAccess(chunks, 2, 'test chunks'))).toEqual({
statements: [],
}); // Empty array detected
expect(parseStreamingArgs(validateArrayAccess(chunks, 3, 'test chunks'))).toEqual({
statements: [],
}); // Incomplete string
expect(parseStreamingArgs(validateArrayAccess(chunks, 4, 'test chunks'))).toEqual({
statements: [],
}); // Still incomplete
expect(parseStreamingArgs(validateArrayAccess(chunks, 5, 'test chunks'))).toEqual({
statements: [],
}); // Still incomplete
expect(parseStreamingArgs(validateArrayAccess(chunks, 6, 'test chunks'))).toEqual({
statements: [],
}); // Still incomplete
expect(parseStreamingArgs(validateArrayAccess(chunks, 7, 'test chunks'))).toEqual({
statements: [],
}); // Still incomplete
expect(parseStreamingArgs(validateArrayAccess(chunks, 8, 'test chunks'))).toEqual({
statements: ['SELECT user_id, name FROM public.users'],
}); // First complete statement
expect(parseStreamingArgs(validateArrayAccess(chunks, 9, 'test chunks'))).toEqual({
statements: ['SELECT user_id, name FROM public.users'],
}); // Comma added
expect(parseStreamingArgs(validateArrayAccess(chunks, 10, 'test chunks'))).toEqual({
statements: ['SELECT user_id, name FROM public.users'],
}); // Second statement starting
expect(parseStreamingArgs(validateArrayAccess(chunks, 11, 'test chunks'))).toEqual({
statements: ['SELECT user_id, name FROM public.users', 'SELECT COUNT(*)'],
}); // Second statement complete
// Final complete chunk should be parsed as complete JSON
const finalResult = parseStreamingArgs(validateArrayAccess(chunks, 12, 'test chunks'));
expect(finalResult).toEqual({
statements: ['SELECT user_id, name FROM public.users', 'SELECT COUNT(*)'],
});
});
test('should handle single statement', () => {
const singleStatement = '{"statements": ["SELECT * FROM public.users"]}';
const result = parseStreamingArgs(singleStatement);
expect(result).toEqual({
statements: ['SELECT * FROM public.users'],
});
});
test('should handle escaped quotes in SQL statements', () => {
const withEscapedQuotes =
'{"statements": ["SELECT name FROM users WHERE status = \\"active\\""]}';
const result = parseStreamingArgs(withEscapedQuotes);
expect(result).toEqual({
statements: ['SELECT name FROM users WHERE status = "active"'],
});
});
test('should handle complex SQL with newlines and special characters', () => {
const complexSql = JSON.stringify({
statements: [
'SELECT \n u.user_id,\n u.name,\n COUNT(o.order_id) as order_count\nFROM public.users u\nLEFT JOIN public.orders o ON u.user_id = o.user_id\nGROUP BY u.user_id, u.name',
],
});
const result = parseStreamingArgs(complexSql);
expect(result).toEqual({
statements: [
'SELECT \n u.user_id,\n u.name,\n COUNT(o.order_id) as order_count\nFROM public.users u\nLEFT JOIN public.orders o ON u.user_id = o.user_id\nGROUP BY u.user_id, u.name',
],
});
});
test('should handle multiple statements being built incrementally', () => {
const partialMultiple = '{"statements": ["SELECT user_id FROM users", "SELECT';
const result = parseStreamingArgs(partialMultiple);
// Should extract the complete first statement only
expect(result).toEqual({
statements: ['SELECT user_id FROM users'],
});
});
test('should handle whitespace variations', () => {
const withWhitespace = '{ "statements" : [ "SELECT * FROM table" , "SELECT COUNT(*)" ]';
const result = parseStreamingArgs(withWhitespace);
expect(result).toEqual({
statements: ['SELECT * FROM table', 'SELECT COUNT(*)'],
});
});
test('should handle empty statements array', () => {
const emptyArray = '{"statements": []}';
const result = parseStreamingArgs(emptyArray);
expect(result).toEqual({
statements: [],
});
});
test('should handle statements with date literals and special characters', () => {
const withDates = JSON.stringify({
statements: [
"SELECT * FROM orders WHERE created_at >= '2024-01-01'",
'SELECT COUNT(*) FROM products WHERE price > 100.50',
],
});
const result = parseStreamingArgs(withDates);
expect(result).toEqual({
statements: [
"SELECT * FROM orders WHERE created_at >= '2024-01-01'",
'SELECT COUNT(*) FROM products WHERE price > 100.50',
],
});
});
test('should return undefined for statements if field is not present', () => {
const withoutStatements = '{"other_field": "value"}';
const result = parseStreamingArgs(withoutStatements);
expect(result).toEqual({
statements: undefined,
});
});
test('should handle incomplete array with partial second statement', () => {
const partialSecond = '{"statements": ["SELECT user_id FROM users", "SELECT COUNT(*) FROM';
const result = parseStreamingArgs(partialSecond);
// Should only return the complete first statement
expect(result).toEqual({
statements: ['SELECT user_id FROM users'],
});
});
test('should handle statements with schema qualifiers', () => {
const withSchema = JSON.stringify({
statements: [
'SELECT analytics.users.user_id FROM analytics.users',
'SELECT public.orders.order_id FROM public.orders WHERE public.orders.total > 100',
],
});
const result = parseStreamingArgs(withSchema);
expect(result).toEqual({
statements: [
'SELECT analytics.users.user_id FROM analytics.users',
'SELECT public.orders.order_id FROM public.orders WHERE public.orders.total > 100',
],
});
});
});

View File

@ -5,7 +5,6 @@ import { z } from 'zod';
import type { AnalystAgentOptions } from '../../agents/analyst-agent/analyst-agent';
import { getWorkflowDataSourceManager } from '../../utils/data-source-manager';
import { createPermissionErrorMessage, validateSqlPermissions } from '../../utils/sql-permissions';
import type { AnalystRuntimeContext } from '../../workflows/analyst-workflow';
const executeSqlStatementInputSchema = z.object({
statements: z.array(z.string()).describe(
@ -20,6 +19,14 @@ const executeSqlStatementInputSchema = z.object({
),
});
const executeSqlContextSchema = z.object({
dataSourceId: z.string().describe('ID of the data source to execute SQL against'),
userId: z.string().describe('ID of the user executing the SQL'),
dataSourceSyntax: z.string().describe('SQL syntax variant for the data source'),
});
type ExecuteSqlContext = z.infer<typeof executeSqlContextSchema>;
/**
* Processes a single column value for truncation
*/
@ -89,7 +96,7 @@ const executeSqlStatementOutputSchema = z.object({
const executeSqlStatement = wrapTraced(
async (
params: z.infer<typeof executeSqlStatementInputSchema>,
context: AnalystAgentOptions
context: ExecuteSqlContext
): Promise<z.infer<typeof executeSqlStatementOutputSchema>> => {
let { statements } = params;
@ -271,7 +278,7 @@ const executeSqlStatement = wrapTraced(
async function executeSingleStatement(
sqlStatement: string,
dataSource: DataSource,
runtimeContext: AnalystAgentOptions
context: ExecuteSqlContext
): Promise<{
success: boolean;
data?: Record<string, unknown>[];
@ -287,12 +294,12 @@ async function executeSingleStatement(
}
// Validate permissions before execution
const userId = runtimeContext.userId;
const userId = context.userId;
if (!userId) {
return { success: false, error: 'User authentication required for SQL execution' };
}
const dataSourceSyntax = runtimeContext.dataSourceSyntax;
const dataSourceSyntax = context.dataSourceSyntax;
const permissionResult = await validateSqlPermissions(sqlStatement, userId, dataSourceSyntax);
if (!permissionResult.isAuthorized) {
return {
@ -394,7 +401,15 @@ export const executeSql = tool({
inputSchema: executeSqlStatementInputSchema,
outputSchema: executeSqlStatementOutputSchema,
execute: async (input, { experimental_context: context }) => {
return await executeSqlStatement(input, context as AnalystAgentOptions);
const rawContext = context as AnalystAgentOptions;
const executeSqlContext = executeSqlContextSchema.parse({
dataSourceId: rawContext.dataSourceId,
userId: rawContext.userId,
dataSourceSyntax: rawContext.dataSourceSyntax,
});
return await executeSqlStatement(input, executeSqlContext);
},
});

View File

@ -2,12 +2,11 @@ import { describe, expect, test } from 'vitest';
import { parseStreamingArgs as parseIdleArgs } from '../../tools/communication-tools/idle-tool';
// Note: Some tools have been converted to AI SDK v5 and no longer have parseStreamingArgs
// Only test tools that still have the parseStreamingArgs function
import { parseStreamingArgs as parseExecuteSqlDocsAgentArgs } from '../../tools/database-tools/execute-sql-docs-agent';
// execute-sql-docs-agent has been converted to AI SDK v5 and no longer has parseStreamingArgs
describe('Streaming Parser Error Handling', () => {
const parsers = [
// Only test tools that still have parseStreamingArgs function
{ name: 'execute-sql-docs-agent', parser: parseExecuteSqlDocsAgentArgs },
{ name: 'idle-tool', parser: parseIdleArgs },
];
@ -73,12 +72,6 @@ describe('Streaming Parser Error Handling', () => {
});
describe('Successful Parsing (Should Work)', () => {
test('execute-sql-docs-agent should parse valid complete JSON', () => {
const validJson = '{"statements": ["SELECT * FROM test"]}';
const result = parseExecuteSqlDocsAgentArgs(validJson);
expect(result).toEqual({ statements: ['SELECT * FROM test'] });
});
test('idle-tool should parse valid complete JSON', () => {
const validJson = '{"final_response": "Test response"}';
const result = parseIdleArgs(validJson);