From cc1ca5a34f8e4e31c25bb047eb7dcb100e31b9b7 Mon Sep 17 00:00:00 2001 From: dal Date: Fri, 11 Apr 2025 10:52:41 -0600 Subject: [PATCH] good so far --- .../agents/src/agents/buster_multi_agent.rs | 21 +++++++++++++++ .../file_tools/search_data_catalog.rs | 12 ++++++--- .../handlers/src/chats/post_chat_handler.rs | 27 +++++++++---------- 3 files changed, 43 insertions(+), 17 deletions(-) diff --git a/api/libs/agents/src/agents/buster_multi_agent.rs b/api/libs/agents/src/agents/buster_multi_agent.rs index dff0a585e..796683544 100644 --- a/api/libs/agents/src/agents/buster_multi_agent.rs +++ b/api/libs/agents/src/agents/buster_multi_agent.rs @@ -269,6 +269,13 @@ impl BusterMultiAgent { &self, thread: &mut AgentThread, ) -> Result>> { + self.get_agent() + .set_state_value( + "user_prompt".to_string(), + Value::String(self.get_latest_user_message(thread).unwrap_or_default()), + ) + .await; + // Start processing (prompt is handled dynamically within process_thread_with_depth) let rx = self.stream_process_thread(thread).await?; @@ -279,6 +286,20 @@ impl BusterMultiAgent { pub async fn shutdown(&self) -> Result<()> { self.get_agent().shutdown().await } + + /// Gets the most recent user message from the agent thread + /// + /// This function extracts the latest message with role "user" from the thread's messages. + /// Returns None if no user messages are found. + pub fn get_latest_user_message(&self, thread: &AgentThread) -> Option { + // Iterate through messages in reverse order to find the most recent user message + for message in thread.messages.iter().rev() { + if let AgentMessage::User { content, .. } = message { + return Some(content.clone()); + } + } + None + } } const INTIALIZATION_PROMPT: &str = r##"### Role & Task diff --git a/api/libs/agents/src/tools/categories/file_tools/search_data_catalog.rs b/api/libs/agents/src/tools/categories/file_tools/search_data_catalog.rs index c6036458c..6217f65c7 100644 --- a/api/libs/agents/src/tools/categories/file_tools/search_data_catalog.rs +++ b/api/libs/agents/src/tools/categories/file_tools/search_data_catalog.rs @@ -15,7 +15,7 @@ use diesel_async::RunQueryDsl; use futures::stream::{self, StreamExt}; use litellm::{AgentMessage, ChatCompletionRequest, LiteLLMClient, Metadata, ResponseFormat}; use serde::{Deserialize, Serialize}; -use serde_json::{json, Value}; +use serde_json::Value; use tracing::{debug, error, info, warn}; use uuid::Uuid; @@ -237,7 +237,12 @@ impl ToolExecutor for SearchDataCatalogTool { } }; - match filter_datasets_with_llm(¤t_query, ranked, &user_id_clone, &session_id_clone).await { + let user_prompt = match self.agent.get_state_value("user_prompt").await { + Some(Value::String(prompt)) => prompt, + _ => current_query.clone(), + }; + + match filter_datasets_with_llm(¤t_query, &user_prompt, ranked, &user_id_clone, &session_id_clone).await { Ok(filtered) => Ok(filtered), Err(e) => { error!(error = %e, query = current_query, "LLM filtering failed for query"); @@ -404,6 +409,7 @@ async fn rerank_datasets( async fn filter_datasets_with_llm( query: &str, + user_prompt: &str, ranked_datasets: Vec, user_id: &Uuid, session_id: &Uuid, @@ -431,7 +437,7 @@ async fn filter_datasets_with_llm( .collect::>(); let prompt = LLM_FILTER_PROMPT - .replace("{user_request}", query) + .replace("{user_request}", user_prompt) .replace("{query}", query) .replace( "{datasets_json}", diff --git a/api/libs/handlers/src/chats/post_chat_handler.rs b/api/libs/handlers/src/chats/post_chat_handler.rs index cadddd018..0ee8eaa4e 100644 --- a/api/libs/handlers/src/chats/post_chat_handler.rs +++ b/api/libs/handlers/src/chats/post_chat_handler.rs @@ -757,13 +757,13 @@ pub async fn post_chat_handler( // Format the final reasoning duration let formatted_final_reasoning_duration = if final_duration.as_secs() < 60 { - format!("Completed in {} seconds", final_duration.as_secs()) + format!("Reasoned for {} seconds", final_duration.as_secs()) } else { let minutes = final_duration.as_secs() / 60; if minutes == 1 { - "Completed in 1 minute".to_string() // Singular minute + "Reasoned for 1 minute".to_string() // Singular minute } else { - format!("Completed in {} min", minutes) // Plural minutes (abbreviated) + format!("Reasoned for {} min", minutes) // Plural minutes (abbreviated) } }; @@ -1666,7 +1666,7 @@ fn tool_modify_metrics(id: String, content: String, delta_duration: Duration) -> let buster_file = BusterReasoningMessage::File(BusterReasoningFile { id, message_type: "files".to_string(), - title: format!("Modified {} metric files", files_count), + title: if files_count == 1 { "Modified 1 metric file".to_string() } else { format!("Modified {} metric files", files_count) }, secondary_title: format!("{} seconds", delta_duration.as_secs()), // Use delta_duration status: "completed".to_string(), file_ids, @@ -1795,7 +1795,7 @@ fn tool_modify_dashboards(id: String, content: String, delta_duration: Duration) let buster_file = BusterReasoningMessage::File(BusterReasoningFile { id, message_type: "files".to_string(), - title: format!("Modified {} dashboard files", files_count), + title: format!("Modified {} dashboard file{}", files_count, if files_count == 1 { "" } else { "s" }), secondary_title: format!("{} seconds", delta_duration.as_secs()), // Use delta_duration status: "completed".to_string(), file_ids, @@ -1827,7 +1827,6 @@ fn tool_data_catalog_search(id: String, content: String, delta_duration: Duratio // Remove internal duration calculation // let duration = (data_catalog_result.duration as f64 / 1000.0 * 10.0).round() / 10.0; let result_count = data_catalog_result.results.len(); - let input_queries = data_catalog_result.queries.join(", "); // Join queries for display let thought_pill_containers = match proccess_data_catalog_search_results(data_catalog_result) { Ok(containers) => containers, @@ -1841,7 +1840,7 @@ fn tool_data_catalog_search(id: String, content: String, delta_duration: Duratio BusterReasoningMessage::Pill(BusterReasoningPill { id: id.clone(), thought_type: "pills".to_string(), - title: "Data Catalog Search Results".to_string(), // Updated title + title: format!("{} data catalog items found", result_count).to_string(), secondary_title: format!("{} seconds", delta_duration.as_secs()), pill_containers: Some(thought_pill_containers), status: "completed".to_string(), @@ -1924,8 +1923,8 @@ fn transform_assistant_tool_message( let generating_response_msg = BusterReasoningMessage::Text(BusterReasoningText { id: Uuid::new_v4().to_string(), // Unique ID for this message reasoning_type: "text".to_string(), - title: "Generating final response...".to_string(), - secondary_title: format!("{} seconds", last_reasoning_completion_time.elapsed().as_secs()), // Use Delta for *this* message + title: "Finished reasoning".to_string(), + secondary_title: "".to_string(), // Use Delta for *this* message message: None, message_chunk: None, status: Some("completed".to_string()), @@ -2003,8 +2002,8 @@ fn transform_assistant_tool_message( all_results.push(ToolTransformResult::Reasoning(BusterReasoningMessage::Text(BusterReasoningText { id: tool_id.clone(), reasoning_type: "text".to_string(), - title: "Creating Plan".to_string(), - secondary_title: format!("{} seconds", last_reasoning_completion_time.elapsed().as_secs()), // Use Delta + title: "Creating Plan...".to_string(), + secondary_title: "".to_string(), // Use Delta message: None, message_chunk: Some(delta), status: Some("loading".to_string()), @@ -2019,7 +2018,7 @@ fn transform_assistant_tool_message( all_results.push(ToolTransformResult::Reasoning(BusterReasoningMessage::Text(BusterReasoningText { id: tool_id.clone(), reasoning_type: "text".to_string(), - title: "Creating Plan".to_string(), + title: "Created a plan".to_string(), secondary_title: format!("{} seconds", last_reasoning_completion_time.elapsed().as_secs()), // Use Delta message: Some(final_text), // Final text message_chunk: None, @@ -2041,7 +2040,7 @@ fn transform_assistant_tool_message( id: tool_id.clone(), reasoning_type: "text".to_string(), title: "Searching your data catalog...".to_string(), - secondary_title: format!("{} seconds", last_reasoning_completion_time.elapsed().as_secs()), // Use Delta + secondary_title: "".to_string(), message: None, message_chunk: None, status: Some("loading".to_string()), @@ -2711,7 +2710,7 @@ fn generate_file_response_values(filtered_files: &[CompletedFileInfo]) -> Vec