mirror of https://github.com/buster-so/buster.git
Refactor tool names and enhance context handling in AnalystAgent
- Updated tool name constants for visualization and communication tools to improve clarity and consistency. - Modified the AnalystAgent to utilize these constants, ensuring better context handling during tool calls. - Enhanced the STOP_CONDITIONS to reference the DONE_TOOL_NAME constant, improving maintainability. These changes streamline the integration of tools within the AnalystAgent, enhancing overall code organization and readability.
This commit is contained in:
parent
0f0a5ed7d1
commit
6c989b7c25
|
@ -12,6 +12,13 @@ import {
|
||||||
createModifyMetricsTool,
|
createModifyMetricsTool,
|
||||||
createModifyReportsTool,
|
createModifyReportsTool,
|
||||||
} from '../../tools';
|
} from '../../tools';
|
||||||
|
import { DONE_TOOL_NAME } from '../../tools/communication-tools/done-tool/done-tool';
|
||||||
|
import { CREATE_DASHBOARDS_TOOL_NAME } from '../../tools/visualization-tools/dashboards/create-dashboards-tool/create-dashboards-tool';
|
||||||
|
import { MODIFY_DASHBOARDS_TOOL_NAME } from '../../tools/visualization-tools/dashboards/modify-dashboards-tool/modify-dashboards-tool';
|
||||||
|
import { CREATE_METRICS_TOOL_NAME } from '../../tools/visualization-tools/metrics/create-metrics-tool/create-metrics-tool';
|
||||||
|
import { MODIFY_METRICS_TOOL_NAME } from '../../tools/visualization-tools/metrics/modify-metrics-tool/modify-metrics-tool';
|
||||||
|
import { CREATE_REPORTS_TOOL_NAME } from '../../tools/visualization-tools/reports/create-reports-tool/create-reports-tool';
|
||||||
|
import { MODIFY_REPORTS_TOOL_NAME } from '../../tools/visualization-tools/reports/modify-reports-tool/modify-reports-tool';
|
||||||
import { healToolWithLlm } from '../../utils/tool-call-repair';
|
import { healToolWithLlm } from '../../utils/tool-call-repair';
|
||||||
import { getAnalystAgentSystemPrompt } from './get-analyst-agent-system-prompt';
|
import { getAnalystAgentSystemPrompt } from './get-analyst-agent-system-prompt';
|
||||||
|
|
||||||
|
@ -23,7 +30,7 @@ const DEFAULT_CACHE_OPTIONS = {
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
const STOP_CONDITIONS = [stepCountIs(25), hasToolCall('doneTool')];
|
const STOP_CONDITIONS = [stepCountIs(25), hasToolCall(DONE_TOOL_NAME)];
|
||||||
|
|
||||||
export const AnalystAgentOptionsSchema = z.object({
|
export const AnalystAgentOptionsSchema = z.object({
|
||||||
userId: z.string(),
|
userId: z.string(),
|
||||||
|
@ -72,14 +79,12 @@ export function createAnalystAgent(analystAgentOptions: AnalystAgentOptions) {
|
||||||
let attempt = 0;
|
let attempt = 0;
|
||||||
const currentMessages = [...messages];
|
const currentMessages = [...messages];
|
||||||
|
|
||||||
// Create tool instances - all visualization tools now accept context directly
|
|
||||||
const createMetrics = createCreateMetricsTool(analystAgentOptions);
|
const createMetrics = createCreateMetricsTool(analystAgentOptions);
|
||||||
const modifyMetrics = createModifyMetricsTool(analystAgentOptions);
|
const modifyMetrics = createModifyMetricsTool(analystAgentOptions);
|
||||||
const createDashboards = createCreateDashboardsTool(analystAgentOptions);
|
const createDashboards = createCreateDashboardsTool(analystAgentOptions);
|
||||||
const modifyDashboards = createModifyDashboardsTool(analystAgentOptions);
|
const modifyDashboards = createModifyDashboardsTool(analystAgentOptions);
|
||||||
const createReports = createCreateReportsTool(analystAgentOptions);
|
const createReports = createCreateReportsTool(analystAgentOptions);
|
||||||
const modifyReports = createModifyReportsTool(analystAgentOptions);
|
const modifyReports = createModifyReportsTool(analystAgentOptions);
|
||||||
// Done tool now accepts context directly
|
|
||||||
const doneTool = createDoneTool(analystAgentOptions);
|
const doneTool = createDoneTool(analystAgentOptions);
|
||||||
|
|
||||||
while (attempt <= maxRetries) {
|
while (attempt <= maxRetries) {
|
||||||
|
@ -89,13 +94,13 @@ export function createAnalystAgent(analystAgentOptions: AnalystAgentOptions) {
|
||||||
streamText({
|
streamText({
|
||||||
model: Sonnet4,
|
model: Sonnet4,
|
||||||
tools: {
|
tools: {
|
||||||
createMetrics,
|
[CREATE_METRICS_TOOL_NAME]: createMetrics,
|
||||||
modifyMetrics,
|
[MODIFY_METRICS_TOOL_NAME]: modifyMetrics,
|
||||||
createDashboards,
|
[CREATE_DASHBOARDS_TOOL_NAME]: createDashboards,
|
||||||
modifyDashboards,
|
[MODIFY_DASHBOARDS_TOOL_NAME]: modifyDashboards,
|
||||||
createReports,
|
[CREATE_REPORTS_TOOL_NAME]: createReports,
|
||||||
modifyReports,
|
[MODIFY_REPORTS_TOOL_NAME]: modifyReports,
|
||||||
doneTool,
|
[DONE_TOOL_NAME]: doneTool,
|
||||||
},
|
},
|
||||||
messages: [systemMessage, datasetsSystemMessage, ...currentMessages],
|
messages: [systemMessage, datasetsSystemMessage, ...currentMessages],
|
||||||
stopWhen: STOP_CONDITIONS,
|
stopWhen: STOP_CONDITIONS,
|
||||||
|
|
|
@ -5,6 +5,8 @@ import { createDoneToolExecute } from './done-tool-execute';
|
||||||
import { createDoneToolFinish } from './done-tool-finish';
|
import { createDoneToolFinish } from './done-tool-finish';
|
||||||
import { createDoneToolStart } from './done-tool-start';
|
import { createDoneToolStart } from './done-tool-start';
|
||||||
|
|
||||||
|
export const DONE_TOOL_NAME = 'doneTool';
|
||||||
|
|
||||||
export const DoneToolInputSchema = z.object({
|
export const DoneToolInputSchema = z.object({
|
||||||
final_response: z
|
final_response: z
|
||||||
.string()
|
.string()
|
||||||
|
|
|
@ -7,7 +7,7 @@ import { createCreateDashboardsExecute } from './create-dashboards-execute';
|
||||||
import { createCreateDashboardsFinish } from './create-dashboards-finish';
|
import { createCreateDashboardsFinish } from './create-dashboards-finish';
|
||||||
import { createDashboardsStart } from './create-dashboards-start';
|
import { createDashboardsStart } from './create-dashboards-start';
|
||||||
|
|
||||||
export const TOOL_NAME = 'createDashboards';
|
export const CREATE_DASHBOARDS_TOOL_NAME = 'createDashboards';
|
||||||
|
|
||||||
const CreateDashboardsInputFileSchema = z.object({
|
const CreateDashboardsInputFileSchema = z.object({
|
||||||
name: z.string(),
|
name: z.string(),
|
||||||
|
|
|
@ -3,7 +3,7 @@ import type {
|
||||||
ChatMessageReasoningMessage_File,
|
ChatMessageReasoningMessage_File,
|
||||||
} from '@buster/server-shared/chats';
|
} from '@buster/server-shared/chats';
|
||||||
import type { ModelMessage } from 'ai';
|
import type { ModelMessage } from 'ai';
|
||||||
import { type CreateDashboardsState, TOOL_NAME } from '../create-dashboards-tool';
|
import { CREATE_DASHBOARD_TOOL_NAME, type CreateDashboardsState } from '../create-dashboards-tool';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a reasoning entry for create-dashboards tool
|
* Create a reasoning entry for create-dashboards tool
|
||||||
|
@ -67,7 +67,7 @@ export function createCreateDashboardsRawLlmMessageEntry(
|
||||||
{
|
{
|
||||||
type: 'tool-call',
|
type: 'tool-call',
|
||||||
toolCallId,
|
toolCallId,
|
||||||
toolName: TOOL_NAME,
|
toolName: CREATE_DASHBOARD_TOOL_NAME,
|
||||||
input: {
|
input: {
|
||||||
files: state.files
|
files: state.files
|
||||||
.filter((file) => file != null) // Filter out null/undefined entries first
|
.filter((file) => file != null) // Filter out null/undefined entries first
|
||||||
|
|
|
@ -84,6 +84,8 @@ export type ModifyDashboardsOutputFailedFile = z.infer<
|
||||||
export type ModifyDashboardsState = z.infer<typeof ModifyDashboardsStateSchema>;
|
export type ModifyDashboardsState = z.infer<typeof ModifyDashboardsStateSchema>;
|
||||||
export type ModifyDashboardStateFile = z.infer<typeof ModifyDashboardStateFileSchema>;
|
export type ModifyDashboardStateFile = z.infer<typeof ModifyDashboardStateFileSchema>;
|
||||||
|
|
||||||
|
export const MODIFY_DASHBOARDS_TOOL_NAME = 'modifyDashboards';
|
||||||
|
|
||||||
// Factory function that accepts agent context and maps to tool context
|
// Factory function that accepts agent context and maps to tool context
|
||||||
export function createModifyDashboardsTool(context: ModifyDashboardsContext) {
|
export function createModifyDashboardsTool(context: ModifyDashboardsContext) {
|
||||||
// Initialize state for streaming
|
// Initialize state for streaming
|
||||||
|
|
|
@ -7,7 +7,7 @@ import { createCreateMetricsExecute } from './create-metrics-execute';
|
||||||
import { createCreateMetricsFinish } from './create-metrics-finish';
|
import { createCreateMetricsFinish } from './create-metrics-finish';
|
||||||
import { createCreateMetricsStart } from './create-metrics-start';
|
import { createCreateMetricsStart } from './create-metrics-start';
|
||||||
|
|
||||||
export const TOOL_NAME = 'createMetrics';
|
export const CREATE_METRICS_TOOL_NAME = 'createMetrics';
|
||||||
|
|
||||||
const CreateMetricsInputFileSchema = z.object({
|
const CreateMetricsInputFileSchema = z.object({
|
||||||
name: z
|
name: z
|
||||||
|
|
|
@ -3,7 +3,7 @@ import type {
|
||||||
ChatMessageReasoningMessage_File,
|
ChatMessageReasoningMessage_File,
|
||||||
} from '@buster/server-shared/chats';
|
} from '@buster/server-shared/chats';
|
||||||
import type { ModelMessage } from 'ai';
|
import type { ModelMessage } from 'ai';
|
||||||
import { type CreateMetricsState, TOOL_NAME } from '../create-metrics-tool';
|
import { CREATE_METRICS_TOOL_NAME, type CreateMetricsState } from '../create-metrics-tool';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a reasoning entry for create-metrics tool
|
* Create a reasoning entry for create-metrics tool
|
||||||
|
@ -67,7 +67,7 @@ export function createCreateMetricsRawLlmMessageEntry(
|
||||||
{
|
{
|
||||||
type: 'tool-call',
|
type: 'tool-call',
|
||||||
toolCallId,
|
toolCallId,
|
||||||
toolName: TOOL_NAME,
|
toolName: CREATE_METRICS_TOOL_NAME,
|
||||||
input: {
|
input: {
|
||||||
files: state.files
|
files: state.files
|
||||||
.filter((file) => file != null) // Filter out null/undefined entries first
|
.filter((file) => file != null) // Filter out null/undefined entries first
|
||||||
|
|
|
@ -3,7 +3,7 @@ import type {
|
||||||
ChatMessageReasoningMessage_File,
|
ChatMessageReasoningMessage_File,
|
||||||
} from '@buster/server-shared/chats';
|
} from '@buster/server-shared/chats';
|
||||||
import type { ModelMessage } from 'ai';
|
import type { ModelMessage } from 'ai';
|
||||||
import { type ModifyMetricsState, TOOL_NAME } from '../modify-metrics-tool';
|
import { MODIFY_METRICS_TOOL_NAME, type ModifyMetricsState } from '../modify-metrics-tool';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a reasoning entry for modify-metrics tool
|
* Create a reasoning entry for modify-metrics tool
|
||||||
|
@ -84,7 +84,7 @@ export function createModifyMetricsRawLlmMessageEntry(
|
||||||
{
|
{
|
||||||
type: 'tool-call',
|
type: 'tool-call',
|
||||||
toolCallId,
|
toolCallId,
|
||||||
toolName: TOOL_NAME,
|
toolName: MODIFY_METRICS_TOOL_NAME,
|
||||||
input: {
|
input: {
|
||||||
files: state.files
|
files: state.files
|
||||||
.filter((file) => file != null) // Filter out null/undefined entries first
|
.filter((file) => file != null) // Filter out null/undefined entries first
|
||||||
|
|
|
@ -7,7 +7,7 @@ import { createModifyMetricsExecute } from './modify-metrics-execute';
|
||||||
import { createModifyMetricsFinish } from './modify-metrics-finish';
|
import { createModifyMetricsFinish } from './modify-metrics-finish';
|
||||||
import { createModifyMetricsStart } from './modify-metrics-start';
|
import { createModifyMetricsStart } from './modify-metrics-start';
|
||||||
|
|
||||||
export const TOOL_NAME = 'modifyMetrics';
|
export const MODIFY_METRICS_TOOL_NAME = 'modifyMetrics';
|
||||||
|
|
||||||
const ModifyMetricsInputFileSchema = z.object({
|
const ModifyMetricsInputFileSchema = z.object({
|
||||||
id: z.string().describe('The UUID of the metric file to modify'),
|
id: z.string().describe('The UUID of the metric file to modify'),
|
||||||
|
|
|
@ -6,7 +6,7 @@ import { createCreateReportsExecute } from './create-reports-execute';
|
||||||
import { createCreateReportsFinish } from './create-reports-finish';
|
import { createCreateReportsFinish } from './create-reports-finish';
|
||||||
import { createReportsStart } from './create-reports-start';
|
import { createReportsStart } from './create-reports-start';
|
||||||
|
|
||||||
export const TOOL_NAME = 'createReports';
|
export const CREATE_REPORTS_TOOL_NAME = 'createReports';
|
||||||
|
|
||||||
const CreateReportsInputFileSchema = z.object({
|
const CreateReportsInputFileSchema = z.object({
|
||||||
name: z
|
name: z
|
||||||
|
|
|
@ -4,9 +4,9 @@ import type {
|
||||||
} from '@buster/server-shared/chats';
|
} from '@buster/server-shared/chats';
|
||||||
import type { ModelMessage } from 'ai';
|
import type { ModelMessage } from 'ai';
|
||||||
import {
|
import {
|
||||||
|
CREATE_REPORTS_TOOL_NAME,
|
||||||
type CreateReportStateFile,
|
type CreateReportStateFile,
|
||||||
type CreateReportsState,
|
type CreateReportsState,
|
||||||
TOOL_NAME,
|
|
||||||
} from '../create-reports-tool';
|
} from '../create-reports-tool';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -74,7 +74,7 @@ export function createCreateReportsRawLlmMessageEntry(
|
||||||
{
|
{
|
||||||
type: 'tool-call',
|
type: 'tool-call',
|
||||||
toolCallId,
|
toolCallId,
|
||||||
toolName: TOOL_NAME,
|
toolName: CREATE_REPORTS_TOOL_NAME,
|
||||||
input: {
|
input: {
|
||||||
files: state.files
|
files: state.files
|
||||||
.filter((file) => file != null) // Filter out null/undefined entries first
|
.filter((file) => file != null) // Filter out null/undefined entries first
|
||||||
|
|
|
@ -3,7 +3,7 @@ import type {
|
||||||
ChatMessageReasoningMessage_File,
|
ChatMessageReasoningMessage_File,
|
||||||
} from '@buster/server-shared/chats';
|
} from '@buster/server-shared/chats';
|
||||||
import type { ModelMessage } from 'ai';
|
import type { ModelMessage } from 'ai';
|
||||||
import { type ModifyReportsState, TOOL_NAME } from '../modify-reports-tool';
|
import { MODIFY_REPORTS_TOOL_NAME, type ModifyReportsState } from '../modify-reports-tool';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a reasoning entry for modify-reports tool
|
* Create a reasoning entry for modify-reports tool
|
||||||
|
@ -66,7 +66,7 @@ export function createModifyReportsRawLlmMessageEntry(
|
||||||
{
|
{
|
||||||
type: 'tool-call',
|
type: 'tool-call',
|
||||||
toolCallId,
|
toolCallId,
|
||||||
toolName: TOOL_NAME,
|
toolName: MODIFY_REPORTS_TOOL_NAME,
|
||||||
input: {
|
input: {
|
||||||
id: state.reportId,
|
id: state.reportId,
|
||||||
name: state.reportName ?? 'Untitled Report',
|
name: state.reportName ?? 'Untitled Report',
|
||||||
|
|
|
@ -6,7 +6,7 @@ import { createModifyReportsExecute } from './modify-reports-execute';
|
||||||
import { createModifyReportsFinish } from './modify-reports-finish';
|
import { createModifyReportsFinish } from './modify-reports-finish';
|
||||||
import { modifyReportsStart } from './modify-reports-start';
|
import { modifyReportsStart } from './modify-reports-start';
|
||||||
|
|
||||||
export const TOOL_NAME = 'modifyReports';
|
export const MODIFY_REPORTS_TOOL_NAME = 'modifyReports';
|
||||||
|
|
||||||
const ModifyReportsEditSchema = z.object({
|
const ModifyReportsEditSchema = z.object({
|
||||||
code_to_replace: z
|
code_to_replace: z
|
||||||
|
|
Loading…
Reference in New Issue