mirror of https://github.com/buster-so/buster.git
commit
86417ff057
|
@ -1,22 +1,16 @@
|
|||
import type { GetDashboardResponse } from '@buster/server-shared/dashboards';
|
||||
import { type QueryClient, useQueryClient } from '@tanstack/react-query';
|
||||
import last from 'lodash/last';
|
||||
import { dashboardQueryKeys } from '@/api/query_keys/dashboard';
|
||||
import { metricsQueryKeys } from '@/api/query_keys/metric';
|
||||
import { useBusterNotifications } from '@/context/BusterNotifications';
|
||||
import { setOriginalDashboard } from '@/context/Dashboards/useOriginalDashboardStore';
|
||||
import { setOriginalMetric } from '@/context/Metrics/useOriginalMetricStore';
|
||||
import { useMemoizedFn } from '@/hooks/useMemoizedFn';
|
||||
import { upgradeMetricToIMetric } from '@/lib/metrics/upgradeToIMetric';
|
||||
import { prefetchGetMetricDataClient } from '../metrics/queryRequests';
|
||||
import { initializeMetrics } from '../metrics/metricQueryHelpers';
|
||||
import { getDashboardById } from './requests';
|
||||
|
||||
export const useEnsureDashboardConfig = (params?: { prefetchData?: boolean }) => {
|
||||
const { prefetchData = true } = params || {};
|
||||
const queryClient = useQueryClient();
|
||||
const prefetchDashboard = useGetDashboardAndInitializeMetrics({
|
||||
prefetchData,
|
||||
});
|
||||
|
||||
const { openErrorMessage } = useBusterNotifications();
|
||||
|
||||
const method = useMemoizedFn(
|
||||
|
@ -24,11 +18,13 @@ export const useEnsureDashboardConfig = (params?: { prefetchData?: boolean }) =>
|
|||
const options = dashboardQueryKeys.dashboardGetDashboard(dashboardId, 'LATEST');
|
||||
let dashboardResponse = queryClient.getQueryData(options.queryKey);
|
||||
if (!dashboardResponse) {
|
||||
const res = await prefetchDashboard({
|
||||
const res = await getDashboardAndInitializeMetrics({
|
||||
id: dashboardId,
|
||||
version_number: 'LATEST',
|
||||
shouldInitializeMetrics: initializeMetrics,
|
||||
password,
|
||||
queryClient,
|
||||
shouldInitializeMetrics: initializeMetrics,
|
||||
prefetchMetricsData: prefetchData,
|
||||
}).catch(() => {
|
||||
openErrorMessage('Failed to save metrics to dashboard. Dashboard not found');
|
||||
});
|
||||
|
@ -48,63 +44,6 @@ export const useEnsureDashboardConfig = (params?: { prefetchData?: boolean }) =>
|
|||
return method;
|
||||
};
|
||||
|
||||
export const initializeMetrics = (
|
||||
metrics: GetDashboardResponse['metrics'],
|
||||
queryClient: QueryClient,
|
||||
prefetchData: boolean
|
||||
) => {
|
||||
for (const metric of Object.values(metrics)) {
|
||||
const upgradedMetric = upgradeMetricToIMetric(metric, null);
|
||||
queryClient.setQueryData(
|
||||
metricsQueryKeys.metricsGetMetric(metric.id, metric.version_number).queryKey,
|
||||
upgradedMetric
|
||||
);
|
||||
const isLatestVersion = metric.version_number === last(metric.versions)?.version_number;
|
||||
if (isLatestVersion) {
|
||||
setOriginalMetric(upgradedMetric);
|
||||
queryClient.setQueryData(
|
||||
metricsQueryKeys.metricsGetMetric(metric.id, 'LATEST').queryKey,
|
||||
upgradedMetric
|
||||
);
|
||||
}
|
||||
if (prefetchData) {
|
||||
prefetchGetMetricDataClient(
|
||||
{ id: metric.id, version_number: metric.version_number },
|
||||
queryClient
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
export const useGetDashboardAndInitializeMetrics = (params?: { prefetchData?: boolean }) => {
|
||||
const { prefetchData = true } = params || {};
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
return useMemoizedFn(
|
||||
async ({
|
||||
id,
|
||||
version_number,
|
||||
shouldInitializeMetrics = true,
|
||||
password,
|
||||
}: {
|
||||
id: string;
|
||||
version_number: number | 'LATEST';
|
||||
shouldInitializeMetrics?: boolean;
|
||||
password: string | undefined;
|
||||
}) => {
|
||||
return getDashboardAndInitializeMetrics({
|
||||
id,
|
||||
version_number,
|
||||
password,
|
||||
queryClient,
|
||||
shouldInitializeMetrics,
|
||||
prefetchMetricsData: prefetchData,
|
||||
});
|
||||
}
|
||||
);
|
||||
};
|
||||
|
||||
//Can use this in server side
|
||||
export const getDashboardAndInitializeMetrics = async ({
|
||||
id,
|
||||
version_number,
|
||||
|
|
|
@ -3,7 +3,7 @@ import { useNavigate } from '@tanstack/react-router';
|
|||
import last from 'lodash/last';
|
||||
import { dashboardQueryKeys } from '@/api/query_keys/dashboard';
|
||||
import { setOriginalDashboard } from '@/context/Dashboards/useOriginalDashboardStore';
|
||||
import { initializeMetrics } from '../dashboardQueryHelpers';
|
||||
import { initializeMetrics } from '../../metrics/metricQueryHelpers';
|
||||
import { dashboardsUpdateDashboard } from '../requests';
|
||||
|
||||
/**
|
||||
|
|
|
@ -15,10 +15,7 @@ import {
|
|||
import { useMemoizedFn } from '@/hooks/useMemoizedFn';
|
||||
import { isQueryStale } from '@/lib/query';
|
||||
import { hasOrganizationId } from '../../users/userQueryHelpers';
|
||||
import {
|
||||
getDashboardAndInitializeMetrics,
|
||||
useGetDashboardAndInitializeMetrics,
|
||||
} from '../dashboardQueryHelpers';
|
||||
import { getDashboardAndInitializeMetrics } from '../dashboardQueryHelpers';
|
||||
import { useGetDashboardVersionNumber } from '../dashboardVersionNumber';
|
||||
import { dashboardsGetList } from '../requests';
|
||||
|
||||
|
@ -36,14 +33,22 @@ export const useGetDashboard = <TData = GetDashboardResponse>(
|
|||
) => {
|
||||
const id = idProp || '';
|
||||
const password = useProtectedAssetPassword(id);
|
||||
const queryFn = useGetDashboardAndInitializeMetrics();
|
||||
const queryClient = useQueryClient();
|
||||
// const queryFn = useGetDashboardAndInitializeMetrics();
|
||||
|
||||
const { selectedVersionNumber } = useGetDashboardVersionNumber(id, versionNumberProp);
|
||||
|
||||
const { isFetched: isFetchedInitial, isError: isErrorInitial } = useQuery({
|
||||
...dashboardQueryKeys.dashboardGetDashboard(id, 'LATEST'),
|
||||
queryFn: () =>
|
||||
queryFn({ id, version_number: 'LATEST', shouldInitializeMetrics: true, password }),
|
||||
getDashboardAndInitializeMetrics({
|
||||
id,
|
||||
version_number: 'LATEST',
|
||||
password,
|
||||
queryClient,
|
||||
shouldInitializeMetrics: true,
|
||||
prefetchMetricsData: true,
|
||||
}),
|
||||
enabled: true,
|
||||
retry(_failureCount, error) {
|
||||
if (error?.message !== undefined) {
|
||||
|
@ -61,11 +66,13 @@ export const useGetDashboard = <TData = GetDashboardResponse>(
|
|||
return useQuery({
|
||||
...dashboardQueryKeys.dashboardGetDashboard(id, selectedVersionNumber),
|
||||
queryFn: () =>
|
||||
queryFn({
|
||||
getDashboardAndInitializeMetrics({
|
||||
id,
|
||||
version_number: selectedVersionNumber,
|
||||
shouldInitializeMetrics: true,
|
||||
password,
|
||||
queryClient,
|
||||
shouldInitializeMetrics: true,
|
||||
prefetchMetricsData: true,
|
||||
}),
|
||||
enabled: isFetchedInitial && !isErrorInitial,
|
||||
select: params?.select,
|
||||
|
@ -80,7 +87,6 @@ export const usePrefetchGetDashboardClient = <TData = GetDashboardResponse>(
|
|||
params?: Omit<UseQueryOptions<GetDashboardResponse, ApiError, TData>, 'queryKey' | 'queryFn'>
|
||||
) => {
|
||||
const queryClient = useQueryClient();
|
||||
const queryFn = useGetDashboardAndInitializeMetrics({ prefetchData: false });
|
||||
|
||||
return useMemoizedFn((id: string, versionNumber: number | 'LATEST') => {
|
||||
const getDashboardQueryKey = dashboardQueryKeys.dashboardGetDashboard(id, versionNumber);
|
||||
|
@ -89,11 +95,13 @@ export const usePrefetchGetDashboardClient = <TData = GetDashboardResponse>(
|
|||
return queryClient.prefetchQuery({
|
||||
...dashboardQueryKeys.dashboardGetDashboard(id, versionNumber),
|
||||
queryFn: () =>
|
||||
queryFn({
|
||||
getDashboardAndInitializeMetrics({
|
||||
id,
|
||||
version_number: versionNumber,
|
||||
shouldInitializeMetrics: true,
|
||||
version_number: 'LATEST',
|
||||
password: undefined,
|
||||
queryClient,
|
||||
shouldInitializeMetrics: true,
|
||||
prefetchMetricsData: false,
|
||||
}),
|
||||
...params,
|
||||
});
|
||||
|
|
|
@ -1,8 +1,14 @@
|
|||
import { useQueryClient } from '@tanstack/react-query';
|
||||
import type { GetDashboardResponse } from '@buster/server-shared/dashboards';
|
||||
import type { GetReportResponse } from '@buster/server-shared/reports';
|
||||
import { type QueryClient, useQueryClient } from '@tanstack/react-query';
|
||||
import last from 'lodash/last';
|
||||
import { setOriginalMetric } from '@/context/Metrics/useOriginalMetricStore';
|
||||
import { useMemoizedFn } from '@/hooks/useMemoizedFn';
|
||||
import { upgradeMetricToIMetric } from '@/lib/metrics/upgradeToIMetric';
|
||||
import type { BusterMetricDataExtended } from '../../asset_interfaces/metric';
|
||||
import type { BusterMetric } from '../../asset_interfaces/metric/interfaces';
|
||||
import { metricsQueryKeys } from '../../query_keys/metric';
|
||||
import { prefetchGetMetricDataClient } from './getMetricQueryRequests';
|
||||
|
||||
export const useGetMetricMemoized = () => {
|
||||
const queryClient = useQueryClient();
|
||||
|
@ -37,3 +43,31 @@ export const useGetMetricDataMemoized = () => {
|
|||
);
|
||||
return getMetricDataMemoized;
|
||||
};
|
||||
|
||||
export const initializeMetrics = (
|
||||
metrics: GetDashboardResponse['metrics'] | GetReportResponse['metrics'],
|
||||
queryClient: QueryClient,
|
||||
prefetchData: boolean
|
||||
) => {
|
||||
for (const metric of Object.values(metrics)) {
|
||||
const upgradedMetric = upgradeMetricToIMetric(metric, null);
|
||||
queryClient.setQueryData(
|
||||
metricsQueryKeys.metricsGetMetric(metric.id, metric.version_number).queryKey,
|
||||
upgradedMetric
|
||||
);
|
||||
const isLatestVersion = metric.version_number === last(metric.versions)?.version_number;
|
||||
if (isLatestVersion) {
|
||||
setOriginalMetric(upgradedMetric);
|
||||
queryClient.setQueryData(
|
||||
metricsQueryKeys.metricsGetMetric(metric.id, 'LATEST').queryKey,
|
||||
upgradedMetric
|
||||
);
|
||||
}
|
||||
if (prefetchData) {
|
||||
prefetchGetMetricDataClient(
|
||||
{ id: metric.id, version_number: metric.version_number },
|
||||
queryClient
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -17,7 +17,8 @@ import {
|
|||
useRemoveAssetFromCollection,
|
||||
} from '../collections/queryRequests';
|
||||
import { useGetUserFavorites } from '../users/favorites';
|
||||
import { getReportById, getReportsList, updateReport } from './requests';
|
||||
import { getReportAndInitializeMetrics } from './reportQueryHelpers';
|
||||
import { getReportsList, updateReport } from './requests';
|
||||
|
||||
/**
|
||||
* Hook to get a list of reports
|
||||
|
@ -80,9 +81,13 @@ export const prefetchGetReport = async (
|
|||
await queryClient.prefetchQuery({
|
||||
...reportsQueryKeys.reportsGetReport(reportId, version_number || 'LATEST'),
|
||||
queryFn: () =>
|
||||
getReportById({
|
||||
getReportAndInitializeMetrics({
|
||||
id: reportId,
|
||||
version_number: typeof version_number === 'number' ? version_number : undefined,
|
||||
password: undefined,
|
||||
queryClient,
|
||||
shouldInitializeMetrics: true,
|
||||
prefetchMetricsData: false,
|
||||
}),
|
||||
retry: silenceAssetErrors,
|
||||
});
|
||||
|
@ -106,17 +111,20 @@ export const useGetReport = <T = GetReportResponse>(
|
|||
options?: Omit<UseQueryOptions<GetReportResponse, ApiError, T>, 'queryKey' | 'queryFn'>
|
||||
) => {
|
||||
const password = useProtectedAssetPassword(id || '');
|
||||
const queryFn = () => {
|
||||
return getReportById({
|
||||
id: id ?? '',
|
||||
version_number: typeof versionNumber === 'number' ? versionNumber : undefined,
|
||||
password,
|
||||
});
|
||||
};
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
return useQuery({
|
||||
...reportsQueryKeys.reportsGetReport(id ?? '', versionNumber || 'LATEST'),
|
||||
queryFn,
|
||||
queryFn: () => {
|
||||
return getReportAndInitializeMetrics({
|
||||
id: id ?? '',
|
||||
version_number: typeof versionNumber === 'number' ? versionNumber : undefined,
|
||||
password,
|
||||
queryClient,
|
||||
shouldInitializeMetrics: true,
|
||||
prefetchMetricsData: true,
|
||||
});
|
||||
},
|
||||
enabled: !!id,
|
||||
select: options?.select,
|
||||
...options,
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
import type { QueryClient } from '@tanstack/react-query';
|
||||
import last from 'lodash/last';
|
||||
import { reportsQueryKeys } from '@/api/query_keys/reports';
|
||||
import { initializeMetrics } from '../metrics/metricQueryHelpers';
|
||||
import { getReportById } from './requests';
|
||||
|
||||
export const getReportAndInitializeMetrics = async ({
|
||||
id,
|
||||
version_number,
|
||||
password,
|
||||
queryClient,
|
||||
shouldInitializeMetrics = true,
|
||||
prefetchMetricsData = false,
|
||||
}: {
|
||||
id: string;
|
||||
version_number: number | 'LATEST' | undefined;
|
||||
password: string | undefined;
|
||||
queryClient: QueryClient;
|
||||
shouldInitializeMetrics: boolean;
|
||||
prefetchMetricsData: boolean;
|
||||
}) => {
|
||||
const chosenVersionNumber = version_number === 'LATEST' ? undefined : version_number;
|
||||
return getReportById({
|
||||
id,
|
||||
version_number: chosenVersionNumber,
|
||||
password,
|
||||
}).then((data) => {
|
||||
const latestVersion = last(data.versions)?.version_number || 1;
|
||||
const isLatestVersion = data.version_number === latestVersion;
|
||||
|
||||
if (isLatestVersion) {
|
||||
// set the original report?
|
||||
}
|
||||
|
||||
if (data.version_number) {
|
||||
queryClient.setQueryData(
|
||||
reportsQueryKeys.reportsGetReport(data.id, data.version_number).queryKey,
|
||||
data
|
||||
);
|
||||
}
|
||||
|
||||
if (shouldInitializeMetrics || prefetchMetricsData) {
|
||||
initializeMetrics(data.metrics, queryClient, !!prefetchMetricsData);
|
||||
}
|
||||
|
||||
return data;
|
||||
});
|
||||
};
|
|
@ -203,7 +203,6 @@ export const useXAxis = ({
|
|||
return formatLabel(value, firstXColumnLabelFormat);
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any -- I had a devil of a time trying to type this... This is a hack to get the type to work
|
||||
return DEFAULT_X_AXIS_TICK_CALLBACK.call(
|
||||
this,
|
||||
value,
|
||||
|
|
|
@ -1,13 +1,9 @@
|
|||
import {
|
||||
type ColumnLabelFormat,
|
||||
DEFAULT_COLUMN_LABEL_FORMAT,
|
||||
DEFAULT_COLUMN_SETTINGS,
|
||||
} from '@buster/server-shared/metrics';
|
||||
import { type ColumnLabelFormat, DEFAULT_COLUMN_SETTINGS } from '@buster/server-shared/metrics';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
import type { BusterChartProps } from '../../../BusterChart.types';
|
||||
import type { DatasetOption } from '../../../chartHooks';
|
||||
import type { DatasetOptionsWithTicks } from '../../../chartHooks/useDatasetOptions/interfaces';
|
||||
import { barSeriesBuilder } from './barSeriesBuilder';
|
||||
import { barSeriesBuilder, barSeriesBuilder_labels } from './barSeriesBuilder';
|
||||
import type { SeriesBuilderProps } from './interfaces';
|
||||
|
||||
describe('barSeriesBuilder', () => {
|
||||
|
@ -313,3 +309,210 @@ describe('percentage mode logic', () => {
|
|||
expect(result).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('barSeriesBuilder_labels - dateTicks', () => {
|
||||
it('should return date ticks when columnLabelFormat has date style and single xAxis', () => {
|
||||
// Arrange
|
||||
const mockDatasetOptions: DatasetOptionsWithTicks = {
|
||||
datasets: [],
|
||||
ticks: [['2024-01-01'], ['2024-02-01'], ['2024-03-01']],
|
||||
ticksKey: [{ key: 'date', value: '' }],
|
||||
};
|
||||
|
||||
const columnLabelFormats = {
|
||||
date: {
|
||||
columnType: 'date',
|
||||
style: 'date',
|
||||
minimumFractionDigits: 0,
|
||||
maximumFractionDigits: 0,
|
||||
multiplier: 1,
|
||||
prefix: '',
|
||||
suffix: '',
|
||||
replaceMissingDataWith: 0,
|
||||
makeLabelHumanReadable: true,
|
||||
} as ColumnLabelFormat,
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = barSeriesBuilder_labels({
|
||||
datasetOptions: mockDatasetOptions,
|
||||
columnLabelFormats,
|
||||
xAxisKeys: ['date'],
|
||||
});
|
||||
|
||||
// Assert
|
||||
expect(result).toHaveLength(3);
|
||||
expect(result[0]).toBeInstanceOf(Date);
|
||||
expect(result[1]).toBeInstanceOf(Date);
|
||||
expect(result[2]).toBeInstanceOf(Date);
|
||||
});
|
||||
|
||||
it('should return quarter ticks when columnLabelFormat has quarter convertNumberTo and single xAxis', () => {
|
||||
// Arrange
|
||||
const mockDatasetOptions: DatasetOptionsWithTicks = {
|
||||
datasets: [],
|
||||
ticks: [[1], [2], [3], [4]],
|
||||
ticksKey: [{ key: 'quarter', value: '' }],
|
||||
};
|
||||
|
||||
const columnLabelFormats = {
|
||||
quarter: {
|
||||
columnType: 'number',
|
||||
style: 'date',
|
||||
convertNumberTo: 'quarter',
|
||||
minimumFractionDigits: 0,
|
||||
maximumFractionDigits: 0,
|
||||
multiplier: 1,
|
||||
prefix: '',
|
||||
suffix: '',
|
||||
replaceMissingDataWith: 0,
|
||||
makeLabelHumanReadable: true,
|
||||
} as ColumnLabelFormat,
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = barSeriesBuilder_labels({
|
||||
datasetOptions: mockDatasetOptions,
|
||||
columnLabelFormats,
|
||||
xAxisKeys: ['quarter'],
|
||||
});
|
||||
|
||||
// Assert
|
||||
expect(result).toHaveLength(4);
|
||||
expect(result[0]).toBe('Q1');
|
||||
expect(result[1]).toBe('Q2');
|
||||
expect(result[2]).toBe('Q3');
|
||||
expect(result[3]).toBe('Q4');
|
||||
});
|
||||
|
||||
it('should return quarter ticks with double xAxis when one is quarter and one is number', () => {
|
||||
// Arrange
|
||||
const mockDatasetOptions: DatasetOptionsWithTicks = {
|
||||
datasets: [],
|
||||
ticks: [
|
||||
[1, 2023],
|
||||
[2, 2023],
|
||||
[3, 2023],
|
||||
],
|
||||
ticksKey: [
|
||||
{ key: 'quarter', value: '' },
|
||||
{ key: 'year', value: '' },
|
||||
],
|
||||
};
|
||||
|
||||
const columnLabelFormats = {
|
||||
quarter: {
|
||||
columnType: 'number',
|
||||
style: 'date',
|
||||
convertNumberTo: 'quarter',
|
||||
minimumFractionDigits: 0,
|
||||
maximumFractionDigits: 0,
|
||||
multiplier: 1,
|
||||
prefix: '',
|
||||
suffix: '',
|
||||
replaceMissingDataWith: 0,
|
||||
makeLabelHumanReadable: true,
|
||||
} as ColumnLabelFormat,
|
||||
year: {
|
||||
columnType: 'number',
|
||||
style: 'number',
|
||||
minimumFractionDigits: 0,
|
||||
maximumFractionDigits: 0,
|
||||
multiplier: 1,
|
||||
prefix: '',
|
||||
suffix: '',
|
||||
numberSeparatorStyle: null,
|
||||
replaceMissingDataWith: 0,
|
||||
makeLabelHumanReadable: true,
|
||||
} as ColumnLabelFormat,
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = barSeriesBuilder_labels({
|
||||
datasetOptions: mockDatasetOptions,
|
||||
columnLabelFormats,
|
||||
xAxisKeys: ['quarter', 'year'],
|
||||
});
|
||||
|
||||
console.log('result', result[0]);
|
||||
|
||||
// Assert
|
||||
expect(result).toHaveLength(3);
|
||||
expect(result[0]).toContain('Q1');
|
||||
expect(result[0]).toContain('2023');
|
||||
expect(result[1]).toContain('Q2');
|
||||
expect(result[2]).toContain('Q3');
|
||||
});
|
||||
|
||||
it('should return null dateTicks when columnLabelFormat does not have date style', () => {
|
||||
// Arrange
|
||||
const mockDatasetOptions: DatasetOptionsWithTicks = {
|
||||
datasets: [],
|
||||
ticks: [['Product A'], ['Product B'], ['Product C']],
|
||||
ticksKey: [{ key: 'product', value: '' }],
|
||||
};
|
||||
|
||||
const columnLabelFormats = {
|
||||
product: {
|
||||
columnType: 'string',
|
||||
style: 'string',
|
||||
minimumFractionDigits: 0,
|
||||
maximumFractionDigits: 0,
|
||||
multiplier: 1,
|
||||
prefix: '',
|
||||
suffix: '',
|
||||
replaceMissingDataWith: 0,
|
||||
makeLabelHumanReadable: true,
|
||||
} as unknown as ColumnLabelFormat,
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = barSeriesBuilder_labels({
|
||||
datasetOptions: mockDatasetOptions,
|
||||
columnLabelFormats,
|
||||
xAxisKeys: ['product'],
|
||||
});
|
||||
|
||||
// Assert
|
||||
expect(result).toHaveLength(3);
|
||||
expect(result[0]).toBe('Product A');
|
||||
expect(result[1]).toBe('Product B');
|
||||
expect(result[2]).toBe('Product C');
|
||||
});
|
||||
|
||||
it('should return date ticks when columnType is number but style is date', () => {
|
||||
// Arrange
|
||||
const mockDatasetOptions: DatasetOptionsWithTicks = {
|
||||
datasets: [],
|
||||
ticks: [['2024-01-15'], ['2024-01-16'], ['2024-01-17']],
|
||||
ticksKey: [{ key: 'timestamp', value: '' }],
|
||||
};
|
||||
|
||||
const columnLabelFormats = {
|
||||
timestamp: {
|
||||
columnType: 'number',
|
||||
style: 'date',
|
||||
minimumFractionDigits: 0,
|
||||
maximumFractionDigits: 0,
|
||||
multiplier: 1,
|
||||
prefix: '',
|
||||
suffix: '',
|
||||
replaceMissingDataWith: 0,
|
||||
makeLabelHumanReadable: true,
|
||||
} as ColumnLabelFormat,
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = barSeriesBuilder_labels({
|
||||
datasetOptions: mockDatasetOptions,
|
||||
columnLabelFormats,
|
||||
xAxisKeys: ['timestamp'],
|
||||
});
|
||||
|
||||
// Assert
|
||||
expect(result).toHaveLength(3);
|
||||
expect(result[0]).toBeInstanceOf(Date);
|
||||
expect(result[1]).toBeInstanceOf(Date);
|
||||
expect(result[2]).toBeInstanceOf(Date);
|
||||
});
|
||||
});
|
||||
|
|
|
@ -11,6 +11,7 @@ import { DEFAULT_CHART_LAYOUT } from '../../ChartJSTheme';
|
|||
import type { ChartProps } from '../../core';
|
||||
import { dataLabelFontColorContrast, formatBarAndLineDataLabel } from '../../helpers';
|
||||
import { defaultLabelOptionConfig } from '../useChartSpecificOptions/labelOptionConfig';
|
||||
import { createTickDates } from './createTickDate';
|
||||
import { createTrendlineOnSeries } from './createTrendlines';
|
||||
import type { SeriesBuilderProps } from './interfaces';
|
||||
import type { LabelBuilderProps } from './useSeriesOptions';
|
||||
|
@ -352,8 +353,12 @@ const getFormattedValueAndSetBarDataLabels = (
|
|||
export const barSeriesBuilder_labels = ({
|
||||
datasetOptions,
|
||||
columnLabelFormats,
|
||||
}: Pick<LabelBuilderProps, 'datasetOptions' | 'columnLabelFormats'>) => {
|
||||
const ticksKey = datasetOptions.ticksKey;
|
||||
xAxisKeys,
|
||||
}: Pick<LabelBuilderProps, 'datasetOptions' | 'columnLabelFormats' | 'xAxisKeys'>) => {
|
||||
const dateTicks = createTickDates(datasetOptions.ticks, xAxisKeys, columnLabelFormats);
|
||||
if (dateTicks) {
|
||||
return dateTicks;
|
||||
}
|
||||
|
||||
const containsADateStyle = datasetOptions.ticksKey.some((tick) => {
|
||||
const selectedColumnLabelFormat = columnLabelFormats[tick.key];
|
||||
|
@ -364,7 +369,7 @@ export const barSeriesBuilder_labels = ({
|
|||
const labels = datasetOptions.ticks.flatMap((item) => {
|
||||
return item
|
||||
.map<string>((item, index) => {
|
||||
const key = ticksKey[index]?.key || '';
|
||||
const key = datasetOptions.ticksKey[index]?.key || '';
|
||||
const columnLabelFormat = columnLabelFormats[key];
|
||||
return formatLabel(item, columnLabelFormat);
|
||||
})
|
||||
|
|
|
@ -45,6 +45,7 @@ export const createTickDates = (
|
|||
}
|
||||
|
||||
const isDoubleXAxis = xAxisKeys.length === 2;
|
||||
console.log('isDoubleXAxis', isDoubleXAxis);
|
||||
if (isDoubleXAxis) {
|
||||
const oneIsAQuarter = xColumnLabelFormats.findIndex(
|
||||
(format) => format.convertNumberTo === 'quarter' && format.columnType === 'number'
|
||||
|
@ -52,6 +53,8 @@ export const createTickDates = (
|
|||
const oneIsANumber = xColumnLabelFormats.findIndex(
|
||||
(format) => format.columnType === 'number' && format.style === 'number'
|
||||
);
|
||||
console.log('oneIsAQuarter', oneIsAQuarter);
|
||||
console.log('oneIsANumber', oneIsANumber);
|
||||
if (oneIsAQuarter !== -1 && oneIsANumber !== -1) {
|
||||
return createQuarterTickDates(ticks, xColumnLabelFormats, oneIsAQuarter);
|
||||
}
|
||||
|
|
|
@ -330,6 +330,7 @@ describe('lineSeriesBuilder', () => {
|
|||
const labels = lineSeriesBuilder_labels(props);
|
||||
|
||||
expect(labels).toHaveLength(3);
|
||||
console.log('labels', labels[0]);
|
||||
expect(labels[0]).toBe('formatted-2023-01-01 formatted-A');
|
||||
expect(labels[1]).toBe('formatted-2023-01-02 formatted-B');
|
||||
expect(labels[2]).toBe('formatted-2023-01-03 formatted-A');
|
||||
|
|
|
@ -216,13 +216,9 @@ export const lineSeriesBuilder_labels = ({
|
|||
xAxisKeys,
|
||||
columnLabelFormats,
|
||||
}: LabelBuilderProps): (string | Date)[] => {
|
||||
const dateTicks = createTickDates(datasetOptions.ticks, xAxisKeys, columnLabelFormats);
|
||||
if (dateTicks) {
|
||||
return dateTicks;
|
||||
}
|
||||
|
||||
return barSeriesBuilder_labels({
|
||||
datasetOptions,
|
||||
columnLabelFormats,
|
||||
xAxisKeys,
|
||||
});
|
||||
};
|
||||
|
|
|
@ -1657,3 +1657,243 @@ export const BarChartWithSortedDayOfWeek: Story = {
|
|||
],
|
||||
},
|
||||
};
|
||||
|
||||
export const BarWithProblemQuarters: Story = {
|
||||
args: {
|
||||
colors: [
|
||||
'#B399FD',
|
||||
'#FC8497',
|
||||
'#FBBC30',
|
||||
'#279EFF',
|
||||
'#E83562',
|
||||
'#41F8FF',
|
||||
'#F3864F',
|
||||
'#C82184',
|
||||
'#31FCB4',
|
||||
'#E83562',
|
||||
],
|
||||
barLayout: 'vertical',
|
||||
barSortBy: [],
|
||||
goalLines: [],
|
||||
gridLines: true,
|
||||
pieSortBy: 'value',
|
||||
showLegend: null,
|
||||
trendlines: [],
|
||||
scatterAxis: {
|
||||
x: [],
|
||||
y: [],
|
||||
size: [],
|
||||
tooltip: null,
|
||||
category: [],
|
||||
},
|
||||
barGroupType: 'stack',
|
||||
metricHeader: null,
|
||||
pieChartAxis: {
|
||||
x: [],
|
||||
y: [],
|
||||
tooltip: null,
|
||||
},
|
||||
lineGroupType: null,
|
||||
pieDonutWidth: 40,
|
||||
xAxisDataZoom: false,
|
||||
barAndLineAxis: {
|
||||
x: ['quarter'],
|
||||
y: ['product_count'],
|
||||
colorBy: [],
|
||||
tooltip: null,
|
||||
category: ['metric_seasoncategory'],
|
||||
},
|
||||
columnSettings: {},
|
||||
comboChartAxis: {
|
||||
x: [],
|
||||
y: [],
|
||||
y2: [],
|
||||
colorBy: [],
|
||||
tooltip: null,
|
||||
category: [],
|
||||
},
|
||||
disableTooltip: false,
|
||||
metricColumnId: '',
|
||||
scatterDotSize: [3, 15],
|
||||
xAxisAxisTitle: null,
|
||||
yAxisAxisTitle: null,
|
||||
yAxisScaleType: 'linear',
|
||||
metricSubHeader: null,
|
||||
y2AxisAxisTitle: null,
|
||||
y2AxisScaleType: 'linear',
|
||||
metricValueLabel: null,
|
||||
pieLabelPosition: 'none',
|
||||
tableColumnOrder: null,
|
||||
barShowTotalAtTop: false,
|
||||
categoryAxisTitle: null,
|
||||
pieDisplayLabelAs: 'number',
|
||||
pieShowInnerLabel: true,
|
||||
selectedChartType: 'bar',
|
||||
tableColumnWidths: null,
|
||||
xAxisTimeInterval: null,
|
||||
columnLabelFormats: {
|
||||
quarter: {
|
||||
isUTC: false,
|
||||
style: 'date',
|
||||
prefix: '',
|
||||
suffix: '',
|
||||
currency: 'USD',
|
||||
columnType: 'number',
|
||||
dateFormat: 'auto',
|
||||
multiplier: 1,
|
||||
displayName: '',
|
||||
compactNumbers: false,
|
||||
convertNumberTo: 'quarter',
|
||||
useRelativeTime: false,
|
||||
numberSeparatorStyle: null,
|
||||
maximumFractionDigits: 2,
|
||||
minimumFractionDigits: 0,
|
||||
makeLabelHumanReadable: true,
|
||||
replaceMissingDataWith: 0,
|
||||
},
|
||||
product_count: {
|
||||
isUTC: false,
|
||||
style: 'number',
|
||||
prefix: '',
|
||||
suffix: '',
|
||||
currency: 'USD',
|
||||
columnType: 'number',
|
||||
dateFormat: 'auto',
|
||||
multiplier: 1,
|
||||
displayName: 'Product Count',
|
||||
compactNumbers: false,
|
||||
convertNumberTo: null,
|
||||
useRelativeTime: false,
|
||||
numberSeparatorStyle: ',',
|
||||
maximumFractionDigits: 0,
|
||||
minimumFractionDigits: 0,
|
||||
makeLabelHumanReadable: true,
|
||||
replaceMissingDataWith: 0,
|
||||
},
|
||||
metric_seasoncategory: {
|
||||
isUTC: false,
|
||||
style: 'string',
|
||||
prefix: '',
|
||||
suffix: '',
|
||||
currency: 'USD',
|
||||
columnType: 'text',
|
||||
dateFormat: 'auto',
|
||||
multiplier: 1,
|
||||
displayName: 'Season Category',
|
||||
compactNumbers: false,
|
||||
convertNumberTo: null,
|
||||
useRelativeTime: false,
|
||||
numberSeparatorStyle: null,
|
||||
maximumFractionDigits: 2,
|
||||
minimumFractionDigits: 0,
|
||||
makeLabelHumanReadable: true,
|
||||
replaceMissingDataWith: null,
|
||||
},
|
||||
},
|
||||
pieInnerLabelTitle: null,
|
||||
showLegendHeadline: false,
|
||||
xAxisLabelRotation: 'auto',
|
||||
xAxisShowAxisLabel: true,
|
||||
xAxisShowAxisTitle: true,
|
||||
yAxisShowAxisLabel: true,
|
||||
yAxisShowAxisTitle: true,
|
||||
y2AxisShowAxisLabel: true,
|
||||
y2AxisShowAxisTitle: true,
|
||||
metricValueAggregate: 'sum',
|
||||
tableColumnFontColor: null,
|
||||
tableHeaderFontColor: null,
|
||||
yAxisStartAxisAtZero: null,
|
||||
y2AxisStartAxisAtZero: true,
|
||||
pieInnerLabelAggregate: 'sum',
|
||||
pieMinimumSlicePercentage: 0,
|
||||
tableHeaderBackgroundColor: null,
|
||||
data: [
|
||||
{
|
||||
quarter: 1,
|
||||
metric_seasoncategory: 'High Season',
|
||||
product_count: 55,
|
||||
},
|
||||
{
|
||||
quarter: 1,
|
||||
metric_seasoncategory: 'Low Season',
|
||||
product_count: 18,
|
||||
},
|
||||
{
|
||||
quarter: 1,
|
||||
metric_seasoncategory: 'Regular Season',
|
||||
product_count: 174,
|
||||
},
|
||||
{
|
||||
quarter: 2,
|
||||
metric_seasoncategory: 'High Season',
|
||||
product_count: 27,
|
||||
},
|
||||
{
|
||||
quarter: 2,
|
||||
metric_seasoncategory: 'Low Season',
|
||||
product_count: 60,
|
||||
},
|
||||
{
|
||||
quarter: 2,
|
||||
metric_seasoncategory: 'Regular Season',
|
||||
product_count: 156,
|
||||
},
|
||||
{
|
||||
quarter: 3,
|
||||
metric_seasoncategory: 'High Season',
|
||||
product_count: 2,
|
||||
},
|
||||
{
|
||||
quarter: 3,
|
||||
metric_seasoncategory: 'Low Season',
|
||||
product_count: 181,
|
||||
},
|
||||
{
|
||||
quarter: 3,
|
||||
metric_seasoncategory: 'Regular Season',
|
||||
product_count: 82,
|
||||
},
|
||||
{
|
||||
quarter: 4,
|
||||
metric_seasoncategory: 'High Season',
|
||||
product_count: 136,
|
||||
},
|
||||
{
|
||||
quarter: 4,
|
||||
metric_seasoncategory: 'Low Season',
|
||||
product_count: 22,
|
||||
},
|
||||
{
|
||||
quarter: 4,
|
||||
metric_seasoncategory: 'Regular Season',
|
||||
product_count: 108,
|
||||
},
|
||||
],
|
||||
columnMetadata: [
|
||||
{
|
||||
name: 'quarter',
|
||||
min_value: 1,
|
||||
max_value: 4,
|
||||
unique_values: 4,
|
||||
simple_type: 'number',
|
||||
type: 'numeric',
|
||||
},
|
||||
{
|
||||
name: 'metric_seasoncategory',
|
||||
min_value: 'High Season',
|
||||
max_value: 'Regular Season',
|
||||
unique_values: 3,
|
||||
simple_type: 'text',
|
||||
type: 'text',
|
||||
},
|
||||
{
|
||||
name: 'product_count',
|
||||
min_value: 2,
|
||||
max_value: 181,
|
||||
unique_values: 12,
|
||||
simple_type: 'number',
|
||||
type: 'int8',
|
||||
},
|
||||
],
|
||||
},
|
||||
};
|
||||
|
|
|
@ -218,7 +218,7 @@ export const MentionInputSuggestions = forwardRef<
|
|||
className={inputClassName}
|
||||
onBlur={onBlur}
|
||||
/>
|
||||
{children && <div className="mt-4.5">{children}</div>}
|
||||
{children && <div className="mt-6">{children}</div>}
|
||||
</MentionInputSuggestionsContainer>
|
||||
<SuggestionsSeperator />
|
||||
<MentionInputSuggestionsList
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -10,11 +10,27 @@ import {
|
|||
createCreateMetricsTool,
|
||||
createCreateReportsTool,
|
||||
createDoneTool,
|
||||
createExecuteSqlTool,
|
||||
createModifyDashboardsTool,
|
||||
createModifyMetricsTool,
|
||||
createModifyReportsTool,
|
||||
createSequentialThinkingTool,
|
||||
} from '../../tools';
|
||||
import { DONE_TOOL_NAME } from '../../tools/communication-tools/done-tool/done-tool';
|
||||
import {
|
||||
MESSAGE_USER_CLARIFYING_QUESTION_TOOL_NAME,
|
||||
createMessageUserClarifyingQuestionTool,
|
||||
} from '../../tools/communication-tools/message-user-clarifying-question/message-user-clarifying-question';
|
||||
import {
|
||||
RESPOND_WITHOUT_ASSET_CREATION_TOOL_NAME,
|
||||
createRespondWithoutAssetCreationTool,
|
||||
} from '../../tools/communication-tools/respond-without-asset-creation/respond-without-asset-creation-tool';
|
||||
import {
|
||||
SUBMIT_THOUGHTS_TOOL_NAME,
|
||||
createSubmitThoughtsTool,
|
||||
} from '../../tools/communication-tools/submit-thoughts-tool/submit-thoughts-tool';
|
||||
import { EXECUTE_SQL_TOOL_NAME } from '../../tools/database-tools/execute-sql/execute-sql';
|
||||
import { SEQUENTIAL_THINKING_TOOL_NAME } from '../../tools/planning-thinking-tools/sequential-thinking-tool/sequential-thinking-tool';
|
||||
import { CREATE_DASHBOARDS_TOOL_NAME } from '../../tools/visualization-tools/dashboards/create-dashboards-tool/create-dashboards-tool';
|
||||
import { MODIFY_DASHBOARDS_TOOL_NAME } from '../../tools/visualization-tools/dashboards/modify-dashboards-tool/modify-dashboards-tool';
|
||||
import { CREATE_METRICS_TOOL_NAME } from '../../tools/visualization-tools/metrics/create-metrics-tool/create-metrics-tool';
|
||||
|
@ -28,13 +44,22 @@ import { getAnalystAgentSystemPrompt } from './get-analyst-agent-system-prompt';
|
|||
|
||||
export const ANALYST_AGENT_NAME = 'analystAgent';
|
||||
|
||||
const STOP_CONDITIONS = [stepCountIs(25), hasToolCall(DONE_TOOL_NAME)];
|
||||
const STOP_CONDITIONS = [
|
||||
stepCountIs(25),
|
||||
hasToolCall(DONE_TOOL_NAME),
|
||||
hasToolCall(RESPOND_WITHOUT_ASSET_CREATION_TOOL_NAME),
|
||||
hasToolCall(MESSAGE_USER_CLARIFYING_QUESTION_TOOL_NAME),
|
||||
];
|
||||
|
||||
export const AnalystAgentOptionsSchema = z.object({
|
||||
userId: z.string(),
|
||||
chatId: z.string(),
|
||||
dataSourceId: z.string(),
|
||||
dataSourceSyntax: z.string(),
|
||||
sql_dialect_guidance: z
|
||||
.string()
|
||||
.describe('The SQL dialect guidance for the analyst agent.')
|
||||
.optional(),
|
||||
organizationId: z.string(),
|
||||
messageId: z.string(),
|
||||
datasets: z.array(z.custom<PermissionedDataset>()),
|
||||
|
@ -72,7 +97,10 @@ export function createAnalystAgent(analystAgentOptions: AnalystAgentOptions) {
|
|||
|
||||
const systemMessage = {
|
||||
role: 'system',
|
||||
content: getAnalystAgentSystemPrompt(analystAgentOptions.dataSourceSyntax),
|
||||
content: getAnalystAgentSystemPrompt(
|
||||
analystAgentOptions.dataSourceSyntax,
|
||||
analystAgentOptions.analysisMode || 'standard'
|
||||
),
|
||||
providerOptions: DEFAULT_ANTHROPIC_OPTIONS,
|
||||
} as ModelMessage;
|
||||
|
||||
|
@ -106,6 +134,26 @@ export function createAnalystAgent(analystAgentOptions: AnalystAgentOptions) {
|
|||
: null;
|
||||
|
||||
async function stream({ messages }: AnalystStreamOptions) {
|
||||
// Think-and-prep tools
|
||||
const sequentialThinking = createSequentialThinkingTool({
|
||||
messageId: analystAgentOptions.messageId,
|
||||
});
|
||||
const executeSqlTool = createExecuteSqlTool({
|
||||
messageId: analystAgentOptions.messageId,
|
||||
dataSourceId: analystAgentOptions.dataSourceId,
|
||||
dataSourceSyntax: analystAgentOptions.dataSourceSyntax,
|
||||
userId: analystAgentOptions.userId,
|
||||
});
|
||||
const respondWithoutAssetCreation = createRespondWithoutAssetCreationTool({
|
||||
messageId: analystAgentOptions.messageId,
|
||||
workflowStartTime: analystAgentOptions.workflowStartTime,
|
||||
});
|
||||
const messageUserClarifyingQuestion = createMessageUserClarifyingQuestionTool({
|
||||
messageId: analystAgentOptions.messageId,
|
||||
workflowStartTime: analystAgentOptions.workflowStartTime,
|
||||
});
|
||||
|
||||
// Visualization tools
|
||||
const createMetrics = createCreateMetricsTool(analystAgentOptions);
|
||||
const modifyMetrics = createModifyMetricsTool(analystAgentOptions);
|
||||
const createDashboards = createCreateDashboardsTool(analystAgentOptions);
|
||||
|
@ -118,6 +166,10 @@ export function createAnalystAgent(analystAgentOptions: AnalystAgentOptions) {
|
|||
const doneTool = createDoneTool(analystAgentOptions);
|
||||
|
||||
const availableTools = [
|
||||
SEQUENTIAL_THINKING_TOOL_NAME,
|
||||
EXECUTE_SQL_TOOL_NAME,
|
||||
RESPOND_WITHOUT_ASSET_CREATION_TOOL_NAME,
|
||||
MESSAGE_USER_CLARIFYING_QUESTION_TOOL_NAME,
|
||||
CREATE_METRICS_TOOL_NAME,
|
||||
MODIFY_METRICS_TOOL_NAME,
|
||||
CREATE_DASHBOARDS_TOOL_NAME,
|
||||
|
@ -160,6 +212,10 @@ export function createAnalystAgent(analystAgentOptions: AnalystAgentOptions) {
|
|||
anthropic_beta: 'fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07',
|
||||
},
|
||||
tools: {
|
||||
[SEQUENTIAL_THINKING_TOOL_NAME]: sequentialThinking,
|
||||
[EXECUTE_SQL_TOOL_NAME]: executeSqlTool,
|
||||
[RESPOND_WITHOUT_ASSET_CREATION_TOOL_NAME]: respondWithoutAssetCreation,
|
||||
[MESSAGE_USER_CLARIFYING_QUESTION_TOOL_NAME]: messageUserClarifyingQuestion,
|
||||
[CREATE_METRICS_TOOL_NAME]: createMetrics,
|
||||
[MODIFY_METRICS_TOOL_NAME]: modifyMetrics,
|
||||
[CREATE_DASHBOARDS_TOOL_NAME]: createDashboards,
|
||||
|
|
|
@ -63,7 +63,7 @@ describe('Analyst Agent Instructions', () => {
|
|||
expect(result).toContain('<intro>');
|
||||
expect(result).toContain('<sql_best_practices>');
|
||||
expect(result).toContain('<visualization_and_charting_guidelines>');
|
||||
expect(result).toContain('You are a Buster');
|
||||
expect(result).toContain('You are an agent');
|
||||
});
|
||||
|
||||
it('should throw an error for empty SQL dialect guidance', () => {
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import analystAgentPrompt from './analyst-agent-prompt.txt';
|
||||
import type { AnalysisMode } from '../../types/analysis-mode.types';
|
||||
import analystAgentInvestigationPrompt from './analyst-agent-investigation-prompt.txt';
|
||||
import analystAgentStandardPrompt from './analyst-agent-standard-prompt.txt';
|
||||
|
||||
/**
|
||||
* Template parameters for the analyst agent prompt
|
||||
|
@ -8,11 +10,24 @@ export interface AnalystTemplateParams {
|
|||
date: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Type-safe mapping of analysis modes to prompt content
|
||||
*/
|
||||
const PROMPTS: Record<AnalysisMode, string> = {
|
||||
standard: analystAgentStandardPrompt,
|
||||
investigation: analystAgentInvestigationPrompt,
|
||||
} as const;
|
||||
|
||||
/**
|
||||
* Loads the analyst agent prompt template and replaces variables
|
||||
*/
|
||||
function loadAndProcessPrompt(params: AnalystTemplateParams): string {
|
||||
return analystAgentPrompt
|
||||
function loadAndProcessPrompt(
|
||||
params: AnalystTemplateParams,
|
||||
analysisMode: AnalysisMode = 'standard'
|
||||
): string {
|
||||
const content = PROMPTS[analysisMode];
|
||||
|
||||
return content
|
||||
.replace(/\{\{sql_dialect_guidance\}\}/g, params.dataSourceSyntax)
|
||||
.replace(/\{\{date\}\}/g, params.date);
|
||||
}
|
||||
|
@ -20,15 +35,21 @@ function loadAndProcessPrompt(params: AnalystTemplateParams): string {
|
|||
/**
|
||||
* Export the template function for use in step files
|
||||
*/
|
||||
export const getAnalystAgentSystemPrompt = (dataSourceSyntax: string): string => {
|
||||
export const getAnalystAgentSystemPrompt = (
|
||||
dataSourceSyntax: string,
|
||||
analysisMode: AnalysisMode = 'standard'
|
||||
): string => {
|
||||
if (!dataSourceSyntax.trim()) {
|
||||
throw new Error('SQL dialect guidance is required');
|
||||
}
|
||||
|
||||
const currentDate = new Date().toISOString();
|
||||
|
||||
return loadAndProcessPrompt({
|
||||
dataSourceSyntax,
|
||||
date: currentDate,
|
||||
});
|
||||
return loadAndProcessPrompt(
|
||||
{
|
||||
dataSourceSyntax,
|
||||
date: currentDate,
|
||||
},
|
||||
analysisMode
|
||||
);
|
||||
};
|
||||
|
|
|
@ -180,11 +180,15 @@ export function createDoneToolDelta(context: DoneToolContext, doneToolState: Don
|
|||
}
|
||||
}
|
||||
|
||||
// Store ALL assets (including reports) for chat update later
|
||||
if (newAssets.length > 0) {
|
||||
// Store assets for chat update later (excluding reasoning which is not an asset type)
|
||||
const assetsForChatUpdate = newAssets.filter(
|
||||
(a): a is typeof a & { assetType: Exclude<typeof a.assetType, 'reasoning'> } =>
|
||||
a.assetType !== 'reasoning'
|
||||
);
|
||||
if (assetsForChatUpdate.length > 0) {
|
||||
doneToolState.addedAssets = [
|
||||
...(doneToolState.addedAssets || []),
|
||||
...newAssets.map((a) => ({
|
||||
...assetsForChatUpdate.map((a) => ({
|
||||
assetId: a.assetId,
|
||||
assetType: a.assetType,
|
||||
versionNumber: a.versionNumber,
|
||||
|
|
|
@ -63,8 +63,10 @@ async function processDone(
|
|||
|
||||
return {
|
||||
output,
|
||||
sequenceNumber: updateResult.sequenceNumber,
|
||||
skipped: updateResult.skipped,
|
||||
...(updateResult.sequenceNumber !== undefined && {
|
||||
sequenceNumber: updateResult.sequenceNumber,
|
||||
}),
|
||||
...(updateResult.skipped !== undefined && { skipped: updateResult.skipped }),
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('[done-tool] Error updating message entries:', error);
|
||||
|
|
|
@ -1,12 +1,4 @@
|
|||
import { randomUUID } from 'node:crypto';
|
||||
import {
|
||||
getAssetLatestVersion,
|
||||
isMessageUpdateQueueClosed,
|
||||
updateChat,
|
||||
updateMessage,
|
||||
updateMessageEntries,
|
||||
waitForPendingUpdates,
|
||||
} from '@buster/database/queries';
|
||||
import type { ModelMessage, ToolCallOptions } from 'ai';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { CREATE_DASHBOARDS_TOOL_NAME } from '../../visualization-tools/dashboards/create-dashboards-tool/create-dashboards-tool';
|
||||
|
@ -17,6 +9,7 @@ import type { DoneToolContext, DoneToolState } from './done-tool';
|
|||
import { createDoneToolDelta } from './done-tool-delta';
|
||||
import { createDoneToolStart } from './done-tool-start';
|
||||
|
||||
// Mock database queries
|
||||
vi.mock('@buster/database/queries', () => ({
|
||||
updateChat: vi.fn(),
|
||||
updateMessage: vi.fn(),
|
||||
|
@ -30,6 +23,21 @@ vi.mock('@buster/database/queries', () => ({
|
|||
getAssetLatestVersion: vi.fn().mockResolvedValue(1),
|
||||
}));
|
||||
|
||||
// Import mocked functions after the mock definition
|
||||
import {
|
||||
getAssetLatestVersion,
|
||||
isMessageUpdateQueueClosed,
|
||||
updateChat,
|
||||
updateMessage,
|
||||
updateMessageEntries,
|
||||
waitForPendingUpdates,
|
||||
} from '@buster/database/queries';
|
||||
|
||||
// Type assertion for mocked functions
|
||||
const mockedIsMessageUpdateQueueClosed = vi.mocked(isMessageUpdateQueueClosed);
|
||||
const mockedWaitForPendingUpdates = vi.mocked(waitForPendingUpdates);
|
||||
const mockedGetAssetLatestVersion = vi.mocked(getAssetLatestVersion);
|
||||
|
||||
describe('done-tool-start', () => {
|
||||
const mockContext: DoneToolContext = {
|
||||
chatId: 'chat-123',
|
||||
|
@ -43,11 +51,20 @@ describe('done-tool-start', () => {
|
|||
finalResponse: undefined,
|
||||
};
|
||||
|
||||
// Helper to create mock ToolCallOptions
|
||||
const createMockToolCallOptions = (
|
||||
overrides: Partial<ToolCallOptions> = {}
|
||||
): ToolCallOptions => ({
|
||||
messages: [],
|
||||
toolCallId: 'done-call',
|
||||
...overrides,
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
isMessageUpdateQueueClosed.mockReturnValue(false);
|
||||
waitForPendingUpdates.mockResolvedValue(undefined);
|
||||
getAssetLatestVersion.mockResolvedValue(1);
|
||||
mockedIsMessageUpdateQueueClosed.mockReturnValue(false);
|
||||
mockedWaitForPendingUpdates.mockResolvedValue(undefined);
|
||||
mockedGetAssetLatestVersion.mockResolvedValue(1);
|
||||
});
|
||||
|
||||
describe('mostRecentFile selection', () => {
|
||||
|
@ -142,10 +159,12 @@ describe('done-tool-start', () => {
|
|||
const doneToolDelta = createDoneToolDelta(mockContext, mockDoneToolState);
|
||||
|
||||
// Start phase - initializes state
|
||||
await doneToolStart({
|
||||
toolCallId: 'done-call',
|
||||
messages: mockMessages,
|
||||
} as ToolCallOptions);
|
||||
await doneToolStart(
|
||||
createMockToolCallOptions({
|
||||
toolCallId: 'done-call',
|
||||
messages: mockMessages,
|
||||
})
|
||||
);
|
||||
|
||||
// Delta phase - streams in the assets and final response
|
||||
const deltaInput = JSON.stringify({
|
||||
|
@ -161,8 +180,8 @@ describe('done-tool-start', () => {
|
|||
|
||||
await doneToolDelta({
|
||||
inputTextDelta: deltaInput,
|
||||
toolCallId: 'done-call',
|
||||
} as ToolCallOptions);
|
||||
...createMockToolCallOptions({ toolCallId: 'done-call' }),
|
||||
});
|
||||
|
||||
expect(updateChat).toHaveBeenCalledWith('chat-123', {
|
||||
mostRecentFileId: reportId,
|
||||
|
@ -248,10 +267,12 @@ describe('done-tool-start', () => {
|
|||
const doneToolStart = createDoneToolStart(mockContext, mockDoneToolState);
|
||||
const doneToolDelta = createDoneToolDelta(mockContext, mockDoneToolState);
|
||||
|
||||
await doneToolStart({
|
||||
toolCallId: 'done-call',
|
||||
messages: mockMessages,
|
||||
} as ToolCallOptions);
|
||||
await doneToolStart(
|
||||
createMockToolCallOptions({
|
||||
toolCallId: 'done-call',
|
||||
messages: mockMessages,
|
||||
})
|
||||
);
|
||||
|
||||
// Delta phase - stream in the first metric as the asset to return
|
||||
const deltaInput = JSON.stringify({
|
||||
|
@ -267,8 +288,8 @@ describe('done-tool-start', () => {
|
|||
|
||||
await doneToolDelta({
|
||||
inputTextDelta: deltaInput,
|
||||
toolCallId: 'done-call',
|
||||
} as ToolCallOptions);
|
||||
...createMockToolCallOptions({ toolCallId: 'done-call' }),
|
||||
});
|
||||
|
||||
// Should select the first metric (first in extractedFiles)
|
||||
expect(updateChat).toHaveBeenCalledWith('chat-123', {
|
||||
|
@ -359,10 +380,12 @@ describe('done-tool-start', () => {
|
|||
const doneToolStart = createDoneToolStart(mockContext, mockDoneToolState);
|
||||
const doneToolDelta = createDoneToolDelta(mockContext, mockDoneToolState);
|
||||
|
||||
await doneToolStart({
|
||||
toolCallId: 'done-call',
|
||||
messages: mockMessages,
|
||||
} as ToolCallOptions);
|
||||
await doneToolStart(
|
||||
createMockToolCallOptions({
|
||||
toolCallId: 'done-call',
|
||||
messages: mockMessages,
|
||||
})
|
||||
);
|
||||
|
||||
// Delta phase - stream in the report and standalone metric as assets to return
|
||||
const deltaInput = JSON.stringify({
|
||||
|
@ -383,8 +406,8 @@ describe('done-tool-start', () => {
|
|||
|
||||
await doneToolDelta({
|
||||
inputTextDelta: deltaInput,
|
||||
toolCallId: 'done-call',
|
||||
} as ToolCallOptions);
|
||||
...createMockToolCallOptions({ toolCallId: 'done-call' }),
|
||||
});
|
||||
|
||||
// Report should be selected as mostRecentFile
|
||||
expect(updateChat).toHaveBeenCalledWith('chat-123', {
|
||||
|
@ -459,10 +482,12 @@ describe('done-tool-start', () => {
|
|||
const doneToolStart = createDoneToolStart(mockContext, mockDoneToolState);
|
||||
const doneToolDelta = createDoneToolDelta(mockContext, mockDoneToolState);
|
||||
|
||||
await doneToolStart({
|
||||
toolCallId: 'done-call',
|
||||
messages: mockMessages,
|
||||
} as ToolCallOptions);
|
||||
await doneToolStart(
|
||||
createMockToolCallOptions({
|
||||
toolCallId: 'done-call',
|
||||
messages: mockMessages,
|
||||
})
|
||||
);
|
||||
|
||||
// Delta phase - stream in the first metric
|
||||
const deltaInput = JSON.stringify({
|
||||
|
@ -478,8 +503,8 @@ describe('done-tool-start', () => {
|
|||
|
||||
await doneToolDelta({
|
||||
inputTextDelta: deltaInput,
|
||||
toolCallId: 'done-call',
|
||||
} as ToolCallOptions);
|
||||
...createMockToolCallOptions({ toolCallId: 'done-call' }),
|
||||
});
|
||||
|
||||
// Should select the first metric
|
||||
expect(updateChat).toHaveBeenCalledWith('chat-123', {
|
||||
|
@ -526,10 +551,12 @@ describe('done-tool-start', () => {
|
|||
const doneToolStart = createDoneToolStart(mockContext, mockDoneToolState);
|
||||
const doneToolDelta = createDoneToolDelta(mockContext, mockDoneToolState);
|
||||
|
||||
await doneToolStart({
|
||||
toolCallId: 'done-call',
|
||||
messages: mockMessages,
|
||||
} as ToolCallOptions);
|
||||
await doneToolStart(
|
||||
createMockToolCallOptions({
|
||||
toolCallId: 'done-call',
|
||||
messages: mockMessages,
|
||||
})
|
||||
);
|
||||
|
||||
// Delta phase - stream in the first dashboard
|
||||
const deltaInput = JSON.stringify({
|
||||
|
@ -545,8 +572,8 @@ describe('done-tool-start', () => {
|
|||
|
||||
await doneToolDelta({
|
||||
inputTextDelta: deltaInput,
|
||||
toolCallId: 'done-call',
|
||||
} as ToolCallOptions);
|
||||
...createMockToolCallOptions({ toolCallId: 'done-call' }),
|
||||
});
|
||||
|
||||
// Should select the first dashboard
|
||||
expect(updateChat).toHaveBeenCalledWith('chat-123', {
|
||||
|
@ -612,10 +639,12 @@ describe('done-tool-start', () => {
|
|||
const doneToolStart = createDoneToolStart(mockContext, mockDoneToolState);
|
||||
const doneToolDelta = createDoneToolDelta(mockContext, mockDoneToolState);
|
||||
|
||||
await doneToolStart({
|
||||
toolCallId: 'done-call',
|
||||
messages: mockMessages,
|
||||
} as ToolCallOptions);
|
||||
await doneToolStart(
|
||||
createMockToolCallOptions({
|
||||
toolCallId: 'done-call',
|
||||
messages: mockMessages,
|
||||
})
|
||||
);
|
||||
|
||||
// Delta phase - stream in the dashboard
|
||||
const deltaInput = JSON.stringify({
|
||||
|
@ -631,8 +660,8 @@ describe('done-tool-start', () => {
|
|||
|
||||
await doneToolDelta({
|
||||
inputTextDelta: deltaInput,
|
||||
toolCallId: 'done-call',
|
||||
} as ToolCallOptions);
|
||||
...createMockToolCallOptions({ toolCallId: 'done-call' }),
|
||||
});
|
||||
|
||||
// Should select the dashboard (first in extractedFiles)
|
||||
expect(updateChat).toHaveBeenCalledWith('chat-123', {
|
||||
|
@ -644,10 +673,12 @@ describe('done-tool-start', () => {
|
|||
|
||||
it('should handle empty file lists gracefully', async () => {
|
||||
const doneToolStart = createDoneToolStart(mockContext, mockDoneToolState);
|
||||
await doneToolStart({
|
||||
toolCallId: 'done-call',
|
||||
messages: [],
|
||||
} as ToolCallOptions);
|
||||
await doneToolStart(
|
||||
createMockToolCallOptions({
|
||||
toolCallId: 'done-call',
|
||||
messages: [],
|
||||
})
|
||||
);
|
||||
|
||||
// Should not call updateChat when no files exist
|
||||
expect(updateChat).not.toHaveBeenCalled();
|
||||
|
@ -697,10 +728,12 @@ describe('done-tool-start', () => {
|
|||
const doneToolStart = createDoneToolStart(mockContext, mockDoneToolState);
|
||||
const doneToolDelta = createDoneToolDelta(mockContext, mockDoneToolState);
|
||||
|
||||
await doneToolStart({
|
||||
toolCallId: 'done-call',
|
||||
messages: mockMessages,
|
||||
} as ToolCallOptions);
|
||||
await doneToolStart(
|
||||
createMockToolCallOptions({
|
||||
toolCallId: 'done-call',
|
||||
messages: mockMessages,
|
||||
})
|
||||
);
|
||||
|
||||
// Delta phase - stream in the report
|
||||
const deltaInput = JSON.stringify({
|
||||
|
@ -716,8 +749,8 @@ describe('done-tool-start', () => {
|
|||
|
||||
await doneToolDelta({
|
||||
inputTextDelta: deltaInput,
|
||||
toolCallId: 'done-call',
|
||||
} as ToolCallOptions);
|
||||
...createMockToolCallOptions({ toolCallId: 'done-call' }),
|
||||
});
|
||||
|
||||
// Report should still be selected as mostRecentFile
|
||||
expect(updateChat).toHaveBeenCalledWith('chat-123', {
|
||||
|
@ -788,10 +821,12 @@ describe('done-tool-start', () => {
|
|||
const doneToolStart = createDoneToolStart(mockContext, mockDoneToolState);
|
||||
const doneToolDelta = createDoneToolDelta(mockContext, mockDoneToolState);
|
||||
|
||||
await doneToolStart({
|
||||
toolCallId: 'done-call',
|
||||
messages: mockMessages,
|
||||
} as ToolCallOptions);
|
||||
await doneToolStart(
|
||||
createMockToolCallOptions({
|
||||
toolCallId: 'done-call',
|
||||
messages: mockMessages,
|
||||
})
|
||||
);
|
||||
|
||||
// Delta phase - stream in the dashboard (metrics are embedded)
|
||||
const deltaInput = JSON.stringify({
|
||||
|
@ -807,8 +842,8 @@ describe('done-tool-start', () => {
|
|||
|
||||
await doneToolDelta({
|
||||
inputTextDelta: deltaInput,
|
||||
toolCallId: 'done-call',
|
||||
} as ToolCallOptions);
|
||||
...createMockToolCallOptions({ toolCallId: 'done-call' }),
|
||||
});
|
||||
|
||||
// Should select the dashboard since that's what we're returning
|
||||
expect(updateChat).toHaveBeenCalledWith('chat-123', {
|
||||
|
@ -852,10 +887,12 @@ describe('done-tool-start', () => {
|
|||
];
|
||||
|
||||
const doneToolStart = createDoneToolStart(contextWithEmptyChatId, mockDoneToolState);
|
||||
await doneToolStart({
|
||||
toolCallId: 'done-call',
|
||||
messages: mockMessages,
|
||||
} as ToolCallOptions);
|
||||
await doneToolStart(
|
||||
createMockToolCallOptions({
|
||||
toolCallId: 'done-call',
|
||||
messages: mockMessages,
|
||||
})
|
||||
);
|
||||
|
||||
// Should not call updateChat when chatId is missing
|
||||
expect(updateChat).not.toHaveBeenCalled();
|
||||
|
|
|
@ -71,6 +71,15 @@ describe('Done Tool Streaming Tests', () => {
|
|||
workflowStartTime: Date.now(),
|
||||
};
|
||||
|
||||
// Helper to create mock ToolCallOptions
|
||||
const createMockToolCallOptions = (
|
||||
overrides: Partial<ToolCallOptions> = {}
|
||||
): ToolCallOptions => ({
|
||||
messages: [],
|
||||
toolCallId: 'test-call-id',
|
||||
...overrides,
|
||||
});
|
||||
|
||||
describe('createDoneToolStart', () => {
|
||||
test('should initialize state with entry_id on start', async () => {
|
||||
const state: DoneToolState = {
|
||||
|
@ -252,8 +261,8 @@ describe('Done Tool Streaming Tests', () => {
|
|||
});
|
||||
await deltaHandler({
|
||||
inputTextDelta: deltaInput,
|
||||
toolCallId: 'call-1',
|
||||
} as ToolCallOptions);
|
||||
...createMockToolCallOptions({ toolCallId: 'call-1' }),
|
||||
});
|
||||
|
||||
const queries = await import('@buster/database/queries');
|
||||
|
||||
|
@ -375,8 +384,8 @@ describe('Done Tool Streaming Tests', () => {
|
|||
});
|
||||
await deltaHandler({
|
||||
inputTextDelta: deltaInput,
|
||||
toolCallId: 'call-2',
|
||||
} as ToolCallOptions);
|
||||
...createMockToolCallOptions({ toolCallId: 'call-2' }),
|
||||
});
|
||||
|
||||
const queries = await import('@buster/database/queries');
|
||||
|
||||
|
@ -492,8 +501,8 @@ describe('Done Tool Streaming Tests', () => {
|
|||
});
|
||||
await deltaHandler({
|
||||
inputTextDelta: deltaInput,
|
||||
toolCallId: 'call-3',
|
||||
} as ToolCallOptions);
|
||||
...createMockToolCallOptions({ toolCallId: 'call-3' }),
|
||||
});
|
||||
|
||||
const queries = await import('@buster/database/queries');
|
||||
const updateArgs = ((queries.updateChat as unknown as { mock: { calls: unknown[][] } }).mock
|
||||
|
@ -688,6 +697,7 @@ describe('Done Tool Streaming Tests', () => {
|
|||
const finishHandler = createDoneToolFinish(mockContext, state);
|
||||
|
||||
const input: DoneToolInput = {
|
||||
assetsToReturn: [],
|
||||
finalResponse: 'This is the final response message',
|
||||
};
|
||||
|
||||
|
@ -710,6 +720,7 @@ describe('Done Tool Streaming Tests', () => {
|
|||
const finishHandler = createDoneToolFinish(mockContext, state);
|
||||
|
||||
const input: DoneToolInput = {
|
||||
assetsToReturn: [],
|
||||
finalResponse: 'Response without prior start',
|
||||
};
|
||||
|
||||
|
@ -746,6 +757,7 @@ The following items were processed:
|
|||
`;
|
||||
|
||||
const input: DoneToolInput = {
|
||||
assetsToReturn: [],
|
||||
finalResponse: markdownResponse,
|
||||
};
|
||||
|
||||
|
@ -810,6 +822,7 @@ The following items were processed:
|
|||
expect(state.finalResponse).toBeTypeOf('string');
|
||||
|
||||
const input: DoneToolInput = {
|
||||
assetsToReturn: [],
|
||||
finalResponse: 'Final test',
|
||||
};
|
||||
await finishHandler({ input, toolCallId: 'test-123', messages: [] });
|
||||
|
@ -859,6 +872,7 @@ The following items were processed:
|
|||
);
|
||||
|
||||
const input: DoneToolInput = {
|
||||
assetsToReturn: [],
|
||||
finalResponse: 'This is a streaming response that comes in multiple chunks',
|
||||
};
|
||||
await finishHandler({ input, toolCallId, messages: [] });
|
||||
|
|
|
@ -108,6 +108,7 @@ describe('Done Tool Integration Tests', () => {
|
|||
await startHandler({ toolCallId, messages: [] });
|
||||
|
||||
const input: DoneToolInput = {
|
||||
assetsToReturn: [],
|
||||
finalResponse: 'This is the complete final response',
|
||||
};
|
||||
|
||||
|
@ -168,6 +169,7 @@ All operations completed successfully.`;
|
|||
expect(state.finalResponse).toBe(expectedResponse);
|
||||
|
||||
const input: DoneToolInput = {
|
||||
assetsToReturn: [],
|
||||
finalResponse: expectedResponse,
|
||||
};
|
||||
|
||||
|
@ -207,14 +209,14 @@ All operations completed successfully.`;
|
|||
|
||||
await startHandler1({ toolCallId: toolCallId1, messages: [] });
|
||||
await finishHandler1({
|
||||
input: { finalResponse: 'First response' },
|
||||
input: { assetsToReturn: [], finalResponse: 'First response' },
|
||||
toolCallId: toolCallId1,
|
||||
messages: [],
|
||||
});
|
||||
|
||||
await startHandler2({ toolCallId: toolCallId2, messages: [] });
|
||||
await finishHandler2({
|
||||
input: { finalResponse: 'Second response' },
|
||||
input: { assetsToReturn: [], finalResponse: 'Second response' },
|
||||
toolCallId: toolCallId2,
|
||||
messages: [],
|
||||
});
|
||||
|
|
|
@ -83,7 +83,47 @@ export async function runAnalystWorkflow(
|
|||
// Add all messages from create-todos step (tool call, result, and user message)
|
||||
messages.push(...todos.messages);
|
||||
|
||||
const thinkAndPrepAgentStepResults = await runThinkAndPrepAgentStep({
|
||||
// const thinkAndPrepAgentStepResults = await runThinkAndPrepAgentStep({
|
||||
// options: {
|
||||
// messageId: input.messageId,
|
||||
// chatId: input.chatId,
|
||||
// organizationId: input.organizationId,
|
||||
// dataSourceId: input.dataSourceId,
|
||||
// dataSourceSyntax: input.dataSourceSyntax,
|
||||
// userId: input.userId,
|
||||
// sql_dialect_guidance: input.dataSourceSyntax,
|
||||
// datasets: input.datasets,
|
||||
// workflowStartTime,
|
||||
// analysisMode,
|
||||
// analystInstructions,
|
||||
// organizationDocs,
|
||||
// userPersonalizationMessageContent,
|
||||
// },
|
||||
// streamOptions: {
|
||||
// messages,
|
||||
// },
|
||||
// });
|
||||
|
||||
// console.info('[runAnalystWorkflow] DEBUG: Think-and-prep results', {
|
||||
// workflowId,
|
||||
// messageId: input.messageId,
|
||||
// earlyTermination: thinkAndPrepAgentStepResults.earlyTermination,
|
||||
// messageCount: thinkAndPrepAgentStepResults.messages.length,
|
||||
// });
|
||||
|
||||
// messages.push(...thinkAndPrepAgentStepResults.messages);
|
||||
|
||||
// // Check if think-and-prep agent terminated early (clarifying question or direct response)
|
||||
let analystAgentStepResults = { messages: [] as ModelMessage[] };
|
||||
|
||||
// if (!thinkAndPrepAgentStepResults.earlyTermination) {
|
||||
// console.info('[runAnalystWorkflow] Running analyst agent step (early termination = false)', {
|
||||
// workflowId,
|
||||
// messageId: input.messageId,
|
||||
// earlyTermination: thinkAndPrepAgentStepResults.earlyTermination,
|
||||
// });
|
||||
|
||||
analystAgentStepResults = await runAnalystAgentStep({
|
||||
options: {
|
||||
messageId: input.messageId,
|
||||
chatId: input.chatId,
|
||||
|
@ -91,7 +131,6 @@ export async function runAnalystWorkflow(
|
|||
dataSourceId: input.dataSourceId,
|
||||
dataSourceSyntax: input.dataSourceSyntax,
|
||||
userId: input.userId,
|
||||
sql_dialect_guidance: input.dataSourceSyntax,
|
||||
datasets: input.datasets,
|
||||
workflowStartTime,
|
||||
analysisMode,
|
||||
|
@ -104,53 +143,14 @@ export async function runAnalystWorkflow(
|
|||
},
|
||||
});
|
||||
|
||||
console.info('[runAnalystWorkflow] DEBUG: Think-and-prep results', {
|
||||
workflowId,
|
||||
messageId: input.messageId,
|
||||
earlyTermination: thinkAndPrepAgentStepResults.earlyTermination,
|
||||
messageCount: thinkAndPrepAgentStepResults.messages.length,
|
||||
});
|
||||
|
||||
messages.push(...thinkAndPrepAgentStepResults.messages);
|
||||
|
||||
// Check if think-and-prep agent terminated early (clarifying question or direct response)
|
||||
let analystAgentStepResults = { messages: [] as ModelMessage[] };
|
||||
|
||||
if (!thinkAndPrepAgentStepResults.earlyTermination) {
|
||||
console.info('[runAnalystWorkflow] Running analyst agent step (early termination = false)', {
|
||||
workflowId,
|
||||
messageId: input.messageId,
|
||||
earlyTermination: thinkAndPrepAgentStepResults.earlyTermination,
|
||||
});
|
||||
|
||||
analystAgentStepResults = await runAnalystAgentStep({
|
||||
options: {
|
||||
messageId: input.messageId,
|
||||
chatId: input.chatId,
|
||||
organizationId: input.organizationId,
|
||||
dataSourceId: input.dataSourceId,
|
||||
dataSourceSyntax: input.dataSourceSyntax,
|
||||
userId: input.userId,
|
||||
datasets: input.datasets,
|
||||
workflowStartTime,
|
||||
analysisMode,
|
||||
analystInstructions,
|
||||
organizationDocs,
|
||||
userPersonalizationMessageContent,
|
||||
},
|
||||
streamOptions: {
|
||||
messages,
|
||||
},
|
||||
});
|
||||
|
||||
messages.push(...analystAgentStepResults.messages);
|
||||
} else {
|
||||
console.info('[runAnalystWorkflow] DEBUG: SKIPPING analyst agent due to early termination', {
|
||||
workflowId,
|
||||
messageId: input.messageId,
|
||||
earlyTermination: thinkAndPrepAgentStepResults.earlyTermination,
|
||||
});
|
||||
}
|
||||
messages.push(...analystAgentStepResults.messages);
|
||||
// } else {
|
||||
// console.info('[runAnalystWorkflow] DEBUG: SKIPPING analyst agent due to early termination', {
|
||||
// workflowId,
|
||||
// messageId: input.messageId,
|
||||
// earlyTermination: thinkAndPrepAgentStepResults.earlyTermination,
|
||||
// });
|
||||
// }
|
||||
|
||||
// Extract all tool calls from messages
|
||||
const allToolCalls = extractToolCallsFromMessages(messages);
|
||||
|
|
|
@ -0,0 +1,338 @@
|
|||
import { randomUUID } from 'node:crypto';
|
||||
import type { ModelMessage } from 'ai';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { getChatConversationHistory } from './chatConversationHistory';
|
||||
|
||||
// Mock the database connection and queries
|
||||
vi.mock('../../connection', () => ({
|
||||
db: {
|
||||
select: vi.fn().mockReturnThis(),
|
||||
from: vi.fn().mockReturnThis(),
|
||||
where: vi.fn().mockReturnThis(),
|
||||
limit: vi.fn().mockReturnThis(),
|
||||
orderBy: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
describe('getChatConversationHistory - Orphaned Tool Call Cleanup', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should remove orphaned tool calls (tool calls without matching results)', async () => {
|
||||
// Mock database to return messages with orphaned tool calls
|
||||
const mockMessages: ModelMessage[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'test question',
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [
|
||||
{
|
||||
type: 'tool-call',
|
||||
toolCallId: 'orphaned-call-123',
|
||||
toolName: 'sequentialThinking',
|
||||
input: { thought: 'test thought', nextThoughtNeeded: false },
|
||||
},
|
||||
],
|
||||
},
|
||||
// No tool result for orphaned-call-123
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [
|
||||
{
|
||||
type: 'tool-call',
|
||||
toolCallId: 'valid-call-456',
|
||||
toolName: 'executeSql',
|
||||
input: { statements: ['SELECT 1'] },
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: 'tool',
|
||||
content: [
|
||||
{
|
||||
type: 'tool-result',
|
||||
toolCallId: 'valid-call-456',
|
||||
toolName: 'executeSql',
|
||||
output: { type: 'json', value: '{"results":[]}' },
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
// Mock the database query functions
|
||||
const { db } = await import('../../connection');
|
||||
vi.mocked(db.orderBy).mockResolvedValue([
|
||||
{
|
||||
id: 'msg-1',
|
||||
rawLlmMessages: mockMessages,
|
||||
createdAt: '2025-01-01T00:00:00Z',
|
||||
isCompleted: true,
|
||||
},
|
||||
]);
|
||||
vi.mocked(db.limit).mockResolvedValue([
|
||||
{
|
||||
chatId: 'chat-123',
|
||||
createdAt: '2025-01-01T00:00:00Z',
|
||||
},
|
||||
]);
|
||||
|
||||
const result = await getChatConversationHistory({
|
||||
messageId: randomUUID(),
|
||||
});
|
||||
|
||||
// Should have removed the orphaned tool call but kept the valid one
|
||||
expect(result).toHaveLength(3); // user, assistant (with valid tool call), tool result
|
||||
|
||||
// Find the assistant message
|
||||
const assistantMessages = result.filter((m) => m.role === 'assistant');
|
||||
expect(assistantMessages).toHaveLength(1);
|
||||
|
||||
// The remaining assistant message should only have the valid tool call
|
||||
const assistantContent = assistantMessages[0]?.content;
|
||||
expect(Array.isArray(assistantContent)).toBe(true);
|
||||
if (Array.isArray(assistantContent)) {
|
||||
expect(assistantContent).toHaveLength(1);
|
||||
expect(assistantContent[0]).toMatchObject({
|
||||
type: 'tool-call',
|
||||
toolCallId: 'valid-call-456',
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
it('should keep assistant messages with valid tool calls', async () => {
|
||||
const mockMessages: ModelMessage[] = [
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [
|
||||
{
|
||||
type: 'tool-call',
|
||||
toolCallId: 'call-123',
|
||||
toolName: 'testTool',
|
||||
input: {},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: 'tool',
|
||||
content: [
|
||||
{
|
||||
type: 'tool-result',
|
||||
toolCallId: 'call-123',
|
||||
toolName: 'testTool',
|
||||
output: { type: 'text', value: 'result' },
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
const { db } = await import('../../connection');
|
||||
vi.mocked(db.orderBy).mockResolvedValue([
|
||||
{
|
||||
id: 'msg-1',
|
||||
rawLlmMessages: mockMessages,
|
||||
createdAt: '2025-01-01T00:00:00Z',
|
||||
isCompleted: true,
|
||||
},
|
||||
]);
|
||||
vi.mocked(db.limit).mockResolvedValue([
|
||||
{
|
||||
chatId: 'chat-123',
|
||||
createdAt: '2025-01-01T00:00:00Z',
|
||||
},
|
||||
]);
|
||||
|
||||
const result = await getChatConversationHistory({
|
||||
messageId: randomUUID(),
|
||||
});
|
||||
|
||||
// Should keep both messages (tool call and result)
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result[0]?.role).toBe('assistant');
|
||||
expect(result[1]?.role).toBe('tool');
|
||||
});
|
||||
|
||||
it('should keep assistant messages that have at least one valid tool call (even if some are orphaned)', async () => {
|
||||
const mockMessages: ModelMessage[] = [
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'Let me analyze this',
|
||||
},
|
||||
{
|
||||
type: 'tool-call',
|
||||
toolCallId: 'orphaned-123',
|
||||
toolName: 'orphanedTool',
|
||||
input: {},
|
||||
},
|
||||
{
|
||||
type: 'tool-call',
|
||||
toolCallId: 'valid-456',
|
||||
toolName: 'validTool',
|
||||
input: {},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: 'tool',
|
||||
content: [
|
||||
{
|
||||
type: 'tool-result',
|
||||
toolCallId: 'valid-456',
|
||||
toolName: 'validTool',
|
||||
output: { type: 'text', value: 'success' },
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
const { db } = await import('../../connection');
|
||||
vi.mocked(db.orderBy).mockResolvedValue([
|
||||
{
|
||||
id: 'msg-1',
|
||||
rawLlmMessages: mockMessages,
|
||||
createdAt: '2025-01-01T00:00:00Z',
|
||||
isCompleted: true,
|
||||
},
|
||||
]);
|
||||
vi.mocked(db.limit).mockResolvedValue([
|
||||
{
|
||||
chatId: 'chat-123',
|
||||
createdAt: '2025-01-01T00:00:00Z',
|
||||
},
|
||||
]);
|
||||
|
||||
const result = await getChatConversationHistory({
|
||||
messageId: randomUUID(),
|
||||
});
|
||||
|
||||
// Should keep the assistant message because it has at least one valid tool call
|
||||
// Note: We keep the entire message including the orphaned tool call
|
||||
expect(result).toHaveLength(2);
|
||||
|
||||
const assistantMessage = result.find((m) => m.role === 'assistant');
|
||||
expect(assistantMessage).toBeDefined();
|
||||
|
||||
const content = assistantMessage?.content;
|
||||
expect(Array.isArray(content)).toBe(true);
|
||||
if (Array.isArray(content)) {
|
||||
// Should have all content including the orphaned tool call (we don't modify the message)
|
||||
expect(content).toHaveLength(3);
|
||||
}
|
||||
});
|
||||
|
||||
it('should remove assistant messages that only contain orphaned tool calls', async () => {
|
||||
const mockMessages: ModelMessage[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'test',
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [
|
||||
{
|
||||
type: 'tool-call',
|
||||
toolCallId: 'orphaned-only',
|
||||
toolName: 'orphanedTool',
|
||||
input: {},
|
||||
},
|
||||
],
|
||||
},
|
||||
// No tool result for orphaned-only
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'This is a text response',
|
||||
},
|
||||
];
|
||||
|
||||
const { db } = await import('../../connection');
|
||||
vi.mocked(db.orderBy).mockResolvedValue([
|
||||
{
|
||||
id: 'msg-1',
|
||||
rawLlmMessages: mockMessages,
|
||||
createdAt: '2025-01-01T00:00:00Z',
|
||||
isCompleted: true,
|
||||
},
|
||||
]);
|
||||
vi.mocked(db.limit).mockResolvedValue([
|
||||
{
|
||||
chatId: 'chat-123',
|
||||
createdAt: '2025-01-01T00:00:00Z',
|
||||
},
|
||||
]);
|
||||
|
||||
const result = await getChatConversationHistory({
|
||||
messageId: randomUUID(),
|
||||
});
|
||||
|
||||
// Should have removed the assistant message with only orphaned tool call
|
||||
expect(result).toHaveLength(2); // user + assistant text response
|
||||
expect(result[0]?.role).toBe('user');
|
||||
expect(result[1]?.role).toBe('assistant');
|
||||
expect(result[1]?.content).toBe('This is a text response');
|
||||
});
|
||||
|
||||
it('should handle empty message arrays', async () => {
|
||||
const { db } = await import('../../connection');
|
||||
vi.mocked(db.orderBy).mockResolvedValue([
|
||||
{
|
||||
id: 'msg-1',
|
||||
rawLlmMessages: [],
|
||||
createdAt: '2025-01-01T00:00:00Z',
|
||||
isCompleted: true,
|
||||
},
|
||||
]);
|
||||
vi.mocked(db.limit).mockResolvedValue([
|
||||
{
|
||||
chatId: 'chat-123',
|
||||
createdAt: '2025-01-01T00:00:00Z',
|
||||
},
|
||||
]);
|
||||
|
||||
const result = await getChatConversationHistory({
|
||||
messageId: randomUUID(),
|
||||
});
|
||||
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle messages with no tool calls', async () => {
|
||||
const mockMessages: ModelMessage[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello',
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'Hi there!',
|
||||
},
|
||||
];
|
||||
|
||||
const { db } = await import('../../connection');
|
||||
vi.mocked(db.orderBy).mockResolvedValue([
|
||||
{
|
||||
id: 'msg-1',
|
||||
rawLlmMessages: mockMessages,
|
||||
createdAt: '2025-01-01T00:00:00Z',
|
||||
isCompleted: true,
|
||||
},
|
||||
]);
|
||||
vi.mocked(db.limit).mockResolvedValue([
|
||||
{
|
||||
chatId: 'chat-123',
|
||||
createdAt: '2025-01-01T00:00:00Z',
|
||||
},
|
||||
]);
|
||||
|
||||
const result = await getChatConversationHistory({
|
||||
messageId: randomUUID(),
|
||||
});
|
||||
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result).toEqual(mockMessages);
|
||||
});
|
||||
});
|
|
@ -324,6 +324,80 @@ export const ChatConversationHistoryOutputSchema = z.array(z.custom<ModelMessage
|
|||
export type ChatConversationHistoryInput = z.infer<typeof ChatConversationHistoryInputSchema>;
|
||||
export type ChatConversationHistoryOutput = z.infer<typeof ChatConversationHistoryOutputSchema>;
|
||||
|
||||
/**
|
||||
* Removes orphaned assistant messages with tool calls that have no matching tool results
|
||||
* An orphaned message is an assistant message where ALL its tool calls lack corresponding tool results
|
||||
*/
|
||||
function removeOrphanedToolCalls(messages: ModelMessage[]): ModelMessage[] {
|
||||
// Build a Set of all tool call IDs that have results (from tool role messages)
|
||||
const toolCallIdsWithResults = new Set<string>();
|
||||
|
||||
for (const message of messages) {
|
||||
if (message.role === 'tool' && Array.isArray(message.content)) {
|
||||
for (const part of message.content) {
|
||||
if (
|
||||
typeof part === 'object' &&
|
||||
part !== null &&
|
||||
'type' in part &&
|
||||
part.type === 'tool-result' &&
|
||||
'toolCallId' in part &&
|
||||
typeof part.toolCallId === 'string'
|
||||
) {
|
||||
toolCallIdsWithResults.add(part.toolCallId);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Filter out assistant messages where ALL tool calls are orphaned
|
||||
const filteredMessages: ModelMessage[] = [];
|
||||
|
||||
for (const message of messages) {
|
||||
// Only check assistant messages with array content
|
||||
if (message.role === 'assistant' && Array.isArray(message.content)) {
|
||||
// Extract tool call IDs from this message
|
||||
const toolCallIds: string[] = [];
|
||||
|
||||
for (const part of message.content) {
|
||||
if (
|
||||
typeof part === 'object' &&
|
||||
part !== null &&
|
||||
'type' in part &&
|
||||
part.type === 'tool-call' &&
|
||||
'toolCallId' in part &&
|
||||
typeof part.toolCallId === 'string'
|
||||
) {
|
||||
toolCallIds.push(part.toolCallId);
|
||||
}
|
||||
}
|
||||
|
||||
// If this message has no tool calls, keep it
|
||||
if (toolCallIds.length === 0) {
|
||||
filteredMessages.push(message);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if ANY of the tool calls have results
|
||||
const hasAnyResults = toolCallIds.some((id) => toolCallIdsWithResults.has(id));
|
||||
|
||||
if (hasAnyResults) {
|
||||
// At least one tool call has a result, keep the message
|
||||
filteredMessages.push(message);
|
||||
} else {
|
||||
// ALL tool calls are orphaned, skip this message entirely
|
||||
console.warn('[chatConversationHistory] Removing orphaned assistant message:', {
|
||||
orphanedToolCallIds: toolCallIds,
|
||||
});
|
||||
}
|
||||
} else {
|
||||
// Keep all non-assistant messages (including tool role messages)
|
||||
filteredMessages.push(message);
|
||||
}
|
||||
}
|
||||
|
||||
return filteredMessages;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get conversation history for a chat up to and including a specific message
|
||||
* Finds the chat from the given messageId, then merges and deduplicates all rawLlmMessages
|
||||
|
@ -389,9 +463,12 @@ export async function getChatConversationHistory(
|
|||
// Since we're merging from multiple messages, we should preserve the order they appear
|
||||
const deduplicatedMessages = Array.from(uniqueMessagesMap.values());
|
||||
|
||||
// Remove orphaned tool calls (tool calls without matching tool results)
|
||||
const cleanedMessages = removeOrphanedToolCalls(deduplicatedMessages);
|
||||
|
||||
// Validate output
|
||||
try {
|
||||
return ChatConversationHistoryOutputSchema.parse(deduplicatedMessages);
|
||||
return ChatConversationHistoryOutputSchema.parse(cleanedMessages);
|
||||
} catch (validationError) {
|
||||
throw new Error(
|
||||
`Output validation failed: ${validationError instanceof Error ? validationError.message : 'Invalid output format'}`
|
||||
|
|
Loading…
Reference in New Issue