From c691f904f935823fee8e340c58efda222d6f8d6e Mon Sep 17 00:00:00 2001 From: dal Date: Fri, 11 Apr 2025 13:33:56 -0600 Subject: [PATCH] things feeling pretty good. --- api/libs/agents/src/agent.rs | 521 ++++++++++-------- .../agents/src/agents/buster_multi_agent.rs | 34 +- .../src/tools/categories/file_tools/common.rs | 22 +- api/libs/database/src/helpers/datasets.rs | 22 + api/libs/database/src/helpers/mod.rs | 3 +- .../handlers/src/chats/post_chat_handler.rs | 44 +- 6 files changed, 381 insertions(+), 265 deletions(-) create mode 100644 api/libs/database/src/helpers/datasets.rs diff --git a/api/libs/agents/src/agent.rs b/api/libs/agents/src/agent.rs index 35f367d83..6d44fd0e0 100644 --- a/api/libs/agents/src/agent.rs +++ b/api/libs/agents/src/agent.rs @@ -7,11 +7,11 @@ use litellm::{ }; use once_cell::sync::Lazy; use serde_json::Value; +use std::time::{Duration, Instant}; use std::{collections::HashMap, env, sync::Arc}; use tokio::sync::{broadcast, RwLock}; use tracing::error; use uuid::Uuid; -use std::time::{Duration, Instant}; // Type definition for tool registry to simplify complex type // No longer needed, defined below @@ -19,16 +19,17 @@ use crate::models::AgentThread; // Global BraintrustClient instance static BRAINTRUST_CLIENT: Lazy>> = Lazy::new(|| { - match (std::env::var("BRAINTRUST_API_KEY"), std::env::var("BRAINTRUST_LOGGING_ID")) { - (Ok(_), Ok(buster_logging_id)) => { - match BraintrustClient::new(None, &buster_logging_id) { - Ok(client) => Some(client), - Err(e) => { - eprintln!("Failed to create Braintrust client: {}", e); - None - } + match ( + std::env::var("BRAINTRUST_API_KEY"), + std::env::var("BRAINTRUST_LOGGING_ID"), + ) { + (Ok(_), Ok(buster_logging_id)) => match BraintrustClient::new(None, &buster_logging_id) { + Ok(client) => Some(client), + Err(e) => { + eprintln!("Failed to create Braintrust client: {}", e); + None } - } + }, _ => None, } }); @@ -55,7 +56,6 @@ struct MessageBuffer { first_message_sent: bool, } - impl MessageBuffer { fn new() -> Self { Self { @@ -101,7 +101,11 @@ impl MessageBuffer { // Create and send the message let message = AgentMessage::assistant( self.message_id.clone(), - if self.content.is_empty() { None } else { Some(self.content.clone()) }, + if self.content.is_empty() { + None + } else { + Some(self.content.clone()) + }, tool_calls, MessageProgress::InProgress, Some(!self.first_message_sent), @@ -124,7 +128,6 @@ impl MessageBuffer { } } - // Helper struct to store the tool and its enablement condition struct RegisteredTool { executor: Box + Send + Sync>, @@ -141,7 +144,6 @@ struct DynamicPromptRule { // Update the ToolRegistry type alias is no longer needed, but we need the new type for the map type ToolsMap = Arc>>; - #[derive(Clone)] /// The Agent struct is responsible for managing conversations with the LLM /// and coordinating tool executions. It maintains a registry of available tools @@ -211,11 +213,7 @@ impl Agent { } /// Create a new Agent that shares state and stream with an existing agent - pub fn from_existing( - existing_agent: &Agent, - name: String, - default_prompt: String, - ) -> Self { + pub fn from_existing(existing_agent: &Agent, name: String, default_prompt: String) -> Self { let llm_api_key = env::var("LLM_API_KEY").ok(); // Use ok() instead of expect let llm_base_url = env::var("LLM_BASE_URL").ok(); // Use ok() instead of expect @@ -308,7 +306,6 @@ impl Agent { } // --- End Helper state functions --- - /// Get the current thread being processed, if any pub async fn get_current_thread(&self) -> Option { self.current_thread.read().await.clone() @@ -368,7 +365,8 @@ impl Agent { let registered_tool = RegisteredTool { executor: Box::new(value_tool), // Box the closure only if it's Some - enablement_condition: enablement_condition.map(|f| Box::new(f) as Box) -> bool + Send + Sync>), + enablement_condition: enablement_condition + .map(|f| Box::new(f) as Box) -> bool + Send + Sync>), }; tools.insert(name, registered_tool); } @@ -389,13 +387,14 @@ impl Agent { let value_tool = tool.into_tool_call_executor(); let registered_tool = RegisteredTool { executor: Box::new(value_tool), - enablement_condition: condition.map(|f| Box::new(f) as Box) -> bool + Send + Sync>), + enablement_condition: condition.map(|f| { + Box::new(f) as Box) -> bool + Send + Sync> + }), }; tools_map.insert(name, registered_tool); } } - /// Process a thread of conversation, potentially executing tools and continuing /// the conversation recursively until a final response is reached. /// @@ -413,9 +412,9 @@ impl Agent { let mut final_message = None; while let Ok(msg) = rx.recv().await { match msg { - Ok(AgentMessage::Done) => break, // Stop collecting on Done message + Ok(AgentMessage::Done) => break, // Stop collecting on Done message Ok(m) => final_message = Some(m), // Store the latest non-Done message - Err(e) => return Err(e.into()), // Propagate errors + Err(e) => return Err(e.into()), // Propagate errors } } @@ -500,7 +499,9 @@ impl Agent { let (trace_builder, parent_span) = if trace_builder.is_none() && parent_span.is_none() { if let Some(client) = &*BRAINTRUST_CLIENT { // Find the most recent user message to use as our input content - let user_input_message = thread.messages.iter() + let user_input_message = thread + .messages + .iter() .filter(|msg| matches!(msg, AgentMessage::User { .. })) .last() .cloned(); @@ -525,7 +526,10 @@ impl Agent { // Add the user prompt text (not the full message) as input to the root span // Ensure we're passing ONLY the content text, not the full message object - let root_span = trace.root_span().clone().with_input(serde_json::json!(user_prompt_text)); + let root_span = trace + .root_span() + .clone() + .with_input(serde_json::json!(user_prompt_text)); // Add chat_id (session_id) as metadata to the root span let span = root_span.with_metadata("chat_id", self.session_id.to_string()); @@ -554,30 +558,37 @@ impl Agent { Some(self.name.clone()), ); if let Err(e) = self.get_stream_sender().await.send(Ok(message)) { - tracing::warn!("Channel send error when sending recursion limit message: {}", e); + tracing::warn!( + "Channel send error when sending recursion limit message: {}", + e + ); } self.close().await; // Ensure stream is closed return Ok(()); // Don't return error, just stop processing } - // --- Dynamic Prompt Selection --- + // --- Dynamic Prompt Selection --- let current_system_prompt = self.get_current_prompt().await; let system_message = AgentMessage::developer(current_system_prompt); // Prepare messages for LLM: Inject current system prompt and filter out old ones let mut llm_messages = vec![system_message]; llm_messages.extend( - thread.messages.iter() + thread + .messages + .iter() .filter(|msg| !matches!(msg, AgentMessage::Developer { .. })) - .cloned() + .cloned(), ); - // --- End Dynamic Prompt Selection --- + // --- End Dynamic Prompt Selection --- // Collect all enabled tools and their schemas let tools = self.get_enabled_tools().await; // Now uses the new logic // Get the most recent user message for logging (used only in error logging) - let _user_message = thread.messages.last() + let _user_message = thread + .messages + .last() .filter(|msg| matches!(msg, AgentMessage::User { .. })) .cloned(); @@ -594,11 +605,16 @@ impl Agent { session_id: thread.id.to_string(), trace_id: Uuid::new_v4().to_string(), }), + // reasoning_effort: Some("high".to_string()), ..Default::default() }; // Get the streaming response from the LLM - let mut stream_rx = match self.llm_client.stream_chat_completion(request.clone()).await { + let mut stream_rx = match self + .llm_client + .stream_chat_completion(request.clone()) + .await + { Ok(rx) => rx, Err(e) => { // Log error in span @@ -616,7 +632,7 @@ impl Agent { } let error_message = format!("Error starting stream: {:?}", e); return Err(anyhow::anyhow!(error_message)); - }, + } }; // We store the parent span to use for creating individual tool spans @@ -646,16 +662,16 @@ impl Agent { if let Some(tool_calls) = &delta.tool_calls { for tool_call in tool_calls { let id = tool_call.id.clone().unwrap_or_else(|| { - buffer.tool_calls + buffer + .tool_calls .keys() - .next().cloned() + .next() + .cloned() .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()) }); // Get or create the pending tool call - let pending_call = buffer.tool_calls - .entry(id.clone()) - .or_default(); + let pending_call = buffer.tool_calls.entry(id.clone()).or_default(); // Update the pending call with the delta pending_call.update_from_delta(tool_call); @@ -692,18 +708,18 @@ impl Agent { } let error_message = format!("Error in stream: {:?}", e); return Err(anyhow::anyhow!(error_message)); - }, + } } } // Flush any remaining buffered content or tool calls before creating final message buffer.flush(self).await?; - // Create and send the final message let final_tool_calls: Option> = if !buffer.tool_calls.is_empty() { Some( - buffer.tool_calls + buffer + .tool_calls .values() .map(|p| p.clone().into_tool_call()) .collect(), @@ -714,7 +730,11 @@ impl Agent { let final_message = AgentMessage::assistant( buffer.message_id, - if buffer.content.is_empty() { None } else { Some(buffer.content) }, + if buffer.content.is_empty() { + None + } else { + Some(buffer.content) + }, final_tool_calls.clone(), MessageProgress::Complete, Some(false), // Never the first message at this stage @@ -723,69 +743,34 @@ impl Agent { // Broadcast the final assistant message // Ensure we don't block if the receiver dropped - if let Err(e) = self.get_stream_sender().await.send(Ok(final_message.clone())) { - tracing::debug!("Failed to send final assistant message (receiver likely dropped): {}", e); + if let Err(e) = self + .get_stream_sender() + .await + .send(Ok(final_message.clone())) + { + tracing::debug!( + "Failed to send final assistant message (receiver likely dropped): {}", + e + ); } - // Update thread with assistant message self.update_current_thread(final_message.clone()).await?; - // For a message without tool calls, create and log a new complete message span - // Otherwise, tool spans will be created individually for each tool call - if final_tool_calls.is_none() && trace_builder.is_some() { - if let (Some(trace), Some(parent)) = (&trace_builder, &parent_span) { - if let Some(client) = &*BRAINTRUST_CLIENT { - // Ensure we have the complete message content - // Make sure we clone the final message to avoid mutating it - let complete_final_message = final_message.clone(); + // Get the updated thread state AFTER adding the final assistant message + // This will be used for the potential recursive call later. + let mut updated_thread_for_recursion = self + .current_thread + .read() + .await + .as_ref() + .cloned() + .ok_or_else(|| { + anyhow::anyhow!("Failed to get updated thread state after adding assistant message") + })?; - // Create a fresh span for the text-only response - let span = trace.add_child_span("Assistant Response", "llm", parent).await?; - - // Add chat_id (session_id) as metadata to the span - let span = span.with_metadata("chat_id", self.session_id.to_string()); - - // Add the full request/response information - let span = span.with_input(serde_json::to_value(&request)?); - let span = span.with_output(serde_json::to_value(&complete_final_message)?); - - // Log span non-blockingly (client handles the background processing) - if let Err(log_err) = client.log_span(span).await { - error!("Failed to log assistant response span: {}", log_err); - } - } - } - } - // For messages with tool calls, we won't log the output here - // Instead, we'll create tool spans with this assistant span as parent - - // If this is an auto response without tool calls, it means we're done - if final_tool_calls.is_none() { - // Log the final output to the parent span - if let Some(parent_span) = &parent_span { - if let Some(client) = &*BRAINTRUST_CLIENT { - // Create a new span with the final message as output - let final_span = parent_span.clone().with_output(serde_json::to_value(&final_message)?); - - // Log span non-blockingly (client handles the background processing) - if let Err(log_err) = client.log_span(final_span).await { - error!("Failed to log final output span: {}", log_err); - } - } - } - - // Finish the trace without consuming it - self.finish_trace(&trace_builder).await?; - - // // Send Done message and return - Done message is now sent by the caller task - // self.get_stream_sender() - // .await - // .send(Ok(AgentMessage::Done))?; - return Ok(()); - } - - // If the LLM wants to use tools, execute them and continue + // --- Tool Execution Logic --- + // If the LLM wants to use tools, execute them if let Some(tool_calls) = final_tool_calls { let mut results = Vec::new(); let agent_tools = self.tools.read().await; // Read tools once @@ -794,23 +779,28 @@ impl Agent { // Execute each requested tool let mut should_terminate = false; // Flag to indicate if loop should terminate after this tool for tool_call in tool_calls { - // Find the registered tool entry + // Find the registered tool entry if let Some(registered_tool) = agent_tools.get(&tool_call.function.name) { // Create a tool span that combines the assistant request with the tool execution - let tool_span = if let (Some(trace), Some(parent)) = (&trace_builder, &parent_for_tool_spans) { + let tool_span = if let (Some(trace), Some(parent)) = + (&trace_builder, &parent_for_tool_spans) + { if let Some(_client) = &*BRAINTRUST_CLIENT { // Create a span for the assistant + tool execution - let span = trace.add_child_span( - &format!("Assistant: {}", tool_call.function.name), - "tool", - parent - ).await?; + let span = trace + .add_child_span( + &format!("Assistant: {}", tool_call.function.name), + "tool", + parent, + ) + .await?; // Add chat_id (session_id) as metadata to the span let span = span.with_metadata("chat_id", self.session_id.to_string()); // Parse the parameters (unused in this context since we're using final_message) - let _params: Value = serde_json::from_str(&tool_call.function.arguments)?; + let _params: Value = + serde_json::from_str(&tool_call.function.arguments)?; // Use the assistant message as input to this span // This connects the assistant's request to the tool execution @@ -829,13 +819,16 @@ impl Agent { // Parse the parameters let params: Value = match serde_json::from_str(&tool_call.function.arguments) { - Ok(p) => p, - Err(e) => { - let err_msg = format!("Failed to parse tool arguments for {}: {}", tool_call.function.name, e); - error!("{}", err_msg); - // Optionally log to Braintrust span here - return Err(anyhow::anyhow!(err_msg)); - } + Ok(p) => p, + Err(e) => { + let err_msg = format!( + "Failed to parse tool arguments for {}: {}", + tool_call.function.name, e + ); + error!("{}", err_msg); + // Optionally log to Braintrust span here + return Err(anyhow::anyhow!(err_msg)); + } }; let _tool_input = serde_json::json!({ @@ -847,7 +840,11 @@ impl Agent { }); // Execute the tool using the executor from RegisteredTool - let result = match registered_tool.executor.execute(params, tool_call.id.clone()).await { + let result = match registered_tool + .executor + .execute(params, tool_call.id.clone()) + .await + { Ok(r) => r, Err(e) => { // Log error in tool span @@ -862,11 +859,17 @@ impl Agent { // Log span non-blockingly (client handles the background processing) if let Err(log_err) = client.log_span(error_span).await { - error!("Failed to log tool execution error span: {}", log_err); + error!( + "Failed to log tool execution error span: {}", + log_err + ); } } } - let error_message = format!("Tool execution error for {}: {:?}", tool_call.function.name, e); + let error_message = format!( + "Tool execution error for {}: {:?}", + tool_call.function.name, e + ); error!("{}", error_message); // Log locally return Err(anyhow::anyhow!(error_message)); } @@ -885,10 +888,18 @@ impl Agent { if let Some(tool_span) = &tool_span { if let Some(client) = &*BRAINTRUST_CLIENT { // Only log completed messages - if matches!(tool_message, AgentMessage::Tool { progress: MessageProgress::Complete, .. }) { + if matches!( + tool_message, + AgentMessage::Tool { + progress: MessageProgress::Complete, + .. + } + ) { // Now that we have the tool result, add it as output and log the span // This creates a span showing assistant message -> tool execution -> tool result - let result_span = tool_span.clone().with_output(serde_json::to_value(&tool_message)?); + let result_span = tool_span + .clone() + .with_output(serde_json::to_value(&tool_message)?); // Log span non-blockingly (client handles the background processing) if let Err(log_err) = client.log_span(result_span).await { @@ -899,10 +910,16 @@ impl Agent { } // Broadcast the tool message as soon as we receive it - use try_send to avoid blocking - if let Err(e) = self.get_stream_sender().await.send(Ok(tool_message.clone())) { - tracing::debug!("Failed to send tool message (receiver likely dropped): {}", e); - } - + if let Err(e) = self + .get_stream_sender() + .await + .send(Ok(tool_message.clone())) + { + tracing::debug!( + "Failed to send tool message (receiver likely dropped): {}", + e + ); + } // Update thread with tool response BEFORE checking termination self.update_current_thread(tool_message.clone()).await?; @@ -911,76 +928,103 @@ impl Agent { // Check if this tool's name is in the terminating list if terminating_names.contains(&tool_call.function.name) { should_terminate = true; - tracing::info!("Tool '{}' triggered agent termination.", tool_call.function.name); + tracing::info!( + "Tool '{}' triggered agent termination.", + tool_call.function.name + ); break; // Exit the tool execution loop } } else { - // Handle case where the LLM hallucinated a tool name - let err_msg = format!("Attempted to call non-existent tool: {}", tool_call.function.name); - error!("{}", err_msg); - // Create a fake tool result indicating the error - let error_result = AgentMessage::tool( - None, - serde_json::json!({"error": err_msg}).to_string(), - tool_call.id.clone(), - Some(tool_call.function.name.clone()), - MessageProgress::Complete, - ); - // Broadcast the error message - if let Err(e) = self.get_stream_sender().await.send(Ok(error_result.clone())) { - tracing::debug!("Failed to send tool error message (receiver likely dropped): {}", e); - } - // Update thread and push the error result for the next LLM call - self.update_current_thread(error_result.clone()).await?; + // Handle case where the LLM hallucinated a tool name + let err_msg = format!( + "Attempted to call non-existent tool: {}", + tool_call.function.name + ); + error!("{}", err_msg); + // Create a fake tool result indicating the error + let error_result = AgentMessage::tool( + None, + serde_json::json!({"error": err_msg}).to_string(), + tool_call.id.clone(), + Some(tool_call.function.name.clone()), + MessageProgress::Complete, + ); + // Broadcast the error message + if let Err(e) = self + .get_stream_sender() + .await + .send(Ok(error_result.clone())) + { + tracing::debug!( + "Failed to send tool error message (receiver likely dropped): {}", + e + ); + } + // Update thread and push the error result for the next LLM call + self.update_current_thread(error_result.clone()).await?; // Continue processing other tool calls if any } } - // If a tool signaled termination, send Done and finish. + // If a tool signaled termination, finish trace, send Done and exit. if should_terminate { - // Finish the trace without consuming it + // Finish the trace without consuming it self.finish_trace(&trace_builder).await?; // Send Done message if let Err(e) = self.get_stream_sender().await.send(Ok(AgentMessage::Done)) { - tracing::debug!("Failed to send Done message after tool termination (receiver likely dropped): {}", e); - } + tracing::debug!("Failed to send Done message after tool termination (receiver likely dropped): {}", e); + } return Ok(()); // Exit the function, preventing recursion } - - // Create a new thread with the tool results and continue recursively - let mut new_thread = thread.clone(); - // Add the assistant message that contained the tool_calls to ensure correct history order - new_thread.messages.push(final_message); // Add the assistant message - // The assistant message that requested the tools is already added above - new_thread.messages.extend(results); - - // For recursive calls, we'll continue with the same trace - // We don't finish the trace here to keep all interactions in one trace - Box::pin(self.process_thread_with_depth(&new_thread, recursion_depth + 1, trace_builder, parent_span)).await + // Add the tool results to the thread state for the recursive call + updated_thread_for_recursion.messages.extend(results); } else { - // Log the final output to the parent span (This case should ideally not be reached if final_tool_calls was None earlier) + // Log the final assistant response span only if NO tools were called + if let (Some(trace), Some(parent)) = (&trace_builder, &parent_span) { + if let Some(client) = &*BRAINTRUST_CLIENT { + // Ensure we have the complete message content + let complete_final_message = final_message.clone(); + + // Create a fresh span for the text-only response + let span = trace + .add_child_span("Assistant Response", "llm", parent) + .await?; + let span = span.with_metadata("chat_id", self.session_id.to_string()); + let span = span.with_input(serde_json::to_value(&request)?); // Log the request + let span = span.with_output(serde_json::to_value(&complete_final_message)?); // Log the response + + // Log span non-blockingly + if let Err(log_err) = client.log_span(span).await { + error!("Failed to log assistant response span: {}", log_err); + } + } + } + + // Also log the final output to the parent span if no tools were called if let Some(parent_span) = &parent_span { if let Some(client) = &*BRAINTRUST_CLIENT { - // Create a new span with the final message as output - let final_span = parent_span.clone().with_output(serde_json::to_value(&final_message)?); - - // Log span non-blockingly (client handles the background processing) + let final_span = parent_span + .clone() + .with_output(serde_json::to_value(&final_message)?); if let Err(log_err) = client.log_span(final_span).await { error!("Failed to log final output span: {}", log_err); } } } - - // Finish the trace without consuming it - self.finish_trace(&trace_builder).await?; - - // // Send Done message and return - Done message is now sent by the caller task - // self.get_stream_sender() - // .await - // .send(Ok(AgentMessage::Done))?; - Ok(()) + // --- End Logging for Text-Only Response --- } + + // Continue the conversation recursively using the updated thread state, + // unless a terminating tool caused an early return above. + // This call happens regardless of whether tools were executed in this step. + Box::pin(self.process_thread_with_depth( + &updated_thread_for_recursion, + recursion_depth + 1, + trace_builder, + parent_span, + )) + .await } /// Get a receiver for the shutdown signal @@ -996,11 +1040,12 @@ impl Agent { } /// Get a read lock on the tools map (Exposes RegisteredTool now) - pub async fn get_tools_map(&self) -> tokio::sync::RwLockReadGuard<'_, HashMap> { + pub async fn get_tools_map( + &self, + ) -> tokio::sync::RwLockReadGuard<'_, HashMap> { self.tools.read().await } - /// Helper method to finish a trace without consuming the TraceBuilder /// This method is fully non-blocking and never affects application performance async fn finish_trace(&self, trace: &Option) -> Result<()> { @@ -1017,12 +1062,14 @@ impl Agent { // Create and log a completion span non-blockingly if let Some(client) = &*BRAINTRUST_CLIENT { // Create a new span for completion linked to the trace - let completion_span = client.create_span( - "Trace Completion", - "completion", - Some(root_span_id), // Link to the trace's root span - Some(root_span_id) // Set parent to also be the root span - ).with_metadata("chat_id", self.session_id.to_string()); + let completion_span = client + .create_span( + "Trace Completion", + "completion", + Some(root_span_id), // Link to the trace's root span + Some(root_span_id), // Set parent to also be the root span + ) + .with_metadata("chat_id", self.session_id.to_string()); // Log span non-blockingly (client handles the background processing) if let Err(e) = client.log_span(completion_span).await { @@ -1043,11 +1090,7 @@ impl Agent { /// Add a rule for dynamically selecting a system prompt. /// Rules are checked in the order they are added. The first matching rule's prompt is used. - pub async fn add_dynamic_prompt_rule( - &self, - condition: F, - prompt: String, - ) + pub async fn add_dynamic_prompt_rule(&self, condition: F, prompt: String) where F: Fn(&HashMap) -> bool + Send + Sync + 'static, { @@ -1183,13 +1226,8 @@ mod tests { tool_id: String, progress: MessageProgress, ) -> Result<()> { - let message = AgentMessage::tool( - None, - content, - tool_id, - Some(self.get_name()), - progress, - ); + let message = + AgentMessage::tool(None, content, tool_id, Some(self.get_name()), progress); self.agent.get_stream_sender().await.send(Ok(message))?; Ok(()) } @@ -1200,7 +1238,11 @@ mod tests { type Output = Value; type Params = Value; - async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result { + async fn execute( + &self, + params: Self::Params, + tool_call_id: String, + ) -> Result { self.send_progress( "Fetching weather data...".to_string(), tool_call_id.clone(), // Use the actual tool_call_id @@ -1210,7 +1252,6 @@ mod tests { let _params = params.as_object().unwrap(); - // Simulate a delay tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; @@ -1285,7 +1326,7 @@ mod tests { Ok(response) => { println!("Response (no tools): {:?}", response); response - }, + } Err(e) => panic!("Error processing thread: {:?}", e), }; } @@ -1311,7 +1352,9 @@ mod tests { let condition = |_state: &HashMap| true; // Always enabled // Add tool to agent - agent.add_tool(tool_name, weather_tool, Some(condition)).await; + agent + .add_tool(tool_name, weather_tool, Some(condition)) + .await; let thread = AgentThread::new( None, @@ -1323,9 +1366,9 @@ mod tests { let _response = match agent.process_thread(&thread).await { Ok(response) => { - println!("Response (with tools): {:?}", response); - response - }, + println!("Response (with tools): {:?}", response); + response + } Err(e) => panic!("Error processing thread: {:?}", e), }; } @@ -1350,7 +1393,9 @@ mod tests { let tool_name = weather_tool.get_name(); let condition = |_state: &HashMap| true; // Always enabled - agent.add_tool(tool_name, weather_tool, Some(condition)).await; + agent + .add_tool(tool_name, weather_tool, Some(condition)) + .await; let thread = AgentThread::new( None, @@ -1360,16 +1405,16 @@ mod tests { )], ); - let _response = match agent.process_thread(&thread).await { + let _response = match agent.process_thread(&thread).await { Ok(response) => { - println!("Response (multi-step): {:?}", response); - response - }, + println!("Response (multi-step): {:?}", response); + response + } Err(e) => panic!("Error processing thread: {:?}", e), }; } - #[tokio::test] + #[tokio::test] async fn test_agent_disabled_tool() { setup(); @@ -1389,51 +1434,66 @@ mod tests { let tool_name = weather_tool.get_name(); // Condition: only enabled if "weather_enabled" state is true let condition = |state: &HashMap| -> bool { - state.get("weather_enabled").and_then(|v| v.as_bool()).unwrap_or(false) + state + .get("weather_enabled") + .and_then(|v| v.as_bool()) + .unwrap_or(false) }; // Add tool with the condition - agent.add_tool(tool_name, weather_tool, Some(condition)).await; + agent + .add_tool(tool_name, weather_tool, Some(condition)) + .await; // --- Test case 1: Tool disabled --- let thread_disabled = AgentThread::new( None, Uuid::new_v4(), - vec![AgentMessage::user("What is the weather in Provo?".to_string())], + vec![AgentMessage::user( + "What is the weather in Provo?".to_string(), + )], ); // Ensure state doesn't enable the tool - agent.set_state_value("weather_enabled".to_string(), json!(false)).await; + agent + .set_state_value("weather_enabled".to_string(), json!(false)) + .await; let response_disabled = match agent.process_thread(&thread_disabled).await { - Ok(response) => response, - Err(e) => panic!("Error processing thread (disabled): {:?}", e), - }; - // Expect response without tool call - if let AgentMessage::Assistant { tool_calls: Some(_), .. } = response_disabled { - panic!("Tool call occurred even when disabled"); - } - println!("Response (disabled tool): {:?}", response_disabled); - + Ok(response) => response, + Err(e) => panic!("Error processing thread (disabled): {:?}", e), + }; + // Expect response without tool call + if let AgentMessage::Assistant { + tool_calls: Some(_), + .. + } = response_disabled + { + panic!("Tool call occurred even when disabled"); + } + println!("Response (disabled tool): {:?}", response_disabled); // --- Test case 2: Tool enabled --- let thread_enabled = AgentThread::new( None, Uuid::new_v4(), - vec![AgentMessage::user("What is the weather in Orem?".to_string())], + vec![AgentMessage::user( + "What is the weather in Orem?".to_string(), + )], ); // Set state to enable the tool - agent.set_state_value("weather_enabled".to_string(), json!(true)).await; + agent + .set_state_value("weather_enabled".to_string(), json!(true)) + .await; let _response_enabled = match agent.process_thread(&thread_enabled).await { - Ok(response) => response, - Err(e) => panic!("Error processing thread (enabled): {:?}", e), - }; - // Expect response *with* tool call (or final answer after tool call) - // We can't easily check the intermediate step here, but the test should run without panic - println!("Response (enabled tool): {:?}", _response_enabled); + Ok(response) => response, + Err(e) => panic!("Error processing thread (enabled): {:?}", e), + }; + // Expect response *with* tool call (or final answer after tool call) + // We can't easily check the intermediate step here, but the test should run without panic + println!("Response (enabled tool): {:?}", _response_enabled); } - #[tokio::test] async fn test_agent_state_management() { setup(); @@ -1460,10 +1520,11 @@ mod tests { assert_eq!(agent.get_state_bool("test_key").await, None); // Not a bool // Test setting boolean value - agent.set_state_value("bool_key".to_string(), json!(true)).await; + agent + .set_state_value("bool_key".to_string(), json!(true)) + .await; assert_eq!(agent.get_state_bool("bool_key").await, Some(true)); - // Test updating multiple values agent .update_state(|state| { @@ -1486,4 +1547,4 @@ mod tests { assert!(!agent.state_key_exists("test_key").await); assert_eq!(agent.get_state_bool("bool_key").await, None); } -} \ No newline at end of file +} diff --git a/api/libs/agents/src/agents/buster_multi_agent.rs b/api/libs/agents/src/agents/buster_multi_agent.rs index 796683544..383be7b91 100644 --- a/api/libs/agents/src/agents/buster_multi_agent.rs +++ b/api/libs/agents/src/agents/buster_multi_agent.rs @@ -1,4 +1,6 @@ use anyhow::Result; +use database::helpers::datasets::get_dataset_names_for_organization; +use database::organization::get_user_organization_id; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::HashMap; @@ -187,11 +189,19 @@ impl BusterMultiAgent { session_id: Uuid, is_follow_up: bool, // Add flag to determine initial prompt ) -> Result { + let organization_id = match get_user_organization_id(&user_id).await { + Ok(Some(org_id)) => org_id, + Ok(None) => return Err(anyhow::anyhow!("User does not belong to any organization")), + Err(e) => return Err(e), + }; + + let dataset_names = get_dataset_names_for_organization(organization_id).await?; + // Select initial default prompt based on whether it's a follow-up let initial_default_prompt = if is_follow_up { - FOLLOW_UP_INTIALIZATION_PROMPT.to_string() + FOLLOW_UP_INTIALIZATION_PROMPT.replace("{DATASETS}", &dataset_names.join(", ")) } else { - INTIALIZATION_PROMPT.to_string() + INTIALIZATION_PROMPT.replace("{DATASETS}", &dataset_names.join(", ")) }; // Create agent, passing the selected initialization prompt as default @@ -199,7 +209,7 @@ impl BusterMultiAgent { "o3-mini".to_string(), user_id, session_id, - "buster_super_agent".to_string(), + "buster_multi_agent".to_string(), None, None, initial_default_prompt, // Use selected default prompt @@ -212,8 +222,8 @@ impl BusterMultiAgent { } // Define prompt switching conditions - let needs_plan_condition = |state: &HashMap| -> bool { - state.contains_key("data_context") && !state.contains_key("plan_available") + let needs_plan_condition = move |state: &HashMap| -> bool { + state.contains_key("data_context") && !state.contains_key("plan_available") && !is_follow_up }; let needs_analysis_condition = |state: &HashMap| -> bool { // Example: Trigger analysis prompt once plan is available and metrics/dashboards are not yet available @@ -245,13 +255,11 @@ impl BusterMultiAgent { )); // Re-apply prompt rules for the new agent instance - let needs_plan_condition = |state: &HashMap| -> bool { + let needs_plan_condition = move |state: &HashMap| -> bool { state.contains_key("data_context") && !state.contains_key("plan_available") }; let needs_analysis_condition = |state: &HashMap| -> bool { - state.contains_key("plan_available") - && !state.contains_key("metrics_available") - && !state.contains_key("dashboards_available") + state.contains_key("data_context") && state.contains_key("plan_available") }; agent .add_dynamic_prompt_rule(needs_plan_condition, CREATE_PLAN_PROMPT.to_string()) @@ -724,6 +732,14 @@ Always use your best judgement when selecting visualization types, and be confid --- +### Available Datasets +Datasets include: +{DATASETS} + +**Reminder**: Always use `search_data_catalog` to confirm specific data points or columns within these datasets — do not assume availability. + +--- + ## Workflow Examples - **Fully Supported Workflow** diff --git a/api/libs/agents/src/tools/categories/file_tools/common.rs b/api/libs/agents/src/tools/categories/file_tools/common.rs index 37ded697b..1ba49084f 100644 --- a/api/libs/agents/src/tools/categories/file_tools/common.rs +++ b/api/libs/agents/src/tools/categories/file_tools/common.rs @@ -189,7 +189,19 @@ properties: ### sql: type: string - description: SQL query using YAML pipe syntax (|) + description: | + SQL query using YAML pipe syntax (|) + + The SQL query should be formatted with proper indentation using the YAML pipe (|) syntax. + This ensures the multi-line SQL is properly parsed while preserving whitespace and newlines. + + Example: + sql: | + SELECT + column1, + column2 + FROM table + WHERE condition # CHART CONFIGURATION chartConfig: @@ -540,7 +552,9 @@ pub const DASHBOARD_YML_SCHEMA: &str = r##" # items: # - id: metric-uuid-2 # - id: metric-uuid-3 -# columnSizes: [6, 6] # Required - must sum to exactly 12 +# columnSizes: +# - 6 +# - 6 # # Rules: # 1. Each row can have up to 4 items @@ -548,7 +562,9 @@ pub const DASHBOARD_YML_SCHEMA: &str = r##" # 3. columnSizes is required and must specify the width for each item # 4. Sum of columnSizes in a row must be exactly 12 # 5. Each column size must be at least 3 -# 6. All arrays should follow the YML array syntax using `-` not `[` and `]` +# 6. All arrays should follow the YML array syntax using `-` +# 7. All arrays should NOT USE `[]` formatting. +# 8. don't use comments. the ones in the example are just for explanation # ---------------------------------------- type: object diff --git a/api/libs/database/src/helpers/datasets.rs b/api/libs/database/src/helpers/datasets.rs new file mode 100644 index 000000000..69bf6887b --- /dev/null +++ b/api/libs/database/src/helpers/datasets.rs @@ -0,0 +1,22 @@ +use anyhow::Result; +use diesel::prelude::*; +use diesel_async::RunQueryDsl; +use uuid::Uuid; + +use crate::pool::get_pg_pool; + +pub async fn get_dataset_names_for_organization(org_id: Uuid) -> Result, anyhow::Error> { + use crate::schema::datasets::dsl::*; + + let mut conn = get_pg_pool().get().await?; + + let results = datasets + .filter(organization_id.eq(org_id)) + .filter(deleted_at.is_null()) + .filter(yml_file.is_not_null()) + .select(name) + .load::(&mut conn) + .await?; + + Ok(results) +} diff --git a/api/libs/database/src/helpers/mod.rs b/api/libs/database/src/helpers/mod.rs index 80e3328bb..373a58ab2 100644 --- a/api/libs/database/src/helpers/mod.rs +++ b/api/libs/database/src/helpers/mod.rs @@ -3,4 +3,5 @@ pub mod dashboard_files; pub mod metric_files; pub mod chats; pub mod organization; -pub mod test_utils; \ No newline at end of file +pub mod test_utils; +pub mod datasets; \ No newline at end of file diff --git a/api/libs/handlers/src/chats/post_chat_handler.rs b/api/libs/handlers/src/chats/post_chat_handler.rs index 0ee8eaa4e..98294e32b 100644 --- a/api/libs/handlers/src/chats/post_chat_handler.rs +++ b/api/libs/handlers/src/chats/post_chat_handler.rs @@ -1279,30 +1279,30 @@ pub async fn transform_message( .map(|container| (container, ThreadEvent::GeneratingResponseMessage)), ); - // Add the "Finished reasoning" message if we're just starting - if initial { - let reasoning_message = BusterReasoningMessage::Text(BusterReasoningText { - id: Uuid::new_v4().to_string(), - reasoning_type: "text".to_string(), - title: "Finished reasoning".to_string(), - // Use total duration from start for this initial message - secondary_title: format!("{} seconds", start_time.elapsed().as_secs()), - message: None, - message_chunk: None, - status: Some("completed".to_string()), - }); - // Reset the completion time after showing the initial reasoning message - *last_reasoning_completion_time = Instant::now(); + // // Add the "Finished reasoning" message if we're just starting + // if initial { + // let reasoning_message = BusterReasoningMessage::Text(BusterReasoningText { + // id: Uuid::new_v4().to_string(), + // reasoning_type: "text".to_string(), + // title: "Finished reasoning".to_string(), + // // Use total duration from start for this initial message + // secondary_title: format!("{} seconds", start_time.elapsed().as_secs()), + // message: None, + // message_chunk: None, + // status: Some("completed".to_string()), + // }); + // // Reset the completion time after showing the initial reasoning message + // *last_reasoning_completion_time = Instant::now(); - let reasoning_container = - BusterContainer::ReasoningMessage(BusterReasoningMessageContainer { - reasoning: reasoning_message, - chat_id: *chat_id, - message_id: *message_id, - }); + // let reasoning_container = + // BusterContainer::ReasoningMessage(BusterReasoningMessageContainer { + // reasoning: reasoning_message, + // chat_id: *chat_id, + // message_id: *message_id, + // }); - containers.push((reasoning_container, ThreadEvent::GeneratingResponseMessage)); - } + // containers.push((reasoning_container, ThreadEvent::GeneratingResponseMessage)); + // } Ok(containers) } else if let Some(tool_calls) = tool_calls {