From 26a939722861e6af0e90b455bc35be424c7e2598 Mon Sep 17 00:00:00 2001 From: dal Date: Fri, 21 Mar 2025 11:36:13 -0600 Subject: [PATCH] associate files correclty. --- .../handlers/src/chats/post_chat_handler.rs | 128 ++++++++++-------- 1 file changed, 74 insertions(+), 54 deletions(-) diff --git a/api/libs/handlers/src/chats/post_chat_handler.rs b/api/libs/handlers/src/chats/post_chat_handler.rs index 66b85d77b..474526f44 100644 --- a/api/libs/handlers/src/chats/post_chat_handler.rs +++ b/api/libs/handlers/src/chats/post_chat_handler.rs @@ -1,12 +1,15 @@ +use dashmap::DashMap; use middleware::AuthenticatedUser; use std::{collections::HashMap, time::Instant}; -use dashmap::DashMap; use agents::{ - tools::{file_tools::{ - common::ModifyFilesOutput, create_dashboards::CreateDashboardFilesOutput, - create_metrics::CreateMetricFilesOutput, search_data_catalog::SearchDataCatalogOutput, - }, planning_tools::CreatePlanOutput}, + tools::{ + file_tools::{ + common::ModifyFilesOutput, create_dashboards::CreateDashboardFilesOutput, + create_metrics::CreateMetricFilesOutput, search_data_catalog::SearchDataCatalogOutput, + }, + planning_tools::CreatePlanOutput, + }, AgentExt, AgentMessage, AgentThread, BusterSuperAgent, }; @@ -81,13 +84,15 @@ impl ChunkTracker { pub fn add_chunk(&self, chunk_id: String, new_chunk: String) -> String { // Compute delta and update in one operation using DashMap let mut delta_to_return = String::new(); - + { - self.chunks.entry(chunk_id.clone()).or_insert_with(|| ChunkState { - complete_text: String::new(), - last_seen_content: String::new(), - }); - + self.chunks + .entry(chunk_id.clone()) + .or_insert_with(|| ChunkState { + complete_text: String::new(), + last_seen_content: String::new(), + }); + // Now that we've initialized the entry if needed, get mutable access to update it if let Some(mut entry) = self.chunks.get_mut(&chunk_id) { // Calculate the delta @@ -107,9 +112,9 @@ impl ChunkTracker { } } }; - + delta_to_return = delta.clone(); - + // Update tracking state only if we found new content if !delta.is_empty() { entry.complete_text.push_str(&delta); @@ -117,12 +122,14 @@ impl ChunkTracker { } } } - + delta_to_return } pub fn get_complete_text(&self, chunk_id: String) -> Option { - self.chunks.get(&chunk_id).map(|state| state.complete_text.clone()) + self.chunks + .get(&chunk_id) + .map(|state| state.complete_text.clone()) } pub fn clear_chunk(&self, chunk_id: String) { @@ -235,12 +242,17 @@ pub async fn post_chat_handler( // Only store completed messages in raw_llm_messages match &msg { AgentMessage::Assistant { - progress, content, id, .. + progress, + content, + id, + .. } => { // Store chunks in the tracker to ensure deduplication if let Some(content_str) = content { // Use message ID as chunk ID, or generate a consistent one if missing - let chunk_id = id.clone().unwrap_or_else(|| "assistant_message".to_string()); + let chunk_id = id + .clone() + .unwrap_or_else(|| "assistant_message".to_string()); // Add to chunk tracker to handle deduplication chunk_tracker.add_chunk(chunk_id.clone(), content_str.clone()); } @@ -248,12 +260,15 @@ pub async fn post_chat_handler( if matches!(progress, MessageProgress::Complete) { if let Some(content_str) = content { // Use message ID as chunk ID, or generate a consistent one if missing - let chunk_id = id.clone().unwrap_or_else(|| "assistant_message".to_string()); - + let chunk_id = id + .clone() + .unwrap_or_else(|| "assistant_message".to_string()); + // Get the complete deduplicated text from the chunk tracker - let complete_text = chunk_tracker.get_complete_text(chunk_id.clone()) + let complete_text = chunk_tracker + .get_complete_text(chunk_id.clone()) .unwrap_or_else(|| content_str.clone()); - + // Create a new message with the deduplicated content raw_llm_messages.push(AgentMessage::Assistant { id: id.clone(), @@ -263,7 +278,7 @@ pub async fn post_chat_handler( progress: MessageProgress::Complete, initial: false, }); - + // Clear the chunk from the tracker chunk_tracker.clear_chunk(chunk_id); } else { @@ -285,7 +300,9 @@ pub async fn post_chat_handler( } // Always transform the message - match transform_message(&chat_id, &message_id, msg, tx.as_ref(), &chunk_tracker).await { + match transform_message(&chat_id, &message_id, msg, tx.as_ref(), &chunk_tracker) + .await + { Ok(containers) => { // Store all transformed containers for (container, _) in containers.clone() { @@ -369,7 +386,7 @@ pub async fn post_chat_handler( final_reasoning_message: format!("Reasoned for {} seconds", reasoning_duration), title: title.title.clone().unwrap_or_default(), raw_llm_messages: serde_json::to_value(&raw_llm_messages)?, - feedback: None + feedback: None, }; let mut conn = get_pg_pool().get().await?; @@ -435,7 +452,8 @@ pub async fn post_chat_handler( fn prepare_final_message_state(containers: &[BusterContainer]) -> Result<(Vec, Vec)> { let mut response_messages = Vec::new(); // Use a Vec to maintain order, with a HashMap to track latest version of each message - let mut reasoning_map: std::collections::HashMap = std::collections::HashMap::new(); + let mut reasoning_map: std::collections::HashMap = + std::collections::HashMap::new(); let mut reasoning_order = Vec::new(); for container in containers { @@ -485,7 +503,7 @@ fn prepare_final_message_state(containers: &[BusterContainer]) -> Result<(Vec>>, tracker: &ChunkTracker, ) -> Result> { - match message { AgentMessage::Assistant { id, @@ -985,10 +1006,10 @@ fn transform_text_message( let complete_text = tracker .get_complete_text(id.clone()) .unwrap_or_else(|| content.clone()); - + // Clear the tracker for this chunk tracker.clear_chunk(id.clone()); - + Ok(vec![BusterChatMessage::Text { id: id.clone(), message: Some(complete_text), @@ -1022,7 +1043,6 @@ fn transform_tool_message( } fn tool_create_plan(id: String, content: String) -> Result> { - let plan_markdown = match serde_json::from_str::(&content) { Ok(result) => result.plan_markdown, Err(e) => { @@ -1046,7 +1066,6 @@ fn tool_create_plan(id: String, content: String) -> Result Result> { - // Parse the CreateMetricFilesOutput from content let create_metrics_result = match serde_json::from_str::(&content) { Ok(result) => result, @@ -1102,7 +1121,6 @@ fn tool_create_metrics(id: String, content: String) -> Result Result> { - // Parse the ModifyFilesOutput from content let modify_metrics_result = match serde_json::from_str::(&content) { Ok(result) => result, @@ -1158,7 +1176,6 @@ fn tool_modify_metrics(id: String, content: String) -> Result Result> { - // Parse the CreateDashboardFilesOutput from content let create_dashboards_result = match serde_json::from_str::(&content) { @@ -1444,7 +1461,8 @@ fn transform_assistant_tool_message( let mut updated_files = std::collections::HashMap::new(); for (file_id, file_content) in file.files.iter() { - let chunk_id = format!("{}_{}", file.id, file_content.file_name); + let chunk_id = + format!("{}_{}", file.id, file_content.file_name); let complete_text = tracker .get_complete_text(chunk_id.clone()) .unwrap_or_else(|| { @@ -1471,10 +1489,12 @@ fn transform_assistant_tool_message( let mut updated_files = std::collections::HashMap::new(); for (file_id, file_content) in file.files.iter() { - let chunk_id = format!("{}_{}", file.id, file_content.file_name); + let chunk_id = + format!("{}_{}", file.id, file_content.file_name); if let Some(chunk) = &file_content.file.text_chunk { - let delta = tracker.add_chunk(chunk_id.clone(), chunk.clone()); + let delta = + tracker.add_chunk(chunk_id.clone(), chunk.clone()); if !delta.is_empty() { let mut updated_content = file_content.clone();