diff --git a/packages/ai/src/agents/analyst-agent/analyst-agent.int.test.ts b/packages/ai/src/agents/analyst-agent/analyst-agent.int.test.ts index 1c6976ad2..42f5e0e57 100644 --- a/packages/ai/src/agents/analyst-agent/analyst-agent.int.test.ts +++ b/packages/ai/src/agents/analyst-agent/analyst-agent.int.test.ts @@ -8,7 +8,12 @@ describe('Analyst Agent Integration Tests', () => { try { const analystAgent = createAnalystAgent({ - sql_dialect_guidance: 'postgresql', + userId: '123', + chatId: '123', + dataSourceId: '123', + dataSourceSyntax: 'postgresql', + organizationId: '123', + messageId: '123', }); const streamResult = await analystAgent.stream({ diff --git a/packages/ai/src/agents/analyst-agent/analyst-agent.ts b/packages/ai/src/agents/analyst-agent/analyst-agent.ts index 352be3545..585db9633 100644 --- a/packages/ai/src/agents/analyst-agent/analyst-agent.ts +++ b/packages/ai/src/agents/analyst-agent/analyst-agent.ts @@ -1,9 +1,14 @@ import { type ModelMessage, hasToolCall, stepCountIs, streamText } from 'ai'; import { wrapTraced } from 'braintrust'; import z from 'zod'; -import { createAnalystTools } from '../../tools/tool-factories'; +import { + createDashboards, + createMetrics, + doneTool, + modifyDashboards, + modifyMetrics, +} from '../../tools'; import { Sonnet4 } from '../../utils/models/sonnet-4'; -import { injectAgentContext } from '../helpers/context/agent-context-injection'; import { getAnalystAgentSystemPrompt } from './get-analyst-agent-system-prompt'; const DEFAULT_CACHE_OPTIONS = { @@ -39,15 +44,18 @@ export function createAnalystAgent(analystAgentOptions: AnalystAgentOptions) { providerOptions: DEFAULT_CACHE_OPTIONS, } as ModelMessage; - // Create tools with session context baked in - const tools = createAnalystTools(analystAgentOptions); - async function stream({ messages }: AnalystStreamOptions) { return wrapTraced( () => streamText({ model: Sonnet4, - tools, + tools: { + createMetrics, + modifyMetrics, + createDashboards, + modifyDashboards, + doneTool, + }, messages: [systemMessage, ...messages], stopWhen: STOP_CONDITIONS, toolChoice: 'required', diff --git a/packages/ai/src/agents/helpers/context/agent-context-injection.ts b/packages/ai/src/agents/helpers/context/agent-context-injection.ts deleted file mode 100644 index dc7b0f520..000000000 --- a/packages/ai/src/agents/helpers/context/agent-context-injection.ts +++ /dev/null @@ -1,11 +0,0 @@ -/** - * Creates a context injection function for any agent prepareStep. - * This function returns the agent options as context that can be accessed by tools. - */ -export function injectAgentContext(agentOptions: T) { - return async () => { - return { - context: agentOptions, - }; - }; -} diff --git a/packages/ai/src/tools/database-tools/execute-sql.ts b/packages/ai/src/tools/database-tools/execute-sql.ts index 86c286597..97c7bfc2f 100644 --- a/packages/ai/src/tools/database-tools/execute-sql.ts +++ b/packages/ai/src/tools/database-tools/execute-sql.ts @@ -193,12 +193,9 @@ const executeSqlStatement = wrapTraced( } const dataSourceId = context.dataSourceId; - const workflowStartTime = context.get('workflowStartTime') as number | undefined; // Generate a unique workflow ID using start time and data source - const workflowId = workflowStartTime - ? `workflow-${workflowStartTime}-${dataSourceId}` - : `workflow-${Date.now()}-${dataSourceId}`; + const workflowId = `workflow-${Date.now()}-${dataSourceId}`; // Get data source from workflow manager (reuses existing connections) const manager = getWorkflowDataSourceManager(workflowId); @@ -211,7 +208,7 @@ const executeSqlStatement = wrapTraced( withRateLimit( 'sql-execution', async () => { - const result = await executeSingleStatement(sqlStatement, dataSource, runtimeContext); + const result = await executeSingleStatement(sqlStatement, dataSource, context); return { sql: sqlStatement, result }; }, { @@ -274,7 +271,7 @@ const executeSqlStatement = wrapTraced( async function executeSingleStatement( sqlStatement: string, dataSource: DataSource, - runtimeContext: RuntimeContext + runtimeContext: AnalystAgentOptions ): Promise<{ success: boolean; data?: Record[]; @@ -290,12 +287,12 @@ async function executeSingleStatement( } // Validate permissions before execution - const userId = runtimeContext.get('userId'); + const userId = runtimeContext.userId; if (!userId) { return { success: false, error: 'User authentication required for SQL execution' }; } - const dataSourceSyntax = runtimeContext.get('dataSourceSyntax'); + const dataSourceSyntax = runtimeContext.dataSourceSyntax; const permissionResult = await validateSqlPermissions(sqlStatement, userId, dataSourceSyntax); if (!permissionResult.isAuthorized) { return {