buster/packages/database/src/queries/messages/update-message-entries.ts

225 lines
8.9 KiB
TypeScript

import type { ModelMessage } from 'ai';
import { type SQL, and, eq, isNull, sql } from 'drizzle-orm';
import { z } from 'zod';
import { db } from '../../connection';
import { messages } from '../../schema';
import { ReasoningMessageSchema, ResponseMessageSchema } from '../../schemas/message-schemas';
const UpdateMessageEntriesSchema = z.object({
messageId: z.string().uuid(),
rawLlmMessages: z.array(z.custom<ModelMessage>()).optional(),
responseMessages: z.array(ResponseMessageSchema).optional(),
reasoningMessages: z.array(ReasoningMessageSchema).optional(),
});
export type UpdateMessageEntriesParams = z.infer<typeof UpdateMessageEntriesSchema>;
/**
* Updates message entries using optimized JSONB merge operations.
* Performs batch upserts for multiple entries in a single database operation.
*
* Upsert logic:
* - responseMessages: upsert by 'id' field
* - reasoningMessages: upsert by 'id' field
* - rawLlmMessages: upsert by combination of 'role' and 'toolCallId' in content array
* (handles both string content and array content with tool calls)
*/
export async function updateMessageEntries({
messageId,
rawLlmMessages,
responseMessages,
reasoningMessages,
}: UpdateMessageEntriesParams): Promise<{ success: boolean }> {
try {
const updates: Record<string, SQL | string> = { updatedAt: new Date().toISOString() };
// Optimized merge for response messages
if (responseMessages?.length) {
const newData = JSON.stringify(responseMessages);
updates.responseMessages = sql`
CASE
WHEN ${messages.responseMessages} IS NULL THEN ${newData}::jsonb
ELSE (
WITH new_map AS (
SELECT jsonb_object_agg(value->>'id', value) AS map
FROM jsonb_array_elements(${newData}::jsonb) AS value
WHERE value->>'id' IS NOT NULL
),
merged AS (
SELECT jsonb_agg(
CASE
WHEN new_map.map ? (existing.value->>'id')
THEN new_map.map->(existing.value->>'id')
ELSE existing.value
END
ORDER BY existing.ordinality
) AS result
FROM jsonb_array_elements(${messages.responseMessages}) WITH ORDINALITY AS existing(value, ordinality)
CROSS JOIN new_map
UNION ALL
SELECT jsonb_agg(new_item.value ORDER BY new_item.ordinality)
FROM jsonb_array_elements(${newData}::jsonb) WITH ORDINALITY AS new_item(value, ordinality)
CROSS JOIN new_map
WHERE NOT EXISTS (
SELECT 1 FROM jsonb_array_elements(${messages.responseMessages}) AS existing
WHERE existing.value->>'id' = new_item.value->>'id'
)
)
SELECT COALESCE(jsonb_agg(value), '[]'::jsonb)
FROM (
SELECT jsonb_array_elements(result) AS value
FROM merged
WHERE result IS NOT NULL
) t
)
END`;
}
// Optimized merge for reasoning messages
if (reasoningMessages?.length) {
const newData = JSON.stringify(reasoningMessages);
updates.reasoning = sql`
CASE
WHEN ${messages.reasoning} IS NULL THEN ${newData}::jsonb
ELSE (
WITH new_map AS (
SELECT jsonb_object_agg(value->>'id', value) AS map
FROM jsonb_array_elements(${newData}::jsonb) AS value
WHERE value->>'id' IS NOT NULL
),
merged AS (
SELECT jsonb_agg(
CASE
WHEN new_map.map ? (existing.value->>'id')
THEN new_map.map->(existing.value->>'id')
ELSE existing.value
END
ORDER BY existing.ordinality
) AS result
FROM jsonb_array_elements(${messages.reasoning}) WITH ORDINALITY AS existing(value, ordinality)
CROSS JOIN new_map
UNION ALL
SELECT jsonb_agg(new_item.value ORDER BY new_item.ordinality)
FROM jsonb_array_elements(${newData}::jsonb) WITH ORDINALITY AS new_item(value, ordinality)
CROSS JOIN new_map
WHERE NOT EXISTS (
SELECT 1 FROM jsonb_array_elements(${messages.reasoning}) AS existing
WHERE existing.value->>'id' = new_item.value->>'id'
)
)
SELECT COALESCE(jsonb_agg(value), '[]'::jsonb)
FROM (
SELECT jsonb_array_elements(result) AS value
FROM merged
WHERE result IS NOT NULL
) t
)
END`;
}
// Optimized merge for raw LLM messages - handles both string and array content
if (rawLlmMessages?.length) {
const newData = JSON.stringify(rawLlmMessages);
updates.rawLlmMessages = sql`
CASE
WHEN ${messages.rawLlmMessages} IS NULL THEN ${newData}::jsonb
ELSE (
WITH new_messages AS (
SELECT
value,
ordinality as input_order,
value->>'role' AS role,
COALESCE(
CASE
WHEN jsonb_typeof(value->'content') = 'array' THEN
(SELECT string_agg(c->>'toolCallId', ',' ORDER BY c->>'toolCallId')
FROM jsonb_array_elements(value->'content') c
WHERE c->>'toolCallId' IS NOT NULL)
ELSE NULL
END,
''
) AS tool_calls,
-- Extract toolCallId for tool result messages to find their corresponding call
CASE
WHEN value->>'role' = 'tool' AND jsonb_typeof(value->'content') = 'array' THEN
value->'content'->0->>'toolCallId'
ELSE NULL
END AS result_tool_call_id
FROM jsonb_array_elements(${newData}::jsonb) WITH ORDINALITY AS t(value, ordinality)
),
existing_messages AS (
SELECT
value,
ordinality,
value->>'role' AS role,
COALESCE(
CASE
WHEN jsonb_typeof(value->'content') = 'array' THEN
(SELECT string_agg(c->>'toolCallId', ',' ORDER BY c->>'toolCallId')
FROM jsonb_array_elements(value->'content') c
WHERE c->>'toolCallId' IS NOT NULL)
ELSE NULL
END,
''
) AS tool_calls
FROM jsonb_array_elements(${messages.rawLlmMessages}) WITH ORDINALITY AS t(value, ordinality)
),
-- Find positions of tool calls that are being updated
tool_call_positions AS (
SELECT
n.tool_calls,
MAX(e.ordinality) as call_position
FROM new_messages n
JOIN existing_messages e ON n.role = e.role AND n.tool_calls = e.tool_calls
WHERE n.role = 'assistant' AND n.tool_calls != ''
GROUP BY n.tool_calls
)
SELECT COALESCE(
jsonb_agg(value ORDER BY ord),
'[]'::jsonb
)
FROM (
-- Keep existing messages that aren't being updated
SELECT e.value, e.ordinality AS ord
FROM existing_messages e
WHERE NOT EXISTS (
SELECT 1 FROM new_messages n
WHERE n.role = e.role AND n.tool_calls = e.tool_calls
)
UNION ALL
-- Add new messages with smart ordering
SELECT
n.value,
CASE
-- Tool result: place immediately after its corresponding tool call
WHEN n.role = 'tool' AND n.result_tool_call_id IS NOT NULL THEN
COALESCE(
-- If the tool call was just updated, place result right after it
(SELECT tcp.call_position + 0.5
FROM tool_call_positions tcp
WHERE tcp.tool_calls LIKE '%' || n.result_tool_call_id || '%'
LIMIT 1),
-- Otherwise append at the end
1000000 + n.input_order
)
-- Regular messages and tool calls: append at end
ELSE 1000000 + n.input_order
END AS ord
FROM new_messages n
) combined
)
END`;
}
await db
.update(messages)
.set(updates)
.where(and(eq(messages.id, messageId), isNull(messages.deletedAt)));
return { success: true };
} catch (error) {
console.error('Failed to update message entries:', error);
throw new Error(`Failed to update message entries for message ${messageId}`);
}
}