diff --git a/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs b/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs index 180be42bb..33c55df50 100644 --- a/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs +++ b/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs @@ -6,6 +6,8 @@ use uuid::Uuid; use crate::utils::clients::ai::litellm::{Message, MessageProgress, ToolCall}; +use crate::utils::tools::file_tools::file_types::file::FileEnum; +use crate::utils::tools::file_tools::open_files::OpenFilesOutput; use crate::utils::tools::file_tools::search_data_catalog::SearchDataCatalogOutput; use crate::utils::tools::file_tools::search_files::SearchFilesOutput; @@ -51,7 +53,7 @@ pub struct BusterThoughtPill { } pub fn transform_message(message: Message) -> Result { - let buster_message = match message { + match message { Message::Assistant { id, content, @@ -66,6 +68,8 @@ pub fn transform_message(message: Message) -> Result { if let (Some(name), Some(tool_calls)) = (name, tool_calls) { return transform_assistant_tool_message(id, name, tool_calls, progress); } + + Err(anyhow::anyhow!("Assistant message missing required fields")) } Message::Tool { id, @@ -77,9 +81,11 @@ pub fn transform_message(message: Message) -> Result { if let Some(name) = name { return transform_tool_message(id, name, content, progress); } + + Err(anyhow::anyhow!("Tool message missing name field")) } _ => Err(anyhow::anyhow!("Unsupported message type")), - }; + } } fn transform_text_message( @@ -122,11 +128,12 @@ fn transform_tool_message( progress: Option, ) -> Result { match name.as_str() { - "data_catalog_search" => assistant_data_catalog_search(id, progress), - "stored_values_search" => assistant_stored_values_search(id, progress), - "file_search" => assistant_file_search(id, progress), - "create_file" => assistant_create_file(id, content, progress), - "modify_file" => assistant_modify_file(id, content, progress), + "data_catalog_search" => tool_data_catalog_search(id, content, progress), + "stored_values_search" => tool_stored_values_search(id, content, progress), + "file_search" => tool_file_search(id, content, progress), + "create_file" => tool_create_file(id, content, progress), + "modify_file" => tool_modify_file(id, content, progress), + "open_files" => tool_open_files(id, content, progress), _ => Err(anyhow::anyhow!("Unsupported tool name")), } } @@ -141,8 +148,10 @@ fn transform_assistant_tool_message( "data_catalog_search" => assistant_data_catalog_search(id, progress), "stored_values_search" => assistant_stored_values_search(id, progress), "file_search" => assistant_file_search(id, progress), - "create_file" => assistant_create_file(id, "".to_string(), progress), - "modify_file" => assistant_modify_file(id, "".to_string(), progress), + "create_file" => assistant_create_file(id, progress), + "modify_file" => assistant_modify_file(id, progress), + "open_files" => assistant_open_files(id, progress), + _ => Err(anyhow::anyhow!("Unsupported tool name")), } } @@ -466,7 +475,7 @@ fn process_file_search_results( Ok(buster_thought_pill_containers) } -fn assistant_open_file( +fn assistant_open_files( id: Option, progress: Option, ) -> Result { @@ -489,22 +498,69 @@ fn assistant_open_file( } } -fn tool_open_file( +fn tool_open_files( id: Option, content: String, progress: Option, ) -> Result { - Ok(BusterThreadMessage::ChatMessage(BusterChatMessage { - id: id.unwrap_or_else(|| Uuid::new_v4().to_string()), - message_type: "text".to_string(), - message: Some(content), - message_chunk: None, - })) + if let Some(progress) = progress { + let open_files_result = match serde_json::from_str::(&content) { + Ok(result) => result, + Err(e) => { + return Err(anyhow::anyhow!("Failed to parse open files result: {}", e)); + } + }; + + let duration = (open_files_result.duration as f64 / 1000.0 * 10.0).round() / 10.0; + let result_count = open_files_result.results.len(); + + let mut file_results: HashMap> = HashMap::new(); + + for result in open_files_result.results { + let file_type = match result { + FileEnum::Dashboard(_) => "dashboard", + FileEnum::Metric(_) => "metric", + } + .to_string(); + + file_results + .entry(file_type.clone()) + .or_insert_with(Vec::new) + .push(BusterThoughtPill { + id: Uuid::new_v4().to_string(), + text: open_files_result.message.clone(), + thought_file_type: file_type, + }); + } + + let thought_pill_containers = file_results + .into_iter() + .map(|(title, thought_pills)| BusterThoughtPillContainer { + title: title.chars().next().unwrap().to_uppercase().to_string() + &title[1..], + thought_pills, + }) + .collect::>(); + + let buster_thought = BusterThreadMessage::Thought(BusterThought { + id: id.unwrap_or_else(|| Uuid::new_v4().to_string()), + thought_type: "thought".to_string(), + thought_title: format!("Looked through {} assets", result_count), + thought_secondary_title: format!("{} seconds", duration), + thought_pills: Some(thought_pill_containers), + status: "completed".to_string(), + }); + + match progress { + MessageProgress::Complete => Ok(buster_thought), + _ => Err(anyhow::anyhow!("Tool open file only supports complete.")), + } + } else { + Err(anyhow::anyhow!("Tool open file requires progress.")) + } } fn assistant_create_file( id: Option, - content: String, progress: Option, ) -> Result { if let Some(progress) = progress { @@ -528,7 +584,6 @@ fn assistant_create_file( fn assistant_modify_file( id: Option, - content: String, progress: Option, ) -> Result { if let Some(progress) = progress { @@ -550,6 +605,72 @@ fn assistant_modify_file( } } +fn tool_create_file( + id: Option, + content: String, + progress: Option, +) -> Result { + if let Some(progress) = progress { + let duration = 0.1; // File creation is typically very fast + + let buster_thought = BusterThreadMessage::Thought(BusterThought { + id: id.unwrap_or_else(|| Uuid::new_v4().to_string()), + thought_type: "thought".to_string(), + thought_title: "Created new file".to_string(), + thought_secondary_title: format!("{} seconds", duration), + thought_pills: Some(vec![BusterThoughtPillContainer { + title: "Created".to_string(), + thought_pills: vec![BusterThoughtPill { + id: Uuid::new_v4().to_string(), + text: content, + thought_file_type: "file".to_string(), + }], + }]), + status: "completed".to_string(), + }); + + match progress { + MessageProgress::Complete => Ok(buster_thought), + _ => Err(anyhow::anyhow!("Tool create file only supports complete.")), + } + } else { + Err(anyhow::anyhow!("Tool create file requires progress.")) + } +} + +fn tool_modify_file( + id: Option, + content: String, + progress: Option, +) -> Result { + if let Some(progress) = progress { + let duration = 0.1; // File modification is typically very fast + + let buster_thought = BusterThreadMessage::Thought(BusterThought { + id: id.unwrap_or_else(|| Uuid::new_v4().to_string()), + thought_type: "thought".to_string(), + thought_title: "Modified file".to_string(), + thought_secondary_title: format!("{} seconds", duration), + thought_pills: Some(vec![BusterThoughtPillContainer { + title: "Modified".to_string(), + thought_pills: vec![BusterThoughtPill { + id: Uuid::new_v4().to_string(), + text: content, + thought_file_type: "file".to_string(), + }], + }]), + status: "completed".to_string(), + }); + + match progress { + MessageProgress::Complete => Ok(buster_thought), + _ => Err(anyhow::anyhow!("Tool modify file only supports complete.")), + } + } else { + Err(anyhow::anyhow!("Tool modify file requires progress.")) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/api/src/utils/tools/file_tools/open_files.rs b/api/src/utils/tools/file_tools/open_files.rs index 60c1fe163..b23a92215 100644 --- a/api/src/utils/tools/file_tools/open_files.rs +++ b/api/src/utils/tools/file_tools/open_files.rs @@ -35,11 +35,11 @@ struct OpenFilesParams { files: Vec, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] pub struct OpenFilesOutput { - message: String, - duration: i64, - results: Vec, + pub message: String, + pub duration: i64, + pub results: Vec, } pub struct OpenFilesTool; diff --git a/api/src/utils/tools/file_tools/search_data_catalog.rs b/api/src/utils/tools/file_tools/search_data_catalog.rs index 971f6242a..e2f0c7b77 100644 --- a/api/src/utils/tools/file_tools/search_data_catalog.rs +++ b/api/src/utils/tools/file_tools/search_data_catalog.rs @@ -34,7 +34,7 @@ pub struct SearchDataCatalogOutput { } #[derive(Debug, Serialize, Deserialize)] -struct DatasetSearchResult { +pub struct DatasetSearchResult { pub id: Uuid, pub name: String, pub yml_content: String,