good so far

This commit is contained in:
dal 2025-04-11 10:52:41 -06:00
parent 96a5e54354
commit cc1ca5a34f
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
3 changed files with 43 additions and 17 deletions

View File

@ -269,6 +269,13 @@ impl BusterMultiAgent {
&self,
thread: &mut AgentThread,
) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> {
self.get_agent()
.set_state_value(
"user_prompt".to_string(),
Value::String(self.get_latest_user_message(thread).unwrap_or_default()),
)
.await;
// Start processing (prompt is handled dynamically within process_thread_with_depth)
let rx = self.stream_process_thread(thread).await?;
@ -279,6 +286,20 @@ impl BusterMultiAgent {
pub async fn shutdown(&self) -> Result<()> {
self.get_agent().shutdown().await
}
/// Gets the most recent user message from the agent thread
///
/// This function extracts the latest message with role "user" from the thread's messages.
/// Returns None if no user messages are found.
pub fn get_latest_user_message(&self, thread: &AgentThread) -> Option<String> {
// Iterate through messages in reverse order to find the most recent user message
for message in thread.messages.iter().rev() {
if let AgentMessage::User { content, .. } = message {
return Some(content.clone());
}
}
None
}
}
const INTIALIZATION_PROMPT: &str = r##"### Role & Task

View File

@ -15,7 +15,7 @@ use diesel_async::RunQueryDsl;
use futures::stream::{self, StreamExt};
use litellm::{AgentMessage, ChatCompletionRequest, LiteLLMClient, Metadata, ResponseFormat};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use serde_json::Value;
use tracing::{debug, error, info, warn};
use uuid::Uuid;
@ -237,7 +237,12 @@ impl ToolExecutor for SearchDataCatalogTool {
}
};
match filter_datasets_with_llm(&current_query, ranked, &user_id_clone, &session_id_clone).await {
let user_prompt = match self.agent.get_state_value("user_prompt").await {
Some(Value::String(prompt)) => prompt,
_ => current_query.clone(),
};
match filter_datasets_with_llm(&current_query, &user_prompt, ranked, &user_id_clone, &session_id_clone).await {
Ok(filtered) => Ok(filtered),
Err(e) => {
error!(error = %e, query = current_query, "LLM filtering failed for query");
@ -404,6 +409,7 @@ async fn rerank_datasets(
async fn filter_datasets_with_llm(
query: &str,
user_prompt: &str,
ranked_datasets: Vec<RankedDataset>,
user_id: &Uuid,
session_id: &Uuid,
@ -431,7 +437,7 @@ async fn filter_datasets_with_llm(
.collect::<Vec<_>>();
let prompt = LLM_FILTER_PROMPT
.replace("{user_request}", query)
.replace("{user_request}", user_prompt)
.replace("{query}", query)
.replace(
"{datasets_json}",

View File

@ -757,13 +757,13 @@ pub async fn post_chat_handler(
// Format the final reasoning duration
let formatted_final_reasoning_duration = if final_duration.as_secs() < 60 {
format!("Completed in {} seconds", final_duration.as_secs())
format!("Reasoned for {} seconds", final_duration.as_secs())
} else {
let minutes = final_duration.as_secs() / 60;
if minutes == 1 {
"Completed in 1 minute".to_string() // Singular minute
"Reasoned for 1 minute".to_string() // Singular minute
} else {
format!("Completed in {} min", minutes) // Plural minutes (abbreviated)
format!("Reasoned for {} min", minutes) // Plural minutes (abbreviated)
}
};
@ -1666,7 +1666,7 @@ fn tool_modify_metrics(id: String, content: String, delta_duration: Duration) ->
let buster_file = BusterReasoningMessage::File(BusterReasoningFile {
id,
message_type: "files".to_string(),
title: format!("Modified {} metric files", files_count),
title: if files_count == 1 { "Modified 1 metric file".to_string() } else { format!("Modified {} metric files", files_count) },
secondary_title: format!("{} seconds", delta_duration.as_secs()), // Use delta_duration
status: "completed".to_string(),
file_ids,
@ -1795,7 +1795,7 @@ fn tool_modify_dashboards(id: String, content: String, delta_duration: Duration)
let buster_file = BusterReasoningMessage::File(BusterReasoningFile {
id,
message_type: "files".to_string(),
title: format!("Modified {} dashboard files", files_count),
title: format!("Modified {} dashboard file{}", files_count, if files_count == 1 { "" } else { "s" }),
secondary_title: format!("{} seconds", delta_duration.as_secs()), // Use delta_duration
status: "completed".to_string(),
file_ids,
@ -1827,7 +1827,6 @@ fn tool_data_catalog_search(id: String, content: String, delta_duration: Duratio
// Remove internal duration calculation
// let duration = (data_catalog_result.duration as f64 / 1000.0 * 10.0).round() / 10.0;
let result_count = data_catalog_result.results.len();
let input_queries = data_catalog_result.queries.join(", "); // Join queries for display
let thought_pill_containers = match proccess_data_catalog_search_results(data_catalog_result) {
Ok(containers) => containers,
@ -1841,7 +1840,7 @@ fn tool_data_catalog_search(id: String, content: String, delta_duration: Duratio
BusterReasoningMessage::Pill(BusterReasoningPill {
id: id.clone(),
thought_type: "pills".to_string(),
title: "Data Catalog Search Results".to_string(), // Updated title
title: format!("{} data catalog items found", result_count).to_string(),
secondary_title: format!("{} seconds", delta_duration.as_secs()),
pill_containers: Some(thought_pill_containers),
status: "completed".to_string(),
@ -1924,8 +1923,8 @@ fn transform_assistant_tool_message(
let generating_response_msg = BusterReasoningMessage::Text(BusterReasoningText {
id: Uuid::new_v4().to_string(), // Unique ID for this message
reasoning_type: "text".to_string(),
title: "Generating final response...".to_string(),
secondary_title: format!("{} seconds", last_reasoning_completion_time.elapsed().as_secs()), // Use Delta for *this* message
title: "Finished reasoning".to_string(),
secondary_title: "".to_string(), // Use Delta for *this* message
message: None,
message_chunk: None,
status: Some("completed".to_string()),
@ -2003,8 +2002,8 @@ fn transform_assistant_tool_message(
all_results.push(ToolTransformResult::Reasoning(BusterReasoningMessage::Text(BusterReasoningText {
id: tool_id.clone(),
reasoning_type: "text".to_string(),
title: "Creating Plan".to_string(),
secondary_title: format!("{} seconds", last_reasoning_completion_time.elapsed().as_secs()), // Use Delta
title: "Creating Plan...".to_string(),
secondary_title: "".to_string(), // Use Delta
message: None,
message_chunk: Some(delta),
status: Some("loading".to_string()),
@ -2019,7 +2018,7 @@ fn transform_assistant_tool_message(
all_results.push(ToolTransformResult::Reasoning(BusterReasoningMessage::Text(BusterReasoningText {
id: tool_id.clone(),
reasoning_type: "text".to_string(),
title: "Creating Plan".to_string(),
title: "Created a plan".to_string(),
secondary_title: format!("{} seconds", last_reasoning_completion_time.elapsed().as_secs()), // Use Delta
message: Some(final_text), // Final text
message_chunk: None,
@ -2041,7 +2040,7 @@ fn transform_assistant_tool_message(
id: tool_id.clone(),
reasoning_type: "text".to_string(),
title: "Searching your data catalog...".to_string(),
secondary_title: format!("{} seconds", last_reasoning_completion_time.elapsed().as_secs()), // Use Delta
secondary_title: "".to_string(),
message: None,
message_chunk: None,
status: Some("loading".to_string()),
@ -2711,7 +2710,7 @@ fn generate_file_response_values(filtered_files: &[CompletedFileInfo]) -> Vec<Va
filter_version_id: None,
metadata: Some(vec![BusterChatResponseFileMetadata {
status: "completed".to_string(),
message: "Generated by Buster".to_string(),
message: format!("Created new {} file.", file_info.file_type),
timestamp: Some(Utc::now().timestamp()),
}]),
};