things feeling pretty good.

This commit is contained in:
dal 2025-04-11 13:33:56 -06:00
parent 47e8d55954
commit c691f904f9
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
6 changed files with 381 additions and 265 deletions

View File

@ -7,11 +7,11 @@ use litellm::{
};
use once_cell::sync::Lazy;
use serde_json::Value;
use std::time::{Duration, Instant};
use std::{collections::HashMap, env, sync::Arc};
use tokio::sync::{broadcast, RwLock};
use tracing::error;
use uuid::Uuid;
use std::time::{Duration, Instant};
// Type definition for tool registry to simplify complex type
// No longer needed, defined below
@ -19,16 +19,17 @@ use crate::models::AgentThread;
// Global BraintrustClient instance
static BRAINTRUST_CLIENT: Lazy<Option<Arc<BraintrustClient>>> = Lazy::new(|| {
match (std::env::var("BRAINTRUST_API_KEY"), std::env::var("BRAINTRUST_LOGGING_ID")) {
(Ok(_), Ok(buster_logging_id)) => {
match BraintrustClient::new(None, &buster_logging_id) {
Ok(client) => Some(client),
Err(e) => {
eprintln!("Failed to create Braintrust client: {}", e);
None
}
match (
std::env::var("BRAINTRUST_API_KEY"),
std::env::var("BRAINTRUST_LOGGING_ID"),
) {
(Ok(_), Ok(buster_logging_id)) => match BraintrustClient::new(None, &buster_logging_id) {
Ok(client) => Some(client),
Err(e) => {
eprintln!("Failed to create Braintrust client: {}", e);
None
}
}
},
_ => None,
}
});
@ -55,7 +56,6 @@ struct MessageBuffer {
first_message_sent: bool,
}
impl MessageBuffer {
fn new() -> Self {
Self {
@ -101,7 +101,11 @@ impl MessageBuffer {
// Create and send the message
let message = AgentMessage::assistant(
self.message_id.clone(),
if self.content.is_empty() { None } else { Some(self.content.clone()) },
if self.content.is_empty() {
None
} else {
Some(self.content.clone())
},
tool_calls,
MessageProgress::InProgress,
Some(!self.first_message_sent),
@ -124,7 +128,6 @@ impl MessageBuffer {
}
}
// Helper struct to store the tool and its enablement condition
struct RegisteredTool {
executor: Box<dyn ToolExecutor<Output = Value, Params = Value> + Send + Sync>,
@ -141,7 +144,6 @@ struct DynamicPromptRule {
// Update the ToolRegistry type alias is no longer needed, but we need the new type for the map
type ToolsMap = Arc<RwLock<HashMap<String, RegisteredTool>>>;
#[derive(Clone)]
/// The Agent struct is responsible for managing conversations with the LLM
/// and coordinating tool executions. It maintains a registry of available tools
@ -211,11 +213,7 @@ impl Agent {
}
/// Create a new Agent that shares state and stream with an existing agent
pub fn from_existing(
existing_agent: &Agent,
name: String,
default_prompt: String,
) -> Self {
pub fn from_existing(existing_agent: &Agent, name: String, default_prompt: String) -> Self {
let llm_api_key = env::var("LLM_API_KEY").ok(); // Use ok() instead of expect
let llm_base_url = env::var("LLM_BASE_URL").ok(); // Use ok() instead of expect
@ -308,7 +306,6 @@ impl Agent {
}
// --- End Helper state functions ---
/// Get the current thread being processed, if any
pub async fn get_current_thread(&self) -> Option<AgentThread> {
self.current_thread.read().await.clone()
@ -368,7 +365,8 @@ impl Agent {
let registered_tool = RegisteredTool {
executor: Box::new(value_tool),
// Box the closure only if it's Some
enablement_condition: enablement_condition.map(|f| Box::new(f) as Box<dyn Fn(&HashMap<String, Value>) -> bool + Send + Sync>),
enablement_condition: enablement_condition
.map(|f| Box::new(f) as Box<dyn Fn(&HashMap<String, Value>) -> bool + Send + Sync>),
};
tools.insert(name, registered_tool);
}
@ -389,13 +387,14 @@ impl Agent {
let value_tool = tool.into_tool_call_executor();
let registered_tool = RegisteredTool {
executor: Box::new(value_tool),
enablement_condition: condition.map(|f| Box::new(f) as Box<dyn Fn(&HashMap<String, Value>) -> bool + Send + Sync>),
enablement_condition: condition.map(|f| {
Box::new(f) as Box<dyn Fn(&HashMap<String, Value>) -> bool + Send + Sync>
}),
};
tools_map.insert(name, registered_tool);
}
}
/// Process a thread of conversation, potentially executing tools and continuing
/// the conversation recursively until a final response is reached.
///
@ -413,9 +412,9 @@ impl Agent {
let mut final_message = None;
while let Ok(msg) = rx.recv().await {
match msg {
Ok(AgentMessage::Done) => break, // Stop collecting on Done message
Ok(AgentMessage::Done) => break, // Stop collecting on Done message
Ok(m) => final_message = Some(m), // Store the latest non-Done message
Err(e) => return Err(e.into()), // Propagate errors
Err(e) => return Err(e.into()), // Propagate errors
}
}
@ -500,7 +499,9 @@ impl Agent {
let (trace_builder, parent_span) = if trace_builder.is_none() && parent_span.is_none() {
if let Some(client) = &*BRAINTRUST_CLIENT {
// Find the most recent user message to use as our input content
let user_input_message = thread.messages.iter()
let user_input_message = thread
.messages
.iter()
.filter(|msg| matches!(msg, AgentMessage::User { .. }))
.last()
.cloned();
@ -525,7 +526,10 @@ impl Agent {
// Add the user prompt text (not the full message) as input to the root span
// Ensure we're passing ONLY the content text, not the full message object
let root_span = trace.root_span().clone().with_input(serde_json::json!(user_prompt_text));
let root_span = trace
.root_span()
.clone()
.with_input(serde_json::json!(user_prompt_text));
// Add chat_id (session_id) as metadata to the root span
let span = root_span.with_metadata("chat_id", self.session_id.to_string());
@ -554,30 +558,37 @@ impl Agent {
Some(self.name.clone()),
);
if let Err(e) = self.get_stream_sender().await.send(Ok(message)) {
tracing::warn!("Channel send error when sending recursion limit message: {}", e);
tracing::warn!(
"Channel send error when sending recursion limit message: {}",
e
);
}
self.close().await; // Ensure stream is closed
return Ok(()); // Don't return error, just stop processing
}
// --- Dynamic Prompt Selection ---
// --- Dynamic Prompt Selection ---
let current_system_prompt = self.get_current_prompt().await;
let system_message = AgentMessage::developer(current_system_prompt);
// Prepare messages for LLM: Inject current system prompt and filter out old ones
let mut llm_messages = vec![system_message];
llm_messages.extend(
thread.messages.iter()
thread
.messages
.iter()
.filter(|msg| !matches!(msg, AgentMessage::Developer { .. }))
.cloned()
.cloned(),
);
// --- End Dynamic Prompt Selection ---
// --- End Dynamic Prompt Selection ---
// Collect all enabled tools and their schemas
let tools = self.get_enabled_tools().await; // Now uses the new logic
// Get the most recent user message for logging (used only in error logging)
let _user_message = thread.messages.last()
let _user_message = thread
.messages
.last()
.filter(|msg| matches!(msg, AgentMessage::User { .. }))
.cloned();
@ -594,11 +605,16 @@ impl Agent {
session_id: thread.id.to_string(),
trace_id: Uuid::new_v4().to_string(),
}),
// reasoning_effort: Some("high".to_string()),
..Default::default()
};
// Get the streaming response from the LLM
let mut stream_rx = match self.llm_client.stream_chat_completion(request.clone()).await {
let mut stream_rx = match self
.llm_client
.stream_chat_completion(request.clone())
.await
{
Ok(rx) => rx,
Err(e) => {
// Log error in span
@ -616,7 +632,7 @@ impl Agent {
}
let error_message = format!("Error starting stream: {:?}", e);
return Err(anyhow::anyhow!(error_message));
},
}
};
// We store the parent span to use for creating individual tool spans
@ -646,16 +662,16 @@ impl Agent {
if let Some(tool_calls) = &delta.tool_calls {
for tool_call in tool_calls {
let id = tool_call.id.clone().unwrap_or_else(|| {
buffer.tool_calls
buffer
.tool_calls
.keys()
.next().cloned()
.next()
.cloned()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string())
});
// Get or create the pending tool call
let pending_call = buffer.tool_calls
.entry(id.clone())
.or_default();
let pending_call = buffer.tool_calls.entry(id.clone()).or_default();
// Update the pending call with the delta
pending_call.update_from_delta(tool_call);
@ -692,18 +708,18 @@ impl Agent {
}
let error_message = format!("Error in stream: {:?}", e);
return Err(anyhow::anyhow!(error_message));
},
}
}
}
// Flush any remaining buffered content or tool calls before creating final message
buffer.flush(self).await?;
// Create and send the final message
let final_tool_calls: Option<Vec<ToolCall>> = if !buffer.tool_calls.is_empty() {
Some(
buffer.tool_calls
buffer
.tool_calls
.values()
.map(|p| p.clone().into_tool_call())
.collect(),
@ -714,7 +730,11 @@ impl Agent {
let final_message = AgentMessage::assistant(
buffer.message_id,
if buffer.content.is_empty() { None } else { Some(buffer.content) },
if buffer.content.is_empty() {
None
} else {
Some(buffer.content)
},
final_tool_calls.clone(),
MessageProgress::Complete,
Some(false), // Never the first message at this stage
@ -723,69 +743,34 @@ impl Agent {
// Broadcast the final assistant message
// Ensure we don't block if the receiver dropped
if let Err(e) = self.get_stream_sender().await.send(Ok(final_message.clone())) {
tracing::debug!("Failed to send final assistant message (receiver likely dropped): {}", e);
if let Err(e) = self
.get_stream_sender()
.await
.send(Ok(final_message.clone()))
{
tracing::debug!(
"Failed to send final assistant message (receiver likely dropped): {}",
e
);
}
// Update thread with assistant message
self.update_current_thread(final_message.clone()).await?;
// For a message without tool calls, create and log a new complete message span
// Otherwise, tool spans will be created individually for each tool call
if final_tool_calls.is_none() && trace_builder.is_some() {
if let (Some(trace), Some(parent)) = (&trace_builder, &parent_span) {
if let Some(client) = &*BRAINTRUST_CLIENT {
// Ensure we have the complete message content
// Make sure we clone the final message to avoid mutating it
let complete_final_message = final_message.clone();
// Get the updated thread state AFTER adding the final assistant message
// This will be used for the potential recursive call later.
let mut updated_thread_for_recursion = self
.current_thread
.read()
.await
.as_ref()
.cloned()
.ok_or_else(|| {
anyhow::anyhow!("Failed to get updated thread state after adding assistant message")
})?;
// Create a fresh span for the text-only response
let span = trace.add_child_span("Assistant Response", "llm", parent).await?;
// Add chat_id (session_id) as metadata to the span
let span = span.with_metadata("chat_id", self.session_id.to_string());
// Add the full request/response information
let span = span.with_input(serde_json::to_value(&request)?);
let span = span.with_output(serde_json::to_value(&complete_final_message)?);
// Log span non-blockingly (client handles the background processing)
if let Err(log_err) = client.log_span(span).await {
error!("Failed to log assistant response span: {}", log_err);
}
}
}
}
// For messages with tool calls, we won't log the output here
// Instead, we'll create tool spans with this assistant span as parent
// If this is an auto response without tool calls, it means we're done
if final_tool_calls.is_none() {
// Log the final output to the parent span
if let Some(parent_span) = &parent_span {
if let Some(client) = &*BRAINTRUST_CLIENT {
// Create a new span with the final message as output
let final_span = parent_span.clone().with_output(serde_json::to_value(&final_message)?);
// Log span non-blockingly (client handles the background processing)
if let Err(log_err) = client.log_span(final_span).await {
error!("Failed to log final output span: {}", log_err);
}
}
}
// Finish the trace without consuming it
self.finish_trace(&trace_builder).await?;
// // Send Done message and return - Done message is now sent by the caller task
// self.get_stream_sender()
// .await
// .send(Ok(AgentMessage::Done))?;
return Ok(());
}
// If the LLM wants to use tools, execute them and continue
// --- Tool Execution Logic ---
// If the LLM wants to use tools, execute them
if let Some(tool_calls) = final_tool_calls {
let mut results = Vec::new();
let agent_tools = self.tools.read().await; // Read tools once
@ -794,23 +779,28 @@ impl Agent {
// Execute each requested tool
let mut should_terminate = false; // Flag to indicate if loop should terminate after this tool
for tool_call in tool_calls {
// Find the registered tool entry
// Find the registered tool entry
if let Some(registered_tool) = agent_tools.get(&tool_call.function.name) {
// Create a tool span that combines the assistant request with the tool execution
let tool_span = if let (Some(trace), Some(parent)) = (&trace_builder, &parent_for_tool_spans) {
let tool_span = if let (Some(trace), Some(parent)) =
(&trace_builder, &parent_for_tool_spans)
{
if let Some(_client) = &*BRAINTRUST_CLIENT {
// Create a span for the assistant + tool execution
let span = trace.add_child_span(
&format!("Assistant: {}", tool_call.function.name),
"tool",
parent
).await?;
let span = trace
.add_child_span(
&format!("Assistant: {}", tool_call.function.name),
"tool",
parent,
)
.await?;
// Add chat_id (session_id) as metadata to the span
let span = span.with_metadata("chat_id", self.session_id.to_string());
// Parse the parameters (unused in this context since we're using final_message)
let _params: Value = serde_json::from_str(&tool_call.function.arguments)?;
let _params: Value =
serde_json::from_str(&tool_call.function.arguments)?;
// Use the assistant message as input to this span
// This connects the assistant's request to the tool execution
@ -829,13 +819,16 @@ impl Agent {
// Parse the parameters
let params: Value = match serde_json::from_str(&tool_call.function.arguments) {
Ok(p) => p,
Err(e) => {
let err_msg = format!("Failed to parse tool arguments for {}: {}", tool_call.function.name, e);
error!("{}", err_msg);
// Optionally log to Braintrust span here
return Err(anyhow::anyhow!(err_msg));
}
Ok(p) => p,
Err(e) => {
let err_msg = format!(
"Failed to parse tool arguments for {}: {}",
tool_call.function.name, e
);
error!("{}", err_msg);
// Optionally log to Braintrust span here
return Err(anyhow::anyhow!(err_msg));
}
};
let _tool_input = serde_json::json!({
@ -847,7 +840,11 @@ impl Agent {
});
// Execute the tool using the executor from RegisteredTool
let result = match registered_tool.executor.execute(params, tool_call.id.clone()).await {
let result = match registered_tool
.executor
.execute(params, tool_call.id.clone())
.await
{
Ok(r) => r,
Err(e) => {
// Log error in tool span
@ -862,11 +859,17 @@ impl Agent {
// Log span non-blockingly (client handles the background processing)
if let Err(log_err) = client.log_span(error_span).await {
error!("Failed to log tool execution error span: {}", log_err);
error!(
"Failed to log tool execution error span: {}",
log_err
);
}
}
}
let error_message = format!("Tool execution error for {}: {:?}", tool_call.function.name, e);
let error_message = format!(
"Tool execution error for {}: {:?}",
tool_call.function.name, e
);
error!("{}", error_message); // Log locally
return Err(anyhow::anyhow!(error_message));
}
@ -885,10 +888,18 @@ impl Agent {
if let Some(tool_span) = &tool_span {
if let Some(client) = &*BRAINTRUST_CLIENT {
// Only log completed messages
if matches!(tool_message, AgentMessage::Tool { progress: MessageProgress::Complete, .. }) {
if matches!(
tool_message,
AgentMessage::Tool {
progress: MessageProgress::Complete,
..
}
) {
// Now that we have the tool result, add it as output and log the span
// This creates a span showing assistant message -> tool execution -> tool result
let result_span = tool_span.clone().with_output(serde_json::to_value(&tool_message)?);
let result_span = tool_span
.clone()
.with_output(serde_json::to_value(&tool_message)?);
// Log span non-blockingly (client handles the background processing)
if let Err(log_err) = client.log_span(result_span).await {
@ -899,10 +910,16 @@ impl Agent {
}
// Broadcast the tool message as soon as we receive it - use try_send to avoid blocking
if let Err(e) = self.get_stream_sender().await.send(Ok(tool_message.clone())) {
tracing::debug!("Failed to send tool message (receiver likely dropped): {}", e);
}
if let Err(e) = self
.get_stream_sender()
.await
.send(Ok(tool_message.clone()))
{
tracing::debug!(
"Failed to send tool message (receiver likely dropped): {}",
e
);
}
// Update thread with tool response BEFORE checking termination
self.update_current_thread(tool_message.clone()).await?;
@ -911,76 +928,103 @@ impl Agent {
// Check if this tool's name is in the terminating list
if terminating_names.contains(&tool_call.function.name) {
should_terminate = true;
tracing::info!("Tool '{}' triggered agent termination.", tool_call.function.name);
tracing::info!(
"Tool '{}' triggered agent termination.",
tool_call.function.name
);
break; // Exit the tool execution loop
}
} else {
// Handle case where the LLM hallucinated a tool name
let err_msg = format!("Attempted to call non-existent tool: {}", tool_call.function.name);
error!("{}", err_msg);
// Create a fake tool result indicating the error
let error_result = AgentMessage::tool(
None,
serde_json::json!({"error": err_msg}).to_string(),
tool_call.id.clone(),
Some(tool_call.function.name.clone()),
MessageProgress::Complete,
);
// Broadcast the error message
if let Err(e) = self.get_stream_sender().await.send(Ok(error_result.clone())) {
tracing::debug!("Failed to send tool error message (receiver likely dropped): {}", e);
}
// Update thread and push the error result for the next LLM call
self.update_current_thread(error_result.clone()).await?;
// Handle case where the LLM hallucinated a tool name
let err_msg = format!(
"Attempted to call non-existent tool: {}",
tool_call.function.name
);
error!("{}", err_msg);
// Create a fake tool result indicating the error
let error_result = AgentMessage::tool(
None,
serde_json::json!({"error": err_msg}).to_string(),
tool_call.id.clone(),
Some(tool_call.function.name.clone()),
MessageProgress::Complete,
);
// Broadcast the error message
if let Err(e) = self
.get_stream_sender()
.await
.send(Ok(error_result.clone()))
{
tracing::debug!(
"Failed to send tool error message (receiver likely dropped): {}",
e
);
}
// Update thread and push the error result for the next LLM call
self.update_current_thread(error_result.clone()).await?;
// Continue processing other tool calls if any
}
}
// If a tool signaled termination, send Done and finish.
// If a tool signaled termination, finish trace, send Done and exit.
if should_terminate {
// Finish the trace without consuming it
// Finish the trace without consuming it
self.finish_trace(&trace_builder).await?;
// Send Done message
if let Err(e) = self.get_stream_sender().await.send(Ok(AgentMessage::Done)) {
tracing::debug!("Failed to send Done message after tool termination (receiver likely dropped): {}", e);
}
tracing::debug!("Failed to send Done message after tool termination (receiver likely dropped): {}", e);
}
return Ok(()); // Exit the function, preventing recursion
}
// Create a new thread with the tool results and continue recursively
let mut new_thread = thread.clone();
// Add the assistant message that contained the tool_calls to ensure correct history order
new_thread.messages.push(final_message); // Add the assistant message
// The assistant message that requested the tools is already added above
new_thread.messages.extend(results);
// For recursive calls, we'll continue with the same trace
// We don't finish the trace here to keep all interactions in one trace
Box::pin(self.process_thread_with_depth(&new_thread, recursion_depth + 1, trace_builder, parent_span)).await
// Add the tool results to the thread state for the recursive call
updated_thread_for_recursion.messages.extend(results);
} else {
// Log the final output to the parent span (This case should ideally not be reached if final_tool_calls was None earlier)
// Log the final assistant response span only if NO tools were called
if let (Some(trace), Some(parent)) = (&trace_builder, &parent_span) {
if let Some(client) = &*BRAINTRUST_CLIENT {
// Ensure we have the complete message content
let complete_final_message = final_message.clone();
// Create a fresh span for the text-only response
let span = trace
.add_child_span("Assistant Response", "llm", parent)
.await?;
let span = span.with_metadata("chat_id", self.session_id.to_string());
let span = span.with_input(serde_json::to_value(&request)?); // Log the request
let span = span.with_output(serde_json::to_value(&complete_final_message)?); // Log the response
// Log span non-blockingly
if let Err(log_err) = client.log_span(span).await {
error!("Failed to log assistant response span: {}", log_err);
}
}
}
// Also log the final output to the parent span if no tools were called
if let Some(parent_span) = &parent_span {
if let Some(client) = &*BRAINTRUST_CLIENT {
// Create a new span with the final message as output
let final_span = parent_span.clone().with_output(serde_json::to_value(&final_message)?);
// Log span non-blockingly (client handles the background processing)
let final_span = parent_span
.clone()
.with_output(serde_json::to_value(&final_message)?);
if let Err(log_err) = client.log_span(final_span).await {
error!("Failed to log final output span: {}", log_err);
}
}
}
// Finish the trace without consuming it
self.finish_trace(&trace_builder).await?;
// // Send Done message and return - Done message is now sent by the caller task
// self.get_stream_sender()
// .await
// .send(Ok(AgentMessage::Done))?;
Ok(())
// --- End Logging for Text-Only Response ---
}
// Continue the conversation recursively using the updated thread state,
// unless a terminating tool caused an early return above.
// This call happens regardless of whether tools were executed in this step.
Box::pin(self.process_thread_with_depth(
&updated_thread_for_recursion,
recursion_depth + 1,
trace_builder,
parent_span,
))
.await
}
/// Get a receiver for the shutdown signal
@ -996,11 +1040,12 @@ impl Agent {
}
/// Get a read lock on the tools map (Exposes RegisteredTool now)
pub async fn get_tools_map(&self) -> tokio::sync::RwLockReadGuard<'_, HashMap<String, RegisteredTool>> {
pub async fn get_tools_map(
&self,
) -> tokio::sync::RwLockReadGuard<'_, HashMap<String, RegisteredTool>> {
self.tools.read().await
}
/// Helper method to finish a trace without consuming the TraceBuilder
/// This method is fully non-blocking and never affects application performance
async fn finish_trace(&self, trace: &Option<TraceBuilder>) -> Result<()> {
@ -1017,12 +1062,14 @@ impl Agent {
// Create and log a completion span non-blockingly
if let Some(client) = &*BRAINTRUST_CLIENT {
// Create a new span for completion linked to the trace
let completion_span = client.create_span(
"Trace Completion",
"completion",
Some(root_span_id), // Link to the trace's root span
Some(root_span_id) // Set parent to also be the root span
).with_metadata("chat_id", self.session_id.to_string());
let completion_span = client
.create_span(
"Trace Completion",
"completion",
Some(root_span_id), // Link to the trace's root span
Some(root_span_id), // Set parent to also be the root span
)
.with_metadata("chat_id", self.session_id.to_string());
// Log span non-blockingly (client handles the background processing)
if let Err(e) = client.log_span(completion_span).await {
@ -1043,11 +1090,7 @@ impl Agent {
/// Add a rule for dynamically selecting a system prompt.
/// Rules are checked in the order they are added. The first matching rule's prompt is used.
pub async fn add_dynamic_prompt_rule<F>(
&self,
condition: F,
prompt: String,
)
pub async fn add_dynamic_prompt_rule<F>(&self, condition: F, prompt: String)
where
F: Fn(&HashMap<String, Value>) -> bool + Send + Sync + 'static,
{
@ -1183,13 +1226,8 @@ mod tests {
tool_id: String,
progress: MessageProgress,
) -> Result<()> {
let message = AgentMessage::tool(
None,
content,
tool_id,
Some(self.get_name()),
progress,
);
let message =
AgentMessage::tool(None, content, tool_id, Some(self.get_name()), progress);
self.agent.get_stream_sender().await.send(Ok(message))?;
Ok(())
}
@ -1200,7 +1238,11 @@ mod tests {
type Output = Value;
type Params = Value;
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
async fn execute(
&self,
params: Self::Params,
tool_call_id: String,
) -> Result<Self::Output> {
self.send_progress(
"Fetching weather data...".to_string(),
tool_call_id.clone(), // Use the actual tool_call_id
@ -1210,7 +1252,6 @@ mod tests {
let _params = params.as_object().unwrap();
// Simulate a delay
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
@ -1285,7 +1326,7 @@ mod tests {
Ok(response) => {
println!("Response (no tools): {:?}", response);
response
},
}
Err(e) => panic!("Error processing thread: {:?}", e),
};
}
@ -1311,7 +1352,9 @@ mod tests {
let condition = |_state: &HashMap<String, Value>| true; // Always enabled
// Add tool to agent
agent.add_tool(tool_name, weather_tool, Some(condition)).await;
agent
.add_tool(tool_name, weather_tool, Some(condition))
.await;
let thread = AgentThread::new(
None,
@ -1323,9 +1366,9 @@ mod tests {
let _response = match agent.process_thread(&thread).await {
Ok(response) => {
println!("Response (with tools): {:?}", response);
response
},
println!("Response (with tools): {:?}", response);
response
}
Err(e) => panic!("Error processing thread: {:?}", e),
};
}
@ -1350,7 +1393,9 @@ mod tests {
let tool_name = weather_tool.get_name();
let condition = |_state: &HashMap<String, Value>| true; // Always enabled
agent.add_tool(tool_name, weather_tool, Some(condition)).await;
agent
.add_tool(tool_name, weather_tool, Some(condition))
.await;
let thread = AgentThread::new(
None,
@ -1360,16 +1405,16 @@ mod tests {
)],
);
let _response = match agent.process_thread(&thread).await {
let _response = match agent.process_thread(&thread).await {
Ok(response) => {
println!("Response (multi-step): {:?}", response);
response
},
println!("Response (multi-step): {:?}", response);
response
}
Err(e) => panic!("Error processing thread: {:?}", e),
};
}
#[tokio::test]
#[tokio::test]
async fn test_agent_disabled_tool() {
setup();
@ -1389,51 +1434,66 @@ mod tests {
let tool_name = weather_tool.get_name();
// Condition: only enabled if "weather_enabled" state is true
let condition = |state: &HashMap<String, Value>| -> bool {
state.get("weather_enabled").and_then(|v| v.as_bool()).unwrap_or(false)
state
.get("weather_enabled")
.and_then(|v| v.as_bool())
.unwrap_or(false)
};
// Add tool with the condition
agent.add_tool(tool_name, weather_tool, Some(condition)).await;
agent
.add_tool(tool_name, weather_tool, Some(condition))
.await;
// --- Test case 1: Tool disabled ---
let thread_disabled = AgentThread::new(
None,
Uuid::new_v4(),
vec![AgentMessage::user("What is the weather in Provo?".to_string())],
vec![AgentMessage::user(
"What is the weather in Provo?".to_string(),
)],
);
// Ensure state doesn't enable the tool
agent.set_state_value("weather_enabled".to_string(), json!(false)).await;
agent
.set_state_value("weather_enabled".to_string(), json!(false))
.await;
let response_disabled = match agent.process_thread(&thread_disabled).await {
Ok(response) => response,
Err(e) => panic!("Error processing thread (disabled): {:?}", e),
};
// Expect response without tool call
if let AgentMessage::Assistant { tool_calls: Some(_), .. } = response_disabled {
panic!("Tool call occurred even when disabled");
}
println!("Response (disabled tool): {:?}", response_disabled);
Ok(response) => response,
Err(e) => panic!("Error processing thread (disabled): {:?}", e),
};
// Expect response without tool call
if let AgentMessage::Assistant {
tool_calls: Some(_),
..
} = response_disabled
{
panic!("Tool call occurred even when disabled");
}
println!("Response (disabled tool): {:?}", response_disabled);
// --- Test case 2: Tool enabled ---
let thread_enabled = AgentThread::new(
None,
Uuid::new_v4(),
vec![AgentMessage::user("What is the weather in Orem?".to_string())],
vec![AgentMessage::user(
"What is the weather in Orem?".to_string(),
)],
);
// Set state to enable the tool
agent.set_state_value("weather_enabled".to_string(), json!(true)).await;
agent
.set_state_value("weather_enabled".to_string(), json!(true))
.await;
let _response_enabled = match agent.process_thread(&thread_enabled).await {
Ok(response) => response,
Err(e) => panic!("Error processing thread (enabled): {:?}", e),
};
// Expect response *with* tool call (or final answer after tool call)
// We can't easily check the intermediate step here, but the test should run without panic
println!("Response (enabled tool): {:?}", _response_enabled);
Ok(response) => response,
Err(e) => panic!("Error processing thread (enabled): {:?}", e),
};
// Expect response *with* tool call (or final answer after tool call)
// We can't easily check the intermediate step here, but the test should run without panic
println!("Response (enabled tool): {:?}", _response_enabled);
}
#[tokio::test]
async fn test_agent_state_management() {
setup();
@ -1460,10 +1520,11 @@ mod tests {
assert_eq!(agent.get_state_bool("test_key").await, None); // Not a bool
// Test setting boolean value
agent.set_state_value("bool_key".to_string(), json!(true)).await;
agent
.set_state_value("bool_key".to_string(), json!(true))
.await;
assert_eq!(agent.get_state_bool("bool_key").await, Some(true));
// Test updating multiple values
agent
.update_state(|state| {
@ -1486,4 +1547,4 @@ mod tests {
assert!(!agent.state_key_exists("test_key").await);
assert_eq!(agent.get_state_bool("bool_key").await, None);
}
}
}

View File

@ -1,4 +1,6 @@
use anyhow::Result;
use database::helpers::datasets::get_dataset_names_for_organization;
use database::organization::get_user_organization_id;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
@ -187,11 +189,19 @@ impl BusterMultiAgent {
session_id: Uuid,
is_follow_up: bool, // Add flag to determine initial prompt
) -> Result<Self> {
let organization_id = match get_user_organization_id(&user_id).await {
Ok(Some(org_id)) => org_id,
Ok(None) => return Err(anyhow::anyhow!("User does not belong to any organization")),
Err(e) => return Err(e),
};
let dataset_names = get_dataset_names_for_organization(organization_id).await?;
// Select initial default prompt based on whether it's a follow-up
let initial_default_prompt = if is_follow_up {
FOLLOW_UP_INTIALIZATION_PROMPT.to_string()
FOLLOW_UP_INTIALIZATION_PROMPT.replace("{DATASETS}", &dataset_names.join(", "))
} else {
INTIALIZATION_PROMPT.to_string()
INTIALIZATION_PROMPT.replace("{DATASETS}", &dataset_names.join(", "))
};
// Create agent, passing the selected initialization prompt as default
@ -199,7 +209,7 @@ impl BusterMultiAgent {
"o3-mini".to_string(),
user_id,
session_id,
"buster_super_agent".to_string(),
"buster_multi_agent".to_string(),
None,
None,
initial_default_prompt, // Use selected default prompt
@ -212,8 +222,8 @@ impl BusterMultiAgent {
}
// Define prompt switching conditions
let needs_plan_condition = |state: &HashMap<String, Value>| -> bool {
state.contains_key("data_context") && !state.contains_key("plan_available")
let needs_plan_condition = move |state: &HashMap<String, Value>| -> bool {
state.contains_key("data_context") && !state.contains_key("plan_available") && !is_follow_up
};
let needs_analysis_condition = |state: &HashMap<String, Value>| -> bool {
// Example: Trigger analysis prompt once plan is available and metrics/dashboards are not yet available
@ -245,13 +255,11 @@ impl BusterMultiAgent {
));
// Re-apply prompt rules for the new agent instance
let needs_plan_condition = |state: &HashMap<String, Value>| -> bool {
let needs_plan_condition = move |state: &HashMap<String, Value>| -> bool {
state.contains_key("data_context") && !state.contains_key("plan_available")
};
let needs_analysis_condition = |state: &HashMap<String, Value>| -> bool {
state.contains_key("plan_available")
&& !state.contains_key("metrics_available")
&& !state.contains_key("dashboards_available")
state.contains_key("data_context") && state.contains_key("plan_available")
};
agent
.add_dynamic_prompt_rule(needs_plan_condition, CREATE_PLAN_PROMPT.to_string())
@ -724,6 +732,14 @@ Always use your best judgement when selecting visualization types, and be confid
---
### Available Datasets
Datasets include:
{DATASETS}
**Reminder**: Always use `search_data_catalog` to confirm specific data points or columns within these datasets do not assume availability.
---
## Workflow Examples
- **Fully Supported Workflow**

View File

@ -189,7 +189,19 @@ properties:
###
sql:
type: string
description: SQL query using YAML pipe syntax (|)
description: |
SQL query using YAML pipe syntax (|)
The SQL query should be formatted with proper indentation using the YAML pipe (|) syntax.
This ensures the multi-line SQL is properly parsed while preserving whitespace and newlines.
Example:
sql: |
SELECT
column1,
column2
FROM table
WHERE condition
# CHART CONFIGURATION
chartConfig:
@ -540,7 +552,9 @@ pub const DASHBOARD_YML_SCHEMA: &str = r##"
# items:
# - id: metric-uuid-2
# - id: metric-uuid-3
# columnSizes: [6, 6] # Required - must sum to exactly 12
# columnSizes:
# - 6
# - 6
#
# Rules:
# 1. Each row can have up to 4 items
@ -548,7 +562,9 @@ pub const DASHBOARD_YML_SCHEMA: &str = r##"
# 3. columnSizes is required and must specify the width for each item
# 4. Sum of columnSizes in a row must be exactly 12
# 5. Each column size must be at least 3
# 6. All arrays should follow the YML array syntax using `-` not `[` and `]`
# 6. All arrays should follow the YML array syntax using `-`
# 7. All arrays should NOT USE `[]` formatting.
# 8. don't use comments. the ones in the example are just for explanation
# ----------------------------------------
type: object

View File

@ -0,0 +1,22 @@
use anyhow::Result;
use diesel::prelude::*;
use diesel_async::RunQueryDsl;
use uuid::Uuid;
use crate::pool::get_pg_pool;
pub async fn get_dataset_names_for_organization(org_id: Uuid) -> Result<Vec<String>, anyhow::Error> {
use crate::schema::datasets::dsl::*;
let mut conn = get_pg_pool().get().await?;
let results = datasets
.filter(organization_id.eq(org_id))
.filter(deleted_at.is_null())
.filter(yml_file.is_not_null())
.select(name)
.load::<String>(&mut conn)
.await?;
Ok(results)
}

View File

@ -3,4 +3,5 @@ pub mod dashboard_files;
pub mod metric_files;
pub mod chats;
pub mod organization;
pub mod test_utils;
pub mod test_utils;
pub mod datasets;

View File

@ -1279,30 +1279,30 @@ pub async fn transform_message(
.map(|container| (container, ThreadEvent::GeneratingResponseMessage)),
);
// Add the "Finished reasoning" message if we're just starting
if initial {
let reasoning_message = BusterReasoningMessage::Text(BusterReasoningText {
id: Uuid::new_v4().to_string(),
reasoning_type: "text".to_string(),
title: "Finished reasoning".to_string(),
// Use total duration from start for this initial message
secondary_title: format!("{} seconds", start_time.elapsed().as_secs()),
message: None,
message_chunk: None,
status: Some("completed".to_string()),
});
// Reset the completion time after showing the initial reasoning message
*last_reasoning_completion_time = Instant::now();
// // Add the "Finished reasoning" message if we're just starting
// if initial {
// let reasoning_message = BusterReasoningMessage::Text(BusterReasoningText {
// id: Uuid::new_v4().to_string(),
// reasoning_type: "text".to_string(),
// title: "Finished reasoning".to_string(),
// // Use total duration from start for this initial message
// secondary_title: format!("{} seconds", start_time.elapsed().as_secs()),
// message: None,
// message_chunk: None,
// status: Some("completed".to_string()),
// });
// // Reset the completion time after showing the initial reasoning message
// *last_reasoning_completion_time = Instant::now();
let reasoning_container =
BusterContainer::ReasoningMessage(BusterReasoningMessageContainer {
reasoning: reasoning_message,
chat_id: *chat_id,
message_id: *message_id,
});
// let reasoning_container =
// BusterContainer::ReasoningMessage(BusterReasoningMessageContainer {
// reasoning: reasoning_message,
// chat_id: *chat_id,
// message_id: *message_id,
// });
containers.push((reasoning_container, ThreadEvent::GeneratingResponseMessage));
}
// containers.push((reasoning_container, ThreadEvent::GeneratingResponseMessage));
// }
Ok(containers)
} else if let Some(tool_calls) = tool_calls {