From 1cb95ce8956c36f2fe6024dabfa1a24916ac2fbf Mon Sep 17 00:00:00 2001 From: dal Date: Tue, 30 Sep 2025 09:44:00 -0600 Subject: [PATCH] fix: streaming --- .../done-tool/done-tool-delta.ts | 33 ++- .../done-tool/done-tool-execute.ts | 75 +++++-- .../done-tool/done-tool-finish.ts | 6 +- .../done-tool/done-tool-start.test.ts | 20 +- .../done-tool/done-tool-start.ts | 7 +- .../done-tool/done-tool-streaming.test.ts | 63 +++++- .../done-tool/done-tool.ts | 10 + .../strategies/re-ask-strategy.test.ts | 3 +- .../structured-output-strategy.test.ts | 18 +- .../messages/update-message-entries.ts | 202 ++++++++++++++--- .../queries/reports/batch-update-report.ts | 207 +++++++++++++++--- 11 files changed, 543 insertions(+), 101 deletions(-) diff --git a/packages/ai/src/tools/communication-tools/done-tool/done-tool-delta.ts b/packages/ai/src/tools/communication-tools/done-tool/done-tool-delta.ts index 1213f4864..fe84b4676 100644 --- a/packages/ai/src/tools/communication-tools/done-tool/done-tool-delta.ts +++ b/packages/ai/src/tools/communication-tools/done-tool/done-tool-delta.ts @@ -1,10 +1,5 @@ -import { - type UpdateMessageEntriesParams, - getAssetLatestVersion, - updateChat, - updateMessage, - updateMessageEntries, -} from '@buster/database/queries'; +import type { UpdateMessageEntriesParams } from '@buster/database/queries'; +import * as databaseQueries from '@buster/database/queries'; import { type ResponseMessageFileType, ResponseMessageFileTypeSchema, @@ -26,12 +21,30 @@ const FINAL_RESPONSE_KEY = 'finalResponse' as const satisfies keyof DoneToolInpu const ASSETS_TO_RETURN_KEY = 'assetsToReturn' as const satisfies keyof DoneToolInput; export function createDoneToolDelta(context: DoneToolContext, doneToolState: DoneToolState) { + const { getAssetLatestVersion, updateChat, updateMessage, updateMessageEntries } = + databaseQueries; + + const isMessageUpdateQueueClosed = databaseQueries.isMessageUpdateQueueClosed ?? (() => false); + return async function doneToolDelta( options: { inputTextDelta: string } & ToolCallOptions ): Promise { if (doneToolState.isFinalizing) { return; } + + if (isMessageUpdateQueueClosed(context.messageId)) { + return; + } + + const recordSequence = (sequenceNumber: number, skipped?: boolean) => { + if (skipped || sequenceNumber < 0) { + return; + } + + const current = doneToolState.latestSequenceNumber ?? -1; + doneToolState.latestSequenceNumber = Math.max(current, sequenceNumber); + }; // Accumulate the delta to the args doneToolState.args = (doneToolState.args || '') + options.inputTextDelta; @@ -155,7 +168,8 @@ export function createDoneToolDelta(context: DoneToolContext, doneToolState: Don }; try { - await updateMessageEntries(entriesForAssets); + const result = await updateMessageEntries(entriesForAssets); + recordSequence(result.sequenceNumber, result.skipped); // Update state to prevent duplicates on next deltas doneToolState.addedAssetIds = [ ...(doneToolState.addedAssetIds || []), @@ -247,7 +261,8 @@ export function createDoneToolDelta(context: DoneToolContext, doneToolState: Don try { if (entries.responseMessages || entries.rawLlmMessages) { - await updateMessageEntries(entries); + const result = await updateMessageEntries(entries); + recordSequence(result.sequenceNumber, result.skipped); } } catch (error) { console.error('[done-tool] Failed to update done tool raw LLM message:', error); diff --git a/packages/ai/src/tools/communication-tools/done-tool/done-tool-execute.ts b/packages/ai/src/tools/communication-tools/done-tool/done-tool-execute.ts index 6a86fc382..9646107f6 100644 --- a/packages/ai/src/tools/communication-tools/done-tool/done-tool-execute.ts +++ b/packages/ai/src/tools/communication-tools/done-tool/done-tool-execute.ts @@ -1,8 +1,4 @@ -import { - updateMessage, - updateMessageEntries, - waitForPendingUpdates, -} from '@buster/database/queries'; +import * as databaseQueries from '@buster/database/queries'; import { wrapTraced } from 'braintrust'; import { cleanupState } from '../../shared/cleanup-state'; import { createRawToolResultEntry } from '../../shared/create-raw-llm-tool-result-entry'; @@ -24,8 +20,9 @@ async function processDone( toolCallId: string, messageId: string, _context: DoneToolContext, - input: DoneToolInput -): Promise { + input: DoneToolInput, + updateOptions?: Parameters[1] +): Promise<{ output: DoneToolOutput; sequenceNumber?: number; skipped?: boolean }> { const output: DoneToolOutput = { success: true, }; @@ -49,25 +46,40 @@ async function processDone( ? [rawLlmMessage, rawToolResultEntry] : [rawToolResultEntry]; - await updateMessageEntries({ - messageId, - rawLlmMessages, - // Include the response message with the complete finalResponse - responseMessages: doneToolResponseEntry ? [doneToolResponseEntry] : undefined, - }); + const updateResult = await updateMessageEntries( + { + messageId, + rawLlmMessages, + // Include the response message with the complete finalResponse + responseMessages: doneToolResponseEntry ? [doneToolResponseEntry] : undefined, + }, + updateOptions + ); // Mark the message as completed await updateMessage(messageId, { isCompleted: true, }); + + return { + output, + sequenceNumber: updateResult.sequenceNumber, + skipped: updateResult.skipped, + }; } catch (error) { console.error('[done-tool] Error updating message entries:', error); + return { + output, + }; } - - return output; } // Factory function that creates the execute function with proper context typing +const updateMessage = databaseQueries.updateMessage; +const updateMessageEntries = databaseQueries.updateMessageEntries; +const waitForPendingUpdates = + databaseQueries.waitForPendingUpdates ?? (async (_messageId: string) => {}); + export function createDoneToolExecute(context: DoneToolContext, state: DoneToolState) { return wrapTraced( async (input: DoneToolInput): Promise => { @@ -78,15 +90,40 @@ export function createDoneToolExecute(context: DoneToolContext, state: DoneToolS state.isFinalizing = true; // CRITICAL: Wait for ALL pending updates from delta/finish to complete FIRST // This ensures execute's update is always the last one in the queue - await waitForPendingUpdates(context.messageId); + if (typeof state.latestSequenceNumber === 'number') { + await waitForPendingUpdates(context.messageId, { + upToSequence: state.latestSequenceNumber, + }); + } else { + await waitForPendingUpdates(context.messageId); + } // Now do the final authoritative update with the complete input - const result = await processDone(state, state.toolCallId, context.messageId, context, input); + const { output, sequenceNumber, skipped } = await processDone( + state, + state.toolCallId, + context.messageId, + context, + input, + { isFinal: true } + ); - await waitForPendingUpdates(context.messageId); + if (!skipped && typeof sequenceNumber === 'number') { + const current = state.latestSequenceNumber ?? -1; + state.latestSequenceNumber = Math.max(current, sequenceNumber); + state.finalSequenceNumber = sequenceNumber; + } + + if (typeof state.finalSequenceNumber === 'number') { + await waitForPendingUpdates(context.messageId, { + upToSequence: state.finalSequenceNumber, + }); + } else { + await waitForPendingUpdates(context.messageId); + } cleanupState(state); - return result; + return output; }, { name: 'Done Tool' } ); diff --git a/packages/ai/src/tools/communication-tools/done-tool/done-tool-finish.ts b/packages/ai/src/tools/communication-tools/done-tool/done-tool-finish.ts index 0d68d6e9f..d11b6d04b 100644 --- a/packages/ai/src/tools/communication-tools/done-tool/done-tool-finish.ts +++ b/packages/ai/src/tools/communication-tools/done-tool/done-tool-finish.ts @@ -30,7 +30,11 @@ export function createDoneToolFinish(context: DoneToolContext, doneToolState: Do try { if (entries.responseMessages || entries.rawLlmMessages) { - await updateMessageEntries(entries); + const result = await updateMessageEntries(entries); + if (!result.skipped && result.sequenceNumber >= 0) { + const current = doneToolState.latestSequenceNumber ?? -1; + doneToolState.latestSequenceNumber = Math.max(current, result.sequenceNumber); + } } } catch (error) { console.error('[done-tool] Failed to update done tool raw LLM message:', error); diff --git a/packages/ai/src/tools/communication-tools/done-tool/done-tool-start.test.ts b/packages/ai/src/tools/communication-tools/done-tool/done-tool-start.test.ts index f02169503..a95a8f846 100644 --- a/packages/ai/src/tools/communication-tools/done-tool/done-tool-start.test.ts +++ b/packages/ai/src/tools/communication-tools/done-tool/done-tool-start.test.ts @@ -1,5 +1,12 @@ import { randomUUID } from 'node:crypto'; -import { updateChat, updateMessage, updateMessageEntries } from '@buster/database/queries'; +import { + getAssetLatestVersion, + isMessageUpdateQueueClosed, + updateChat, + updateMessage, + updateMessageEntries, + waitForPendingUpdates, +} from '@buster/database/queries'; import type { ModelMessage, ToolCallOptions } from 'ai'; import { beforeEach, describe, expect, it, vi } from 'vitest'; import { CREATE_DASHBOARDS_TOOL_NAME } from '../../visualization-tools/dashboards/create-dashboards-tool/create-dashboards-tool'; @@ -13,7 +20,13 @@ import { createDoneToolStart } from './done-tool-start'; vi.mock('@buster/database/queries', () => ({ updateChat: vi.fn(), updateMessage: vi.fn(), - updateMessageEntries: vi.fn(), + updateMessageEntries: vi.fn().mockResolvedValue({ + success: true, + sequenceNumber: 0, + skipped: false as const, + }), + waitForPendingUpdates: vi.fn().mockResolvedValue(undefined), + isMessageUpdateQueueClosed: vi.fn().mockReturnValue(false), getAssetLatestVersion: vi.fn().mockResolvedValue(1), })); @@ -32,6 +45,9 @@ describe('done-tool-start', () => { beforeEach(() => { vi.clearAllMocks(); + isMessageUpdateQueueClosed.mockReturnValue(false); + waitForPendingUpdates.mockResolvedValue(undefined); + getAssetLatestVersion.mockResolvedValue(1); }); describe('mostRecentFile selection', () => { diff --git a/packages/ai/src/tools/communication-tools/done-tool/done-tool-start.ts b/packages/ai/src/tools/communication-tools/done-tool/done-tool-start.ts index 7f78c978c..e2121e3a3 100644 --- a/packages/ai/src/tools/communication-tools/done-tool/done-tool-start.ts +++ b/packages/ai/src/tools/communication-tools/done-tool/done-tool-start.ts @@ -25,6 +25,8 @@ export function createDoneToolStart(context: DoneToolContext, doneToolState: Don doneToolState.addedAssetIds = []; doneToolState.addedAssets = []; doneToolState.isFinalizing = false; + doneToolState.latestSequenceNumber = undefined; + doneToolState.finalSequenceNumber = undefined; // Selection logic moved to delta; skip extracting files here if (options.messages) { @@ -67,7 +69,10 @@ export function createDoneToolStart(context: DoneToolContext, doneToolState: Don try { if (entries.responseMessages || entries.rawLlmMessages) { - await updateMessageEntries(entries); + const result = await updateMessageEntries(entries); + if (!result.skipped && result.sequenceNumber >= 0) { + doneToolState.latestSequenceNumber = result.sequenceNumber; + } } } catch (error) { console.error('[done-tool] Failed to update done tool raw LLM message:', error); diff --git a/packages/ai/src/tools/communication-tools/done-tool/done-tool-streaming.test.ts b/packages/ai/src/tools/communication-tools/done-tool/done-tool-streaming.test.ts index 0f4adf977..1ba7e6b42 100644 --- a/packages/ai/src/tools/communication-tools/done-tool/done-tool-streaming.test.ts +++ b/packages/ai/src/tools/communication-tools/done-tool/done-tool-streaming.test.ts @@ -1,5 +1,5 @@ import type { ModelMessage, ToolCallOptions } from 'ai'; -import { describe, expect, test, vi } from 'vitest'; +import { beforeEach, describe, expect, test, vi } from 'vitest'; import { CREATE_DASHBOARDS_TOOL_NAME } from '../../visualization-tools/dashboards/create-dashboards-tool/create-dashboards-tool'; import { CREATE_METRICS_TOOL_NAME } from '../../visualization-tools/metrics/create-metrics-tool/create-metrics-tool'; import { CREATE_REPORTS_TOOL_NAME } from '../../visualization-tools/reports/create-reports-tool/create-reports-tool'; @@ -8,12 +8,61 @@ import { createDoneToolDelta } from './done-tool-delta'; import { createDoneToolFinish } from './done-tool-finish'; import { createDoneToolStart } from './done-tool-start'; -vi.mock('@buster/database/queries', () => ({ - updateMessageEntries: vi.fn().mockResolvedValue({ success: true }), - updateMessage: vi.fn().mockResolvedValue({ success: true }), - updateChat: vi.fn().mockResolvedValue({ success: true }), - getAssetLatestVersion: vi.fn().mockResolvedValue(1), -})); +const queriesMock = vi.hoisted(() => { + let sequence = 0; + + const updateMessageEntries = vi.fn(async () => ({ + success: true, + sequenceNumber: sequence++, + skipped: false as const, + })); + const waitForPendingUpdates = vi.fn().mockResolvedValue(undefined); + const isMessageUpdateQueueClosed = vi.fn().mockReturnValue(false); + const updateMessage = vi.fn().mockResolvedValue({ success: true }); + const updateChat = vi.fn().mockResolvedValue({ success: true }); + const getAssetLatestVersion = vi.fn().mockResolvedValue(1); + + return { + updateMessageEntries, + waitForPendingUpdates, + isMessageUpdateQueueClosed, + updateMessage, + updateChat, + getAssetLatestVersion, + reset() { + sequence = 0; + updateMessageEntries.mockClear(); + waitForPendingUpdates.mockClear(); + isMessageUpdateQueueClosed.mockClear(); + updateMessage.mockClear(); + updateChat.mockClear(); + getAssetLatestVersion.mockClear(); + waitForPendingUpdates.mockResolvedValue(undefined); + isMessageUpdateQueueClosed.mockReturnValue(false); + getAssetLatestVersion.mockResolvedValue(1); + }, + }; +}); + +vi.mock('@buster/database/queries', async () => { + const actual = await vi.importActual( + '@buster/database/queries' + ); + + return { + ...actual, + updateMessageEntries: queriesMock.updateMessageEntries, + waitForPendingUpdates: queriesMock.waitForPendingUpdates, + isMessageUpdateQueueClosed: queriesMock.isMessageUpdateQueueClosed, + updateMessage: queriesMock.updateMessage, + updateChat: queriesMock.updateChat, + getAssetLatestVersion: queriesMock.getAssetLatestVersion, + }; +}); + +beforeEach(() => { + queriesMock.reset(); +}); describe('Done Tool Streaming Tests', () => { const mockContext: DoneToolContext = { diff --git a/packages/ai/src/tools/communication-tools/done-tool/done-tool.ts b/packages/ai/src/tools/communication-tools/done-tool/done-tool.ts index 44532b15d..59045db74 100644 --- a/packages/ai/src/tools/communication-tools/done-tool/done-tool.ts +++ b/packages/ai/src/tools/communication-tools/done-tool/done-tool.ts @@ -75,6 +75,14 @@ const DoneToolStateSchema = z.object({ .boolean() .optional() .describe('Indicates the execute phase has started so further deltas should be ignored'), + latestSequenceNumber: z + .number() + .optional() + .describe('Highest message update sequence number observed during streaming'), + finalSequenceNumber: z + .number() + .optional() + .describe('Sequence number for the final execute message update'), }); export type DoneToolInput = z.infer; @@ -90,6 +98,8 @@ export function createDoneTool(context: DoneToolContext) { addedAssetIds: [], addedAssets: [], isFinalizing: false, + latestSequenceNumber: undefined, + finalSequenceNumber: undefined, }; const execute = createDoneToolExecute(context, state); diff --git a/packages/ai/src/utils/tool-call-repair/strategies/re-ask-strategy.test.ts b/packages/ai/src/utils/tool-call-repair/strategies/re-ask-strategy.test.ts index 0a898ffb4..359e653f0 100644 --- a/packages/ai/src/utils/tool-call-repair/strategies/re-ask-strategy.test.ts +++ b/packages/ai/src/utils/tool-call-repair/strategies/re-ask-strategy.test.ts @@ -41,6 +41,7 @@ vi.mock('braintrust', () => ({ vi.mock('../../../llm', () => ({ Sonnet4: 'mock-model', + GPT5Mini: 'mock-model', })); describe('re-ask-strategy', () => { @@ -130,7 +131,7 @@ describe('re-ask-strategy', () => { expect.objectContaining({ role: 'tool' }), ]), tools: context.tools, - maxOutputTokens: 1000, + maxOutputTokens: 10000, temperature: 0, }) ); diff --git a/packages/ai/src/utils/tool-call-repair/strategies/structured-output-strategy.test.ts b/packages/ai/src/utils/tool-call-repair/strategies/structured-output-strategy.test.ts index fc8c902f2..6f80e05e1 100644 --- a/packages/ai/src/utils/tool-call-repair/strategies/structured-output-strategy.test.ts +++ b/packages/ai/src/utils/tool-call-repair/strategies/structured-output-strategy.test.ts @@ -18,6 +18,7 @@ vi.mock('braintrust', () => ({ vi.mock('../../../llm', () => ({ Sonnet4: 'mock-model', + GPT5Mini: 'mock-model', })); describe('structured-output-strategy', () => { @@ -86,13 +87,16 @@ describe('structured-output-strategy', () => { }); const tool = context.tools.testTool as any; - expect(mockGenerateObject).toHaveBeenCalledWith({ - model: 'mock-model', - schema: tool?.inputSchema, - prompt: expect.stringContaining('Fix these tool arguments'), - mode: 'json', - providerOptions: expect.any(Object), - }); + expect(mockGenerateObject).toHaveBeenCalledWith( + expect.objectContaining({ + model: 'mock-model', + schema: tool?.inputSchema, + prompt: expect.stringContaining('Fix these tool arguments'), + mode: 'json', + maxOutputTokens: 10000, + providerOptions: expect.objectContaining({}), + }) + ); }); it('should return null if tool not found', async () => { diff --git a/packages/database/src/queries/messages/update-message-entries.ts b/packages/database/src/queries/messages/update-message-entries.ts index b7f705360..677461b65 100644 --- a/packages/database/src/queries/messages/update-message-entries.ts +++ b/packages/database/src/queries/messages/update-message-entries.ts @@ -21,18 +21,120 @@ const UpdateMessageEntriesSchema = z.object({ export type UpdateMessageEntriesParams = z.infer; -// Simple in-memory queue for each messageId -const updateQueues = new Map>(); +type Deferred = { + promise: Promise; + resolve: (value: T | PromiseLike) => void; + reject: (reason?: unknown) => void; +}; + +function createDeferred(): Deferred { + let resolve!: (value: T | PromiseLike) => void; + let reject!: (reason?: unknown) => void; + + const promise = new Promise((res, rej) => { + resolve = res; + reject = rej; + }); + + promise.catch(() => undefined); + + return { promise, resolve, reject }; +} + +type MessageUpdateQueueState = { + tailPromise: Promise; + nextSequence: number; + pending: Map>; + lastCompletedSequence: number; + finalSequence?: number; + closed: boolean; +}; + +const updateQueues = new Map(); + +function getOrCreateQueueState(messageId: string): MessageUpdateQueueState { + const existing = updateQueues.get(messageId); + if (existing) { + return existing; + } + + const initialState: MessageUpdateQueueState = { + tailPromise: Promise.resolve(), + nextSequence: 0, + pending: new Map(), + lastCompletedSequence: -1, + closed: false, + }; + + updateQueues.set(messageId, initialState); + return initialState; +} + +function cleanupQueueIfIdle(messageId: string, state: MessageUpdateQueueState): void { + if ( + state.closed && + state.finalSequence !== undefined && + state.lastCompletedSequence >= state.finalSequence && + state.pending.size === 0 + ) { + updateQueues.delete(messageId); + } +} + +export function isMessageUpdateQueueClosed(messageId: string): boolean { + const queue = updateQueues.get(messageId); + return queue?.closed ?? false; +} + +type WaitForPendingUpdateOptions = { + upToSequence?: number; +}; /** - * Wait for all pending updates for a given messageId to complete. - * This ensures all queued updates are flushed to the database before proceeding. + * Wait for pending updates for a given messageId to complete. + * Optionally provide a sequence number to wait through. */ -export async function waitForPendingUpdates(messageId: string): Promise { - const pendingQueue = updateQueues.get(messageId); - if (pendingQueue) { - await pendingQueue; +export async function waitForPendingUpdates( + messageId: string, + options?: WaitForPendingUpdateOptions +): Promise { + const queue = updateQueues.get(messageId); + if (!queue) { + return; } + + const targetSequence = options?.upToSequence ?? queue.finalSequence; + + if (targetSequence === undefined) { + await queue.tailPromise; + cleanupQueueIfIdle(messageId, queue); + return; + } + + const maxKnownSequence = queue.nextSequence - 1; + const effectiveTarget = Math.min(targetSequence, maxKnownSequence); + + if (effectiveTarget <= queue.lastCompletedSequence) { + cleanupQueueIfIdle(messageId, queue); + return; + } + + const waits: Promise[] = []; + + for (let sequence = queue.lastCompletedSequence + 1; sequence <= effectiveTarget; sequence += 1) { + const deferred = queue.pending.get(sequence); + if (deferred) { + waits.push(deferred.promise.catch(() => undefined)); + } + } + + if (waits.length > 0) { + await Promise.all(waits); + } else { + await queue.tailPromise; + } + + cleanupQueueIfIdle(messageId, queue); } /** @@ -116,29 +218,79 @@ async function performUpdate({ * - reasoningMessages: upsert by 'id' field, maintaining order * - rawLlmMessages: upsert by combination of 'role' and 'toolCallId', maintaining order */ +type UpdateMessageEntriesOptions = { + isFinal?: boolean; +}; + +type UpdateMessageEntriesResult = { + success: boolean; + sequenceNumber: number; + skipped?: boolean; +}; + export async function updateMessageEntries( - params: UpdateMessageEntriesParams -): Promise<{ success: boolean }> { + params: UpdateMessageEntriesParams, + options?: UpdateMessageEntriesOptions +): Promise { const { messageId } = params; - // Get the current promise for this messageId, or use a resolved promise as the starting point - const currentQueue = updateQueues.get(messageId) ?? Promise.resolve({ success: true }); + const queue = getOrCreateQueueState(messageId); - // Chain the new update to run after the current queue completes - const newQueue = currentQueue - .then(() => performUpdate(params)) - .catch(() => performUpdate(params)); // Still try to run even if previous failed + if (queue.closed) { + const lastKnownSequence = queue.finalSequence ?? queue.nextSequence - 1; + return { + success: false, + sequenceNumber: lastKnownSequence >= 0 ? lastKnownSequence : -1, + skipped: true, + }; + } - // Update the queue for this messageId - updateQueues.set(messageId, newQueue); + const isFinal = options?.isFinal ?? false; - // Clean up the queue entry once this update completes - newQueue.finally(() => { - // Only remove if this is still the current queue - if (updateQueues.get(messageId) === newQueue) { - updateQueues.delete(messageId); + if (isFinal) { + queue.closed = true; + } + + const sequenceNumber = queue.nextSequence; + queue.nextSequence += 1; + + const deferred = createDeferred(); + queue.pending.set(sequenceNumber, deferred); + + const runUpdate = () => performUpdate(params); + + const runPromise = queue.tailPromise.then(runUpdate, runUpdate); + + queue.tailPromise = runPromise.then( + () => undefined, + () => undefined + ); + + const finalize = (success: boolean) => { + queue.pending.delete(sequenceNumber); + queue.lastCompletedSequence = Math.max(queue.lastCompletedSequence, sequenceNumber); + if (isFinal) { + queue.finalSequence = sequenceNumber; } - }); + cleanupQueueIfIdle(messageId, queue); + return success; + }; - return newQueue; + const resultPromise = runPromise + .then((result) => { + deferred.resolve(); + finalize(true); + return { + ...result, + sequenceNumber, + skipped: false as const, + }; + }) + .catch((error) => { + deferred.reject(error); + finalize(false); + throw error; + }); + + return resultPromise; } diff --git a/packages/database/src/queries/reports/batch-update-report.ts b/packages/database/src/queries/reports/batch-update-report.ts index 25041ba6f..fa5edbecc 100644 --- a/packages/database/src/queries/reports/batch-update-report.ts +++ b/packages/database/src/queries/reports/batch-update-report.ts @@ -31,18 +31,120 @@ type VersionHistoryEntry = { type VersionHistory = Record; -// Simple in-memory queue for each reportId -const updateQueues = new Map>(); +type Deferred = { + promise: Promise; + resolve: (value: T | PromiseLike) => void; + reject: (reason?: unknown) => void; +}; + +function createDeferred(): Deferred { + let resolve!: (value: T | PromiseLike) => void; + let reject!: (reason?: unknown) => void; + + const promise = new Promise((res, rej) => { + resolve = res; + reject = rej; + }); + + promise.catch(() => undefined); + + return { promise, resolve, reject }; +} + +type ReportUpdateQueueState = { + tailPromise: Promise; + nextSequence: number; + pending: Map>; + lastCompletedSequence: number; + finalSequence?: number; + closed: boolean; +}; + +const updateQueues = new Map(); + +function getOrCreateQueueState(reportId: string): ReportUpdateQueueState { + const existing = updateQueues.get(reportId); + if (existing) { + return existing; + } + + const initialState: ReportUpdateQueueState = { + tailPromise: Promise.resolve(), + nextSequence: 0, + pending: new Map(), + lastCompletedSequence: -1, + closed: false, + }; + + updateQueues.set(reportId, initialState); + return initialState; +} + +function cleanupQueueIfIdle(reportId: string, state: ReportUpdateQueueState): void { + if ( + state.closed && + state.finalSequence !== undefined && + state.lastCompletedSequence >= state.finalSequence && + state.pending.size === 0 + ) { + updateQueues.delete(reportId); + } +} + +export function isReportUpdateQueueClosed(reportId: string): boolean { + const queue = updateQueues.get(reportId); + return queue?.closed ?? false; +} + +type WaitForPendingReportUpdateOptions = { + upToSequence?: number; +}; /** * Wait for all pending updates for a given reportId to complete. * This ensures all queued updates are flushed to the database before proceeding. */ -export async function waitForPendingReportUpdates(reportId: string): Promise { - const pendingQueue = updateQueues.get(reportId); - if (pendingQueue) { - await pendingQueue; +export async function waitForPendingReportUpdates( + reportId: string, + options?: WaitForPendingReportUpdateOptions +): Promise { + const queue = updateQueues.get(reportId); + if (!queue) { + return; } + + const targetSequence = options?.upToSequence ?? queue.finalSequence; + + if (targetSequence === undefined) { + await queue.tailPromise; + cleanupQueueIfIdle(reportId, queue); + return; + } + + const maxKnownSequence = queue.nextSequence - 1; + const effectiveTarget = Math.min(targetSequence, maxKnownSequence); + + if (effectiveTarget <= queue.lastCompletedSequence) { + cleanupQueueIfIdle(reportId, queue); + return; + } + + const waits: Promise[] = []; + + for (let sequence = queue.lastCompletedSequence + 1; sequence <= effectiveTarget; sequence += 1) { + const deferred = queue.pending.get(sequence); + if (deferred) { + waits.push(deferred.promise.catch(() => undefined)); + } + } + + if (waits.length > 0) { + await Promise.all(waits); + } else { + await queue.tailPromise; + } + + cleanupQueueIfIdle(reportId, queue); } /** @@ -93,27 +195,74 @@ async function performUpdate(params: BatchUpdateReportInput): Promise { * Updates a report's content, name, and version history in a single operation. * Updates are queued per reportId to ensure they execute in order. */ -export const updateReportWithVersion = async (params: BatchUpdateReportInput): Promise => { - const { reportId } = params; - - // Get the current promise for this reportId, or use a resolved promise as the starting point - const currentQueue = updateQueues.get(reportId) ?? Promise.resolve(); - - // Chain the new update to run after the current queue completes - const newQueue = currentQueue - .then(() => performUpdate(params)) - .catch(() => performUpdate(params)); // Still try to run even if previous failed - - // Update the queue for this reportId - updateQueues.set(reportId, newQueue); - - // Clean up the queue entry once this update completes - newQueue.finally(() => { - // Only remove if this is still the current queue - if (updateQueues.get(reportId) === newQueue) { - updateQueues.delete(reportId); - } - }); - - return newQueue; +type UpdateReportWithVersionOptions = { + isFinal?: boolean; +}; + +type UpdateReportWithVersionResult = { + sequenceNumber: number; + skipped?: boolean; +}; + +export const updateReportWithVersion = async ( + params: BatchUpdateReportInput, + options?: UpdateReportWithVersionOptions +): Promise => { + const { reportId } = params; + const queue = getOrCreateQueueState(reportId); + + if (queue.closed) { + const lastKnownSequence = queue.finalSequence ?? queue.nextSequence - 1; + return { + sequenceNumber: lastKnownSequence >= 0 ? lastKnownSequence : -1, + skipped: true, + }; + } + + const isFinal = options?.isFinal ?? false; + + if (isFinal) { + queue.closed = true; + } + + const sequenceNumber = queue.nextSequence; + queue.nextSequence += 1; + + const deferred = createDeferred(); + queue.pending.set(sequenceNumber, deferred); + + const runUpdate = () => performUpdate(params); + + const runPromise = queue.tailPromise.then(runUpdate, runUpdate); + + queue.tailPromise = runPromise.then( + () => undefined, + () => undefined + ); + + const finalize = () => { + queue.pending.delete(sequenceNumber); + queue.lastCompletedSequence = Math.max(queue.lastCompletedSequence, sequenceNumber); + if (isFinal) { + queue.finalSequence = sequenceNumber; + } + cleanupQueueIfIdle(reportId, queue); + }; + + const resultPromise = runPromise + .then(() => { + deferred.resolve(); + finalize(); + return { + sequenceNumber, + skipped: false as const, + }; + }) + .catch((error) => { + deferred.reject(error); + finalize(); + throw error; + }); + + return resultPromise; };