Merge pull request #1247 from buster-so/staging

Staging
This commit is contained in:
dal 2025-10-02 10:08:04 -06:00 committed by GitHub
commit 86417ff057
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 3442 additions and 248 deletions

View File

@ -1,22 +1,16 @@
import type { GetDashboardResponse } from '@buster/server-shared/dashboards';
import { type QueryClient, useQueryClient } from '@tanstack/react-query';
import last from 'lodash/last';
import { dashboardQueryKeys } from '@/api/query_keys/dashboard';
import { metricsQueryKeys } from '@/api/query_keys/metric';
import { useBusterNotifications } from '@/context/BusterNotifications';
import { setOriginalDashboard } from '@/context/Dashboards/useOriginalDashboardStore';
import { setOriginalMetric } from '@/context/Metrics/useOriginalMetricStore';
import { useMemoizedFn } from '@/hooks/useMemoizedFn';
import { upgradeMetricToIMetric } from '@/lib/metrics/upgradeToIMetric';
import { prefetchGetMetricDataClient } from '../metrics/queryRequests';
import { initializeMetrics } from '../metrics/metricQueryHelpers';
import { getDashboardById } from './requests';
export const useEnsureDashboardConfig = (params?: { prefetchData?: boolean }) => {
const { prefetchData = true } = params || {};
const queryClient = useQueryClient();
const prefetchDashboard = useGetDashboardAndInitializeMetrics({
prefetchData,
});
const { openErrorMessage } = useBusterNotifications();
const method = useMemoizedFn(
@ -24,11 +18,13 @@ export const useEnsureDashboardConfig = (params?: { prefetchData?: boolean }) =>
const options = dashboardQueryKeys.dashboardGetDashboard(dashboardId, 'LATEST');
let dashboardResponse = queryClient.getQueryData(options.queryKey);
if (!dashboardResponse) {
const res = await prefetchDashboard({
const res = await getDashboardAndInitializeMetrics({
id: dashboardId,
version_number: 'LATEST',
shouldInitializeMetrics: initializeMetrics,
password,
queryClient,
shouldInitializeMetrics: initializeMetrics,
prefetchMetricsData: prefetchData,
}).catch(() => {
openErrorMessage('Failed to save metrics to dashboard. Dashboard not found');
});
@ -48,63 +44,6 @@ export const useEnsureDashboardConfig = (params?: { prefetchData?: boolean }) =>
return method;
};
export const initializeMetrics = (
metrics: GetDashboardResponse['metrics'],
queryClient: QueryClient,
prefetchData: boolean
) => {
for (const metric of Object.values(metrics)) {
const upgradedMetric = upgradeMetricToIMetric(metric, null);
queryClient.setQueryData(
metricsQueryKeys.metricsGetMetric(metric.id, metric.version_number).queryKey,
upgradedMetric
);
const isLatestVersion = metric.version_number === last(metric.versions)?.version_number;
if (isLatestVersion) {
setOriginalMetric(upgradedMetric);
queryClient.setQueryData(
metricsQueryKeys.metricsGetMetric(metric.id, 'LATEST').queryKey,
upgradedMetric
);
}
if (prefetchData) {
prefetchGetMetricDataClient(
{ id: metric.id, version_number: metric.version_number },
queryClient
);
}
}
};
export const useGetDashboardAndInitializeMetrics = (params?: { prefetchData?: boolean }) => {
const { prefetchData = true } = params || {};
const queryClient = useQueryClient();
return useMemoizedFn(
async ({
id,
version_number,
shouldInitializeMetrics = true,
password,
}: {
id: string;
version_number: number | 'LATEST';
shouldInitializeMetrics?: boolean;
password: string | undefined;
}) => {
return getDashboardAndInitializeMetrics({
id,
version_number,
password,
queryClient,
shouldInitializeMetrics,
prefetchMetricsData: prefetchData,
});
}
);
};
//Can use this in server side
export const getDashboardAndInitializeMetrics = async ({
id,
version_number,

View File

@ -3,7 +3,7 @@ import { useNavigate } from '@tanstack/react-router';
import last from 'lodash/last';
import { dashboardQueryKeys } from '@/api/query_keys/dashboard';
import { setOriginalDashboard } from '@/context/Dashboards/useOriginalDashboardStore';
import { initializeMetrics } from '../dashboardQueryHelpers';
import { initializeMetrics } from '../../metrics/metricQueryHelpers';
import { dashboardsUpdateDashboard } from '../requests';
/**

View File

@ -15,10 +15,7 @@ import {
import { useMemoizedFn } from '@/hooks/useMemoizedFn';
import { isQueryStale } from '@/lib/query';
import { hasOrganizationId } from '../../users/userQueryHelpers';
import {
getDashboardAndInitializeMetrics,
useGetDashboardAndInitializeMetrics,
} from '../dashboardQueryHelpers';
import { getDashboardAndInitializeMetrics } from '../dashboardQueryHelpers';
import { useGetDashboardVersionNumber } from '../dashboardVersionNumber';
import { dashboardsGetList } from '../requests';
@ -36,14 +33,22 @@ export const useGetDashboard = <TData = GetDashboardResponse>(
) => {
const id = idProp || '';
const password = useProtectedAssetPassword(id);
const queryFn = useGetDashboardAndInitializeMetrics();
const queryClient = useQueryClient();
// const queryFn = useGetDashboardAndInitializeMetrics();
const { selectedVersionNumber } = useGetDashboardVersionNumber(id, versionNumberProp);
const { isFetched: isFetchedInitial, isError: isErrorInitial } = useQuery({
...dashboardQueryKeys.dashboardGetDashboard(id, 'LATEST'),
queryFn: () =>
queryFn({ id, version_number: 'LATEST', shouldInitializeMetrics: true, password }),
getDashboardAndInitializeMetrics({
id,
version_number: 'LATEST',
password,
queryClient,
shouldInitializeMetrics: true,
prefetchMetricsData: true,
}),
enabled: true,
retry(_failureCount, error) {
if (error?.message !== undefined) {
@ -61,11 +66,13 @@ export const useGetDashboard = <TData = GetDashboardResponse>(
return useQuery({
...dashboardQueryKeys.dashboardGetDashboard(id, selectedVersionNumber),
queryFn: () =>
queryFn({
getDashboardAndInitializeMetrics({
id,
version_number: selectedVersionNumber,
shouldInitializeMetrics: true,
password,
queryClient,
shouldInitializeMetrics: true,
prefetchMetricsData: true,
}),
enabled: isFetchedInitial && !isErrorInitial,
select: params?.select,
@ -80,7 +87,6 @@ export const usePrefetchGetDashboardClient = <TData = GetDashboardResponse>(
params?: Omit<UseQueryOptions<GetDashboardResponse, ApiError, TData>, 'queryKey' | 'queryFn'>
) => {
const queryClient = useQueryClient();
const queryFn = useGetDashboardAndInitializeMetrics({ prefetchData: false });
return useMemoizedFn((id: string, versionNumber: number | 'LATEST') => {
const getDashboardQueryKey = dashboardQueryKeys.dashboardGetDashboard(id, versionNumber);
@ -89,11 +95,13 @@ export const usePrefetchGetDashboardClient = <TData = GetDashboardResponse>(
return queryClient.prefetchQuery({
...dashboardQueryKeys.dashboardGetDashboard(id, versionNumber),
queryFn: () =>
queryFn({
getDashboardAndInitializeMetrics({
id,
version_number: versionNumber,
shouldInitializeMetrics: true,
version_number: 'LATEST',
password: undefined,
queryClient,
shouldInitializeMetrics: true,
prefetchMetricsData: false,
}),
...params,
});

View File

@ -1,8 +1,14 @@
import { useQueryClient } from '@tanstack/react-query';
import type { GetDashboardResponse } from '@buster/server-shared/dashboards';
import type { GetReportResponse } from '@buster/server-shared/reports';
import { type QueryClient, useQueryClient } from '@tanstack/react-query';
import last from 'lodash/last';
import { setOriginalMetric } from '@/context/Metrics/useOriginalMetricStore';
import { useMemoizedFn } from '@/hooks/useMemoizedFn';
import { upgradeMetricToIMetric } from '@/lib/metrics/upgradeToIMetric';
import type { BusterMetricDataExtended } from '../../asset_interfaces/metric';
import type { BusterMetric } from '../../asset_interfaces/metric/interfaces';
import { metricsQueryKeys } from '../../query_keys/metric';
import { prefetchGetMetricDataClient } from './getMetricQueryRequests';
export const useGetMetricMemoized = () => {
const queryClient = useQueryClient();
@ -37,3 +43,31 @@ export const useGetMetricDataMemoized = () => {
);
return getMetricDataMemoized;
};
export const initializeMetrics = (
metrics: GetDashboardResponse['metrics'] | GetReportResponse['metrics'],
queryClient: QueryClient,
prefetchData: boolean
) => {
for (const metric of Object.values(metrics)) {
const upgradedMetric = upgradeMetricToIMetric(metric, null);
queryClient.setQueryData(
metricsQueryKeys.metricsGetMetric(metric.id, metric.version_number).queryKey,
upgradedMetric
);
const isLatestVersion = metric.version_number === last(metric.versions)?.version_number;
if (isLatestVersion) {
setOriginalMetric(upgradedMetric);
queryClient.setQueryData(
metricsQueryKeys.metricsGetMetric(metric.id, 'LATEST').queryKey,
upgradedMetric
);
}
if (prefetchData) {
prefetchGetMetricDataClient(
{ id: metric.id, version_number: metric.version_number },
queryClient
);
}
}
};

View File

@ -17,7 +17,8 @@ import {
useRemoveAssetFromCollection,
} from '../collections/queryRequests';
import { useGetUserFavorites } from '../users/favorites';
import { getReportById, getReportsList, updateReport } from './requests';
import { getReportAndInitializeMetrics } from './reportQueryHelpers';
import { getReportsList, updateReport } from './requests';
/**
* Hook to get a list of reports
@ -80,9 +81,13 @@ export const prefetchGetReport = async (
await queryClient.prefetchQuery({
...reportsQueryKeys.reportsGetReport(reportId, version_number || 'LATEST'),
queryFn: () =>
getReportById({
getReportAndInitializeMetrics({
id: reportId,
version_number: typeof version_number === 'number' ? version_number : undefined,
password: undefined,
queryClient,
shouldInitializeMetrics: true,
prefetchMetricsData: false,
}),
retry: silenceAssetErrors,
});
@ -106,17 +111,20 @@ export const useGetReport = <T = GetReportResponse>(
options?: Omit<UseQueryOptions<GetReportResponse, ApiError, T>, 'queryKey' | 'queryFn'>
) => {
const password = useProtectedAssetPassword(id || '');
const queryFn = () => {
return getReportById({
id: id ?? '',
version_number: typeof versionNumber === 'number' ? versionNumber : undefined,
password,
});
};
const queryClient = useQueryClient();
return useQuery({
...reportsQueryKeys.reportsGetReport(id ?? '', versionNumber || 'LATEST'),
queryFn,
queryFn: () => {
return getReportAndInitializeMetrics({
id: id ?? '',
version_number: typeof versionNumber === 'number' ? versionNumber : undefined,
password,
queryClient,
shouldInitializeMetrics: true,
prefetchMetricsData: true,
});
},
enabled: !!id,
select: options?.select,
...options,

View File

@ -0,0 +1,48 @@
import type { QueryClient } from '@tanstack/react-query';
import last from 'lodash/last';
import { reportsQueryKeys } from '@/api/query_keys/reports';
import { initializeMetrics } from '../metrics/metricQueryHelpers';
import { getReportById } from './requests';
export const getReportAndInitializeMetrics = async ({
id,
version_number,
password,
queryClient,
shouldInitializeMetrics = true,
prefetchMetricsData = false,
}: {
id: string;
version_number: number | 'LATEST' | undefined;
password: string | undefined;
queryClient: QueryClient;
shouldInitializeMetrics: boolean;
prefetchMetricsData: boolean;
}) => {
const chosenVersionNumber = version_number === 'LATEST' ? undefined : version_number;
return getReportById({
id,
version_number: chosenVersionNumber,
password,
}).then((data) => {
const latestVersion = last(data.versions)?.version_number || 1;
const isLatestVersion = data.version_number === latestVersion;
if (isLatestVersion) {
// set the original report?
}
if (data.version_number) {
queryClient.setQueryData(
reportsQueryKeys.reportsGetReport(data.id, data.version_number).queryKey,
data
);
}
if (shouldInitializeMetrics || prefetchMetricsData) {
initializeMetrics(data.metrics, queryClient, !!prefetchMetricsData);
}
return data;
});
};

View File

@ -203,7 +203,6 @@ export const useXAxis = ({
return formatLabel(value, firstXColumnLabelFormat);
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any -- I had a devil of a time trying to type this... This is a hack to get the type to work
return DEFAULT_X_AXIS_TICK_CALLBACK.call(
this,
value,

View File

@ -1,13 +1,9 @@
import {
type ColumnLabelFormat,
DEFAULT_COLUMN_LABEL_FORMAT,
DEFAULT_COLUMN_SETTINGS,
} from '@buster/server-shared/metrics';
import { type ColumnLabelFormat, DEFAULT_COLUMN_SETTINGS } from '@buster/server-shared/metrics';
import { describe, expect, it } from 'vitest';
import type { BusterChartProps } from '../../../BusterChart.types';
import type { DatasetOption } from '../../../chartHooks';
import type { DatasetOptionsWithTicks } from '../../../chartHooks/useDatasetOptions/interfaces';
import { barSeriesBuilder } from './barSeriesBuilder';
import { barSeriesBuilder, barSeriesBuilder_labels } from './barSeriesBuilder';
import type { SeriesBuilderProps } from './interfaces';
describe('barSeriesBuilder', () => {
@ -313,3 +309,210 @@ describe('percentage mode logic', () => {
expect(result).toBe(false);
});
});
describe('barSeriesBuilder_labels - dateTicks', () => {
it('should return date ticks when columnLabelFormat has date style and single xAxis', () => {
// Arrange
const mockDatasetOptions: DatasetOptionsWithTicks = {
datasets: [],
ticks: [['2024-01-01'], ['2024-02-01'], ['2024-03-01']],
ticksKey: [{ key: 'date', value: '' }],
};
const columnLabelFormats = {
date: {
columnType: 'date',
style: 'date',
minimumFractionDigits: 0,
maximumFractionDigits: 0,
multiplier: 1,
prefix: '',
suffix: '',
replaceMissingDataWith: 0,
makeLabelHumanReadable: true,
} as ColumnLabelFormat,
};
// Act
const result = barSeriesBuilder_labels({
datasetOptions: mockDatasetOptions,
columnLabelFormats,
xAxisKeys: ['date'],
});
// Assert
expect(result).toHaveLength(3);
expect(result[0]).toBeInstanceOf(Date);
expect(result[1]).toBeInstanceOf(Date);
expect(result[2]).toBeInstanceOf(Date);
});
it('should return quarter ticks when columnLabelFormat has quarter convertNumberTo and single xAxis', () => {
// Arrange
const mockDatasetOptions: DatasetOptionsWithTicks = {
datasets: [],
ticks: [[1], [2], [3], [4]],
ticksKey: [{ key: 'quarter', value: '' }],
};
const columnLabelFormats = {
quarter: {
columnType: 'number',
style: 'date',
convertNumberTo: 'quarter',
minimumFractionDigits: 0,
maximumFractionDigits: 0,
multiplier: 1,
prefix: '',
suffix: '',
replaceMissingDataWith: 0,
makeLabelHumanReadable: true,
} as ColumnLabelFormat,
};
// Act
const result = barSeriesBuilder_labels({
datasetOptions: mockDatasetOptions,
columnLabelFormats,
xAxisKeys: ['quarter'],
});
// Assert
expect(result).toHaveLength(4);
expect(result[0]).toBe('Q1');
expect(result[1]).toBe('Q2');
expect(result[2]).toBe('Q3');
expect(result[3]).toBe('Q4');
});
it('should return quarter ticks with double xAxis when one is quarter and one is number', () => {
// Arrange
const mockDatasetOptions: DatasetOptionsWithTicks = {
datasets: [],
ticks: [
[1, 2023],
[2, 2023],
[3, 2023],
],
ticksKey: [
{ key: 'quarter', value: '' },
{ key: 'year', value: '' },
],
};
const columnLabelFormats = {
quarter: {
columnType: 'number',
style: 'date',
convertNumberTo: 'quarter',
minimumFractionDigits: 0,
maximumFractionDigits: 0,
multiplier: 1,
prefix: '',
suffix: '',
replaceMissingDataWith: 0,
makeLabelHumanReadable: true,
} as ColumnLabelFormat,
year: {
columnType: 'number',
style: 'number',
minimumFractionDigits: 0,
maximumFractionDigits: 0,
multiplier: 1,
prefix: '',
suffix: '',
numberSeparatorStyle: null,
replaceMissingDataWith: 0,
makeLabelHumanReadable: true,
} as ColumnLabelFormat,
};
// Act
const result = barSeriesBuilder_labels({
datasetOptions: mockDatasetOptions,
columnLabelFormats,
xAxisKeys: ['quarter', 'year'],
});
console.log('result', result[0]);
// Assert
expect(result).toHaveLength(3);
expect(result[0]).toContain('Q1');
expect(result[0]).toContain('2023');
expect(result[1]).toContain('Q2');
expect(result[2]).toContain('Q3');
});
it('should return null dateTicks when columnLabelFormat does not have date style', () => {
// Arrange
const mockDatasetOptions: DatasetOptionsWithTicks = {
datasets: [],
ticks: [['Product A'], ['Product B'], ['Product C']],
ticksKey: [{ key: 'product', value: '' }],
};
const columnLabelFormats = {
product: {
columnType: 'string',
style: 'string',
minimumFractionDigits: 0,
maximumFractionDigits: 0,
multiplier: 1,
prefix: '',
suffix: '',
replaceMissingDataWith: 0,
makeLabelHumanReadable: true,
} as unknown as ColumnLabelFormat,
};
// Act
const result = barSeriesBuilder_labels({
datasetOptions: mockDatasetOptions,
columnLabelFormats,
xAxisKeys: ['product'],
});
// Assert
expect(result).toHaveLength(3);
expect(result[0]).toBe('Product A');
expect(result[1]).toBe('Product B');
expect(result[2]).toBe('Product C');
});
it('should return date ticks when columnType is number but style is date', () => {
// Arrange
const mockDatasetOptions: DatasetOptionsWithTicks = {
datasets: [],
ticks: [['2024-01-15'], ['2024-01-16'], ['2024-01-17']],
ticksKey: [{ key: 'timestamp', value: '' }],
};
const columnLabelFormats = {
timestamp: {
columnType: 'number',
style: 'date',
minimumFractionDigits: 0,
maximumFractionDigits: 0,
multiplier: 1,
prefix: '',
suffix: '',
replaceMissingDataWith: 0,
makeLabelHumanReadable: true,
} as ColumnLabelFormat,
};
// Act
const result = barSeriesBuilder_labels({
datasetOptions: mockDatasetOptions,
columnLabelFormats,
xAxisKeys: ['timestamp'],
});
// Assert
expect(result).toHaveLength(3);
expect(result[0]).toBeInstanceOf(Date);
expect(result[1]).toBeInstanceOf(Date);
expect(result[2]).toBeInstanceOf(Date);
});
});

View File

@ -11,6 +11,7 @@ import { DEFAULT_CHART_LAYOUT } from '../../ChartJSTheme';
import type { ChartProps } from '../../core';
import { dataLabelFontColorContrast, formatBarAndLineDataLabel } from '../../helpers';
import { defaultLabelOptionConfig } from '../useChartSpecificOptions/labelOptionConfig';
import { createTickDates } from './createTickDate';
import { createTrendlineOnSeries } from './createTrendlines';
import type { SeriesBuilderProps } from './interfaces';
import type { LabelBuilderProps } from './useSeriesOptions';
@ -352,8 +353,12 @@ const getFormattedValueAndSetBarDataLabels = (
export const barSeriesBuilder_labels = ({
datasetOptions,
columnLabelFormats,
}: Pick<LabelBuilderProps, 'datasetOptions' | 'columnLabelFormats'>) => {
const ticksKey = datasetOptions.ticksKey;
xAxisKeys,
}: Pick<LabelBuilderProps, 'datasetOptions' | 'columnLabelFormats' | 'xAxisKeys'>) => {
const dateTicks = createTickDates(datasetOptions.ticks, xAxisKeys, columnLabelFormats);
if (dateTicks) {
return dateTicks;
}
const containsADateStyle = datasetOptions.ticksKey.some((tick) => {
const selectedColumnLabelFormat = columnLabelFormats[tick.key];
@ -364,7 +369,7 @@ export const barSeriesBuilder_labels = ({
const labels = datasetOptions.ticks.flatMap((item) => {
return item
.map<string>((item, index) => {
const key = ticksKey[index]?.key || '';
const key = datasetOptions.ticksKey[index]?.key || '';
const columnLabelFormat = columnLabelFormats[key];
return formatLabel(item, columnLabelFormat);
})

View File

@ -45,6 +45,7 @@ export const createTickDates = (
}
const isDoubleXAxis = xAxisKeys.length === 2;
console.log('isDoubleXAxis', isDoubleXAxis);
if (isDoubleXAxis) {
const oneIsAQuarter = xColumnLabelFormats.findIndex(
(format) => format.convertNumberTo === 'quarter' && format.columnType === 'number'
@ -52,6 +53,8 @@ export const createTickDates = (
const oneIsANumber = xColumnLabelFormats.findIndex(
(format) => format.columnType === 'number' && format.style === 'number'
);
console.log('oneIsAQuarter', oneIsAQuarter);
console.log('oneIsANumber', oneIsANumber);
if (oneIsAQuarter !== -1 && oneIsANumber !== -1) {
return createQuarterTickDates(ticks, xColumnLabelFormats, oneIsAQuarter);
}

View File

@ -330,6 +330,7 @@ describe('lineSeriesBuilder', () => {
const labels = lineSeriesBuilder_labels(props);
expect(labels).toHaveLength(3);
console.log('labels', labels[0]);
expect(labels[0]).toBe('formatted-2023-01-01 formatted-A');
expect(labels[1]).toBe('formatted-2023-01-02 formatted-B');
expect(labels[2]).toBe('formatted-2023-01-03 formatted-A');

View File

@ -216,13 +216,9 @@ export const lineSeriesBuilder_labels = ({
xAxisKeys,
columnLabelFormats,
}: LabelBuilderProps): (string | Date)[] => {
const dateTicks = createTickDates(datasetOptions.ticks, xAxisKeys, columnLabelFormats);
if (dateTicks) {
return dateTicks;
}
return barSeriesBuilder_labels({
datasetOptions,
columnLabelFormats,
xAxisKeys,
});
};

View File

@ -1657,3 +1657,243 @@ export const BarChartWithSortedDayOfWeek: Story = {
],
},
};
export const BarWithProblemQuarters: Story = {
args: {
colors: [
'#B399FD',
'#FC8497',
'#FBBC30',
'#279EFF',
'#E83562',
'#41F8FF',
'#F3864F',
'#C82184',
'#31FCB4',
'#E83562',
],
barLayout: 'vertical',
barSortBy: [],
goalLines: [],
gridLines: true,
pieSortBy: 'value',
showLegend: null,
trendlines: [],
scatterAxis: {
x: [],
y: [],
size: [],
tooltip: null,
category: [],
},
barGroupType: 'stack',
metricHeader: null,
pieChartAxis: {
x: [],
y: [],
tooltip: null,
},
lineGroupType: null,
pieDonutWidth: 40,
xAxisDataZoom: false,
barAndLineAxis: {
x: ['quarter'],
y: ['product_count'],
colorBy: [],
tooltip: null,
category: ['metric_seasoncategory'],
},
columnSettings: {},
comboChartAxis: {
x: [],
y: [],
y2: [],
colorBy: [],
tooltip: null,
category: [],
},
disableTooltip: false,
metricColumnId: '',
scatterDotSize: [3, 15],
xAxisAxisTitle: null,
yAxisAxisTitle: null,
yAxisScaleType: 'linear',
metricSubHeader: null,
y2AxisAxisTitle: null,
y2AxisScaleType: 'linear',
metricValueLabel: null,
pieLabelPosition: 'none',
tableColumnOrder: null,
barShowTotalAtTop: false,
categoryAxisTitle: null,
pieDisplayLabelAs: 'number',
pieShowInnerLabel: true,
selectedChartType: 'bar',
tableColumnWidths: null,
xAxisTimeInterval: null,
columnLabelFormats: {
quarter: {
isUTC: false,
style: 'date',
prefix: '',
suffix: '',
currency: 'USD',
columnType: 'number',
dateFormat: 'auto',
multiplier: 1,
displayName: '',
compactNumbers: false,
convertNumberTo: 'quarter',
useRelativeTime: false,
numberSeparatorStyle: null,
maximumFractionDigits: 2,
minimumFractionDigits: 0,
makeLabelHumanReadable: true,
replaceMissingDataWith: 0,
},
product_count: {
isUTC: false,
style: 'number',
prefix: '',
suffix: '',
currency: 'USD',
columnType: 'number',
dateFormat: 'auto',
multiplier: 1,
displayName: 'Product Count',
compactNumbers: false,
convertNumberTo: null,
useRelativeTime: false,
numberSeparatorStyle: ',',
maximumFractionDigits: 0,
minimumFractionDigits: 0,
makeLabelHumanReadable: true,
replaceMissingDataWith: 0,
},
metric_seasoncategory: {
isUTC: false,
style: 'string',
prefix: '',
suffix: '',
currency: 'USD',
columnType: 'text',
dateFormat: 'auto',
multiplier: 1,
displayName: 'Season Category',
compactNumbers: false,
convertNumberTo: null,
useRelativeTime: false,
numberSeparatorStyle: null,
maximumFractionDigits: 2,
minimumFractionDigits: 0,
makeLabelHumanReadable: true,
replaceMissingDataWith: null,
},
},
pieInnerLabelTitle: null,
showLegendHeadline: false,
xAxisLabelRotation: 'auto',
xAxisShowAxisLabel: true,
xAxisShowAxisTitle: true,
yAxisShowAxisLabel: true,
yAxisShowAxisTitle: true,
y2AxisShowAxisLabel: true,
y2AxisShowAxisTitle: true,
metricValueAggregate: 'sum',
tableColumnFontColor: null,
tableHeaderFontColor: null,
yAxisStartAxisAtZero: null,
y2AxisStartAxisAtZero: true,
pieInnerLabelAggregate: 'sum',
pieMinimumSlicePercentage: 0,
tableHeaderBackgroundColor: null,
data: [
{
quarter: 1,
metric_seasoncategory: 'High Season',
product_count: 55,
},
{
quarter: 1,
metric_seasoncategory: 'Low Season',
product_count: 18,
},
{
quarter: 1,
metric_seasoncategory: 'Regular Season',
product_count: 174,
},
{
quarter: 2,
metric_seasoncategory: 'High Season',
product_count: 27,
},
{
quarter: 2,
metric_seasoncategory: 'Low Season',
product_count: 60,
},
{
quarter: 2,
metric_seasoncategory: 'Regular Season',
product_count: 156,
},
{
quarter: 3,
metric_seasoncategory: 'High Season',
product_count: 2,
},
{
quarter: 3,
metric_seasoncategory: 'Low Season',
product_count: 181,
},
{
quarter: 3,
metric_seasoncategory: 'Regular Season',
product_count: 82,
},
{
quarter: 4,
metric_seasoncategory: 'High Season',
product_count: 136,
},
{
quarter: 4,
metric_seasoncategory: 'Low Season',
product_count: 22,
},
{
quarter: 4,
metric_seasoncategory: 'Regular Season',
product_count: 108,
},
],
columnMetadata: [
{
name: 'quarter',
min_value: 1,
max_value: 4,
unique_values: 4,
simple_type: 'number',
type: 'numeric',
},
{
name: 'metric_seasoncategory',
min_value: 'High Season',
max_value: 'Regular Season',
unique_values: 3,
simple_type: 'text',
type: 'text',
},
{
name: 'product_count',
min_value: 2,
max_value: 181,
unique_values: 12,
simple_type: 'number',
type: 'int8',
},
],
},
};

View File

@ -218,7 +218,7 @@ export const MentionInputSuggestions = forwardRef<
className={inputClassName}
onBlur={onBlur}
/>
{children && <div className="mt-4.5">{children}</div>}
{children && <div className="mt-6">{children}</div>}
</MentionInputSuggestionsContainer>
<SuggestionsSeperator />
<MentionInputSuggestionsList

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -10,11 +10,27 @@ import {
createCreateMetricsTool,
createCreateReportsTool,
createDoneTool,
createExecuteSqlTool,
createModifyDashboardsTool,
createModifyMetricsTool,
createModifyReportsTool,
createSequentialThinkingTool,
} from '../../tools';
import { DONE_TOOL_NAME } from '../../tools/communication-tools/done-tool/done-tool';
import {
MESSAGE_USER_CLARIFYING_QUESTION_TOOL_NAME,
createMessageUserClarifyingQuestionTool,
} from '../../tools/communication-tools/message-user-clarifying-question/message-user-clarifying-question';
import {
RESPOND_WITHOUT_ASSET_CREATION_TOOL_NAME,
createRespondWithoutAssetCreationTool,
} from '../../tools/communication-tools/respond-without-asset-creation/respond-without-asset-creation-tool';
import {
SUBMIT_THOUGHTS_TOOL_NAME,
createSubmitThoughtsTool,
} from '../../tools/communication-tools/submit-thoughts-tool/submit-thoughts-tool';
import { EXECUTE_SQL_TOOL_NAME } from '../../tools/database-tools/execute-sql/execute-sql';
import { SEQUENTIAL_THINKING_TOOL_NAME } from '../../tools/planning-thinking-tools/sequential-thinking-tool/sequential-thinking-tool';
import { CREATE_DASHBOARDS_TOOL_NAME } from '../../tools/visualization-tools/dashboards/create-dashboards-tool/create-dashboards-tool';
import { MODIFY_DASHBOARDS_TOOL_NAME } from '../../tools/visualization-tools/dashboards/modify-dashboards-tool/modify-dashboards-tool';
import { CREATE_METRICS_TOOL_NAME } from '../../tools/visualization-tools/metrics/create-metrics-tool/create-metrics-tool';
@ -28,13 +44,22 @@ import { getAnalystAgentSystemPrompt } from './get-analyst-agent-system-prompt';
export const ANALYST_AGENT_NAME = 'analystAgent';
const STOP_CONDITIONS = [stepCountIs(25), hasToolCall(DONE_TOOL_NAME)];
const STOP_CONDITIONS = [
stepCountIs(25),
hasToolCall(DONE_TOOL_NAME),
hasToolCall(RESPOND_WITHOUT_ASSET_CREATION_TOOL_NAME),
hasToolCall(MESSAGE_USER_CLARIFYING_QUESTION_TOOL_NAME),
];
export const AnalystAgentOptionsSchema = z.object({
userId: z.string(),
chatId: z.string(),
dataSourceId: z.string(),
dataSourceSyntax: z.string(),
sql_dialect_guidance: z
.string()
.describe('The SQL dialect guidance for the analyst agent.')
.optional(),
organizationId: z.string(),
messageId: z.string(),
datasets: z.array(z.custom<PermissionedDataset>()),
@ -72,7 +97,10 @@ export function createAnalystAgent(analystAgentOptions: AnalystAgentOptions) {
const systemMessage = {
role: 'system',
content: getAnalystAgentSystemPrompt(analystAgentOptions.dataSourceSyntax),
content: getAnalystAgentSystemPrompt(
analystAgentOptions.dataSourceSyntax,
analystAgentOptions.analysisMode || 'standard'
),
providerOptions: DEFAULT_ANTHROPIC_OPTIONS,
} as ModelMessage;
@ -106,6 +134,26 @@ export function createAnalystAgent(analystAgentOptions: AnalystAgentOptions) {
: null;
async function stream({ messages }: AnalystStreamOptions) {
// Think-and-prep tools
const sequentialThinking = createSequentialThinkingTool({
messageId: analystAgentOptions.messageId,
});
const executeSqlTool = createExecuteSqlTool({
messageId: analystAgentOptions.messageId,
dataSourceId: analystAgentOptions.dataSourceId,
dataSourceSyntax: analystAgentOptions.dataSourceSyntax,
userId: analystAgentOptions.userId,
});
const respondWithoutAssetCreation = createRespondWithoutAssetCreationTool({
messageId: analystAgentOptions.messageId,
workflowStartTime: analystAgentOptions.workflowStartTime,
});
const messageUserClarifyingQuestion = createMessageUserClarifyingQuestionTool({
messageId: analystAgentOptions.messageId,
workflowStartTime: analystAgentOptions.workflowStartTime,
});
// Visualization tools
const createMetrics = createCreateMetricsTool(analystAgentOptions);
const modifyMetrics = createModifyMetricsTool(analystAgentOptions);
const createDashboards = createCreateDashboardsTool(analystAgentOptions);
@ -118,6 +166,10 @@ export function createAnalystAgent(analystAgentOptions: AnalystAgentOptions) {
const doneTool = createDoneTool(analystAgentOptions);
const availableTools = [
SEQUENTIAL_THINKING_TOOL_NAME,
EXECUTE_SQL_TOOL_NAME,
RESPOND_WITHOUT_ASSET_CREATION_TOOL_NAME,
MESSAGE_USER_CLARIFYING_QUESTION_TOOL_NAME,
CREATE_METRICS_TOOL_NAME,
MODIFY_METRICS_TOOL_NAME,
CREATE_DASHBOARDS_TOOL_NAME,
@ -160,6 +212,10 @@ export function createAnalystAgent(analystAgentOptions: AnalystAgentOptions) {
anthropic_beta: 'fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07',
},
tools: {
[SEQUENTIAL_THINKING_TOOL_NAME]: sequentialThinking,
[EXECUTE_SQL_TOOL_NAME]: executeSqlTool,
[RESPOND_WITHOUT_ASSET_CREATION_TOOL_NAME]: respondWithoutAssetCreation,
[MESSAGE_USER_CLARIFYING_QUESTION_TOOL_NAME]: messageUserClarifyingQuestion,
[CREATE_METRICS_TOOL_NAME]: createMetrics,
[MODIFY_METRICS_TOOL_NAME]: modifyMetrics,
[CREATE_DASHBOARDS_TOOL_NAME]: createDashboards,

View File

@ -63,7 +63,7 @@ describe('Analyst Agent Instructions', () => {
expect(result).toContain('<intro>');
expect(result).toContain('<sql_best_practices>');
expect(result).toContain('<visualization_and_charting_guidelines>');
expect(result).toContain('You are a Buster');
expect(result).toContain('You are an agent');
});
it('should throw an error for empty SQL dialect guidance', () => {

View File

@ -1,4 +1,6 @@
import analystAgentPrompt from './analyst-agent-prompt.txt';
import type { AnalysisMode } from '../../types/analysis-mode.types';
import analystAgentInvestigationPrompt from './analyst-agent-investigation-prompt.txt';
import analystAgentStandardPrompt from './analyst-agent-standard-prompt.txt';
/**
* Template parameters for the analyst agent prompt
@ -8,11 +10,24 @@ export interface AnalystTemplateParams {
date: string;
}
/**
* Type-safe mapping of analysis modes to prompt content
*/
const PROMPTS: Record<AnalysisMode, string> = {
standard: analystAgentStandardPrompt,
investigation: analystAgentInvestigationPrompt,
} as const;
/**
* Loads the analyst agent prompt template and replaces variables
*/
function loadAndProcessPrompt(params: AnalystTemplateParams): string {
return analystAgentPrompt
function loadAndProcessPrompt(
params: AnalystTemplateParams,
analysisMode: AnalysisMode = 'standard'
): string {
const content = PROMPTS[analysisMode];
return content
.replace(/\{\{sql_dialect_guidance\}\}/g, params.dataSourceSyntax)
.replace(/\{\{date\}\}/g, params.date);
}
@ -20,15 +35,21 @@ function loadAndProcessPrompt(params: AnalystTemplateParams): string {
/**
* Export the template function for use in step files
*/
export const getAnalystAgentSystemPrompt = (dataSourceSyntax: string): string => {
export const getAnalystAgentSystemPrompt = (
dataSourceSyntax: string,
analysisMode: AnalysisMode = 'standard'
): string => {
if (!dataSourceSyntax.trim()) {
throw new Error('SQL dialect guidance is required');
}
const currentDate = new Date().toISOString();
return loadAndProcessPrompt({
dataSourceSyntax,
date: currentDate,
});
return loadAndProcessPrompt(
{
dataSourceSyntax,
date: currentDate,
},
analysisMode
);
};

View File

@ -180,11 +180,15 @@ export function createDoneToolDelta(context: DoneToolContext, doneToolState: Don
}
}
// Store ALL assets (including reports) for chat update later
if (newAssets.length > 0) {
// Store assets for chat update later (excluding reasoning which is not an asset type)
const assetsForChatUpdate = newAssets.filter(
(a): a is typeof a & { assetType: Exclude<typeof a.assetType, 'reasoning'> } =>
a.assetType !== 'reasoning'
);
if (assetsForChatUpdate.length > 0) {
doneToolState.addedAssets = [
...(doneToolState.addedAssets || []),
...newAssets.map((a) => ({
...assetsForChatUpdate.map((a) => ({
assetId: a.assetId,
assetType: a.assetType,
versionNumber: a.versionNumber,

View File

@ -63,8 +63,10 @@ async function processDone(
return {
output,
sequenceNumber: updateResult.sequenceNumber,
skipped: updateResult.skipped,
...(updateResult.sequenceNumber !== undefined && {
sequenceNumber: updateResult.sequenceNumber,
}),
...(updateResult.skipped !== undefined && { skipped: updateResult.skipped }),
};
} catch (error) {
console.error('[done-tool] Error updating message entries:', error);

View File

@ -1,12 +1,4 @@
import { randomUUID } from 'node:crypto';
import {
getAssetLatestVersion,
isMessageUpdateQueueClosed,
updateChat,
updateMessage,
updateMessageEntries,
waitForPendingUpdates,
} from '@buster/database/queries';
import type { ModelMessage, ToolCallOptions } from 'ai';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { CREATE_DASHBOARDS_TOOL_NAME } from '../../visualization-tools/dashboards/create-dashboards-tool/create-dashboards-tool';
@ -17,6 +9,7 @@ import type { DoneToolContext, DoneToolState } from './done-tool';
import { createDoneToolDelta } from './done-tool-delta';
import { createDoneToolStart } from './done-tool-start';
// Mock database queries
vi.mock('@buster/database/queries', () => ({
updateChat: vi.fn(),
updateMessage: vi.fn(),
@ -30,6 +23,21 @@ vi.mock('@buster/database/queries', () => ({
getAssetLatestVersion: vi.fn().mockResolvedValue(1),
}));
// Import mocked functions after the mock definition
import {
getAssetLatestVersion,
isMessageUpdateQueueClosed,
updateChat,
updateMessage,
updateMessageEntries,
waitForPendingUpdates,
} from '@buster/database/queries';
// Type assertion for mocked functions
const mockedIsMessageUpdateQueueClosed = vi.mocked(isMessageUpdateQueueClosed);
const mockedWaitForPendingUpdates = vi.mocked(waitForPendingUpdates);
const mockedGetAssetLatestVersion = vi.mocked(getAssetLatestVersion);
describe('done-tool-start', () => {
const mockContext: DoneToolContext = {
chatId: 'chat-123',
@ -43,11 +51,20 @@ describe('done-tool-start', () => {
finalResponse: undefined,
};
// Helper to create mock ToolCallOptions
const createMockToolCallOptions = (
overrides: Partial<ToolCallOptions> = {}
): ToolCallOptions => ({
messages: [],
toolCallId: 'done-call',
...overrides,
});
beforeEach(() => {
vi.clearAllMocks();
isMessageUpdateQueueClosed.mockReturnValue(false);
waitForPendingUpdates.mockResolvedValue(undefined);
getAssetLatestVersion.mockResolvedValue(1);
mockedIsMessageUpdateQueueClosed.mockReturnValue(false);
mockedWaitForPendingUpdates.mockResolvedValue(undefined);
mockedGetAssetLatestVersion.mockResolvedValue(1);
});
describe('mostRecentFile selection', () => {
@ -142,10 +159,12 @@ describe('done-tool-start', () => {
const doneToolDelta = createDoneToolDelta(mockContext, mockDoneToolState);
// Start phase - initializes state
await doneToolStart({
toolCallId: 'done-call',
messages: mockMessages,
} as ToolCallOptions);
await doneToolStart(
createMockToolCallOptions({
toolCallId: 'done-call',
messages: mockMessages,
})
);
// Delta phase - streams in the assets and final response
const deltaInput = JSON.stringify({
@ -161,8 +180,8 @@ describe('done-tool-start', () => {
await doneToolDelta({
inputTextDelta: deltaInput,
toolCallId: 'done-call',
} as ToolCallOptions);
...createMockToolCallOptions({ toolCallId: 'done-call' }),
});
expect(updateChat).toHaveBeenCalledWith('chat-123', {
mostRecentFileId: reportId,
@ -248,10 +267,12 @@ describe('done-tool-start', () => {
const doneToolStart = createDoneToolStart(mockContext, mockDoneToolState);
const doneToolDelta = createDoneToolDelta(mockContext, mockDoneToolState);
await doneToolStart({
toolCallId: 'done-call',
messages: mockMessages,
} as ToolCallOptions);
await doneToolStart(
createMockToolCallOptions({
toolCallId: 'done-call',
messages: mockMessages,
})
);
// Delta phase - stream in the first metric as the asset to return
const deltaInput = JSON.stringify({
@ -267,8 +288,8 @@ describe('done-tool-start', () => {
await doneToolDelta({
inputTextDelta: deltaInput,
toolCallId: 'done-call',
} as ToolCallOptions);
...createMockToolCallOptions({ toolCallId: 'done-call' }),
});
// Should select the first metric (first in extractedFiles)
expect(updateChat).toHaveBeenCalledWith('chat-123', {
@ -359,10 +380,12 @@ describe('done-tool-start', () => {
const doneToolStart = createDoneToolStart(mockContext, mockDoneToolState);
const doneToolDelta = createDoneToolDelta(mockContext, mockDoneToolState);
await doneToolStart({
toolCallId: 'done-call',
messages: mockMessages,
} as ToolCallOptions);
await doneToolStart(
createMockToolCallOptions({
toolCallId: 'done-call',
messages: mockMessages,
})
);
// Delta phase - stream in the report and standalone metric as assets to return
const deltaInput = JSON.stringify({
@ -383,8 +406,8 @@ describe('done-tool-start', () => {
await doneToolDelta({
inputTextDelta: deltaInput,
toolCallId: 'done-call',
} as ToolCallOptions);
...createMockToolCallOptions({ toolCallId: 'done-call' }),
});
// Report should be selected as mostRecentFile
expect(updateChat).toHaveBeenCalledWith('chat-123', {
@ -459,10 +482,12 @@ describe('done-tool-start', () => {
const doneToolStart = createDoneToolStart(mockContext, mockDoneToolState);
const doneToolDelta = createDoneToolDelta(mockContext, mockDoneToolState);
await doneToolStart({
toolCallId: 'done-call',
messages: mockMessages,
} as ToolCallOptions);
await doneToolStart(
createMockToolCallOptions({
toolCallId: 'done-call',
messages: mockMessages,
})
);
// Delta phase - stream in the first metric
const deltaInput = JSON.stringify({
@ -478,8 +503,8 @@ describe('done-tool-start', () => {
await doneToolDelta({
inputTextDelta: deltaInput,
toolCallId: 'done-call',
} as ToolCallOptions);
...createMockToolCallOptions({ toolCallId: 'done-call' }),
});
// Should select the first metric
expect(updateChat).toHaveBeenCalledWith('chat-123', {
@ -526,10 +551,12 @@ describe('done-tool-start', () => {
const doneToolStart = createDoneToolStart(mockContext, mockDoneToolState);
const doneToolDelta = createDoneToolDelta(mockContext, mockDoneToolState);
await doneToolStart({
toolCallId: 'done-call',
messages: mockMessages,
} as ToolCallOptions);
await doneToolStart(
createMockToolCallOptions({
toolCallId: 'done-call',
messages: mockMessages,
})
);
// Delta phase - stream in the first dashboard
const deltaInput = JSON.stringify({
@ -545,8 +572,8 @@ describe('done-tool-start', () => {
await doneToolDelta({
inputTextDelta: deltaInput,
toolCallId: 'done-call',
} as ToolCallOptions);
...createMockToolCallOptions({ toolCallId: 'done-call' }),
});
// Should select the first dashboard
expect(updateChat).toHaveBeenCalledWith('chat-123', {
@ -612,10 +639,12 @@ describe('done-tool-start', () => {
const doneToolStart = createDoneToolStart(mockContext, mockDoneToolState);
const doneToolDelta = createDoneToolDelta(mockContext, mockDoneToolState);
await doneToolStart({
toolCallId: 'done-call',
messages: mockMessages,
} as ToolCallOptions);
await doneToolStart(
createMockToolCallOptions({
toolCallId: 'done-call',
messages: mockMessages,
})
);
// Delta phase - stream in the dashboard
const deltaInput = JSON.stringify({
@ -631,8 +660,8 @@ describe('done-tool-start', () => {
await doneToolDelta({
inputTextDelta: deltaInput,
toolCallId: 'done-call',
} as ToolCallOptions);
...createMockToolCallOptions({ toolCallId: 'done-call' }),
});
// Should select the dashboard (first in extractedFiles)
expect(updateChat).toHaveBeenCalledWith('chat-123', {
@ -644,10 +673,12 @@ describe('done-tool-start', () => {
it('should handle empty file lists gracefully', async () => {
const doneToolStart = createDoneToolStart(mockContext, mockDoneToolState);
await doneToolStart({
toolCallId: 'done-call',
messages: [],
} as ToolCallOptions);
await doneToolStart(
createMockToolCallOptions({
toolCallId: 'done-call',
messages: [],
})
);
// Should not call updateChat when no files exist
expect(updateChat).not.toHaveBeenCalled();
@ -697,10 +728,12 @@ describe('done-tool-start', () => {
const doneToolStart = createDoneToolStart(mockContext, mockDoneToolState);
const doneToolDelta = createDoneToolDelta(mockContext, mockDoneToolState);
await doneToolStart({
toolCallId: 'done-call',
messages: mockMessages,
} as ToolCallOptions);
await doneToolStart(
createMockToolCallOptions({
toolCallId: 'done-call',
messages: mockMessages,
})
);
// Delta phase - stream in the report
const deltaInput = JSON.stringify({
@ -716,8 +749,8 @@ describe('done-tool-start', () => {
await doneToolDelta({
inputTextDelta: deltaInput,
toolCallId: 'done-call',
} as ToolCallOptions);
...createMockToolCallOptions({ toolCallId: 'done-call' }),
});
// Report should still be selected as mostRecentFile
expect(updateChat).toHaveBeenCalledWith('chat-123', {
@ -788,10 +821,12 @@ describe('done-tool-start', () => {
const doneToolStart = createDoneToolStart(mockContext, mockDoneToolState);
const doneToolDelta = createDoneToolDelta(mockContext, mockDoneToolState);
await doneToolStart({
toolCallId: 'done-call',
messages: mockMessages,
} as ToolCallOptions);
await doneToolStart(
createMockToolCallOptions({
toolCallId: 'done-call',
messages: mockMessages,
})
);
// Delta phase - stream in the dashboard (metrics are embedded)
const deltaInput = JSON.stringify({
@ -807,8 +842,8 @@ describe('done-tool-start', () => {
await doneToolDelta({
inputTextDelta: deltaInput,
toolCallId: 'done-call',
} as ToolCallOptions);
...createMockToolCallOptions({ toolCallId: 'done-call' }),
});
// Should select the dashboard since that's what we're returning
expect(updateChat).toHaveBeenCalledWith('chat-123', {
@ -852,10 +887,12 @@ describe('done-tool-start', () => {
];
const doneToolStart = createDoneToolStart(contextWithEmptyChatId, mockDoneToolState);
await doneToolStart({
toolCallId: 'done-call',
messages: mockMessages,
} as ToolCallOptions);
await doneToolStart(
createMockToolCallOptions({
toolCallId: 'done-call',
messages: mockMessages,
})
);
// Should not call updateChat when chatId is missing
expect(updateChat).not.toHaveBeenCalled();

View File

@ -71,6 +71,15 @@ describe('Done Tool Streaming Tests', () => {
workflowStartTime: Date.now(),
};
// Helper to create mock ToolCallOptions
const createMockToolCallOptions = (
overrides: Partial<ToolCallOptions> = {}
): ToolCallOptions => ({
messages: [],
toolCallId: 'test-call-id',
...overrides,
});
describe('createDoneToolStart', () => {
test('should initialize state with entry_id on start', async () => {
const state: DoneToolState = {
@ -252,8 +261,8 @@ describe('Done Tool Streaming Tests', () => {
});
await deltaHandler({
inputTextDelta: deltaInput,
toolCallId: 'call-1',
} as ToolCallOptions);
...createMockToolCallOptions({ toolCallId: 'call-1' }),
});
const queries = await import('@buster/database/queries');
@ -375,8 +384,8 @@ describe('Done Tool Streaming Tests', () => {
});
await deltaHandler({
inputTextDelta: deltaInput,
toolCallId: 'call-2',
} as ToolCallOptions);
...createMockToolCallOptions({ toolCallId: 'call-2' }),
});
const queries = await import('@buster/database/queries');
@ -492,8 +501,8 @@ describe('Done Tool Streaming Tests', () => {
});
await deltaHandler({
inputTextDelta: deltaInput,
toolCallId: 'call-3',
} as ToolCallOptions);
...createMockToolCallOptions({ toolCallId: 'call-3' }),
});
const queries = await import('@buster/database/queries');
const updateArgs = ((queries.updateChat as unknown as { mock: { calls: unknown[][] } }).mock
@ -688,6 +697,7 @@ describe('Done Tool Streaming Tests', () => {
const finishHandler = createDoneToolFinish(mockContext, state);
const input: DoneToolInput = {
assetsToReturn: [],
finalResponse: 'This is the final response message',
};
@ -710,6 +720,7 @@ describe('Done Tool Streaming Tests', () => {
const finishHandler = createDoneToolFinish(mockContext, state);
const input: DoneToolInput = {
assetsToReturn: [],
finalResponse: 'Response without prior start',
};
@ -746,6 +757,7 @@ The following items were processed:
`;
const input: DoneToolInput = {
assetsToReturn: [],
finalResponse: markdownResponse,
};
@ -810,6 +822,7 @@ The following items were processed:
expect(state.finalResponse).toBeTypeOf('string');
const input: DoneToolInput = {
assetsToReturn: [],
finalResponse: 'Final test',
};
await finishHandler({ input, toolCallId: 'test-123', messages: [] });
@ -859,6 +872,7 @@ The following items were processed:
);
const input: DoneToolInput = {
assetsToReturn: [],
finalResponse: 'This is a streaming response that comes in multiple chunks',
};
await finishHandler({ input, toolCallId, messages: [] });

View File

@ -108,6 +108,7 @@ describe('Done Tool Integration Tests', () => {
await startHandler({ toolCallId, messages: [] });
const input: DoneToolInput = {
assetsToReturn: [],
finalResponse: 'This is the complete final response',
};
@ -168,6 +169,7 @@ All operations completed successfully.`;
expect(state.finalResponse).toBe(expectedResponse);
const input: DoneToolInput = {
assetsToReturn: [],
finalResponse: expectedResponse,
};
@ -207,14 +209,14 @@ All operations completed successfully.`;
await startHandler1({ toolCallId: toolCallId1, messages: [] });
await finishHandler1({
input: { finalResponse: 'First response' },
input: { assetsToReturn: [], finalResponse: 'First response' },
toolCallId: toolCallId1,
messages: [],
});
await startHandler2({ toolCallId: toolCallId2, messages: [] });
await finishHandler2({
input: { finalResponse: 'Second response' },
input: { assetsToReturn: [], finalResponse: 'Second response' },
toolCallId: toolCallId2,
messages: [],
});

View File

@ -83,7 +83,47 @@ export async function runAnalystWorkflow(
// Add all messages from create-todos step (tool call, result, and user message)
messages.push(...todos.messages);
const thinkAndPrepAgentStepResults = await runThinkAndPrepAgentStep({
// const thinkAndPrepAgentStepResults = await runThinkAndPrepAgentStep({
// options: {
// messageId: input.messageId,
// chatId: input.chatId,
// organizationId: input.organizationId,
// dataSourceId: input.dataSourceId,
// dataSourceSyntax: input.dataSourceSyntax,
// userId: input.userId,
// sql_dialect_guidance: input.dataSourceSyntax,
// datasets: input.datasets,
// workflowStartTime,
// analysisMode,
// analystInstructions,
// organizationDocs,
// userPersonalizationMessageContent,
// },
// streamOptions: {
// messages,
// },
// });
// console.info('[runAnalystWorkflow] DEBUG: Think-and-prep results', {
// workflowId,
// messageId: input.messageId,
// earlyTermination: thinkAndPrepAgentStepResults.earlyTermination,
// messageCount: thinkAndPrepAgentStepResults.messages.length,
// });
// messages.push(...thinkAndPrepAgentStepResults.messages);
// // Check if think-and-prep agent terminated early (clarifying question or direct response)
let analystAgentStepResults = { messages: [] as ModelMessage[] };
// if (!thinkAndPrepAgentStepResults.earlyTermination) {
// console.info('[runAnalystWorkflow] Running analyst agent step (early termination = false)', {
// workflowId,
// messageId: input.messageId,
// earlyTermination: thinkAndPrepAgentStepResults.earlyTermination,
// });
analystAgentStepResults = await runAnalystAgentStep({
options: {
messageId: input.messageId,
chatId: input.chatId,
@ -91,7 +131,6 @@ export async function runAnalystWorkflow(
dataSourceId: input.dataSourceId,
dataSourceSyntax: input.dataSourceSyntax,
userId: input.userId,
sql_dialect_guidance: input.dataSourceSyntax,
datasets: input.datasets,
workflowStartTime,
analysisMode,
@ -104,53 +143,14 @@ export async function runAnalystWorkflow(
},
});
console.info('[runAnalystWorkflow] DEBUG: Think-and-prep results', {
workflowId,
messageId: input.messageId,
earlyTermination: thinkAndPrepAgentStepResults.earlyTermination,
messageCount: thinkAndPrepAgentStepResults.messages.length,
});
messages.push(...thinkAndPrepAgentStepResults.messages);
// Check if think-and-prep agent terminated early (clarifying question or direct response)
let analystAgentStepResults = { messages: [] as ModelMessage[] };
if (!thinkAndPrepAgentStepResults.earlyTermination) {
console.info('[runAnalystWorkflow] Running analyst agent step (early termination = false)', {
workflowId,
messageId: input.messageId,
earlyTermination: thinkAndPrepAgentStepResults.earlyTermination,
});
analystAgentStepResults = await runAnalystAgentStep({
options: {
messageId: input.messageId,
chatId: input.chatId,
organizationId: input.organizationId,
dataSourceId: input.dataSourceId,
dataSourceSyntax: input.dataSourceSyntax,
userId: input.userId,
datasets: input.datasets,
workflowStartTime,
analysisMode,
analystInstructions,
organizationDocs,
userPersonalizationMessageContent,
},
streamOptions: {
messages,
},
});
messages.push(...analystAgentStepResults.messages);
} else {
console.info('[runAnalystWorkflow] DEBUG: SKIPPING analyst agent due to early termination', {
workflowId,
messageId: input.messageId,
earlyTermination: thinkAndPrepAgentStepResults.earlyTermination,
});
}
messages.push(...analystAgentStepResults.messages);
// } else {
// console.info('[runAnalystWorkflow] DEBUG: SKIPPING analyst agent due to early termination', {
// workflowId,
// messageId: input.messageId,
// earlyTermination: thinkAndPrepAgentStepResults.earlyTermination,
// });
// }
// Extract all tool calls from messages
const allToolCalls = extractToolCallsFromMessages(messages);

View File

@ -0,0 +1,338 @@
import { randomUUID } from 'node:crypto';
import type { ModelMessage } from 'ai';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { getChatConversationHistory } from './chatConversationHistory';
// Mock the database connection and queries
vi.mock('../../connection', () => ({
db: {
select: vi.fn().mockReturnThis(),
from: vi.fn().mockReturnThis(),
where: vi.fn().mockReturnThis(),
limit: vi.fn().mockReturnThis(),
orderBy: vi.fn(),
},
}));
describe('getChatConversationHistory - Orphaned Tool Call Cleanup', () => {
beforeEach(() => {
vi.clearAllMocks();
});
it('should remove orphaned tool calls (tool calls without matching results)', async () => {
// Mock database to return messages with orphaned tool calls
const mockMessages: ModelMessage[] = [
{
role: 'user',
content: 'test question',
},
{
role: 'assistant',
content: [
{
type: 'tool-call',
toolCallId: 'orphaned-call-123',
toolName: 'sequentialThinking',
input: { thought: 'test thought', nextThoughtNeeded: false },
},
],
},
// No tool result for orphaned-call-123
{
role: 'assistant',
content: [
{
type: 'tool-call',
toolCallId: 'valid-call-456',
toolName: 'executeSql',
input: { statements: ['SELECT 1'] },
},
],
},
{
role: 'tool',
content: [
{
type: 'tool-result',
toolCallId: 'valid-call-456',
toolName: 'executeSql',
output: { type: 'json', value: '{"results":[]}' },
},
],
},
];
// Mock the database query functions
const { db } = await import('../../connection');
vi.mocked(db.orderBy).mockResolvedValue([
{
id: 'msg-1',
rawLlmMessages: mockMessages,
createdAt: '2025-01-01T00:00:00Z',
isCompleted: true,
},
]);
vi.mocked(db.limit).mockResolvedValue([
{
chatId: 'chat-123',
createdAt: '2025-01-01T00:00:00Z',
},
]);
const result = await getChatConversationHistory({
messageId: randomUUID(),
});
// Should have removed the orphaned tool call but kept the valid one
expect(result).toHaveLength(3); // user, assistant (with valid tool call), tool result
// Find the assistant message
const assistantMessages = result.filter((m) => m.role === 'assistant');
expect(assistantMessages).toHaveLength(1);
// The remaining assistant message should only have the valid tool call
const assistantContent = assistantMessages[0]?.content;
expect(Array.isArray(assistantContent)).toBe(true);
if (Array.isArray(assistantContent)) {
expect(assistantContent).toHaveLength(1);
expect(assistantContent[0]).toMatchObject({
type: 'tool-call',
toolCallId: 'valid-call-456',
});
}
});
it('should keep assistant messages with valid tool calls', async () => {
const mockMessages: ModelMessage[] = [
{
role: 'assistant',
content: [
{
type: 'tool-call',
toolCallId: 'call-123',
toolName: 'testTool',
input: {},
},
],
},
{
role: 'tool',
content: [
{
type: 'tool-result',
toolCallId: 'call-123',
toolName: 'testTool',
output: { type: 'text', value: 'result' },
},
],
},
];
const { db } = await import('../../connection');
vi.mocked(db.orderBy).mockResolvedValue([
{
id: 'msg-1',
rawLlmMessages: mockMessages,
createdAt: '2025-01-01T00:00:00Z',
isCompleted: true,
},
]);
vi.mocked(db.limit).mockResolvedValue([
{
chatId: 'chat-123',
createdAt: '2025-01-01T00:00:00Z',
},
]);
const result = await getChatConversationHistory({
messageId: randomUUID(),
});
// Should keep both messages (tool call and result)
expect(result).toHaveLength(2);
expect(result[0]?.role).toBe('assistant');
expect(result[1]?.role).toBe('tool');
});
it('should keep assistant messages that have at least one valid tool call (even if some are orphaned)', async () => {
const mockMessages: ModelMessage[] = [
{
role: 'assistant',
content: [
{
type: 'text',
text: 'Let me analyze this',
},
{
type: 'tool-call',
toolCallId: 'orphaned-123',
toolName: 'orphanedTool',
input: {},
},
{
type: 'tool-call',
toolCallId: 'valid-456',
toolName: 'validTool',
input: {},
},
],
},
{
role: 'tool',
content: [
{
type: 'tool-result',
toolCallId: 'valid-456',
toolName: 'validTool',
output: { type: 'text', value: 'success' },
},
],
},
];
const { db } = await import('../../connection');
vi.mocked(db.orderBy).mockResolvedValue([
{
id: 'msg-1',
rawLlmMessages: mockMessages,
createdAt: '2025-01-01T00:00:00Z',
isCompleted: true,
},
]);
vi.mocked(db.limit).mockResolvedValue([
{
chatId: 'chat-123',
createdAt: '2025-01-01T00:00:00Z',
},
]);
const result = await getChatConversationHistory({
messageId: randomUUID(),
});
// Should keep the assistant message because it has at least one valid tool call
// Note: We keep the entire message including the orphaned tool call
expect(result).toHaveLength(2);
const assistantMessage = result.find((m) => m.role === 'assistant');
expect(assistantMessage).toBeDefined();
const content = assistantMessage?.content;
expect(Array.isArray(content)).toBe(true);
if (Array.isArray(content)) {
// Should have all content including the orphaned tool call (we don't modify the message)
expect(content).toHaveLength(3);
}
});
it('should remove assistant messages that only contain orphaned tool calls', async () => {
const mockMessages: ModelMessage[] = [
{
role: 'user',
content: 'test',
},
{
role: 'assistant',
content: [
{
type: 'tool-call',
toolCallId: 'orphaned-only',
toolName: 'orphanedTool',
input: {},
},
],
},
// No tool result for orphaned-only
{
role: 'assistant',
content: 'This is a text response',
},
];
const { db } = await import('../../connection');
vi.mocked(db.orderBy).mockResolvedValue([
{
id: 'msg-1',
rawLlmMessages: mockMessages,
createdAt: '2025-01-01T00:00:00Z',
isCompleted: true,
},
]);
vi.mocked(db.limit).mockResolvedValue([
{
chatId: 'chat-123',
createdAt: '2025-01-01T00:00:00Z',
},
]);
const result = await getChatConversationHistory({
messageId: randomUUID(),
});
// Should have removed the assistant message with only orphaned tool call
expect(result).toHaveLength(2); // user + assistant text response
expect(result[0]?.role).toBe('user');
expect(result[1]?.role).toBe('assistant');
expect(result[1]?.content).toBe('This is a text response');
});
it('should handle empty message arrays', async () => {
const { db } = await import('../../connection');
vi.mocked(db.orderBy).mockResolvedValue([
{
id: 'msg-1',
rawLlmMessages: [],
createdAt: '2025-01-01T00:00:00Z',
isCompleted: true,
},
]);
vi.mocked(db.limit).mockResolvedValue([
{
chatId: 'chat-123',
createdAt: '2025-01-01T00:00:00Z',
},
]);
const result = await getChatConversationHistory({
messageId: randomUUID(),
});
expect(result).toEqual([]);
});
it('should handle messages with no tool calls', async () => {
const mockMessages: ModelMessage[] = [
{
role: 'user',
content: 'Hello',
},
{
role: 'assistant',
content: 'Hi there!',
},
];
const { db } = await import('../../connection');
vi.mocked(db.orderBy).mockResolvedValue([
{
id: 'msg-1',
rawLlmMessages: mockMessages,
createdAt: '2025-01-01T00:00:00Z',
isCompleted: true,
},
]);
vi.mocked(db.limit).mockResolvedValue([
{
chatId: 'chat-123',
createdAt: '2025-01-01T00:00:00Z',
},
]);
const result = await getChatConversationHistory({
messageId: randomUUID(),
});
expect(result).toHaveLength(2);
expect(result).toEqual(mockMessages);
});
});

View File

@ -324,6 +324,80 @@ export const ChatConversationHistoryOutputSchema = z.array(z.custom<ModelMessage
export type ChatConversationHistoryInput = z.infer<typeof ChatConversationHistoryInputSchema>;
export type ChatConversationHistoryOutput = z.infer<typeof ChatConversationHistoryOutputSchema>;
/**
* Removes orphaned assistant messages with tool calls that have no matching tool results
* An orphaned message is an assistant message where ALL its tool calls lack corresponding tool results
*/
function removeOrphanedToolCalls(messages: ModelMessage[]): ModelMessage[] {
// Build a Set of all tool call IDs that have results (from tool role messages)
const toolCallIdsWithResults = new Set<string>();
for (const message of messages) {
if (message.role === 'tool' && Array.isArray(message.content)) {
for (const part of message.content) {
if (
typeof part === 'object' &&
part !== null &&
'type' in part &&
part.type === 'tool-result' &&
'toolCallId' in part &&
typeof part.toolCallId === 'string'
) {
toolCallIdsWithResults.add(part.toolCallId);
}
}
}
}
// Filter out assistant messages where ALL tool calls are orphaned
const filteredMessages: ModelMessage[] = [];
for (const message of messages) {
// Only check assistant messages with array content
if (message.role === 'assistant' && Array.isArray(message.content)) {
// Extract tool call IDs from this message
const toolCallIds: string[] = [];
for (const part of message.content) {
if (
typeof part === 'object' &&
part !== null &&
'type' in part &&
part.type === 'tool-call' &&
'toolCallId' in part &&
typeof part.toolCallId === 'string'
) {
toolCallIds.push(part.toolCallId);
}
}
// If this message has no tool calls, keep it
if (toolCallIds.length === 0) {
filteredMessages.push(message);
continue;
}
// Check if ANY of the tool calls have results
const hasAnyResults = toolCallIds.some((id) => toolCallIdsWithResults.has(id));
if (hasAnyResults) {
// At least one tool call has a result, keep the message
filteredMessages.push(message);
} else {
// ALL tool calls are orphaned, skip this message entirely
console.warn('[chatConversationHistory] Removing orphaned assistant message:', {
orphanedToolCallIds: toolCallIds,
});
}
} else {
// Keep all non-assistant messages (including tool role messages)
filteredMessages.push(message);
}
}
return filteredMessages;
}
/**
* Get conversation history for a chat up to and including a specific message
* Finds the chat from the given messageId, then merges and deduplicates all rawLlmMessages
@ -389,9 +463,12 @@ export async function getChatConversationHistory(
// Since we're merging from multiple messages, we should preserve the order they appear
const deduplicatedMessages = Array.from(uniqueMessagesMap.values());
// Remove orphaned tool calls (tool calls without matching tool results)
const cleanedMessages = removeOrphanedToolCalls(deduplicatedMessages);
// Validate output
try {
return ChatConversationHistoryOutputSchema.parse(deduplicatedMessages);
return ChatConversationHistoryOutputSchema.parse(cleanedMessages);
} catch (validationError) {
throw new Error(
`Output validation failed: ${validationError instanceof Error ? validationError.message : 'Invalid output format'}`