diff --git a/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs b/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs index 0ab1afbc4..180be42bb 100644 --- a/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs +++ b/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use anyhow::Result; use serde::Serialize; use uuid::Uuid; @@ -5,6 +7,7 @@ use uuid::Uuid; use crate::utils::clients::ai::litellm::{Message, MessageProgress, ToolCall}; use crate::utils::tools::file_tools::search_data_catalog::SearchDataCatalogOutput; +use crate::utils::tools::file_tools::search_files::SearchFilesOutput; #[derive(Debug, Serialize)] #[serde(untagged)] @@ -48,7 +51,7 @@ pub struct BusterThoughtPill { } pub fn transform_message(message: Message) -> Result { - match message { + let buster_message = match message { Message::Assistant { id, content, @@ -63,13 +66,6 @@ pub fn transform_message(message: Message) -> Result { if let (Some(name), Some(tool_calls)) = (name, tool_calls) { return transform_assistant_tool_message(id, name, tool_calls, progress); } - - Ok(BusterThreadMessage::ChatMessage(BusterChatMessage { - id, - message_type: "text".to_string(), - message: None, - message_chunk: Some(content), - })) } Message::Tool { id, @@ -78,21 +74,12 @@ pub fn transform_message(message: Message) -> Result { name, progress, } => { - if let (Some(name), Some(content)) = (name, content) { + if let Some(name) = name { return transform_tool_message(id, name, content, progress); } - - Ok(BusterThreadMessage::Thought(BusterThought { - id: tool_call_id.clone(), - thought_type: "text".to_string(), - thought_title: "".to_string(), - thought_secondary_title: "".to_string(), - thought_pills: None, - status: "".to_string(), - })) } _ => Err(anyhow::anyhow!("Unsupported message type")), - } + }; } fn transform_text_message( @@ -151,11 +138,11 @@ fn transform_assistant_tool_message( progress: Option, ) -> Result { match name.as_str() { - "data_catalog_search" => assistant_data_catalog_search(id, content, progress), - "stored_values_search" => assistant_stored_values_search(id, content, progress), - "file_search" => assistant_file_search(id, content, progress), - "create_file" => assistant_create_file(id, content, progress), - "modify_file" => assistant_modify_file(id, content, progress), + "data_catalog_search" => assistant_data_catalog_search(id, progress), + "stored_values_search" => assistant_stored_values_search(id, progress), + "file_search" => assistant_file_search(id, progress), + "create_file" => assistant_create_file(id, "".to_string(), progress), + "modify_file" => assistant_modify_file(id, "".to_string(), progress), } } @@ -204,6 +191,12 @@ fn tool_data_catalog_search( } }; + let duration = (data_catalog_result.duration.clone() as f64 / 1000.0 * 10.0).round() / 10.0; + + let result_count = data_catalog_result.results.len(); + + let query_params = data_catalog_result.query_params.clone(); + let thought_pill_containters = match proccess_data_catalog_search_results(data_catalog_result) { Ok(object) => object, @@ -215,10 +208,6 @@ fn tool_data_catalog_search( } }; - let duration = (data_catalog_result.duration as f64 / 1000.0 * 10.0).round() / 10.0; - - let result_count = data_catalog_result.results.len(); - let buster_thought = if result_count > 0 { BusterThreadMessage::Thought(BusterThought { id: id.unwrap_or_else(|| Uuid::new_v4().to_string()), @@ -234,7 +223,17 @@ fn tool_data_catalog_search( thought_type: "thought".to_string(), thought_title: "No data catalog items found".to_string(), thought_secondary_title: format!("{} seconds", duration), - thought_pills: vec![], + thought_pills: Some(vec![BusterThoughtPillContainer { + title: "No results found".to_string(), + thought_pills: query_params + .iter() + .map(|param| BusterThoughtPill { + id: "".to_string(), + text: param.clone(), + thought_file_type: "empty".to_string(), + }) + .collect(), + }]), status: "completed".to_string(), }) }; @@ -257,55 +256,41 @@ fn proccess_data_catalog_search_results( ) -> Result> { if results.results.is_empty() { return Ok(vec![BusterThoughtPillContainer { - title: "No datasets found".to_string(), + title: "No results found".to_string(), thought_pills: vec![], }]); } - let mut dataset_results = vec![]; - let mut terms_results = vec![]; - let mut verified_metrics = vec![]; + let mut file_results: HashMap> = HashMap::new(); for result in results.results { - match result.name.as_str() { - "dataset" => dataset_results.push(BusterThoughtPill { + file_results + .entry(result.name.clone()) + .or_insert_with(Vec::new) + .push(BusterThoughtPill { id: result.id.to_string(), - text: result.name, - thought_file_type: "dataset".to_string(), - }), - "term" => terms_results.push(BusterThoughtPill { - id: result.id.to_string(), - text: result.name, - thought_file_type: "term".to_string(), - }), - "verified_metric" => verified_metrics.push(BusterThoughtPill { - id: result.id.to_string(), - text: result.name, - thought_file_type: "verified_metric".to_string(), - }), - _ => (), - } + text: result.name.clone(), + thought_file_type: result.name, + }); } - let dataset_count = dataset_results.len(); - let term_count = terms_results.len(); - let verified_metric_count = verified_metrics.len(); + let buster_thought_pill_containers = file_results + .into_iter() + .map(|(title, thought_pills)| { + let count = thought_pills.len(); + BusterThoughtPillContainer { + title: format!( + "{count} {} found", + title.chars().next().unwrap().to_uppercase().to_string() + &title[1..] + ), + thought_pills, + } + }) + .collect(); - Ok(vec![ - BusterThoughtPillContainer { - title: format!("Datasets ({})", dataset_count), - thought_pills: dataset_results, - }, - BusterThoughtPillContainer { - title: format!("Terms ({})", term_count), - thought_pills: terms_results, - }, - BusterThoughtPillContainer { - title: format!("Verified Metrics ({})", verified_metric_count), - thought_pills: verified_metrics, - }, - ]) + Ok(buster_thought_pill_containers) } + fn assistant_stored_values_search( id: Option, progress: Option, @@ -315,7 +300,7 @@ fn assistant_stored_values_search( MessageProgress::InProgress => Ok(BusterThreadMessage::Thought(BusterThought { id: id.unwrap_or_else(|| Uuid::new_v4().to_string()), thought_type: "thought".to_string(), - thought_title: "Searching your stored values...".to_string(), + thought_title: "Searching for relevant values...".to_string(), thought_secondary_title: "".to_string(), thought_pills: None, status: "loading".to_string(), @@ -331,6 +316,7 @@ fn assistant_stored_values_search( } } +// TODO: Implmentation for stored values search. fn tool_stored_values_search( id: Option, content: String, @@ -366,7 +352,7 @@ fn assistant_file_search( MessageProgress::InProgress => Ok(BusterThreadMessage::Thought(BusterThought { id: id.unwrap_or_else(|| Uuid::new_v4().to_string()), thought_type: "thought".to_string(), - thought_title: "Searching your files...".to_string(), + thought_title: "Searching across your assets...".to_string(), thought_secondary_title: "".to_string(), thought_pills: None, status: "loading".to_string(), @@ -385,14 +371,99 @@ fn tool_file_search( content: String, progress: Option, ) -> Result { - Ok(BusterThreadMessage::Thought(BusterThought { - id: id.unwrap_or_else(|| Uuid::new_v4().to_string()), - thought_type: "thought".to_string(), - thought_title: "".to_string(), - thought_secondary_title: "".to_string(), - thought_pills: None, - status: "completed".to_string(), - })) + if let Some(progress) = progress { + let file_search_result = match serde_json::from_str::(&content) { + Ok(result) => result, + Err(e) => { + return Err(anyhow::anyhow!("Failed to parse file search result: {}", e)); + } + }; + + let query_params = file_search_result.query_params.clone(); + let duration = (file_search_result.duration.clone() as f64 / 1000.0 * 10.0).round() / 10.0; + let result_count = file_search_result.files.len(); + + let thought_pill_containers = match process_file_search_results(file_search_result) { + Ok(containers) => containers, + Err(e) => { + return Err(anyhow::anyhow!( + "Failed to process file search results: {}", + e + )); + } + }; + + let buster_thought = if result_count > 0 { + BusterThreadMessage::Thought(BusterThought { + id: id.unwrap_or_else(|| Uuid::new_v4().to_string()), + thought_type: "thought".to_string(), + thought_title: format!("Found {} assets", result_count), + thought_secondary_title: format!("{} seconds", duration), + thought_pills: Some(thought_pill_containers), + status: "completed".to_string(), + }) + } else { + BusterThreadMessage::Thought(BusterThought { + id: id.unwrap_or_else(|| Uuid::new_v4().to_string()), + thought_type: "thought".to_string(), + thought_title: "No assets found".to_string(), + thought_secondary_title: format!("{} seconds", duration), + thought_pills: Some(vec![BusterThoughtPillContainer { + title: "No assets found".to_string(), + thought_pills: query_params + .iter() + .map(|param| BusterThoughtPill { + id: "".to_string(), + text: param.clone(), + thought_file_type: "empty".to_string(), + }) + .collect(), + }]), + status: "completed".to_string(), + }) + }; + + match progress { + MessageProgress::Complete => Ok(buster_thought), + _ => Err(anyhow::anyhow!("Tool file search only supports complete.")), + } + } else { + Err(anyhow::anyhow!("Tool file search requires progress.")) + } +} + +fn process_file_search_results( + results: SearchFilesOutput, +) -> Result> { + if results.files.is_empty() { + return Ok(vec![BusterThoughtPillContainer { + title: "No assets found".to_string(), + thought_pills: vec![], + }]); + } + + let mut file_results: HashMap> = HashMap::new(); + + for result in results.files { + file_results + .entry(result.file_type.clone()) + .or_insert_with(Vec::new) + .push(BusterThoughtPill { + id: result.id.to_string(), + text: result.name, + thought_file_type: result.file_type, + }); + } + + let buster_thought_pill_containers = file_results + .into_iter() + .map(|(title, thought_pills)| BusterThoughtPillContainer { + title: title.chars().next().unwrap().to_uppercase().to_string() + &title[1..], + thought_pills, + }) + .collect(); + + Ok(buster_thought_pill_containers) } fn assistant_open_file( diff --git a/api/src/utils/tools/file_tools/search_data_catalog.rs b/api/src/utils/tools/file_tools/search_data_catalog.rs index bedf068c1..971f6242a 100644 --- a/api/src/utils/tools/file_tools/search_data_catalog.rs +++ b/api/src/utils/tools/file_tools/search_data_catalog.rs @@ -25,7 +25,7 @@ struct SearchDataCatalogParams { query_params: Vec, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] pub struct SearchDataCatalogOutput { pub message: String, pub query_params: Vec, diff --git a/api/src/utils/tools/file_tools/search_files.rs b/api/src/utils/tools/file_tools/search_files.rs index 281bb708f..831ea919d 100644 --- a/api/src/utils/tools/file_tools/search_files.rs +++ b/api/src/utils/tools/file_tools/search_files.rs @@ -29,20 +29,20 @@ struct SearchFilesParams { query_params: Vec, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] pub struct SearchFilesOutput { - message: String, - query_params: Vec, - duration: i64, - files: Vec, + pub message: String, + pub query_params: Vec, + pub duration: i64, + pub files: Vec, } #[derive(Debug, Serialize, Deserialize)] -struct FileSearchResult { - id: Uuid, - name: String, - file_type: String, - updated_at: DateTime, +pub struct FileSearchResult { + pub id: Uuid, + pub name: String, + pub file_type: String, + pub updated_at: DateTime, } const FILE_SEARCH_PROMPT: &str = r#"