mirror of https://github.com/buster-so/buster.git
fix: streaming
This commit is contained in:
parent
f3184d70fd
commit
1cb95ce895
|
@ -1,10 +1,5 @@
|
|||
import {
|
||||
type UpdateMessageEntriesParams,
|
||||
getAssetLatestVersion,
|
||||
updateChat,
|
||||
updateMessage,
|
||||
updateMessageEntries,
|
||||
} from '@buster/database/queries';
|
||||
import type { UpdateMessageEntriesParams } from '@buster/database/queries';
|
||||
import * as databaseQueries from '@buster/database/queries';
|
||||
import {
|
||||
type ResponseMessageFileType,
|
||||
ResponseMessageFileTypeSchema,
|
||||
|
@ -26,12 +21,30 @@ const FINAL_RESPONSE_KEY = 'finalResponse' as const satisfies keyof DoneToolInpu
|
|||
const ASSETS_TO_RETURN_KEY = 'assetsToReturn' as const satisfies keyof DoneToolInput;
|
||||
|
||||
export function createDoneToolDelta(context: DoneToolContext, doneToolState: DoneToolState) {
|
||||
const { getAssetLatestVersion, updateChat, updateMessage, updateMessageEntries } =
|
||||
databaseQueries;
|
||||
|
||||
const isMessageUpdateQueueClosed = databaseQueries.isMessageUpdateQueueClosed ?? (() => false);
|
||||
|
||||
return async function doneToolDelta(
|
||||
options: { inputTextDelta: string } & ToolCallOptions
|
||||
): Promise<void> {
|
||||
if (doneToolState.isFinalizing) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (isMessageUpdateQueueClosed(context.messageId)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const recordSequence = (sequenceNumber: number, skipped?: boolean) => {
|
||||
if (skipped || sequenceNumber < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const current = doneToolState.latestSequenceNumber ?? -1;
|
||||
doneToolState.latestSequenceNumber = Math.max(current, sequenceNumber);
|
||||
};
|
||||
// Accumulate the delta to the args
|
||||
doneToolState.args = (doneToolState.args || '') + options.inputTextDelta;
|
||||
|
||||
|
@ -155,7 +168,8 @@ export function createDoneToolDelta(context: DoneToolContext, doneToolState: Don
|
|||
};
|
||||
|
||||
try {
|
||||
await updateMessageEntries(entriesForAssets);
|
||||
const result = await updateMessageEntries(entriesForAssets);
|
||||
recordSequence(result.sequenceNumber, result.skipped);
|
||||
// Update state to prevent duplicates on next deltas
|
||||
doneToolState.addedAssetIds = [
|
||||
...(doneToolState.addedAssetIds || []),
|
||||
|
@ -247,7 +261,8 @@ export function createDoneToolDelta(context: DoneToolContext, doneToolState: Don
|
|||
|
||||
try {
|
||||
if (entries.responseMessages || entries.rawLlmMessages) {
|
||||
await updateMessageEntries(entries);
|
||||
const result = await updateMessageEntries(entries);
|
||||
recordSequence(result.sequenceNumber, result.skipped);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('[done-tool] Failed to update done tool raw LLM message:', error);
|
||||
|
|
|
@ -1,8 +1,4 @@
|
|||
import {
|
||||
updateMessage,
|
||||
updateMessageEntries,
|
||||
waitForPendingUpdates,
|
||||
} from '@buster/database/queries';
|
||||
import * as databaseQueries from '@buster/database/queries';
|
||||
import { wrapTraced } from 'braintrust';
|
||||
import { cleanupState } from '../../shared/cleanup-state';
|
||||
import { createRawToolResultEntry } from '../../shared/create-raw-llm-tool-result-entry';
|
||||
|
@ -24,8 +20,9 @@ async function processDone(
|
|||
toolCallId: string,
|
||||
messageId: string,
|
||||
_context: DoneToolContext,
|
||||
input: DoneToolInput
|
||||
): Promise<DoneToolOutput> {
|
||||
input: DoneToolInput,
|
||||
updateOptions?: Parameters<typeof updateMessageEntries>[1]
|
||||
): Promise<{ output: DoneToolOutput; sequenceNumber?: number; skipped?: boolean }> {
|
||||
const output: DoneToolOutput = {
|
||||
success: true,
|
||||
};
|
||||
|
@ -49,25 +46,40 @@ async function processDone(
|
|||
? [rawLlmMessage, rawToolResultEntry]
|
||||
: [rawToolResultEntry];
|
||||
|
||||
await updateMessageEntries({
|
||||
messageId,
|
||||
rawLlmMessages,
|
||||
// Include the response message with the complete finalResponse
|
||||
responseMessages: doneToolResponseEntry ? [doneToolResponseEntry] : undefined,
|
||||
});
|
||||
const updateResult = await updateMessageEntries(
|
||||
{
|
||||
messageId,
|
||||
rawLlmMessages,
|
||||
// Include the response message with the complete finalResponse
|
||||
responseMessages: doneToolResponseEntry ? [doneToolResponseEntry] : undefined,
|
||||
},
|
||||
updateOptions
|
||||
);
|
||||
|
||||
// Mark the message as completed
|
||||
await updateMessage(messageId, {
|
||||
isCompleted: true,
|
||||
});
|
||||
|
||||
return {
|
||||
output,
|
||||
sequenceNumber: updateResult.sequenceNumber,
|
||||
skipped: updateResult.skipped,
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('[done-tool] Error updating message entries:', error);
|
||||
return {
|
||||
output,
|
||||
};
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
// Factory function that creates the execute function with proper context typing
|
||||
const updateMessage = databaseQueries.updateMessage;
|
||||
const updateMessageEntries = databaseQueries.updateMessageEntries;
|
||||
const waitForPendingUpdates =
|
||||
databaseQueries.waitForPendingUpdates ?? (async (_messageId: string) => {});
|
||||
|
||||
export function createDoneToolExecute(context: DoneToolContext, state: DoneToolState) {
|
||||
return wrapTraced(
|
||||
async (input: DoneToolInput): Promise<DoneToolOutput> => {
|
||||
|
@ -78,15 +90,40 @@ export function createDoneToolExecute(context: DoneToolContext, state: DoneToolS
|
|||
state.isFinalizing = true;
|
||||
// CRITICAL: Wait for ALL pending updates from delta/finish to complete FIRST
|
||||
// This ensures execute's update is always the last one in the queue
|
||||
await waitForPendingUpdates(context.messageId);
|
||||
if (typeof state.latestSequenceNumber === 'number') {
|
||||
await waitForPendingUpdates(context.messageId, {
|
||||
upToSequence: state.latestSequenceNumber,
|
||||
});
|
||||
} else {
|
||||
await waitForPendingUpdates(context.messageId);
|
||||
}
|
||||
|
||||
// Now do the final authoritative update with the complete input
|
||||
const result = await processDone(state, state.toolCallId, context.messageId, context, input);
|
||||
const { output, sequenceNumber, skipped } = await processDone(
|
||||
state,
|
||||
state.toolCallId,
|
||||
context.messageId,
|
||||
context,
|
||||
input,
|
||||
{ isFinal: true }
|
||||
);
|
||||
|
||||
await waitForPendingUpdates(context.messageId);
|
||||
if (!skipped && typeof sequenceNumber === 'number') {
|
||||
const current = state.latestSequenceNumber ?? -1;
|
||||
state.latestSequenceNumber = Math.max(current, sequenceNumber);
|
||||
state.finalSequenceNumber = sequenceNumber;
|
||||
}
|
||||
|
||||
if (typeof state.finalSequenceNumber === 'number') {
|
||||
await waitForPendingUpdates(context.messageId, {
|
||||
upToSequence: state.finalSequenceNumber,
|
||||
});
|
||||
} else {
|
||||
await waitForPendingUpdates(context.messageId);
|
||||
}
|
||||
|
||||
cleanupState(state);
|
||||
return result;
|
||||
return output;
|
||||
},
|
||||
{ name: 'Done Tool' }
|
||||
);
|
||||
|
|
|
@ -30,7 +30,11 @@ export function createDoneToolFinish(context: DoneToolContext, doneToolState: Do
|
|||
|
||||
try {
|
||||
if (entries.responseMessages || entries.rawLlmMessages) {
|
||||
await updateMessageEntries(entries);
|
||||
const result = await updateMessageEntries(entries);
|
||||
if (!result.skipped && result.sequenceNumber >= 0) {
|
||||
const current = doneToolState.latestSequenceNumber ?? -1;
|
||||
doneToolState.latestSequenceNumber = Math.max(current, result.sequenceNumber);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('[done-tool] Failed to update done tool raw LLM message:', error);
|
||||
|
|
|
@ -1,5 +1,12 @@
|
|||
import { randomUUID } from 'node:crypto';
|
||||
import { updateChat, updateMessage, updateMessageEntries } from '@buster/database/queries';
|
||||
import {
|
||||
getAssetLatestVersion,
|
||||
isMessageUpdateQueueClosed,
|
||||
updateChat,
|
||||
updateMessage,
|
||||
updateMessageEntries,
|
||||
waitForPendingUpdates,
|
||||
} from '@buster/database/queries';
|
||||
import type { ModelMessage, ToolCallOptions } from 'ai';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { CREATE_DASHBOARDS_TOOL_NAME } from '../../visualization-tools/dashboards/create-dashboards-tool/create-dashboards-tool';
|
||||
|
@ -13,7 +20,13 @@ import { createDoneToolStart } from './done-tool-start';
|
|||
vi.mock('@buster/database/queries', () => ({
|
||||
updateChat: vi.fn(),
|
||||
updateMessage: vi.fn(),
|
||||
updateMessageEntries: vi.fn(),
|
||||
updateMessageEntries: vi.fn().mockResolvedValue({
|
||||
success: true,
|
||||
sequenceNumber: 0,
|
||||
skipped: false as const,
|
||||
}),
|
||||
waitForPendingUpdates: vi.fn().mockResolvedValue(undefined),
|
||||
isMessageUpdateQueueClosed: vi.fn().mockReturnValue(false),
|
||||
getAssetLatestVersion: vi.fn().mockResolvedValue(1),
|
||||
}));
|
||||
|
||||
|
@ -32,6 +45,9 @@ describe('done-tool-start', () => {
|
|||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
isMessageUpdateQueueClosed.mockReturnValue(false);
|
||||
waitForPendingUpdates.mockResolvedValue(undefined);
|
||||
getAssetLatestVersion.mockResolvedValue(1);
|
||||
});
|
||||
|
||||
describe('mostRecentFile selection', () => {
|
||||
|
|
|
@ -25,6 +25,8 @@ export function createDoneToolStart(context: DoneToolContext, doneToolState: Don
|
|||
doneToolState.addedAssetIds = [];
|
||||
doneToolState.addedAssets = [];
|
||||
doneToolState.isFinalizing = false;
|
||||
doneToolState.latestSequenceNumber = undefined;
|
||||
doneToolState.finalSequenceNumber = undefined;
|
||||
|
||||
// Selection logic moved to delta; skip extracting files here
|
||||
if (options.messages) {
|
||||
|
@ -67,7 +69,10 @@ export function createDoneToolStart(context: DoneToolContext, doneToolState: Don
|
|||
|
||||
try {
|
||||
if (entries.responseMessages || entries.rawLlmMessages) {
|
||||
await updateMessageEntries(entries);
|
||||
const result = await updateMessageEntries(entries);
|
||||
if (!result.skipped && result.sequenceNumber >= 0) {
|
||||
doneToolState.latestSequenceNumber = result.sequenceNumber;
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('[done-tool] Failed to update done tool raw LLM message:', error);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import type { ModelMessage, ToolCallOptions } from 'ai';
|
||||
import { describe, expect, test, vi } from 'vitest';
|
||||
import { beforeEach, describe, expect, test, vi } from 'vitest';
|
||||
import { CREATE_DASHBOARDS_TOOL_NAME } from '../../visualization-tools/dashboards/create-dashboards-tool/create-dashboards-tool';
|
||||
import { CREATE_METRICS_TOOL_NAME } from '../../visualization-tools/metrics/create-metrics-tool/create-metrics-tool';
|
||||
import { CREATE_REPORTS_TOOL_NAME } from '../../visualization-tools/reports/create-reports-tool/create-reports-tool';
|
||||
|
@ -8,12 +8,61 @@ import { createDoneToolDelta } from './done-tool-delta';
|
|||
import { createDoneToolFinish } from './done-tool-finish';
|
||||
import { createDoneToolStart } from './done-tool-start';
|
||||
|
||||
vi.mock('@buster/database/queries', () => ({
|
||||
updateMessageEntries: vi.fn().mockResolvedValue({ success: true }),
|
||||
updateMessage: vi.fn().mockResolvedValue({ success: true }),
|
||||
updateChat: vi.fn().mockResolvedValue({ success: true }),
|
||||
getAssetLatestVersion: vi.fn().mockResolvedValue(1),
|
||||
}));
|
||||
const queriesMock = vi.hoisted(() => {
|
||||
let sequence = 0;
|
||||
|
||||
const updateMessageEntries = vi.fn(async () => ({
|
||||
success: true,
|
||||
sequenceNumber: sequence++,
|
||||
skipped: false as const,
|
||||
}));
|
||||
const waitForPendingUpdates = vi.fn().mockResolvedValue(undefined);
|
||||
const isMessageUpdateQueueClosed = vi.fn().mockReturnValue(false);
|
||||
const updateMessage = vi.fn().mockResolvedValue({ success: true });
|
||||
const updateChat = vi.fn().mockResolvedValue({ success: true });
|
||||
const getAssetLatestVersion = vi.fn().mockResolvedValue(1);
|
||||
|
||||
return {
|
||||
updateMessageEntries,
|
||||
waitForPendingUpdates,
|
||||
isMessageUpdateQueueClosed,
|
||||
updateMessage,
|
||||
updateChat,
|
||||
getAssetLatestVersion,
|
||||
reset() {
|
||||
sequence = 0;
|
||||
updateMessageEntries.mockClear();
|
||||
waitForPendingUpdates.mockClear();
|
||||
isMessageUpdateQueueClosed.mockClear();
|
||||
updateMessage.mockClear();
|
||||
updateChat.mockClear();
|
||||
getAssetLatestVersion.mockClear();
|
||||
waitForPendingUpdates.mockResolvedValue(undefined);
|
||||
isMessageUpdateQueueClosed.mockReturnValue(false);
|
||||
getAssetLatestVersion.mockResolvedValue(1);
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('@buster/database/queries', async () => {
|
||||
const actual = await vi.importActual<typeof import('@buster/database/queries')>(
|
||||
'@buster/database/queries'
|
||||
);
|
||||
|
||||
return {
|
||||
...actual,
|
||||
updateMessageEntries: queriesMock.updateMessageEntries,
|
||||
waitForPendingUpdates: queriesMock.waitForPendingUpdates,
|
||||
isMessageUpdateQueueClosed: queriesMock.isMessageUpdateQueueClosed,
|
||||
updateMessage: queriesMock.updateMessage,
|
||||
updateChat: queriesMock.updateChat,
|
||||
getAssetLatestVersion: queriesMock.getAssetLatestVersion,
|
||||
};
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
queriesMock.reset();
|
||||
});
|
||||
|
||||
describe('Done Tool Streaming Tests', () => {
|
||||
const mockContext: DoneToolContext = {
|
||||
|
|
|
@ -75,6 +75,14 @@ const DoneToolStateSchema = z.object({
|
|||
.boolean()
|
||||
.optional()
|
||||
.describe('Indicates the execute phase has started so further deltas should be ignored'),
|
||||
latestSequenceNumber: z
|
||||
.number()
|
||||
.optional()
|
||||
.describe('Highest message update sequence number observed during streaming'),
|
||||
finalSequenceNumber: z
|
||||
.number()
|
||||
.optional()
|
||||
.describe('Sequence number for the final execute message update'),
|
||||
});
|
||||
|
||||
export type DoneToolInput = z.infer<typeof DoneToolInputSchema>;
|
||||
|
@ -90,6 +98,8 @@ export function createDoneTool(context: DoneToolContext) {
|
|||
addedAssetIds: [],
|
||||
addedAssets: [],
|
||||
isFinalizing: false,
|
||||
latestSequenceNumber: undefined,
|
||||
finalSequenceNumber: undefined,
|
||||
};
|
||||
|
||||
const execute = createDoneToolExecute(context, state);
|
||||
|
|
|
@ -41,6 +41,7 @@ vi.mock('braintrust', () => ({
|
|||
|
||||
vi.mock('../../../llm', () => ({
|
||||
Sonnet4: 'mock-model',
|
||||
GPT5Mini: 'mock-model',
|
||||
}));
|
||||
|
||||
describe('re-ask-strategy', () => {
|
||||
|
@ -130,7 +131,7 @@ describe('re-ask-strategy', () => {
|
|||
expect.objectContaining({ role: 'tool' }),
|
||||
]),
|
||||
tools: context.tools,
|
||||
maxOutputTokens: 1000,
|
||||
maxOutputTokens: 10000,
|
||||
temperature: 0,
|
||||
})
|
||||
);
|
||||
|
|
|
@ -18,6 +18,7 @@ vi.mock('braintrust', () => ({
|
|||
|
||||
vi.mock('../../../llm', () => ({
|
||||
Sonnet4: 'mock-model',
|
||||
GPT5Mini: 'mock-model',
|
||||
}));
|
||||
|
||||
describe('structured-output-strategy', () => {
|
||||
|
@ -86,13 +87,16 @@ describe('structured-output-strategy', () => {
|
|||
});
|
||||
|
||||
const tool = context.tools.testTool as any;
|
||||
expect(mockGenerateObject).toHaveBeenCalledWith({
|
||||
model: 'mock-model',
|
||||
schema: tool?.inputSchema,
|
||||
prompt: expect.stringContaining('Fix these tool arguments'),
|
||||
mode: 'json',
|
||||
providerOptions: expect.any(Object),
|
||||
});
|
||||
expect(mockGenerateObject).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: 'mock-model',
|
||||
schema: tool?.inputSchema,
|
||||
prompt: expect.stringContaining('Fix these tool arguments'),
|
||||
mode: 'json',
|
||||
maxOutputTokens: 10000,
|
||||
providerOptions: expect.objectContaining({}),
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
it('should return null if tool not found', async () => {
|
||||
|
|
|
@ -21,18 +21,120 @@ const UpdateMessageEntriesSchema = z.object({
|
|||
|
||||
export type UpdateMessageEntriesParams = z.infer<typeof UpdateMessageEntriesSchema>;
|
||||
|
||||
// Simple in-memory queue for each messageId
|
||||
const updateQueues = new Map<string, Promise<{ success: boolean }>>();
|
||||
type Deferred<T> = {
|
||||
promise: Promise<T>;
|
||||
resolve: (value: T | PromiseLike<T>) => void;
|
||||
reject: (reason?: unknown) => void;
|
||||
};
|
||||
|
||||
function createDeferred<T>(): Deferred<T> {
|
||||
let resolve!: (value: T | PromiseLike<T>) => void;
|
||||
let reject!: (reason?: unknown) => void;
|
||||
|
||||
const promise = new Promise<T>((res, rej) => {
|
||||
resolve = res;
|
||||
reject = rej;
|
||||
});
|
||||
|
||||
promise.catch(() => undefined);
|
||||
|
||||
return { promise, resolve, reject };
|
||||
}
|
||||
|
||||
type MessageUpdateQueueState = {
|
||||
tailPromise: Promise<void>;
|
||||
nextSequence: number;
|
||||
pending: Map<number, Deferred<void>>;
|
||||
lastCompletedSequence: number;
|
||||
finalSequence?: number;
|
||||
closed: boolean;
|
||||
};
|
||||
|
||||
const updateQueues = new Map<string, MessageUpdateQueueState>();
|
||||
|
||||
function getOrCreateQueueState(messageId: string): MessageUpdateQueueState {
|
||||
const existing = updateQueues.get(messageId);
|
||||
if (existing) {
|
||||
return existing;
|
||||
}
|
||||
|
||||
const initialState: MessageUpdateQueueState = {
|
||||
tailPromise: Promise.resolve(),
|
||||
nextSequence: 0,
|
||||
pending: new Map(),
|
||||
lastCompletedSequence: -1,
|
||||
closed: false,
|
||||
};
|
||||
|
||||
updateQueues.set(messageId, initialState);
|
||||
return initialState;
|
||||
}
|
||||
|
||||
function cleanupQueueIfIdle(messageId: string, state: MessageUpdateQueueState): void {
|
||||
if (
|
||||
state.closed &&
|
||||
state.finalSequence !== undefined &&
|
||||
state.lastCompletedSequence >= state.finalSequence &&
|
||||
state.pending.size === 0
|
||||
) {
|
||||
updateQueues.delete(messageId);
|
||||
}
|
||||
}
|
||||
|
||||
export function isMessageUpdateQueueClosed(messageId: string): boolean {
|
||||
const queue = updateQueues.get(messageId);
|
||||
return queue?.closed ?? false;
|
||||
}
|
||||
|
||||
type WaitForPendingUpdateOptions = {
|
||||
upToSequence?: number;
|
||||
};
|
||||
|
||||
/**
|
||||
* Wait for all pending updates for a given messageId to complete.
|
||||
* This ensures all queued updates are flushed to the database before proceeding.
|
||||
* Wait for pending updates for a given messageId to complete.
|
||||
* Optionally provide a sequence number to wait through.
|
||||
*/
|
||||
export async function waitForPendingUpdates(messageId: string): Promise<void> {
|
||||
const pendingQueue = updateQueues.get(messageId);
|
||||
if (pendingQueue) {
|
||||
await pendingQueue;
|
||||
export async function waitForPendingUpdates(
|
||||
messageId: string,
|
||||
options?: WaitForPendingUpdateOptions
|
||||
): Promise<void> {
|
||||
const queue = updateQueues.get(messageId);
|
||||
if (!queue) {
|
||||
return;
|
||||
}
|
||||
|
||||
const targetSequence = options?.upToSequence ?? queue.finalSequence;
|
||||
|
||||
if (targetSequence === undefined) {
|
||||
await queue.tailPromise;
|
||||
cleanupQueueIfIdle(messageId, queue);
|
||||
return;
|
||||
}
|
||||
|
||||
const maxKnownSequence = queue.nextSequence - 1;
|
||||
const effectiveTarget = Math.min(targetSequence, maxKnownSequence);
|
||||
|
||||
if (effectiveTarget <= queue.lastCompletedSequence) {
|
||||
cleanupQueueIfIdle(messageId, queue);
|
||||
return;
|
||||
}
|
||||
|
||||
const waits: Promise<unknown>[] = [];
|
||||
|
||||
for (let sequence = queue.lastCompletedSequence + 1; sequence <= effectiveTarget; sequence += 1) {
|
||||
const deferred = queue.pending.get(sequence);
|
||||
if (deferred) {
|
||||
waits.push(deferred.promise.catch(() => undefined));
|
||||
}
|
||||
}
|
||||
|
||||
if (waits.length > 0) {
|
||||
await Promise.all(waits);
|
||||
} else {
|
||||
await queue.tailPromise;
|
||||
}
|
||||
|
||||
cleanupQueueIfIdle(messageId, queue);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -116,29 +218,79 @@ async function performUpdate({
|
|||
* - reasoningMessages: upsert by 'id' field, maintaining order
|
||||
* - rawLlmMessages: upsert by combination of 'role' and 'toolCallId', maintaining order
|
||||
*/
|
||||
type UpdateMessageEntriesOptions = {
|
||||
isFinal?: boolean;
|
||||
};
|
||||
|
||||
type UpdateMessageEntriesResult = {
|
||||
success: boolean;
|
||||
sequenceNumber: number;
|
||||
skipped?: boolean;
|
||||
};
|
||||
|
||||
export async function updateMessageEntries(
|
||||
params: UpdateMessageEntriesParams
|
||||
): Promise<{ success: boolean }> {
|
||||
params: UpdateMessageEntriesParams,
|
||||
options?: UpdateMessageEntriesOptions
|
||||
): Promise<UpdateMessageEntriesResult> {
|
||||
const { messageId } = params;
|
||||
|
||||
// Get the current promise for this messageId, or use a resolved promise as the starting point
|
||||
const currentQueue = updateQueues.get(messageId) ?? Promise.resolve({ success: true });
|
||||
const queue = getOrCreateQueueState(messageId);
|
||||
|
||||
// Chain the new update to run after the current queue completes
|
||||
const newQueue = currentQueue
|
||||
.then(() => performUpdate(params))
|
||||
.catch(() => performUpdate(params)); // Still try to run even if previous failed
|
||||
if (queue.closed) {
|
||||
const lastKnownSequence = queue.finalSequence ?? queue.nextSequence - 1;
|
||||
return {
|
||||
success: false,
|
||||
sequenceNumber: lastKnownSequence >= 0 ? lastKnownSequence : -1,
|
||||
skipped: true,
|
||||
};
|
||||
}
|
||||
|
||||
// Update the queue for this messageId
|
||||
updateQueues.set(messageId, newQueue);
|
||||
const isFinal = options?.isFinal ?? false;
|
||||
|
||||
// Clean up the queue entry once this update completes
|
||||
newQueue.finally(() => {
|
||||
// Only remove if this is still the current queue
|
||||
if (updateQueues.get(messageId) === newQueue) {
|
||||
updateQueues.delete(messageId);
|
||||
if (isFinal) {
|
||||
queue.closed = true;
|
||||
}
|
||||
|
||||
const sequenceNumber = queue.nextSequence;
|
||||
queue.nextSequence += 1;
|
||||
|
||||
const deferred = createDeferred<void>();
|
||||
queue.pending.set(sequenceNumber, deferred);
|
||||
|
||||
const runUpdate = () => performUpdate(params);
|
||||
|
||||
const runPromise = queue.tailPromise.then(runUpdate, runUpdate);
|
||||
|
||||
queue.tailPromise = runPromise.then(
|
||||
() => undefined,
|
||||
() => undefined
|
||||
);
|
||||
|
||||
const finalize = (success: boolean) => {
|
||||
queue.pending.delete(sequenceNumber);
|
||||
queue.lastCompletedSequence = Math.max(queue.lastCompletedSequence, sequenceNumber);
|
||||
if (isFinal) {
|
||||
queue.finalSequence = sequenceNumber;
|
||||
}
|
||||
});
|
||||
cleanupQueueIfIdle(messageId, queue);
|
||||
return success;
|
||||
};
|
||||
|
||||
return newQueue;
|
||||
const resultPromise = runPromise
|
||||
.then((result) => {
|
||||
deferred.resolve();
|
||||
finalize(true);
|
||||
return {
|
||||
...result,
|
||||
sequenceNumber,
|
||||
skipped: false as const,
|
||||
};
|
||||
})
|
||||
.catch((error) => {
|
||||
deferred.reject(error);
|
||||
finalize(false);
|
||||
throw error;
|
||||
});
|
||||
|
||||
return resultPromise;
|
||||
}
|
||||
|
|
|
@ -31,18 +31,120 @@ type VersionHistoryEntry = {
|
|||
|
||||
type VersionHistory = Record<string, VersionHistoryEntry>;
|
||||
|
||||
// Simple in-memory queue for each reportId
|
||||
const updateQueues = new Map<string, Promise<void>>();
|
||||
type Deferred<T> = {
|
||||
promise: Promise<T>;
|
||||
resolve: (value: T | PromiseLike<T>) => void;
|
||||
reject: (reason?: unknown) => void;
|
||||
};
|
||||
|
||||
function createDeferred<T>(): Deferred<T> {
|
||||
let resolve!: (value: T | PromiseLike<T>) => void;
|
||||
let reject!: (reason?: unknown) => void;
|
||||
|
||||
const promise = new Promise<T>((res, rej) => {
|
||||
resolve = res;
|
||||
reject = rej;
|
||||
});
|
||||
|
||||
promise.catch(() => undefined);
|
||||
|
||||
return { promise, resolve, reject };
|
||||
}
|
||||
|
||||
type ReportUpdateQueueState = {
|
||||
tailPromise: Promise<void>;
|
||||
nextSequence: number;
|
||||
pending: Map<number, Deferred<void>>;
|
||||
lastCompletedSequence: number;
|
||||
finalSequence?: number;
|
||||
closed: boolean;
|
||||
};
|
||||
|
||||
const updateQueues = new Map<string, ReportUpdateQueueState>();
|
||||
|
||||
function getOrCreateQueueState(reportId: string): ReportUpdateQueueState {
|
||||
const existing = updateQueues.get(reportId);
|
||||
if (existing) {
|
||||
return existing;
|
||||
}
|
||||
|
||||
const initialState: ReportUpdateQueueState = {
|
||||
tailPromise: Promise.resolve(),
|
||||
nextSequence: 0,
|
||||
pending: new Map(),
|
||||
lastCompletedSequence: -1,
|
||||
closed: false,
|
||||
};
|
||||
|
||||
updateQueues.set(reportId, initialState);
|
||||
return initialState;
|
||||
}
|
||||
|
||||
function cleanupQueueIfIdle(reportId: string, state: ReportUpdateQueueState): void {
|
||||
if (
|
||||
state.closed &&
|
||||
state.finalSequence !== undefined &&
|
||||
state.lastCompletedSequence >= state.finalSequence &&
|
||||
state.pending.size === 0
|
||||
) {
|
||||
updateQueues.delete(reportId);
|
||||
}
|
||||
}
|
||||
|
||||
export function isReportUpdateQueueClosed(reportId: string): boolean {
|
||||
const queue = updateQueues.get(reportId);
|
||||
return queue?.closed ?? false;
|
||||
}
|
||||
|
||||
type WaitForPendingReportUpdateOptions = {
|
||||
upToSequence?: number;
|
||||
};
|
||||
|
||||
/**
|
||||
* Wait for all pending updates for a given reportId to complete.
|
||||
* This ensures all queued updates are flushed to the database before proceeding.
|
||||
*/
|
||||
export async function waitForPendingReportUpdates(reportId: string): Promise<void> {
|
||||
const pendingQueue = updateQueues.get(reportId);
|
||||
if (pendingQueue) {
|
||||
await pendingQueue;
|
||||
export async function waitForPendingReportUpdates(
|
||||
reportId: string,
|
||||
options?: WaitForPendingReportUpdateOptions
|
||||
): Promise<void> {
|
||||
const queue = updateQueues.get(reportId);
|
||||
if (!queue) {
|
||||
return;
|
||||
}
|
||||
|
||||
const targetSequence = options?.upToSequence ?? queue.finalSequence;
|
||||
|
||||
if (targetSequence === undefined) {
|
||||
await queue.tailPromise;
|
||||
cleanupQueueIfIdle(reportId, queue);
|
||||
return;
|
||||
}
|
||||
|
||||
const maxKnownSequence = queue.nextSequence - 1;
|
||||
const effectiveTarget = Math.min(targetSequence, maxKnownSequence);
|
||||
|
||||
if (effectiveTarget <= queue.lastCompletedSequence) {
|
||||
cleanupQueueIfIdle(reportId, queue);
|
||||
return;
|
||||
}
|
||||
|
||||
const waits: Promise<unknown>[] = [];
|
||||
|
||||
for (let sequence = queue.lastCompletedSequence + 1; sequence <= effectiveTarget; sequence += 1) {
|
||||
const deferred = queue.pending.get(sequence);
|
||||
if (deferred) {
|
||||
waits.push(deferred.promise.catch(() => undefined));
|
||||
}
|
||||
}
|
||||
|
||||
if (waits.length > 0) {
|
||||
await Promise.all(waits);
|
||||
} else {
|
||||
await queue.tailPromise;
|
||||
}
|
||||
|
||||
cleanupQueueIfIdle(reportId, queue);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -93,27 +195,74 @@ async function performUpdate(params: BatchUpdateReportInput): Promise<void> {
|
|||
* Updates a report's content, name, and version history in a single operation.
|
||||
* Updates are queued per reportId to ensure they execute in order.
|
||||
*/
|
||||
export const updateReportWithVersion = async (params: BatchUpdateReportInput): Promise<void> => {
|
||||
const { reportId } = params;
|
||||
|
||||
// Get the current promise for this reportId, or use a resolved promise as the starting point
|
||||
const currentQueue = updateQueues.get(reportId) ?? Promise.resolve();
|
||||
|
||||
// Chain the new update to run after the current queue completes
|
||||
const newQueue = currentQueue
|
||||
.then(() => performUpdate(params))
|
||||
.catch(() => performUpdate(params)); // Still try to run even if previous failed
|
||||
|
||||
// Update the queue for this reportId
|
||||
updateQueues.set(reportId, newQueue);
|
||||
|
||||
// Clean up the queue entry once this update completes
|
||||
newQueue.finally(() => {
|
||||
// Only remove if this is still the current queue
|
||||
if (updateQueues.get(reportId) === newQueue) {
|
||||
updateQueues.delete(reportId);
|
||||
}
|
||||
});
|
||||
|
||||
return newQueue;
|
||||
type UpdateReportWithVersionOptions = {
|
||||
isFinal?: boolean;
|
||||
};
|
||||
|
||||
type UpdateReportWithVersionResult = {
|
||||
sequenceNumber: number;
|
||||
skipped?: boolean;
|
||||
};
|
||||
|
||||
export const updateReportWithVersion = async (
|
||||
params: BatchUpdateReportInput,
|
||||
options?: UpdateReportWithVersionOptions
|
||||
): Promise<UpdateReportWithVersionResult> => {
|
||||
const { reportId } = params;
|
||||
const queue = getOrCreateQueueState(reportId);
|
||||
|
||||
if (queue.closed) {
|
||||
const lastKnownSequence = queue.finalSequence ?? queue.nextSequence - 1;
|
||||
return {
|
||||
sequenceNumber: lastKnownSequence >= 0 ? lastKnownSequence : -1,
|
||||
skipped: true,
|
||||
};
|
||||
}
|
||||
|
||||
const isFinal = options?.isFinal ?? false;
|
||||
|
||||
if (isFinal) {
|
||||
queue.closed = true;
|
||||
}
|
||||
|
||||
const sequenceNumber = queue.nextSequence;
|
||||
queue.nextSequence += 1;
|
||||
|
||||
const deferred = createDeferred<void>();
|
||||
queue.pending.set(sequenceNumber, deferred);
|
||||
|
||||
const runUpdate = () => performUpdate(params);
|
||||
|
||||
const runPromise = queue.tailPromise.then(runUpdate, runUpdate);
|
||||
|
||||
queue.tailPromise = runPromise.then(
|
||||
() => undefined,
|
||||
() => undefined
|
||||
);
|
||||
|
||||
const finalize = () => {
|
||||
queue.pending.delete(sequenceNumber);
|
||||
queue.lastCompletedSequence = Math.max(queue.lastCompletedSequence, sequenceNumber);
|
||||
if (isFinal) {
|
||||
queue.finalSequence = sequenceNumber;
|
||||
}
|
||||
cleanupQueueIfIdle(reportId, queue);
|
||||
};
|
||||
|
||||
const resultPromise = runPromise
|
||||
.then(() => {
|
||||
deferred.resolve();
|
||||
finalize();
|
||||
return {
|
||||
sequenceNumber,
|
||||
skipped: false as const,
|
||||
};
|
||||
})
|
||||
.catch((error) => {
|
||||
deferred.reject(error);
|
||||
finalize();
|
||||
throw error;
|
||||
});
|
||||
|
||||
return resultPromise;
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue