Refactor tool names and enhance context handling in AnalystAgent

- Updated tool name constants for visualization and communication tools to improve clarity and consistency.
- Modified the AnalystAgent to utilize these constants, ensuring better context handling during tool calls.
- Enhanced the STOP_CONDITIONS to reference the DONE_TOOL_NAME constant, improving maintainability.

These changes streamline the integration of tools within the AnalystAgent, enhancing overall code organization and readability.
This commit is contained in:
dal 2025-08-12 23:05:45 -06:00
parent 0f0a5ed7d1
commit 6c989b7c25
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
13 changed files with 34 additions and 25 deletions

View File

@ -12,6 +12,13 @@ import {
createModifyMetricsTool,
createModifyReportsTool,
} from '../../tools';
import { DONE_TOOL_NAME } from '../../tools/communication-tools/done-tool/done-tool';
import { CREATE_DASHBOARDS_TOOL_NAME } from '../../tools/visualization-tools/dashboards/create-dashboards-tool/create-dashboards-tool';
import { MODIFY_DASHBOARDS_TOOL_NAME } from '../../tools/visualization-tools/dashboards/modify-dashboards-tool/modify-dashboards-tool';
import { CREATE_METRICS_TOOL_NAME } from '../../tools/visualization-tools/metrics/create-metrics-tool/create-metrics-tool';
import { MODIFY_METRICS_TOOL_NAME } from '../../tools/visualization-tools/metrics/modify-metrics-tool/modify-metrics-tool';
import { CREATE_REPORTS_TOOL_NAME } from '../../tools/visualization-tools/reports/create-reports-tool/create-reports-tool';
import { MODIFY_REPORTS_TOOL_NAME } from '../../tools/visualization-tools/reports/modify-reports-tool/modify-reports-tool';
import { healToolWithLlm } from '../../utils/tool-call-repair';
import { getAnalystAgentSystemPrompt } from './get-analyst-agent-system-prompt';
@ -23,7 +30,7 @@ const DEFAULT_CACHE_OPTIONS = {
},
};
const STOP_CONDITIONS = [stepCountIs(25), hasToolCall('doneTool')];
const STOP_CONDITIONS = [stepCountIs(25), hasToolCall(DONE_TOOL_NAME)];
export const AnalystAgentOptionsSchema = z.object({
userId: z.string(),
@ -72,14 +79,12 @@ export function createAnalystAgent(analystAgentOptions: AnalystAgentOptions) {
let attempt = 0;
const currentMessages = [...messages];
// Create tool instances - all visualization tools now accept context directly
const createMetrics = createCreateMetricsTool(analystAgentOptions);
const modifyMetrics = createModifyMetricsTool(analystAgentOptions);
const createDashboards = createCreateDashboardsTool(analystAgentOptions);
const modifyDashboards = createModifyDashboardsTool(analystAgentOptions);
const createReports = createCreateReportsTool(analystAgentOptions);
const modifyReports = createModifyReportsTool(analystAgentOptions);
// Done tool now accepts context directly
const doneTool = createDoneTool(analystAgentOptions);
while (attempt <= maxRetries) {
@ -89,13 +94,13 @@ export function createAnalystAgent(analystAgentOptions: AnalystAgentOptions) {
streamText({
model: Sonnet4,
tools: {
createMetrics,
modifyMetrics,
createDashboards,
modifyDashboards,
createReports,
modifyReports,
doneTool,
[CREATE_METRICS_TOOL_NAME]: createMetrics,
[MODIFY_METRICS_TOOL_NAME]: modifyMetrics,
[CREATE_DASHBOARDS_TOOL_NAME]: createDashboards,
[MODIFY_DASHBOARDS_TOOL_NAME]: modifyDashboards,
[CREATE_REPORTS_TOOL_NAME]: createReports,
[MODIFY_REPORTS_TOOL_NAME]: modifyReports,
[DONE_TOOL_NAME]: doneTool,
},
messages: [systemMessage, datasetsSystemMessage, ...currentMessages],
stopWhen: STOP_CONDITIONS,

View File

@ -5,6 +5,8 @@ import { createDoneToolExecute } from './done-tool-execute';
import { createDoneToolFinish } from './done-tool-finish';
import { createDoneToolStart } from './done-tool-start';
export const DONE_TOOL_NAME = 'doneTool';
export const DoneToolInputSchema = z.object({
final_response: z
.string()

View File

@ -7,7 +7,7 @@ import { createCreateDashboardsExecute } from './create-dashboards-execute';
import { createCreateDashboardsFinish } from './create-dashboards-finish';
import { createDashboardsStart } from './create-dashboards-start';
export const TOOL_NAME = 'createDashboards';
export const CREATE_DASHBOARDS_TOOL_NAME = 'createDashboards';
const CreateDashboardsInputFileSchema = z.object({
name: z.string(),

View File

@ -3,7 +3,7 @@ import type {
ChatMessageReasoningMessage_File,
} from '@buster/server-shared/chats';
import type { ModelMessage } from 'ai';
import { type CreateDashboardsState, TOOL_NAME } from '../create-dashboards-tool';
import { CREATE_DASHBOARD_TOOL_NAME, type CreateDashboardsState } from '../create-dashboards-tool';
/**
* Create a reasoning entry for create-dashboards tool
@ -67,7 +67,7 @@ export function createCreateDashboardsRawLlmMessageEntry(
{
type: 'tool-call',
toolCallId,
toolName: TOOL_NAME,
toolName: CREATE_DASHBOARD_TOOL_NAME,
input: {
files: state.files
.filter((file) => file != null) // Filter out null/undefined entries first

View File

@ -84,6 +84,8 @@ export type ModifyDashboardsOutputFailedFile = z.infer<
export type ModifyDashboardsState = z.infer<typeof ModifyDashboardsStateSchema>;
export type ModifyDashboardStateFile = z.infer<typeof ModifyDashboardStateFileSchema>;
export const MODIFY_DASHBOARDS_TOOL_NAME = 'modifyDashboards';
// Factory function that accepts agent context and maps to tool context
export function createModifyDashboardsTool(context: ModifyDashboardsContext) {
// Initialize state for streaming

View File

@ -7,7 +7,7 @@ import { createCreateMetricsExecute } from './create-metrics-execute';
import { createCreateMetricsFinish } from './create-metrics-finish';
import { createCreateMetricsStart } from './create-metrics-start';
export const TOOL_NAME = 'createMetrics';
export const CREATE_METRICS_TOOL_NAME = 'createMetrics';
const CreateMetricsInputFileSchema = z.object({
name: z

View File

@ -3,7 +3,7 @@ import type {
ChatMessageReasoningMessage_File,
} from '@buster/server-shared/chats';
import type { ModelMessage } from 'ai';
import { type CreateMetricsState, TOOL_NAME } from '../create-metrics-tool';
import { CREATE_METRICS_TOOL_NAME, type CreateMetricsState } from '../create-metrics-tool';
/**
* Create a reasoning entry for create-metrics tool
@ -67,7 +67,7 @@ export function createCreateMetricsRawLlmMessageEntry(
{
type: 'tool-call',
toolCallId,
toolName: TOOL_NAME,
toolName: CREATE_METRICS_TOOL_NAME,
input: {
files: state.files
.filter((file) => file != null) // Filter out null/undefined entries first

View File

@ -3,7 +3,7 @@ import type {
ChatMessageReasoningMessage_File,
} from '@buster/server-shared/chats';
import type { ModelMessage } from 'ai';
import { type ModifyMetricsState, TOOL_NAME } from '../modify-metrics-tool';
import { MODIFY_METRICS_TOOL_NAME, type ModifyMetricsState } from '../modify-metrics-tool';
/**
* Create a reasoning entry for modify-metrics tool
@ -84,7 +84,7 @@ export function createModifyMetricsRawLlmMessageEntry(
{
type: 'tool-call',
toolCallId,
toolName: TOOL_NAME,
toolName: MODIFY_METRICS_TOOL_NAME,
input: {
files: state.files
.filter((file) => file != null) // Filter out null/undefined entries first

View File

@ -7,7 +7,7 @@ import { createModifyMetricsExecute } from './modify-metrics-execute';
import { createModifyMetricsFinish } from './modify-metrics-finish';
import { createModifyMetricsStart } from './modify-metrics-start';
export const TOOL_NAME = 'modifyMetrics';
export const MODIFY_METRICS_TOOL_NAME = 'modifyMetrics';
const ModifyMetricsInputFileSchema = z.object({
id: z.string().describe('The UUID of the metric file to modify'),

View File

@ -6,7 +6,7 @@ import { createCreateReportsExecute } from './create-reports-execute';
import { createCreateReportsFinish } from './create-reports-finish';
import { createReportsStart } from './create-reports-start';
export const TOOL_NAME = 'createReports';
export const CREATE_REPORTS_TOOL_NAME = 'createReports';
const CreateReportsInputFileSchema = z.object({
name: z

View File

@ -4,9 +4,9 @@ import type {
} from '@buster/server-shared/chats';
import type { ModelMessage } from 'ai';
import {
CREATE_REPORTS_TOOL_NAME,
type CreateReportStateFile,
type CreateReportsState,
TOOL_NAME,
} from '../create-reports-tool';
/**
@ -74,7 +74,7 @@ export function createCreateReportsRawLlmMessageEntry(
{
type: 'tool-call',
toolCallId,
toolName: TOOL_NAME,
toolName: CREATE_REPORTS_TOOL_NAME,
input: {
files: state.files
.filter((file) => file != null) // Filter out null/undefined entries first

View File

@ -3,7 +3,7 @@ import type {
ChatMessageReasoningMessage_File,
} from '@buster/server-shared/chats';
import type { ModelMessage } from 'ai';
import { type ModifyReportsState, TOOL_NAME } from '../modify-reports-tool';
import { MODIFY_REPORTS_TOOL_NAME, type ModifyReportsState } from '../modify-reports-tool';
/**
* Create a reasoning entry for modify-reports tool
@ -66,7 +66,7 @@ export function createModifyReportsRawLlmMessageEntry(
{
type: 'tool-call',
toolCallId,
toolName: TOOL_NAME,
toolName: MODIFY_REPORTS_TOOL_NAME,
input: {
id: state.reportId,
name: state.reportName ?? 'Untitled Report',

View File

@ -6,7 +6,7 @@ import { createModifyReportsExecute } from './modify-reports-execute';
import { createModifyReportsFinish } from './modify-reports-finish';
import { modifyReportsStart } from './modify-reports-start';
export const TOOL_NAME = 'modifyReports';
export const MODIFY_REPORTS_TOOL_NAME = 'modifyReports';
const ModifyReportsEditSchema = z.object({
code_to_replace: z