mirror of https://github.com/buster-so/buster.git
Refactor tool names and enhance context handling in AnalystAgent
- Updated tool name constants for visualization and communication tools to improve clarity and consistency. - Modified the AnalystAgent to utilize these constants, ensuring better context handling during tool calls. - Enhanced the STOP_CONDITIONS to reference the DONE_TOOL_NAME constant, improving maintainability. These changes streamline the integration of tools within the AnalystAgent, enhancing overall code organization and readability.
This commit is contained in:
parent
0f0a5ed7d1
commit
6c989b7c25
|
@ -12,6 +12,13 @@ import {
|
|||
createModifyMetricsTool,
|
||||
createModifyReportsTool,
|
||||
} from '../../tools';
|
||||
import { DONE_TOOL_NAME } from '../../tools/communication-tools/done-tool/done-tool';
|
||||
import { CREATE_DASHBOARDS_TOOL_NAME } from '../../tools/visualization-tools/dashboards/create-dashboards-tool/create-dashboards-tool';
|
||||
import { MODIFY_DASHBOARDS_TOOL_NAME } from '../../tools/visualization-tools/dashboards/modify-dashboards-tool/modify-dashboards-tool';
|
||||
import { CREATE_METRICS_TOOL_NAME } from '../../tools/visualization-tools/metrics/create-metrics-tool/create-metrics-tool';
|
||||
import { MODIFY_METRICS_TOOL_NAME } from '../../tools/visualization-tools/metrics/modify-metrics-tool/modify-metrics-tool';
|
||||
import { CREATE_REPORTS_TOOL_NAME } from '../../tools/visualization-tools/reports/create-reports-tool/create-reports-tool';
|
||||
import { MODIFY_REPORTS_TOOL_NAME } from '../../tools/visualization-tools/reports/modify-reports-tool/modify-reports-tool';
|
||||
import { healToolWithLlm } from '../../utils/tool-call-repair';
|
||||
import { getAnalystAgentSystemPrompt } from './get-analyst-agent-system-prompt';
|
||||
|
||||
|
@ -23,7 +30,7 @@ const DEFAULT_CACHE_OPTIONS = {
|
|||
},
|
||||
};
|
||||
|
||||
const STOP_CONDITIONS = [stepCountIs(25), hasToolCall('doneTool')];
|
||||
const STOP_CONDITIONS = [stepCountIs(25), hasToolCall(DONE_TOOL_NAME)];
|
||||
|
||||
export const AnalystAgentOptionsSchema = z.object({
|
||||
userId: z.string(),
|
||||
|
@ -72,14 +79,12 @@ export function createAnalystAgent(analystAgentOptions: AnalystAgentOptions) {
|
|||
let attempt = 0;
|
||||
const currentMessages = [...messages];
|
||||
|
||||
// Create tool instances - all visualization tools now accept context directly
|
||||
const createMetrics = createCreateMetricsTool(analystAgentOptions);
|
||||
const modifyMetrics = createModifyMetricsTool(analystAgentOptions);
|
||||
const createDashboards = createCreateDashboardsTool(analystAgentOptions);
|
||||
const modifyDashboards = createModifyDashboardsTool(analystAgentOptions);
|
||||
const createReports = createCreateReportsTool(analystAgentOptions);
|
||||
const modifyReports = createModifyReportsTool(analystAgentOptions);
|
||||
// Done tool now accepts context directly
|
||||
const doneTool = createDoneTool(analystAgentOptions);
|
||||
|
||||
while (attempt <= maxRetries) {
|
||||
|
@ -89,13 +94,13 @@ export function createAnalystAgent(analystAgentOptions: AnalystAgentOptions) {
|
|||
streamText({
|
||||
model: Sonnet4,
|
||||
tools: {
|
||||
createMetrics,
|
||||
modifyMetrics,
|
||||
createDashboards,
|
||||
modifyDashboards,
|
||||
createReports,
|
||||
modifyReports,
|
||||
doneTool,
|
||||
[CREATE_METRICS_TOOL_NAME]: createMetrics,
|
||||
[MODIFY_METRICS_TOOL_NAME]: modifyMetrics,
|
||||
[CREATE_DASHBOARDS_TOOL_NAME]: createDashboards,
|
||||
[MODIFY_DASHBOARDS_TOOL_NAME]: modifyDashboards,
|
||||
[CREATE_REPORTS_TOOL_NAME]: createReports,
|
||||
[MODIFY_REPORTS_TOOL_NAME]: modifyReports,
|
||||
[DONE_TOOL_NAME]: doneTool,
|
||||
},
|
||||
messages: [systemMessage, datasetsSystemMessage, ...currentMessages],
|
||||
stopWhen: STOP_CONDITIONS,
|
||||
|
|
|
@ -5,6 +5,8 @@ import { createDoneToolExecute } from './done-tool-execute';
|
|||
import { createDoneToolFinish } from './done-tool-finish';
|
||||
import { createDoneToolStart } from './done-tool-start';
|
||||
|
||||
export const DONE_TOOL_NAME = 'doneTool';
|
||||
|
||||
export const DoneToolInputSchema = z.object({
|
||||
final_response: z
|
||||
.string()
|
||||
|
|
|
@ -7,7 +7,7 @@ import { createCreateDashboardsExecute } from './create-dashboards-execute';
|
|||
import { createCreateDashboardsFinish } from './create-dashboards-finish';
|
||||
import { createDashboardsStart } from './create-dashboards-start';
|
||||
|
||||
export const TOOL_NAME = 'createDashboards';
|
||||
export const CREATE_DASHBOARDS_TOOL_NAME = 'createDashboards';
|
||||
|
||||
const CreateDashboardsInputFileSchema = z.object({
|
||||
name: z.string(),
|
||||
|
|
|
@ -3,7 +3,7 @@ import type {
|
|||
ChatMessageReasoningMessage_File,
|
||||
} from '@buster/server-shared/chats';
|
||||
import type { ModelMessage } from 'ai';
|
||||
import { type CreateDashboardsState, TOOL_NAME } from '../create-dashboards-tool';
|
||||
import { CREATE_DASHBOARD_TOOL_NAME, type CreateDashboardsState } from '../create-dashboards-tool';
|
||||
|
||||
/**
|
||||
* Create a reasoning entry for create-dashboards tool
|
||||
|
@ -67,7 +67,7 @@ export function createCreateDashboardsRawLlmMessageEntry(
|
|||
{
|
||||
type: 'tool-call',
|
||||
toolCallId,
|
||||
toolName: TOOL_NAME,
|
||||
toolName: CREATE_DASHBOARD_TOOL_NAME,
|
||||
input: {
|
||||
files: state.files
|
||||
.filter((file) => file != null) // Filter out null/undefined entries first
|
||||
|
|
|
@ -84,6 +84,8 @@ export type ModifyDashboardsOutputFailedFile = z.infer<
|
|||
export type ModifyDashboardsState = z.infer<typeof ModifyDashboardsStateSchema>;
|
||||
export type ModifyDashboardStateFile = z.infer<typeof ModifyDashboardStateFileSchema>;
|
||||
|
||||
export const MODIFY_DASHBOARDS_TOOL_NAME = 'modifyDashboards';
|
||||
|
||||
// Factory function that accepts agent context and maps to tool context
|
||||
export function createModifyDashboardsTool(context: ModifyDashboardsContext) {
|
||||
// Initialize state for streaming
|
||||
|
|
|
@ -7,7 +7,7 @@ import { createCreateMetricsExecute } from './create-metrics-execute';
|
|||
import { createCreateMetricsFinish } from './create-metrics-finish';
|
||||
import { createCreateMetricsStart } from './create-metrics-start';
|
||||
|
||||
export const TOOL_NAME = 'createMetrics';
|
||||
export const CREATE_METRICS_TOOL_NAME = 'createMetrics';
|
||||
|
||||
const CreateMetricsInputFileSchema = z.object({
|
||||
name: z
|
||||
|
|
|
@ -3,7 +3,7 @@ import type {
|
|||
ChatMessageReasoningMessage_File,
|
||||
} from '@buster/server-shared/chats';
|
||||
import type { ModelMessage } from 'ai';
|
||||
import { type CreateMetricsState, TOOL_NAME } from '../create-metrics-tool';
|
||||
import { CREATE_METRICS_TOOL_NAME, type CreateMetricsState } from '../create-metrics-tool';
|
||||
|
||||
/**
|
||||
* Create a reasoning entry for create-metrics tool
|
||||
|
@ -67,7 +67,7 @@ export function createCreateMetricsRawLlmMessageEntry(
|
|||
{
|
||||
type: 'tool-call',
|
||||
toolCallId,
|
||||
toolName: TOOL_NAME,
|
||||
toolName: CREATE_METRICS_TOOL_NAME,
|
||||
input: {
|
||||
files: state.files
|
||||
.filter((file) => file != null) // Filter out null/undefined entries first
|
||||
|
|
|
@ -3,7 +3,7 @@ import type {
|
|||
ChatMessageReasoningMessage_File,
|
||||
} from '@buster/server-shared/chats';
|
||||
import type { ModelMessage } from 'ai';
|
||||
import { type ModifyMetricsState, TOOL_NAME } from '../modify-metrics-tool';
|
||||
import { MODIFY_METRICS_TOOL_NAME, type ModifyMetricsState } from '../modify-metrics-tool';
|
||||
|
||||
/**
|
||||
* Create a reasoning entry for modify-metrics tool
|
||||
|
@ -84,7 +84,7 @@ export function createModifyMetricsRawLlmMessageEntry(
|
|||
{
|
||||
type: 'tool-call',
|
||||
toolCallId,
|
||||
toolName: TOOL_NAME,
|
||||
toolName: MODIFY_METRICS_TOOL_NAME,
|
||||
input: {
|
||||
files: state.files
|
||||
.filter((file) => file != null) // Filter out null/undefined entries first
|
||||
|
|
|
@ -7,7 +7,7 @@ import { createModifyMetricsExecute } from './modify-metrics-execute';
|
|||
import { createModifyMetricsFinish } from './modify-metrics-finish';
|
||||
import { createModifyMetricsStart } from './modify-metrics-start';
|
||||
|
||||
export const TOOL_NAME = 'modifyMetrics';
|
||||
export const MODIFY_METRICS_TOOL_NAME = 'modifyMetrics';
|
||||
|
||||
const ModifyMetricsInputFileSchema = z.object({
|
||||
id: z.string().describe('The UUID of the metric file to modify'),
|
||||
|
|
|
@ -6,7 +6,7 @@ import { createCreateReportsExecute } from './create-reports-execute';
|
|||
import { createCreateReportsFinish } from './create-reports-finish';
|
||||
import { createReportsStart } from './create-reports-start';
|
||||
|
||||
export const TOOL_NAME = 'createReports';
|
||||
export const CREATE_REPORTS_TOOL_NAME = 'createReports';
|
||||
|
||||
const CreateReportsInputFileSchema = z.object({
|
||||
name: z
|
||||
|
|
|
@ -4,9 +4,9 @@ import type {
|
|||
} from '@buster/server-shared/chats';
|
||||
import type { ModelMessage } from 'ai';
|
||||
import {
|
||||
CREATE_REPORTS_TOOL_NAME,
|
||||
type CreateReportStateFile,
|
||||
type CreateReportsState,
|
||||
TOOL_NAME,
|
||||
} from '../create-reports-tool';
|
||||
|
||||
/**
|
||||
|
@ -74,7 +74,7 @@ export function createCreateReportsRawLlmMessageEntry(
|
|||
{
|
||||
type: 'tool-call',
|
||||
toolCallId,
|
||||
toolName: TOOL_NAME,
|
||||
toolName: CREATE_REPORTS_TOOL_NAME,
|
||||
input: {
|
||||
files: state.files
|
||||
.filter((file) => file != null) // Filter out null/undefined entries first
|
||||
|
|
|
@ -3,7 +3,7 @@ import type {
|
|||
ChatMessageReasoningMessage_File,
|
||||
} from '@buster/server-shared/chats';
|
||||
import type { ModelMessage } from 'ai';
|
||||
import { type ModifyReportsState, TOOL_NAME } from '../modify-reports-tool';
|
||||
import { MODIFY_REPORTS_TOOL_NAME, type ModifyReportsState } from '../modify-reports-tool';
|
||||
|
||||
/**
|
||||
* Create a reasoning entry for modify-reports tool
|
||||
|
@ -66,7 +66,7 @@ export function createModifyReportsRawLlmMessageEntry(
|
|||
{
|
||||
type: 'tool-call',
|
||||
toolCallId,
|
||||
toolName: TOOL_NAME,
|
||||
toolName: MODIFY_REPORTS_TOOL_NAME,
|
||||
input: {
|
||||
id: state.reportId,
|
||||
name: state.reportName ?? 'Untitled Report',
|
||||
|
|
|
@ -6,7 +6,7 @@ import { createModifyReportsExecute } from './modify-reports-execute';
|
|||
import { createModifyReportsFinish } from './modify-reports-finish';
|
||||
import { modifyReportsStart } from './modify-reports-start';
|
||||
|
||||
export const TOOL_NAME = 'modifyReports';
|
||||
export const MODIFY_REPORTS_TOOL_NAME = 'modifyReports';
|
||||
|
||||
const ModifyReportsEditSchema = z.object({
|
||||
code_to_replace: z
|
||||
|
|
Loading…
Reference in New Issue