fix the raw llm message save.

This commit is contained in:
dal 2025-03-07 11:01:48 -07:00
parent 47e1558e2e
commit 456f117cd7
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
2 changed files with 69 additions and 47 deletions

View File

@ -521,18 +521,6 @@ diesel::table! {
} }
} }
diesel::table! {
threads (id) {
id -> Uuid,
title -> Text,
organization_id -> Uuid,
created_at -> Timestamptz,
updated_at -> Timestamptz,
deleted_at -> Nullable<Timestamptz>,
created_by -> Uuid,
}
}
diesel::table! { diesel::table! {
threads_deprecated (id) { threads_deprecated (id) {
id -> Uuid, id -> Uuid,
@ -649,8 +637,6 @@ diesel::joinable!(teams_to_users -> users (user_id));
diesel::joinable!(terms -> organizations (organization_id)); diesel::joinable!(terms -> organizations (organization_id));
diesel::joinable!(terms_to_datasets -> datasets (dataset_id)); diesel::joinable!(terms_to_datasets -> datasets (dataset_id));
diesel::joinable!(terms_to_datasets -> terms (term_id)); diesel::joinable!(terms_to_datasets -> terms (term_id));
diesel::joinable!(threads -> organizations (organization_id));
diesel::joinable!(threads -> users (created_by));
diesel::joinable!(threads_deprecated -> organizations (organization_id)); diesel::joinable!(threads_deprecated -> organizations (organization_id));
diesel::joinable!(threads_to_dashboards -> dashboards (dashboard_id)); diesel::joinable!(threads_to_dashboards -> dashboards (dashboard_id));
diesel::joinable!(threads_to_dashboards -> threads_deprecated (thread_id)); diesel::joinable!(threads_to_dashboards -> threads_deprecated (thread_id));
@ -689,7 +675,6 @@ diesel::allow_tables_to_appear_in_same_query!(
teams_to_users, teams_to_users,
terms, terms,
terms_to_datasets, terms_to_datasets,
threads,
threads_deprecated, threads_deprecated,
threads_to_dashboards, threads_to_dashboards,
user_favorites, user_favorites,

View File

@ -200,6 +200,7 @@ pub async fn post_chat_handler(
// Initialize raw_llm_messages with initial_messages // Initialize raw_llm_messages with initial_messages
let mut raw_llm_messages = initial_messages.clone(); let mut raw_llm_messages = initial_messages.clone();
let mut raw_response_message = String::new();
// Initialize the agent thread // Initialize the agent thread
let mut chat = AgentThread::new(Some(chat_id), user.id, initial_messages); let mut chat = AgentThread::new(Some(chat_id), user.id, initial_messages);
@ -235,9 +236,26 @@ pub async fn post_chat_handler(
// Only store completed messages in raw_llm_messages // Only store completed messages in raw_llm_messages
match &msg { match &msg {
AgentMessage::Assistant { progress, .. } => { AgentMessage::Assistant {
progress, content, ..
} => {
if let Some(content) = content {
raw_response_message.push_str(&content);
}
if matches!(progress, MessageProgress::Complete) { if matches!(progress, MessageProgress::Complete) {
raw_llm_messages.push(msg.clone()); if raw_response_message.is_empty() {
raw_llm_messages.push(msg.clone());
} else {
raw_llm_messages.push(AgentMessage::Assistant {
id: None,
content: Some(raw_response_message.clone()),
name: None,
tool_calls: None,
progress: MessageProgress::Complete,
initial: false,
});
}
} }
} }
AgentMessage::Tool { progress, .. } => { AgentMessage::Tool { progress, .. } => {
@ -395,7 +413,8 @@ pub async fn post_chat_handler(
fn prepare_final_message_state(containers: &[BusterContainer]) -> Result<(Vec<Value>, Vec<Value>)> { fn prepare_final_message_state(containers: &[BusterContainer]) -> Result<(Vec<Value>, Vec<Value>)> {
let mut response_messages = Vec::new(); let mut response_messages = Vec::new();
// Use a HashMap to track the latest reasoning message for each ID // Use a HashMap to track the latest reasoning message for each ID
let mut reasoning_map: std::collections::HashMap<String, Value> = std::collections::HashMap::new(); let mut reasoning_map: std::collections::HashMap<String, Value> =
std::collections::HashMap::new();
for container in containers { for container in containers {
match container { match container {
@ -426,7 +445,9 @@ fn prepare_final_message_state(containers: &[BusterContainer]) -> Result<(Vec<Va
let should_include = match &reasoning.reasoning { let should_include = match &reasoning.reasoning {
BusterReasoningMessage::Pill(thought) => thought.status == "completed", BusterReasoningMessage::Pill(thought) => thought.status == "completed",
BusterReasoningMessage::File(file) => file.status == "completed", BusterReasoningMessage::File(file) => file.status == "completed",
BusterReasoningMessage::Text(text) => text.status.as_deref() == Some("completed"), BusterReasoningMessage::Text(text) => {
text.status.as_deref() == Some("completed")
}
}; };
if should_include { if should_include {
@ -703,7 +724,11 @@ pub async fn transform_message(
vec![] vec![]
} }
}; };
containers.extend(chat_messages.into_iter().map(|container| (container, ThreadEvent::GeneratingResponseMessage))); containers.extend(
chat_messages
.into_iter()
.map(|container| (container, ThreadEvent::GeneratingResponseMessage)),
);
// Add the "Finished reasoning" message if we're just starting // Add the "Finished reasoning" message if we're just starting
if initial { if initial {
@ -724,10 +749,7 @@ pub async fn transform_message(
message_id: *message_id, message_id: *message_id,
}); });
containers.push(( containers.push((reasoning_container, ThreadEvent::GeneratingResponseMessage));
reasoning_container,
ThreadEvent::GeneratingResponseMessage,
));
} }
Ok(containers) Ok(containers)
@ -746,7 +768,11 @@ pub async fn transform_message(
for reasoning_container in messages { for reasoning_container in messages {
// Only process file response messages when they're completed // Only process file response messages when they're completed
match &reasoning_container { match &reasoning_container {
BusterReasoningMessage::File(file) if matches!(progress, MessageProgress::Complete) && file.status == "completed" && file.message_type == "files" => { BusterReasoningMessage::File(file)
if matches!(progress, MessageProgress::Complete)
&& file.status == "completed"
&& file.message_type == "files" =>
{
// For each completed file, create and send a file response message // For each completed file, create and send a file response message
for (file_id, file_content) in &file.files { for (file_id, file_content) in &file.files {
let response_message = BusterChatMessage::File { let response_message = BusterChatMessage::File {
@ -767,11 +793,13 @@ pub async fn transform_message(
}; };
containers.push(( containers.push((
BusterContainer::ChatMessage(BusterChatMessageContainer { BusterContainer::ChatMessage(
response_message, BusterChatMessageContainer {
chat_id: *chat_id, response_message,
message_id: *message_id, chat_id: *chat_id,
}), message_id: *message_id,
},
),
ThreadEvent::GeneratingResponseMessage, ThreadEvent::GeneratingResponseMessage,
)); ));
} }
@ -780,11 +808,13 @@ pub async fn transform_message(
} }
containers.push(( containers.push((
BusterContainer::ReasoningMessage(BusterReasoningMessageContainer { BusterContainer::ReasoningMessage(
reasoning: reasoning_container, BusterReasoningMessageContainer {
chat_id: *chat_id, reasoning: reasoning_container,
message_id: *message_id, chat_id: *chat_id,
}), message_id: *message_id,
},
),
ThreadEvent::GeneratingReasoningMessage, ThreadEvent::GeneratingReasoningMessage,
)); ));
} }
@ -822,14 +852,18 @@ pub async fn transform_message(
) { ) {
Ok(messages) => messages Ok(messages) => messages
.into_iter() .into_iter()
.map(|container| ( .map(|container| {
BusterContainer::ReasoningMessage(BusterReasoningMessageContainer { (
reasoning: container, BusterContainer::ReasoningMessage(
chat_id: *chat_id, BusterReasoningMessageContainer {
message_id: *message_id, reasoning: container,
}), chat_id: *chat_id,
ThreadEvent::GeneratingReasoningMessage, message_id: *message_id,
)) },
),
ThreadEvent::GeneratingReasoningMessage,
)
})
.collect(), .collect(),
Err(e) => { Err(e) => {
tracing::warn!("Error transforming tool message '{}': {:?}", name_str, e); tracing::warn!("Error transforming tool message '{}': {:?}", name_str, e);
@ -1164,9 +1198,10 @@ fn transform_assistant_tool_message(
MessageProgress::Complete => { MessageProgress::Complete => {
// For completed files, only send the final state // For completed files, only send the final state
let mut updated_files = std::collections::HashMap::new(); let mut updated_files = std::collections::HashMap::new();
for (file_id, file_content) in file.files.iter() { for (file_id, file_content) in file.files.iter() {
let chunk_id = format!("{}_{}", file.id, file_content.file_name); let chunk_id =
format!("{}_{}", file.id, file_content.file_name);
let complete_text = tracker let complete_text = tracker
.get_complete_text(chunk_id.clone()) .get_complete_text(chunk_id.clone())
.unwrap_or_else(|| { .unwrap_or_else(|| {
@ -1193,10 +1228,12 @@ fn transform_assistant_tool_message(
let mut updated_files = std::collections::HashMap::new(); let mut updated_files = std::collections::HashMap::new();
for (file_id, file_content) in file.files.iter() { for (file_id, file_content) in file.files.iter() {
let chunk_id = format!("{}_{}", file.id, file_content.file_name); let chunk_id =
format!("{}_{}", file.id, file_content.file_name);
if let Some(chunk) = &file_content.file.text_chunk { if let Some(chunk) = &file_content.file.text_chunk {
let delta = tracker.add_chunk(chunk_id.clone(), chunk.clone()); let delta =
tracker.add_chunk(chunk_id.clone(), chunk.clone());
if !delta.is_empty() { if !delta.is_empty() {
let mut updated_content = file_content.clone(); let mut updated_content = file_content.clone();