tools updating properly

This commit is contained in:
dal 2025-08-15 15:24:05 -06:00
parent e2757c1ad0
commit c476aebd47
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
10 changed files with 115 additions and 30 deletions

View File

@ -33,7 +33,7 @@ const DocsAgentOptionsSchema = z.object({
chatId: z.string(),
dataSourceId: z.string(),
organizationId: z.string(),
messageId: z.string().optional(),
messageId: z.string(),
sandbox: z
.custom<Sandbox>(
(val) => {
@ -55,8 +55,6 @@ export type DocsStreamOptions = z.infer<typeof DocsStreamOptionsSchema>;
export type DocsAgentContextWithSandbox = DocsAgentOptions & { sandbox: Sandbox };
export function createDocsAgent(docsAgentOptions: DocsAgentOptions) {
const steps: never[] = [];
const systemMessage = {
role: 'system',
content: getDocsAgentSystemPrompt(docsAgentOptions.folder_structure),
@ -159,12 +157,7 @@ export function createDocsAgent(docsAgentOptions: DocsAgentOptions) {
)();
}
async function getSteps() {
return steps;
}
return {
stream,
getSteps,
};
}

View File

@ -9,6 +9,7 @@ export const DocsAgentStepInputSchema = z.object({
todos: z.string().describe('The todos string'),
todoList: z.string().describe('The TODO list'),
message: z.string().describe('The user message'),
messageId: z.string().describe('The user message'),
organizationId: z.string().describe('The organization ID'),
context: DocsAgentContextSchema.describe('The docs agent context'),
repositoryTree: z.string().describe('The tree structure of the repository'),
@ -73,7 +74,7 @@ export async function runDocsAgentStep(params: DocsAgentStepInput): Promise<void
chatId: Date.now().toString(), // Using current timestamp as chatId
dataSourceId: dataSourceId || '',
organizationId: validatedParams.organizationId || '',
messageId: undefined, // Optional field
messageId: validatedParams.messageId, // Optional field
sandbox: sandbox, // Pass sandbox for file tools
});

View File

@ -16,12 +16,14 @@ import {
type DashboardYml,
DashboardYmlSchema,
} from '../../../../../../server-shared/src/dashboards/dashboard.types';
import { createRawToolResultEntry } from '../../../shared/create-raw-llm-tool-result-entry';
import { trackFileAssociations } from '../../file-tracking-helper';
import type {
CreateDashboardsContext,
CreateDashboardsInput,
CreateDashboardsOutput,
CreateDashboardsState,
import {
CREATE_DASHBOARDS_TOOL_NAME,
type CreateDashboardsContext,
type CreateDashboardsInput,
type CreateDashboardsOutput,
type CreateDashboardsState,
} from './create-dashboards-tool';
import {
createCreateDashboardsRawLlmMessageEntry,
@ -564,6 +566,13 @@ export function createCreateDashboardsExecute(
const reasoningEntry = createCreateDashboardsReasoningEntry(state, toolCallId);
const rawLlmMessage = createCreateDashboardsRawLlmMessageEntry(state, toolCallId);
const rawLlmResultEntry = createRawToolResultEntry(
toolCallId,
CREATE_DASHBOARDS_TOOL_NAME,
{
files: state.files,
}
);
const updates: Parameters<typeof updateMessageEntries>[0] = {
messageId: context.messageId,
@ -574,7 +583,7 @@ export function createCreateDashboardsExecute(
}
if (rawLlmMessage) {
updates.rawLlmMessages = [rawLlmMessage];
updates.rawLlmMessages = [rawLlmMessage, rawLlmResultEntry];
}
if (reasoningEntry || rawLlmMessage) {

View File

@ -10,6 +10,7 @@ import {
type DashboardYml,
DashboardYmlSchema,
} from '../../../../../../server-shared/src/dashboards/dashboard.types';
import { createRawToolResultEntry } from '../../../shared/create-raw-llm-tool-result-entry';
import { trackFileAssociations } from '../../file-tracking-helper';
import {
createModifyDashboardsRawLlmMessageEntry,
@ -21,6 +22,7 @@ import type {
ModifyDashboardsOutput,
ModifyDashboardsState,
} from './modify-dashboards-tool';
import { MODIFY_DASHBOARDS_TOOL_NAME } from './modify-dashboards-tool';
// Type definitions
type DashboardWithMetadata = DashboardYml;
@ -529,6 +531,13 @@ export function createModifyDashboardsExecute(
const reasoningEntry = createModifyDashboardsReasoningEntry(state, toolCallId);
const rawLlmMessage = createModifyDashboardsRawLlmMessageEntry(state, toolCallId);
const rawLlmResultEntry = createRawToolResultEntry(
toolCallId,
MODIFY_DASHBOARDS_TOOL_NAME,
{
files: state.files,
}
);
const updates: Parameters<typeof updateMessageEntries>[0] = {
messageId: context.messageId,
@ -539,7 +548,7 @@ export function createModifyDashboardsExecute(
}
if (rawLlmMessage) {
updates.rawLlmMessages = [rawLlmMessage];
updates.rawLlmMessages = [rawLlmMessage, rawLlmResultEntry];
}
if (reasoningEntry || rawLlmMessage) {
@ -586,6 +595,13 @@ export function createModifyDashboardsExecute(
const reasoningEntry = createModifyDashboardsReasoningEntry(state, toolCallId);
const rawLlmMessage = createModifyDashboardsRawLlmMessageEntry(state, toolCallId);
const rawLlmResultEntry = createRawToolResultEntry(
toolCallId,
MODIFY_DASHBOARDS_TOOL_NAME,
{
files: state.files,
}
);
const updates: Parameters<typeof updateMessageEntries>[0] = {
messageId: context.messageId,
@ -596,7 +612,7 @@ export function createModifyDashboardsExecute(
}
if (rawLlmMessage) {
updates.rawLlmMessages = [rawLlmMessage];
updates.rawLlmMessages = [rawLlmMessage, rawLlmResultEntry];
}
if (reasoningEntry || rawLlmMessage) {

View File

@ -15,6 +15,7 @@ import {
createPermissionErrorMessage,
validateSqlPermissions,
} from '../../../../utils/sql-permissions';
import { createRawToolResultEntry } from '../../../shared/create-raw-llm-tool-result-entry';
import { trackFileAssociations } from '../../file-tracking-helper';
import { validateAndAdjustBarLineAxes } from '../helpers/bar-line-axis-validator';
import { ensureTimeFrameQuoted } from '../helpers/time-frame-helper';
@ -24,6 +25,7 @@ import type {
CreateMetricsOutput,
CreateMetricsState,
} from './create-metrics-tool';
import { CREATE_METRICS_TOOL_NAME } from './create-metrics-tool';
import {
createCreateMetricsRawLlmMessageEntry,
createCreateMetricsReasoningEntry,
@ -739,6 +741,13 @@ export function createCreateMetricsExecute(
const reasoningEntry = createCreateMetricsReasoningEntry(state, toolCallId);
const rawLlmMessage = createCreateMetricsRawLlmMessageEntry(state, toolCallId);
const rawLlmResultEntry = createRawToolResultEntry(
toolCallId,
CREATE_METRICS_TOOL_NAME,
{
files: state.files,
}
);
const updates: Parameters<typeof updateMessageEntries>[0] = {
messageId: context.messageId,
@ -749,7 +758,7 @@ export function createCreateMetricsExecute(
}
if (rawLlmMessage) {
updates.rawLlmMessages = [rawLlmMessage];
updates.rawLlmMessages = [rawLlmMessage, rawLlmResultEntry];
}
if (reasoningEntry || rawLlmMessage) {
@ -796,6 +805,13 @@ export function createCreateMetricsExecute(
const reasoningEntry = createCreateMetricsReasoningEntry(state, toolCallId);
const rawLlmMessage = createCreateMetricsRawLlmMessageEntry(state, toolCallId);
const rawLlmResultEntry = createRawToolResultEntry(
toolCallId,
CREATE_METRICS_TOOL_NAME,
{
files: state.files,
}
);
const updates: Parameters<typeof updateMessageEntries>[0] = {
messageId: context.messageId,
@ -806,7 +822,7 @@ export function createCreateMetricsExecute(
}
if (rawLlmMessage) {
updates.rawLlmMessages = [rawLlmMessage];
updates.rawLlmMessages = [rawLlmMessage, rawLlmResultEntry];
}
if (reasoningEntry || rawLlmMessage) {

View File

@ -16,6 +16,7 @@ import {
createPermissionErrorMessage,
validateSqlPermissions,
} from '../../../../utils/sql-permissions';
import { createRawToolResultEntry } from '../../../shared/create-raw-llm-tool-result-entry';
import { trackFileAssociations } from '../../file-tracking-helper';
import { validateAndAdjustBarLineAxes } from '../helpers/bar-line-axis-validator';
import { ensureTimeFrameQuoted } from '../helpers/time-frame-helper';
@ -29,6 +30,7 @@ import type {
ModifyMetricsOutput,
ModifyMetricsState,
} from './modify-metrics-tool';
import { MODIFY_METRICS_TOOL_NAME } from './modify-metrics-tool';
interface FileWithId {
id: string;
@ -811,6 +813,13 @@ export function createModifyMetricsExecute(
const reasoningEntry = createModifyMetricsReasoningEntry(state, toolCallId);
const rawLlmMessage = createModifyMetricsRawLlmMessageEntry(state, toolCallId);
const rawLlmResultEntry = createRawToolResultEntry(
toolCallId,
MODIFY_METRICS_TOOL_NAME,
{
files: state.files,
}
);
const updates: Parameters<typeof updateMessageEntries>[0] = {
messageId: context.messageId,
@ -821,7 +830,7 @@ export function createModifyMetricsExecute(
}
if (rawLlmMessage) {
updates.rawLlmMessages = [rawLlmMessage];
updates.rawLlmMessages = [rawLlmMessage, rawLlmResultEntry];
}
if (reasoningEntry || rawLlmMessage) {
@ -864,6 +873,13 @@ export function createModifyMetricsExecute(
try {
const reasoningEntry = createModifyMetricsReasoningEntry(state, state.toolCallId);
const rawLlmMessage = createModifyMetricsRawLlmMessageEntry(state, state.toolCallId);
const rawLlmResultEntry = createRawToolResultEntry(
state.toolCallId,
MODIFY_METRICS_TOOL_NAME,
{
files: state.files,
}
);
const updates: Parameters<typeof updateMessageEntries>[0] = {
messageId: context.messageId,
@ -874,7 +890,7 @@ export function createModifyMetricsExecute(
}
if (rawLlmMessage) {
updates.rawLlmMessages = [rawLlmMessage];
updates.rawLlmMessages = [rawLlmMessage, rawLlmResultEntry];
}
if (reasoningEntry || rawLlmMessage) {

View File

@ -2,6 +2,7 @@ import { randomUUID } from 'node:crypto';
import { db, updateMessageEntries } from '@buster/database';
import { assetPermissions, reportFiles } from '@buster/database';
import { wrapTraced } from 'braintrust';
import { createRawToolResultEntry } from '../../../shared/create-raw-llm-tool-result-entry';
import { trackFileAssociations } from '../../file-tracking-helper';
import type {
CreateReportsContext,
@ -9,6 +10,7 @@ import type {
CreateReportsOutput,
CreateReportsState,
} from './create-reports-tool';
import { CREATE_REPORTS_TOOL_NAME } from './create-reports-tool';
import {
createCreateReportsRawLlmMessageEntry,
createCreateReportsReasoningEntry,
@ -400,6 +402,13 @@ export function createCreateReportsExecute(
const reasoningEntry = createCreateReportsReasoningEntry(state, toolCallId);
const rawLlmMessage = createCreateReportsRawLlmMessageEntry(state, toolCallId);
const rawLlmResultEntry = createRawToolResultEntry(
toolCallId,
CREATE_REPORTS_TOOL_NAME,
{
files: state.files,
}
);
const updates: Parameters<typeof updateMessageEntries>[0] = {
messageId: context.messageId,
@ -410,7 +419,7 @@ export function createCreateReportsExecute(
}
if (rawLlmMessage) {
updates.rawLlmMessages = [rawLlmMessage];
updates.rawLlmMessages = [rawLlmMessage, rawLlmResultEntry];
}
if (reasoningEntry || rawLlmMessage) {
@ -457,6 +466,13 @@ export function createCreateReportsExecute(
const reasoningEntry = createCreateReportsReasoningEntry(state, toolCallId);
const rawLlmMessage = createCreateReportsRawLlmMessageEntry(state, toolCallId);
const rawLlmResultEntry = createRawToolResultEntry(
toolCallId,
CREATE_REPORTS_TOOL_NAME,
{
files: state.files,
}
);
const updates: Parameters<typeof updateMessageEntries>[0] = {
messageId: context.messageId,
@ -467,7 +483,7 @@ export function createCreateReportsExecute(
}
if (rawLlmMessage) {
updates.rawLlmMessages = [rawLlmMessage];
updates.rawLlmMessages = [rawLlmMessage, rawLlmResultEntry];
}
if (reasoningEntry || rawLlmMessage) {

View File

@ -1,5 +1,6 @@
import { appendReportContent, replaceReportContent, updateMessageEntries } from '@buster/database';
import { wrapTraced } from 'braintrust';
import { createRawToolResultEntry } from '../../../shared/create-raw-llm-tool-result-entry';
import { trackFileAssociations } from '../../file-tracking-helper';
import {
createModifyReportsRawLlmMessageEntry,
@ -11,6 +12,7 @@ import type {
ModifyReportsOutput,
ModifyReportsState,
} from './modify-reports-tool';
import { MODIFY_REPORTS_TOOL_NAME } from './modify-reports-tool';
// Process a single edit operation
async function processEditOperation(
@ -283,6 +285,13 @@ export function createModifyReportsExecute(
const reasoningEntry = createModifyReportsReasoningEntry(state, toolCallId);
const rawLlmMessage = createModifyReportsRawLlmMessageEntry(state, toolCallId);
const rawLlmResultEntry = createRawToolResultEntry(
toolCallId,
MODIFY_REPORTS_TOOL_NAME,
{
edits: state.edits,
}
);
const updates: Parameters<typeof updateMessageEntries>[0] = {
messageId: context.messageId,
@ -293,7 +302,7 @@ export function createModifyReportsExecute(
}
if (rawLlmMessage) {
updates.rawLlmMessages = [rawLlmMessage];
updates.rawLlmMessages = [rawLlmMessage, rawLlmResultEntry];
}
if (reasoningEntry || rawLlmMessage) {
@ -343,6 +352,13 @@ export function createModifyReportsExecute(
const reasoningEntry = createModifyReportsReasoningEntry(state, toolCallId);
const rawLlmMessage = createModifyReportsRawLlmMessageEntry(state, toolCallId);
const rawLlmResultEntry = createRawToolResultEntry(
toolCallId,
MODIFY_REPORTS_TOOL_NAME,
{
edits: state.edits,
}
);
const updates: Parameters<typeof updateMessageEntries>[0] = {
messageId: context.messageId,
@ -353,7 +369,7 @@ export function createModifyReportsExecute(
}
if (rawLlmMessage) {
updates.rawLlmMessages = [rawLlmMessage];
updates.rawLlmMessages = [rawLlmMessage, rawLlmResultEntry];
}
if (reasoningEntry || rawLlmMessage) {

View File

@ -1,3 +1,4 @@
import { randomUUID } from 'node:crypto';
import type { ModelMessage } from 'ai';
import { z } from 'zod';
import { DocsAgentContextSchema } from '../../agents/docs-agent/docs-agent-context';
@ -83,11 +84,15 @@ export async function runDocsAgentWorkflow(input: DocsAgentWorkflowInput): Promi
repositoryTree: treeResult.repositoryTree,
});
// TODO: This is a temporary solution to get a messageId
const messageId = randomUUID();
// Step 4: Execute the docs agent with all the prepared data
const _agentResult = await runDocsAgentStep({
todos: todosResult.todos,
todoList: todosResult.todos, // Using todos as todoList
message: treeResult.message,
messageId: messageId,
organizationId: treeResult.organizationId,
context: treeResult.context,
repositoryTree: treeResult.repositoryTree,

View File

@ -1,9 +1,6 @@
import { describe, expect, it } from 'vitest';
import {
ChatMessageSchema,
ReasoningMessageSchema,
ResponseMessageSchema,
} from './chat-message.types';
import { ChatMessageSchema } from './chat-message.types';
import { ReasoningMessageSchema, ResponseMessageSchema } from '@buster/database';
describe('ChatMessageSchema', () => {
it('should parse a valid complete chat message', () => {