tools updating properly

This commit is contained in:
dal 2025-08-15 15:24:05 -06:00
parent e2757c1ad0
commit c476aebd47
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
10 changed files with 115 additions and 30 deletions

View File

@ -33,7 +33,7 @@ const DocsAgentOptionsSchema = z.object({
chatId: z.string(), chatId: z.string(),
dataSourceId: z.string(), dataSourceId: z.string(),
organizationId: z.string(), organizationId: z.string(),
messageId: z.string().optional(), messageId: z.string(),
sandbox: z sandbox: z
.custom<Sandbox>( .custom<Sandbox>(
(val) => { (val) => {
@ -55,8 +55,6 @@ export type DocsStreamOptions = z.infer<typeof DocsStreamOptionsSchema>;
export type DocsAgentContextWithSandbox = DocsAgentOptions & { sandbox: Sandbox }; export type DocsAgentContextWithSandbox = DocsAgentOptions & { sandbox: Sandbox };
export function createDocsAgent(docsAgentOptions: DocsAgentOptions) { export function createDocsAgent(docsAgentOptions: DocsAgentOptions) {
const steps: never[] = [];
const systemMessage = { const systemMessage = {
role: 'system', role: 'system',
content: getDocsAgentSystemPrompt(docsAgentOptions.folder_structure), content: getDocsAgentSystemPrompt(docsAgentOptions.folder_structure),
@ -159,12 +157,7 @@ export function createDocsAgent(docsAgentOptions: DocsAgentOptions) {
)(); )();
} }
async function getSteps() {
return steps;
}
return { return {
stream, stream,
getSteps,
}; };
} }

View File

@ -9,6 +9,7 @@ export const DocsAgentStepInputSchema = z.object({
todos: z.string().describe('The todos string'), todos: z.string().describe('The todos string'),
todoList: z.string().describe('The TODO list'), todoList: z.string().describe('The TODO list'),
message: z.string().describe('The user message'), message: z.string().describe('The user message'),
messageId: z.string().describe('The user message'),
organizationId: z.string().describe('The organization ID'), organizationId: z.string().describe('The organization ID'),
context: DocsAgentContextSchema.describe('The docs agent context'), context: DocsAgentContextSchema.describe('The docs agent context'),
repositoryTree: z.string().describe('The tree structure of the repository'), repositoryTree: z.string().describe('The tree structure of the repository'),
@ -73,7 +74,7 @@ export async function runDocsAgentStep(params: DocsAgentStepInput): Promise<void
chatId: Date.now().toString(), // Using current timestamp as chatId chatId: Date.now().toString(), // Using current timestamp as chatId
dataSourceId: dataSourceId || '', dataSourceId: dataSourceId || '',
organizationId: validatedParams.organizationId || '', organizationId: validatedParams.organizationId || '',
messageId: undefined, // Optional field messageId: validatedParams.messageId, // Optional field
sandbox: sandbox, // Pass sandbox for file tools sandbox: sandbox, // Pass sandbox for file tools
}); });

View File

@ -16,12 +16,14 @@ import {
type DashboardYml, type DashboardYml,
DashboardYmlSchema, DashboardYmlSchema,
} from '../../../../../../server-shared/src/dashboards/dashboard.types'; } from '../../../../../../server-shared/src/dashboards/dashboard.types';
import { createRawToolResultEntry } from '../../../shared/create-raw-llm-tool-result-entry';
import { trackFileAssociations } from '../../file-tracking-helper'; import { trackFileAssociations } from '../../file-tracking-helper';
import type { import {
CreateDashboardsContext, CREATE_DASHBOARDS_TOOL_NAME,
CreateDashboardsInput, type CreateDashboardsContext,
CreateDashboardsOutput, type CreateDashboardsInput,
CreateDashboardsState, type CreateDashboardsOutput,
type CreateDashboardsState,
} from './create-dashboards-tool'; } from './create-dashboards-tool';
import { import {
createCreateDashboardsRawLlmMessageEntry, createCreateDashboardsRawLlmMessageEntry,
@ -564,6 +566,13 @@ export function createCreateDashboardsExecute(
const reasoningEntry = createCreateDashboardsReasoningEntry(state, toolCallId); const reasoningEntry = createCreateDashboardsReasoningEntry(state, toolCallId);
const rawLlmMessage = createCreateDashboardsRawLlmMessageEntry(state, toolCallId); const rawLlmMessage = createCreateDashboardsRawLlmMessageEntry(state, toolCallId);
const rawLlmResultEntry = createRawToolResultEntry(
toolCallId,
CREATE_DASHBOARDS_TOOL_NAME,
{
files: state.files,
}
);
const updates: Parameters<typeof updateMessageEntries>[0] = { const updates: Parameters<typeof updateMessageEntries>[0] = {
messageId: context.messageId, messageId: context.messageId,
@ -574,7 +583,7 @@ export function createCreateDashboardsExecute(
} }
if (rawLlmMessage) { if (rawLlmMessage) {
updates.rawLlmMessages = [rawLlmMessage]; updates.rawLlmMessages = [rawLlmMessage, rawLlmResultEntry];
} }
if (reasoningEntry || rawLlmMessage) { if (reasoningEntry || rawLlmMessage) {

View File

@ -10,6 +10,7 @@ import {
type DashboardYml, type DashboardYml,
DashboardYmlSchema, DashboardYmlSchema,
} from '../../../../../../server-shared/src/dashboards/dashboard.types'; } from '../../../../../../server-shared/src/dashboards/dashboard.types';
import { createRawToolResultEntry } from '../../../shared/create-raw-llm-tool-result-entry';
import { trackFileAssociations } from '../../file-tracking-helper'; import { trackFileAssociations } from '../../file-tracking-helper';
import { import {
createModifyDashboardsRawLlmMessageEntry, createModifyDashboardsRawLlmMessageEntry,
@ -21,6 +22,7 @@ import type {
ModifyDashboardsOutput, ModifyDashboardsOutput,
ModifyDashboardsState, ModifyDashboardsState,
} from './modify-dashboards-tool'; } from './modify-dashboards-tool';
import { MODIFY_DASHBOARDS_TOOL_NAME } from './modify-dashboards-tool';
// Type definitions // Type definitions
type DashboardWithMetadata = DashboardYml; type DashboardWithMetadata = DashboardYml;
@ -529,6 +531,13 @@ export function createModifyDashboardsExecute(
const reasoningEntry = createModifyDashboardsReasoningEntry(state, toolCallId); const reasoningEntry = createModifyDashboardsReasoningEntry(state, toolCallId);
const rawLlmMessage = createModifyDashboardsRawLlmMessageEntry(state, toolCallId); const rawLlmMessage = createModifyDashboardsRawLlmMessageEntry(state, toolCallId);
const rawLlmResultEntry = createRawToolResultEntry(
toolCallId,
MODIFY_DASHBOARDS_TOOL_NAME,
{
files: state.files,
}
);
const updates: Parameters<typeof updateMessageEntries>[0] = { const updates: Parameters<typeof updateMessageEntries>[0] = {
messageId: context.messageId, messageId: context.messageId,
@ -539,7 +548,7 @@ export function createModifyDashboardsExecute(
} }
if (rawLlmMessage) { if (rawLlmMessage) {
updates.rawLlmMessages = [rawLlmMessage]; updates.rawLlmMessages = [rawLlmMessage, rawLlmResultEntry];
} }
if (reasoningEntry || rawLlmMessage) { if (reasoningEntry || rawLlmMessage) {
@ -586,6 +595,13 @@ export function createModifyDashboardsExecute(
const reasoningEntry = createModifyDashboardsReasoningEntry(state, toolCallId); const reasoningEntry = createModifyDashboardsReasoningEntry(state, toolCallId);
const rawLlmMessage = createModifyDashboardsRawLlmMessageEntry(state, toolCallId); const rawLlmMessage = createModifyDashboardsRawLlmMessageEntry(state, toolCallId);
const rawLlmResultEntry = createRawToolResultEntry(
toolCallId,
MODIFY_DASHBOARDS_TOOL_NAME,
{
files: state.files,
}
);
const updates: Parameters<typeof updateMessageEntries>[0] = { const updates: Parameters<typeof updateMessageEntries>[0] = {
messageId: context.messageId, messageId: context.messageId,
@ -596,7 +612,7 @@ export function createModifyDashboardsExecute(
} }
if (rawLlmMessage) { if (rawLlmMessage) {
updates.rawLlmMessages = [rawLlmMessage]; updates.rawLlmMessages = [rawLlmMessage, rawLlmResultEntry];
} }
if (reasoningEntry || rawLlmMessage) { if (reasoningEntry || rawLlmMessage) {

View File

@ -15,6 +15,7 @@ import {
createPermissionErrorMessage, createPermissionErrorMessage,
validateSqlPermissions, validateSqlPermissions,
} from '../../../../utils/sql-permissions'; } from '../../../../utils/sql-permissions';
import { createRawToolResultEntry } from '../../../shared/create-raw-llm-tool-result-entry';
import { trackFileAssociations } from '../../file-tracking-helper'; import { trackFileAssociations } from '../../file-tracking-helper';
import { validateAndAdjustBarLineAxes } from '../helpers/bar-line-axis-validator'; import { validateAndAdjustBarLineAxes } from '../helpers/bar-line-axis-validator';
import { ensureTimeFrameQuoted } from '../helpers/time-frame-helper'; import { ensureTimeFrameQuoted } from '../helpers/time-frame-helper';
@ -24,6 +25,7 @@ import type {
CreateMetricsOutput, CreateMetricsOutput,
CreateMetricsState, CreateMetricsState,
} from './create-metrics-tool'; } from './create-metrics-tool';
import { CREATE_METRICS_TOOL_NAME } from './create-metrics-tool';
import { import {
createCreateMetricsRawLlmMessageEntry, createCreateMetricsRawLlmMessageEntry,
createCreateMetricsReasoningEntry, createCreateMetricsReasoningEntry,
@ -739,6 +741,13 @@ export function createCreateMetricsExecute(
const reasoningEntry = createCreateMetricsReasoningEntry(state, toolCallId); const reasoningEntry = createCreateMetricsReasoningEntry(state, toolCallId);
const rawLlmMessage = createCreateMetricsRawLlmMessageEntry(state, toolCallId); const rawLlmMessage = createCreateMetricsRawLlmMessageEntry(state, toolCallId);
const rawLlmResultEntry = createRawToolResultEntry(
toolCallId,
CREATE_METRICS_TOOL_NAME,
{
files: state.files,
}
);
const updates: Parameters<typeof updateMessageEntries>[0] = { const updates: Parameters<typeof updateMessageEntries>[0] = {
messageId: context.messageId, messageId: context.messageId,
@ -749,7 +758,7 @@ export function createCreateMetricsExecute(
} }
if (rawLlmMessage) { if (rawLlmMessage) {
updates.rawLlmMessages = [rawLlmMessage]; updates.rawLlmMessages = [rawLlmMessage, rawLlmResultEntry];
} }
if (reasoningEntry || rawLlmMessage) { if (reasoningEntry || rawLlmMessage) {
@ -796,6 +805,13 @@ export function createCreateMetricsExecute(
const reasoningEntry = createCreateMetricsReasoningEntry(state, toolCallId); const reasoningEntry = createCreateMetricsReasoningEntry(state, toolCallId);
const rawLlmMessage = createCreateMetricsRawLlmMessageEntry(state, toolCallId); const rawLlmMessage = createCreateMetricsRawLlmMessageEntry(state, toolCallId);
const rawLlmResultEntry = createRawToolResultEntry(
toolCallId,
CREATE_METRICS_TOOL_NAME,
{
files: state.files,
}
);
const updates: Parameters<typeof updateMessageEntries>[0] = { const updates: Parameters<typeof updateMessageEntries>[0] = {
messageId: context.messageId, messageId: context.messageId,
@ -806,7 +822,7 @@ export function createCreateMetricsExecute(
} }
if (rawLlmMessage) { if (rawLlmMessage) {
updates.rawLlmMessages = [rawLlmMessage]; updates.rawLlmMessages = [rawLlmMessage, rawLlmResultEntry];
} }
if (reasoningEntry || rawLlmMessage) { if (reasoningEntry || rawLlmMessage) {

View File

@ -16,6 +16,7 @@ import {
createPermissionErrorMessage, createPermissionErrorMessage,
validateSqlPermissions, validateSqlPermissions,
} from '../../../../utils/sql-permissions'; } from '../../../../utils/sql-permissions';
import { createRawToolResultEntry } from '../../../shared/create-raw-llm-tool-result-entry';
import { trackFileAssociations } from '../../file-tracking-helper'; import { trackFileAssociations } from '../../file-tracking-helper';
import { validateAndAdjustBarLineAxes } from '../helpers/bar-line-axis-validator'; import { validateAndAdjustBarLineAxes } from '../helpers/bar-line-axis-validator';
import { ensureTimeFrameQuoted } from '../helpers/time-frame-helper'; import { ensureTimeFrameQuoted } from '../helpers/time-frame-helper';
@ -29,6 +30,7 @@ import type {
ModifyMetricsOutput, ModifyMetricsOutput,
ModifyMetricsState, ModifyMetricsState,
} from './modify-metrics-tool'; } from './modify-metrics-tool';
import { MODIFY_METRICS_TOOL_NAME } from './modify-metrics-tool';
interface FileWithId { interface FileWithId {
id: string; id: string;
@ -811,6 +813,13 @@ export function createModifyMetricsExecute(
const reasoningEntry = createModifyMetricsReasoningEntry(state, toolCallId); const reasoningEntry = createModifyMetricsReasoningEntry(state, toolCallId);
const rawLlmMessage = createModifyMetricsRawLlmMessageEntry(state, toolCallId); const rawLlmMessage = createModifyMetricsRawLlmMessageEntry(state, toolCallId);
const rawLlmResultEntry = createRawToolResultEntry(
toolCallId,
MODIFY_METRICS_TOOL_NAME,
{
files: state.files,
}
);
const updates: Parameters<typeof updateMessageEntries>[0] = { const updates: Parameters<typeof updateMessageEntries>[0] = {
messageId: context.messageId, messageId: context.messageId,
@ -821,7 +830,7 @@ export function createModifyMetricsExecute(
} }
if (rawLlmMessage) { if (rawLlmMessage) {
updates.rawLlmMessages = [rawLlmMessage]; updates.rawLlmMessages = [rawLlmMessage, rawLlmResultEntry];
} }
if (reasoningEntry || rawLlmMessage) { if (reasoningEntry || rawLlmMessage) {
@ -864,6 +873,13 @@ export function createModifyMetricsExecute(
try { try {
const reasoningEntry = createModifyMetricsReasoningEntry(state, state.toolCallId); const reasoningEntry = createModifyMetricsReasoningEntry(state, state.toolCallId);
const rawLlmMessage = createModifyMetricsRawLlmMessageEntry(state, state.toolCallId); const rawLlmMessage = createModifyMetricsRawLlmMessageEntry(state, state.toolCallId);
const rawLlmResultEntry = createRawToolResultEntry(
state.toolCallId,
MODIFY_METRICS_TOOL_NAME,
{
files: state.files,
}
);
const updates: Parameters<typeof updateMessageEntries>[0] = { const updates: Parameters<typeof updateMessageEntries>[0] = {
messageId: context.messageId, messageId: context.messageId,
@ -874,7 +890,7 @@ export function createModifyMetricsExecute(
} }
if (rawLlmMessage) { if (rawLlmMessage) {
updates.rawLlmMessages = [rawLlmMessage]; updates.rawLlmMessages = [rawLlmMessage, rawLlmResultEntry];
} }
if (reasoningEntry || rawLlmMessage) { if (reasoningEntry || rawLlmMessage) {

View File

@ -2,6 +2,7 @@ import { randomUUID } from 'node:crypto';
import { db, updateMessageEntries } from '@buster/database'; import { db, updateMessageEntries } from '@buster/database';
import { assetPermissions, reportFiles } from '@buster/database'; import { assetPermissions, reportFiles } from '@buster/database';
import { wrapTraced } from 'braintrust'; import { wrapTraced } from 'braintrust';
import { createRawToolResultEntry } from '../../../shared/create-raw-llm-tool-result-entry';
import { trackFileAssociations } from '../../file-tracking-helper'; import { trackFileAssociations } from '../../file-tracking-helper';
import type { import type {
CreateReportsContext, CreateReportsContext,
@ -9,6 +10,7 @@ import type {
CreateReportsOutput, CreateReportsOutput,
CreateReportsState, CreateReportsState,
} from './create-reports-tool'; } from './create-reports-tool';
import { CREATE_REPORTS_TOOL_NAME } from './create-reports-tool';
import { import {
createCreateReportsRawLlmMessageEntry, createCreateReportsRawLlmMessageEntry,
createCreateReportsReasoningEntry, createCreateReportsReasoningEntry,
@ -400,6 +402,13 @@ export function createCreateReportsExecute(
const reasoningEntry = createCreateReportsReasoningEntry(state, toolCallId); const reasoningEntry = createCreateReportsReasoningEntry(state, toolCallId);
const rawLlmMessage = createCreateReportsRawLlmMessageEntry(state, toolCallId); const rawLlmMessage = createCreateReportsRawLlmMessageEntry(state, toolCallId);
const rawLlmResultEntry = createRawToolResultEntry(
toolCallId,
CREATE_REPORTS_TOOL_NAME,
{
files: state.files,
}
);
const updates: Parameters<typeof updateMessageEntries>[0] = { const updates: Parameters<typeof updateMessageEntries>[0] = {
messageId: context.messageId, messageId: context.messageId,
@ -410,7 +419,7 @@ export function createCreateReportsExecute(
} }
if (rawLlmMessage) { if (rawLlmMessage) {
updates.rawLlmMessages = [rawLlmMessage]; updates.rawLlmMessages = [rawLlmMessage, rawLlmResultEntry];
} }
if (reasoningEntry || rawLlmMessage) { if (reasoningEntry || rawLlmMessage) {
@ -457,6 +466,13 @@ export function createCreateReportsExecute(
const reasoningEntry = createCreateReportsReasoningEntry(state, toolCallId); const reasoningEntry = createCreateReportsReasoningEntry(state, toolCallId);
const rawLlmMessage = createCreateReportsRawLlmMessageEntry(state, toolCallId); const rawLlmMessage = createCreateReportsRawLlmMessageEntry(state, toolCallId);
const rawLlmResultEntry = createRawToolResultEntry(
toolCallId,
CREATE_REPORTS_TOOL_NAME,
{
files: state.files,
}
);
const updates: Parameters<typeof updateMessageEntries>[0] = { const updates: Parameters<typeof updateMessageEntries>[0] = {
messageId: context.messageId, messageId: context.messageId,
@ -467,7 +483,7 @@ export function createCreateReportsExecute(
} }
if (rawLlmMessage) { if (rawLlmMessage) {
updates.rawLlmMessages = [rawLlmMessage]; updates.rawLlmMessages = [rawLlmMessage, rawLlmResultEntry];
} }
if (reasoningEntry || rawLlmMessage) { if (reasoningEntry || rawLlmMessage) {

View File

@ -1,5 +1,6 @@
import { appendReportContent, replaceReportContent, updateMessageEntries } from '@buster/database'; import { appendReportContent, replaceReportContent, updateMessageEntries } from '@buster/database';
import { wrapTraced } from 'braintrust'; import { wrapTraced } from 'braintrust';
import { createRawToolResultEntry } from '../../../shared/create-raw-llm-tool-result-entry';
import { trackFileAssociations } from '../../file-tracking-helper'; import { trackFileAssociations } from '../../file-tracking-helper';
import { import {
createModifyReportsRawLlmMessageEntry, createModifyReportsRawLlmMessageEntry,
@ -11,6 +12,7 @@ import type {
ModifyReportsOutput, ModifyReportsOutput,
ModifyReportsState, ModifyReportsState,
} from './modify-reports-tool'; } from './modify-reports-tool';
import { MODIFY_REPORTS_TOOL_NAME } from './modify-reports-tool';
// Process a single edit operation // Process a single edit operation
async function processEditOperation( async function processEditOperation(
@ -283,6 +285,13 @@ export function createModifyReportsExecute(
const reasoningEntry = createModifyReportsReasoningEntry(state, toolCallId); const reasoningEntry = createModifyReportsReasoningEntry(state, toolCallId);
const rawLlmMessage = createModifyReportsRawLlmMessageEntry(state, toolCallId); const rawLlmMessage = createModifyReportsRawLlmMessageEntry(state, toolCallId);
const rawLlmResultEntry = createRawToolResultEntry(
toolCallId,
MODIFY_REPORTS_TOOL_NAME,
{
edits: state.edits,
}
);
const updates: Parameters<typeof updateMessageEntries>[0] = { const updates: Parameters<typeof updateMessageEntries>[0] = {
messageId: context.messageId, messageId: context.messageId,
@ -293,7 +302,7 @@ export function createModifyReportsExecute(
} }
if (rawLlmMessage) { if (rawLlmMessage) {
updates.rawLlmMessages = [rawLlmMessage]; updates.rawLlmMessages = [rawLlmMessage, rawLlmResultEntry];
} }
if (reasoningEntry || rawLlmMessage) { if (reasoningEntry || rawLlmMessage) {
@ -343,6 +352,13 @@ export function createModifyReportsExecute(
const reasoningEntry = createModifyReportsReasoningEntry(state, toolCallId); const reasoningEntry = createModifyReportsReasoningEntry(state, toolCallId);
const rawLlmMessage = createModifyReportsRawLlmMessageEntry(state, toolCallId); const rawLlmMessage = createModifyReportsRawLlmMessageEntry(state, toolCallId);
const rawLlmResultEntry = createRawToolResultEntry(
toolCallId,
MODIFY_REPORTS_TOOL_NAME,
{
edits: state.edits,
}
);
const updates: Parameters<typeof updateMessageEntries>[0] = { const updates: Parameters<typeof updateMessageEntries>[0] = {
messageId: context.messageId, messageId: context.messageId,
@ -353,7 +369,7 @@ export function createModifyReportsExecute(
} }
if (rawLlmMessage) { if (rawLlmMessage) {
updates.rawLlmMessages = [rawLlmMessage]; updates.rawLlmMessages = [rawLlmMessage, rawLlmResultEntry];
} }
if (reasoningEntry || rawLlmMessage) { if (reasoningEntry || rawLlmMessage) {

View File

@ -1,3 +1,4 @@
import { randomUUID } from 'node:crypto';
import type { ModelMessage } from 'ai'; import type { ModelMessage } from 'ai';
import { z } from 'zod'; import { z } from 'zod';
import { DocsAgentContextSchema } from '../../agents/docs-agent/docs-agent-context'; import { DocsAgentContextSchema } from '../../agents/docs-agent/docs-agent-context';
@ -83,11 +84,15 @@ export async function runDocsAgentWorkflow(input: DocsAgentWorkflowInput): Promi
repositoryTree: treeResult.repositoryTree, repositoryTree: treeResult.repositoryTree,
}); });
// TODO: This is a temporary solution to get a messageId
const messageId = randomUUID();
// Step 4: Execute the docs agent with all the prepared data // Step 4: Execute the docs agent with all the prepared data
const _agentResult = await runDocsAgentStep({ const _agentResult = await runDocsAgentStep({
todos: todosResult.todos, todos: todosResult.todos,
todoList: todosResult.todos, // Using todos as todoList todoList: todosResult.todos, // Using todos as todoList
message: treeResult.message, message: treeResult.message,
messageId: messageId,
organizationId: treeResult.organizationId, organizationId: treeResult.organizationId,
context: treeResult.context, context: treeResult.context,
repositoryTree: treeResult.repositoryTree, repositoryTree: treeResult.repositoryTree,

View File

@ -1,9 +1,6 @@
import { describe, expect, it } from 'vitest'; import { describe, expect, it } from 'vitest';
import { import { ChatMessageSchema } from './chat-message.types';
ChatMessageSchema, import { ReasoningMessageSchema, ResponseMessageSchema } from '@buster/database';
ReasoningMessageSchema,
ResponseMessageSchema,
} from './chat-message.types';
describe('ChatMessageSchema', () => { describe('ChatMessageSchema', () => {
it('should parse a valid complete chat message', () => { it('should parse a valid complete chat message', () => {