From 8ad1801d9d2e17b64060498560c8b91b97dad8d7 Mon Sep 17 00:00:00 2001 From: dal Date: Sat, 26 Apr 2025 10:47:58 -0600 Subject: [PATCH 1/3] agent retry syste a little more resilient --- api/Cargo.toml | 1 + api/libs/agents/Cargo.toml | 1 + api/libs/agents/src/agent.rs | 207 ++++++++++++++++++++++------------- 3 files changed, 133 insertions(+), 76 deletions(-) diff --git a/api/Cargo.toml b/api/Cargo.toml index a8b3e26d1..3593b0f12 100644 --- a/api/Cargo.toml +++ b/api/Cargo.toml @@ -106,6 +106,7 @@ rayon = "1.10.0" diesel_migrations = "2.0.0" html-escape = "0.2.13" tokio-cron-scheduler = "0.13.0" +tokio-retry = "0.3.0" [profile.release] debug = false diff --git a/api/libs/agents/Cargo.toml b/api/libs/agents/Cargo.toml index b7775b0f0..397166cf8 100644 --- a/api/libs/agents/Cargo.toml +++ b/api/libs/agents/Cargo.toml @@ -32,6 +32,7 @@ redis = { workspace = true } reqwest = { workspace = true } sqlx = { workspace = true } stored_values = { path = "../stored_values" } +tokio-retry = { workspace = true } # Development dependencies [dev-dependencies] diff --git a/api/libs/agents/src/agent.rs b/api/libs/agents/src/agent.rs index 9c3fdb98f..36bdea697 100644 --- a/api/libs/agents/src/agent.rs +++ b/api/libs/agents/src/agent.rs @@ -2,15 +2,16 @@ use crate::tools::{IntoToolCallExecutor, ToolExecutor}; use anyhow::Result; use braintrust::{BraintrustClient, TraceBuilder}; use litellm::{ - AgentMessage, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient, + AgentMessage, ChatCompletionChunk, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient, MessageProgress, Metadata, Tool, ToolCall, ToolChoice, }; use once_cell::sync::Lazy; use serde_json::Value; use std::time::{Duration, Instant}; use std::{collections::HashMap, env, sync::Arc}; -use tokio::sync::{broadcast, RwLock}; -use tracing::error; +use tokio::sync::{broadcast, mpsc, RwLock}; +use tokio_retry::{strategy::ExponentialBackoff, Retry, Condition}; +use tracing::{error, info, warn}; use uuid::Uuid; // Type definition for tool registry to simplify complex type @@ -661,7 +662,7 @@ impl Agent { // --- End Prepare LLM Messages --- // Collect all enabled tools and their schemas - let tools = agent.get_enabled_tools().await; + let tools = agent.get_enabled_tools().await; // Get user message for logging (unchanged) let _user_message = thread_ref @@ -673,7 +674,7 @@ impl Agent { // Create the tool-enabled request let request = ChatCompletionRequest { model: mode_config.model, // Use the model from mode config - messages: llm_messages, + messages: llm_messages, tools: if tools.is_empty() { None } else { Some(tools) }, tool_choice: Some(ToolChoice::Required), // Or adjust based on mode? stream: Some(true), // Enable streaming @@ -687,22 +688,34 @@ impl Agent { ..Default::default() }; + // --- Retry Logic for Initial Stream Request --- + let retry_strategy = ExponentialBackoff::from_millis(100).take(3); // Retry 3 times, ~100ms, ~200ms, ~400ms + + let stream_rx_result = Retry::spawn(retry_strategy, || { + let agent_clone = agent.clone(); // Clone Arc for use in the closure + let request_clone = request.clone(); // Clone request for use in the closure + async move { + agent_clone + .llm_client + .stream_chat_completion(request_clone) + .await + } + }) + .await; + // --- End Retry Logic --- + // Get the streaming response from the LLM - let mut stream_rx = match agent - .llm_client - .stream_chat_completion(request.clone()) - .await - { + let mut stream_rx: mpsc::Receiver> = match stream_rx_result { Ok(rx) => rx, Err(e) => { - // --- Added Error Handling --- - let error_message = format!("Error starting LLM stream: {:?}", e); + // --- Updated Error Handling for Retries --- + let error_message = format!("Error starting LLM stream after multiple retries: {:?}", e); tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message); // Log error in span if let Some(parent_span) = parent_span.clone() { if let Some(client) = &*BRAINTRUST_CLIENT { let error_span = parent_span.with_output(serde_json::json!({ - "error": format!("Error starting stream: {:?}", e) + "error": format!("Error starting stream after retries: {:?}", e) })); // Log span non-blockingly (client handles the background processing) @@ -711,87 +724,129 @@ impl Agent { } } } - // --- End Added Error Handling --- + // --- End Updated Error Handling --- return Err(anyhow::anyhow!(error_message)); // Return immediately } }; // We store the parent span to use for creating individual tool spans - // This avoids creating a general assistant span that would never be completed let parent_for_tool_spans = parent_span.clone(); // Process the streaming chunks let mut buffer = MessageBuffer::new(); let mut _is_complete = false; + const STREAM_TIMEOUT_SECS: u64 = 120; // Timeout after 120 seconds of inactivity - while let Some(chunk_result) = stream_rx.recv().await { - match chunk_result { - Ok(chunk) => { - if chunk.choices.is_empty() { - continue; - } + loop { // Replaced `while let` with `loop` and explicit timeout + match tokio::time::timeout( + Duration::from_secs(STREAM_TIMEOUT_SECS), + stream_rx.recv(), + ) + .await + { + Ok(Some(chunk_result)) => { // Received a message within timeout + match chunk_result { + Ok(chunk) => { + if chunk.choices.is_empty() { + continue; + } - buffer.message_id = Some(chunk.id.clone()); - let delta = &chunk.choices[0].delta; + buffer.message_id = Some(chunk.id.clone()); + let delta = &chunk.choices[0].delta; - // Accumulate content if present - if let Some(content) = &delta.content { - buffer.content.push_str(content); - } + // Accumulate content if present + if let Some(content) = &delta.content { + buffer.content.push_str(content); + } - // Process tool calls if present - if let Some(tool_calls) = &delta.tool_calls { - for tool_call in tool_calls { - let id = tool_call.id.clone().unwrap_or_else(|| { - buffer - .tool_calls - .keys() - .next() - .cloned() - .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()) - }); + // Process tool calls if present + if let Some(tool_calls) = &delta.tool_calls { + for tool_call in tool_calls { + let id = tool_call.id.clone().unwrap_or_else(|| { + buffer + .tool_calls + .keys() + .next() + .cloned() + .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()) + }); - // Get or create the pending tool call - let pending_call = buffer.tool_calls.entry(id.clone()).or_default(); + // Get or create the pending tool call + let pending_call = + buffer.tool_calls.entry(id.clone()).or_default(); - // Update the pending call with the delta - pending_call.update_from_delta(tool_call); - } - } + // Update the pending call with the delta + pending_call.update_from_delta(tool_call); + } + } - // Check if we should flush the buffer - if buffer.should_flush() { - buffer.flush(&agent).await?; - } + // Check if we should flush the buffer + if buffer.should_flush() { + buffer.flush(&agent).await?; + } - // Check if this is the final chunk - if chunk.choices[0].finish_reason.is_some() { - _is_complete = true; - } - } - Err(e) => { - // --- Added Error Handling --- - let error_message = format!("Error receiving chunk from LLM stream: {:?}", e); - tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message); - // Log error in parent span - if let Some(parent) = &parent_for_tool_spans { - if let Some(client) = &*BRAINTRUST_CLIENT { - // Create error info - let error_info = serde_json::json!({ - "error": format!("Error in stream: {:?}", e) - }); - - // Log error as output to parent span - let error_span = parent.clone().with_output(error_info); - - // Log span non-blockingly (client handles the background processing) - if let Err(log_err) = client.log_span(error_span).await { - error!("Failed to log stream error span: {}", log_err); + // Check if this is the final chunk + if chunk.choices[0].finish_reason.is_some() { + _is_complete = true; + // Don't break here yet, let the loop condition handle it } } + Err(e) => { + // --- Updated Error Handling for Retries --- + let error_message = format!( + "Error receiving chunk from LLM stream after multiple retries: {:?}", + e + ); + tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message); + // Log error in span + if let Some(parent_span) = parent_span.clone() { + if let Some(client) = &*BRAINTRUST_CLIENT { + let error_span = parent_span.with_output(serde_json::json!({ + "error": format!("Error in stream after retries: {:?}", e) + })); + + // Log span non-blockingly (client handles the background processing) + if let Err(log_err) = client.log_span(error_span).await { + error!("Failed to log stream error span: {}", log_err); + } + } + } + // --- End Updated Error Handling --- + // Send error over broadcast channel before returning + let agent_error = AgentError(error_message.clone()); + if let Ok(sender) = agent.get_stream_sender().await { + if let Err(send_err) = sender.send(Err(agent_error)) { + warn!("Failed to send stream error over channel: {}", send_err); + } + } else { + warn!("Stream sender not available for sending error."); + } + return Err(anyhow::anyhow!(error_message)); // Return immediately + } } - // --- End Added Error Handling --- - return Err(anyhow::anyhow!(error_message)); // Return immediately + } + Ok(None) => { // Stream closed gracefully + break; + } + Err(_) => { // Timeout occurred + let timeout_msg = format!( + "LLM stream timed out after {} seconds of inactivity.", + STREAM_TIMEOUT_SECS + ); + warn!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", timeout_msg); + + // Send timeout error over broadcast channel + let agent_error = AgentError(timeout_msg.clone()); + if let Ok(sender) = agent.get_stream_sender().await { + if let Err(send_err) = sender.send(Err(agent_error)) { + warn!("Failed to send timeout error over channel: {}", send_err); + } + } else { + warn!("Stream sender not available for sending timeout error."); + } + // We could return an error here, or just break and let the agent finish + // For now, let's break and proceed with whatever was buffered + break; } } } @@ -932,13 +987,13 @@ impl Agent { { Ok(r) => r, Err(e) => { - // --- Added Error Handling --- + // --- Updated Error Handling for Retries --- let error_message = format!( "Tool execution error for {}: {:?}", tool_call.function.name, e ); tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, tool_name = %tool_call.function.name, "{}", error_message); - // Log error in tool span + // Log error in span if let Some(tool_span) = &tool_span { if let Some(client) = &*BRAINTRUST_CLIENT { let error_info = serde_json::json!({ @@ -957,7 +1012,7 @@ impl Agent { } } } - // --- End Added Error Handling --- + // --- End Updated Error Handling --- let error_message = format!( "Tool execution error for {}: {:?}", tool_call.function.name, e From 279bc7c6c50766be63fb0dd0fd6c22ba52e428cd Mon Sep 17 00:00:00 2001 From: dal Date: Sat, 26 Apr 2025 10:59:37 -0600 Subject: [PATCH 2/3] improved error handling on the agent workflow for catching sender errors and others. --- api/libs/agents/Cargo.toml | 1 + api/libs/agents/src/agent.rs | 340 ++++++++++++++++++++++++----------- 2 files changed, 235 insertions(+), 106 deletions(-) diff --git a/api/libs/agents/Cargo.toml b/api/libs/agents/Cargo.toml index 397166cf8..37a8e7e76 100644 --- a/api/libs/agents/Cargo.toml +++ b/api/libs/agents/Cargo.toml @@ -33,6 +33,7 @@ reqwest = { workspace = true } sqlx = { workspace = true } stored_values = { path = "../stored_values" } tokio-retry = { workspace = true } +thiserror = { workspace = true } # Development dependencies [dev-dependencies] diff --git a/api/libs/agents/src/agent.rs b/api/libs/agents/src/agent.rs index 36bdea697..2c904fddb 100644 --- a/api/libs/agents/src/agent.rs +++ b/api/libs/agents/src/agent.rs @@ -2,16 +2,16 @@ use crate::tools::{IntoToolCallExecutor, ToolExecutor}; use anyhow::Result; use braintrust::{BraintrustClient, TraceBuilder}; use litellm::{ - AgentMessage, ChatCompletionChunk, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient, - MessageProgress, Metadata, Tool, ToolCall, ToolChoice, + AgentMessage, ChatCompletionChunk, ChatCompletionRequest, DeltaToolCall, FunctionCall, + LiteLLMClient, MessageProgress, Metadata, Tool, ToolCall, ToolChoice, }; use once_cell::sync::Lazy; use serde_json::Value; use std::time::{Duration, Instant}; use std::{collections::HashMap, env, sync::Arc}; use tokio::sync::{broadcast, mpsc, RwLock}; -use tokio_retry::{strategy::ExponentialBackoff, Retry, Condition}; -use tracing::{error, info, warn}; +use tokio_retry::{strategy::ExponentialBackoff, Retry}; +use tracing::{error, warn}; use uuid::Uuid; // Type definition for tool registry to simplify complex type @@ -38,6 +38,7 @@ static BRAINTRUST_CLIENT: Lazy>> = Lazy::new(|| { } }); +// --- Reverted AgentError Struct --- #[derive(Debug, Clone)] pub struct AgentError(pub String); @@ -48,6 +49,7 @@ impl std::fmt::Display for AgentError { write!(f, "{}", self.0) } } +// --- End Reverted AgentError Struct --- type MessageResult = Result; @@ -147,13 +149,16 @@ struct RegisteredTool { // Update the ToolRegistry type alias is no longer needed, but we need the new type for the map type ToolsMap = Arc>>; -// --- Define ModeProvider Trait --- +// --- Define ModeProvider Trait --- #[async_trait::async_trait] pub trait ModeProvider { // Fetches the complete configuration for a given agent state - async fn get_configuration_for_state(&self, state: &HashMap) -> Result; + async fn get_configuration_for_state( + &self, + state: &HashMap, + ) -> Result; } -// --- End ModeProvider Trait --- +// --- End ModeProvider Trait --- #[derive(Clone)] /// The Agent struct is responsible for managing conversations with the LLM @@ -184,7 +189,7 @@ pub struct Agent { /// This will be managed by the ModeProvider now. terminating_tool_names: Arc>>, /// Provider for mode-specific logic (prompt, model, tools, termination) - mode_provider: Arc, + mode_provider: Arc, } impl Agent { @@ -217,7 +222,7 @@ impl Agent { shutdown_tx: Arc::new(RwLock::new(shutdown_tx)), name, terminating_tool_names: Arc::new(RwLock::new(Vec::new())), // Initialize empty list - mode_provider, // Store the provider + mode_provider, // Store the provider } } @@ -244,7 +249,7 @@ impl Agent { shutdown_tx: Arc::clone(&existing_agent.shutdown_tx), // Shared shutdown name, terminating_tool_names: Arc::new(RwLock::new(Vec::new())), // Sub-agent starts with empty term tools? - mode_provider: Arc::clone(&mode_provider), // Share provider + mode_provider: Arc::clone(&mode_provider), // Share provider } } @@ -274,10 +279,14 @@ impl Agent { /// Get a new receiver for the broadcast channel. /// Returns an error if the stream channel has been closed or was not initialized. - pub async fn get_stream_receiver(&self) -> Result, AgentError> { + pub async fn get_stream_receiver( + &self, + ) -> Result, AgentError> { match self.stream_tx.read().await.as_ref() { Some(tx) => Ok(tx.subscribe()), - None => Err(AgentError("Stream channel is closed or not initialized.".to_string())) + None => Err(AgentError( + "Stream channel is closed or not initialized.".to_string(), + )), // Use string error } } @@ -286,7 +295,9 @@ impl Agent { pub async fn get_stream_sender(&self) -> Result, AgentError> { match self.stream_tx.read().await.as_ref() { Some(tx) => Ok(tx.clone()), - None => Err(AgentError("Stream channel is closed or not initialized.".to_string())) + None => Err(AgentError( + "Stream channel is closed or not initialized.".to_string(), + )), // Use string error } } @@ -478,12 +489,14 @@ impl Agent { tokio::select! { result = Agent::process_thread_with_depth(agent_arc_clone, thread_clone.clone(), &thread_clone, 0, None, None) => { if let Err(e) = result { + // Log the error let err_msg = format!("Error processing thread: {:?}", e); - error!("{}", err_msg); // Log the error + error!("{}", err_msg); // Use the clone created before select! // Handle the Result from get_stream_sender + let agent_error = AgentError(err_msg); // Use reverted struct if let Ok(sender) = agent_clone_for_post_process.get_stream_sender().await { - if let Err(send_err) = sender.send(Err(AgentError(err_msg.clone()))) { + if let Err(send_err) = sender.send(Err(agent_error)) { tracing::warn!("Failed to send error message to stream: {}", send_err); } } else { @@ -533,7 +546,11 @@ impl Agent { }); // Handle the Result from get_stream_receiver - agent_for_ok.get_stream_receiver().await.map_err(|e| e.into()) + // Add mapping back for the outer function signature + agent_for_ok + .get_stream_receiver() + .await + .map_err(anyhow::Error::from) } async fn process_thread_with_depth( @@ -604,9 +621,11 @@ impl Agent { // Limit recursion to a maximum of 15 times if recursion_depth >= 15 { + let max_depth_msg = format!("Maximum recursion depth ({}) reached.", recursion_depth); + warn!("{}", max_depth_msg); let message = AgentMessage::assistant( Some("max_recursion_depth_message".to_string()), - Some("I apologize, but I've reached the maximum number of actions (15). Please try breaking your request into smaller parts.".to_string()), + Some(max_depth_msg.clone()), // Send the message string None, MessageProgress::Complete, None, @@ -614,41 +633,55 @@ impl Agent { ); // Handle the Result from get_stream_sender if let Ok(sender) = agent.get_stream_sender().await { + // Send the Ok message first if let Err(e) = sender.send(Ok(message)) { - tracing::warn!( - "Channel send error when sending recursion limit message: {}", + warn!( + "Channel send error when sending max recursion depth message: {}", + e + ); + } + // Send the error itself over the channel + if let Err(e) = sender.send(Err(AgentError(max_depth_msg))) { + // Send string error + warn!( + "Channel send error when sending max recursion depth error: {}", e ); } } else { - tracing::warn!("Stream sender not available when sending recursion limit message."); + warn!("Stream sender not available when sending max recursion depth info."); } agent.close().await; // Ensure stream is closed - return Ok(()); // Don't return error, just stop processing + return Ok(()); // Stop processing gracefully, error sent via channel } - // --- Fetch and Apply Mode Configuration --- + // --- Fetch and Apply Mode Configuration --- let state = agent.get_state().await; - let mode_config = agent.mode_provider.get_configuration_for_state(&state).await?; + let mode_config = agent + .mode_provider + .get_configuration_for_state(&state) + .await?; // Apply Tool Loading via the closure provided by the mode agent.clear_tools().await; // Clear previous mode's tools (mode_config.tool_loader)(&agent).await?; // Explicitly cast self // Apply Terminating Tools for this mode - { // Scope for write lock + { + // Scope for write lock let mut term_tools_lock = agent.terminating_tool_names.write().await; term_tools_lock.clear(); term_tools_lock.extend(mode_config.terminating_tools); } // --- End Mode Configuration Application --- - // --- Prepare LLM Messages --- + // --- Prepare LLM Messages --- // Use prompt from mode_config let system_message = AgentMessage::developer(mode_config.prompt); let mut llm_messages = vec![system_message]; llm_messages.extend( - agent.current_thread // Use self.current_thread which is updated + agent + .current_thread // Use self.current_thread which is updated .read() .await .as_ref() @@ -656,7 +689,7 @@ impl Agent { .messages // Filter out previous Developer messages if desired, or keep history clean .iter() - .filter(|msg| !matches!(msg, AgentMessage::Developer { .. })) + .filter(|msg| !matches!(msg, AgentMessage::Developer { .. })) .cloned(), ); // --- End Prepare LLM Messages --- @@ -677,7 +710,7 @@ impl Agent { messages: llm_messages, tools: if tools.is_empty() { None } else { Some(tools) }, tool_choice: Some(ToolChoice::Required), // Or adjust based on mode? - stream: Some(true), // Enable streaming + stream: Some(true), // Enable streaming metadata: Some(Metadata { generation_name: "agent".to_string(), user_id: thread_ref.user_id.to_string(), @@ -691,41 +724,88 @@ impl Agent { // --- Retry Logic for Initial Stream Request --- let retry_strategy = ExponentialBackoff::from_millis(100).take(3); // Retry 3 times, ~100ms, ~200ms, ~400ms + // Define a condition for retrying: only on network-related errors + let retry_condition = |e: &anyhow::Error| -> bool { + if let Some(req_err) = e.downcast_ref::() { + // Retry on specific transient errors + req_err.is_timeout() || req_err.is_connect() || req_err.is_request() + } else { + false // Don't retry if it's not a reqwest network error + } + }; + + // The retry operation now wraps the actual result or a permanent error in an outer Ok + // Retriable errors are returned as the Err variant for Retry::spawn let stream_rx_result = Retry::spawn(retry_strategy, || { - let agent_clone = agent.clone(); // Clone Arc for use in the closure - let request_clone = request.clone(); // Clone request for use in the closure + // Clone necessary data for the closure + let agent_clone = agent.clone(); + let request_clone = request.clone(); + let retry_condition_clone = retry_condition; // Clone the condition closure async move { - agent_clone + match agent_clone .llm_client .stream_chat_completion(request_clone) .await + { + Ok(rx) => Ok(Ok(rx)), // Outer Ok, Inner Ok: Success + Err(e) => { + if retry_condition_clone(&e) { + // Check if error is retriable + Err(e) // Outer Err: Signal retry + } else { + // Outer Ok, Inner Err: Permanent failure, stop retrying + Ok(Err(e)) + } + } + } } }) .await; // --- End Retry Logic --- // Get the streaming response from the LLM + // Handle the nested result from the retry logic let mut stream_rx: mpsc::Receiver> = match stream_rx_result { - Ok(rx) => rx, - Err(e) => { - // --- Updated Error Handling for Retries --- - let error_message = format!("Error starting LLM stream after multiple retries: {:?}", e); + Ok(Ok(rx)) => rx, // Success case + Ok(Err(permanent_error)) => { + // Permanent error case (non-retriable) + let error_message = format!( + "Error starting LLM stream (non-retriable): {:?}", + permanent_error + ); tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message); - // Log error in span + // Log etc. as before... if let Some(parent_span) = parent_span.clone() { if let Some(client) = &*BRAINTRUST_CLIENT { let error_span = parent_span.with_output(serde_json::json!({ - "error": format!("Error starting stream after retries: {:?}", e) + "error": error_message })); - - // Log span non-blockingly (client handles the background processing) if let Err(log_err) = client.log_span(error_span).await { error!("Failed to log error span: {}", log_err); } } } - // --- End Updated Error Handling --- - return Err(anyhow::anyhow!(error_message)); // Return immediately + return Err(permanent_error); // Return the permanent error + } + Err(last_retriable_error) => { + // Error after retries exhausted + let error_message = format!( + "Error starting LLM stream after multiple retries: {:?}", + last_retriable_error + ); + tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message); + // Log etc. as before... + if let Some(parent_span) = parent_span.clone() { + if let Some(client) = &*BRAINTRUST_CLIENT { + let error_span = parent_span.with_output(serde_json::json!({ + "error": error_message + })); + if let Err(log_err) = client.log_span(error_span).await { + error!("Failed to log error span: {}", log_err); + } + } + } + return Err(last_retriable_error); // Return the last retriable error } }; @@ -737,14 +817,13 @@ impl Agent { let mut _is_complete = false; const STREAM_TIMEOUT_SECS: u64 = 120; // Timeout after 120 seconds of inactivity - loop { // Replaced `while let` with `loop` and explicit timeout - match tokio::time::timeout( - Duration::from_secs(STREAM_TIMEOUT_SECS), - stream_rx.recv(), - ) - .await + loop { + // Replaced `while let` with `loop` and explicit timeout + match tokio::time::timeout(Duration::from_secs(STREAM_TIMEOUT_SECS), stream_rx.recv()) + .await { - Ok(Some(chunk_result)) => { // Received a message within timeout + Ok(Some(chunk_result)) => { + // Received a message within timeout match chunk_result { Ok(chunk) => { if chunk.choices.is_empty() { @@ -792,17 +871,16 @@ impl Agent { } } Err(e) => { - // --- Updated Error Handling for Retries --- - let error_message = format!( - "Error receiving chunk from LLM stream after multiple retries: {:?}", - e - ); + // Format the error string + let error_message = + format!("Error receiving chunk from LLM stream: {:?}", e); + tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message); // Log error in span if let Some(parent_span) = parent_span.clone() { if let Some(client) = &*BRAINTRUST_CLIENT { let error_span = parent_span.with_output(serde_json::json!({ - "error": format!("Error in stream after retries: {:?}", e) + "error": error_message // Use formatted string })); // Log span non-blockingly (client handles the background processing) @@ -812,31 +890,36 @@ impl Agent { } } // --- End Updated Error Handling --- - // Send error over broadcast channel before returning - let agent_error = AgentError(error_message.clone()); + // Send string error over broadcast channel before returning + let agent_error = AgentError(error_message.clone()); // Create string error if let Ok(sender) = agent.get_stream_sender().await { + // clone() is now valid for AgentError(String) if let Err(send_err) = sender.send(Err(agent_error)) { warn!("Failed to send stream error over channel: {}", send_err); } } else { warn!("Stream sender not available for sending error."); } - return Err(anyhow::anyhow!(error_message)); // Return immediately + // Return anyhow::Error as before + return Err(anyhow::anyhow!(error_message)); } } } - Ok(None) => { // Stream closed gracefully + Ok(None) => { + // Stream closed gracefully break; } - Err(_) => { // Timeout occurred + Err(_) => { + // Timeout occurred + // Format the timeout message let timeout_msg = format!( "LLM stream timed out after {} seconds of inactivity.", STREAM_TIMEOUT_SECS ); warn!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", timeout_msg); - // Send timeout error over broadcast channel - let agent_error = AgentError(timeout_msg.clone()); + // Send string timeout error over broadcast channel + let agent_error = AgentError(timeout_msg.clone()); // Create string error if let Ok(sender) = agent.get_stream_sender().await { if let Err(send_err) = sender.send(Err(agent_error)) { warn!("Failed to send timeout error over channel: {}", send_err); @@ -884,7 +967,7 @@ impl Agent { // Ensure we don't block if the receiver dropped // Handle the Result from get_stream_sender if let Ok(sender) = agent.get_stream_sender().await { - if let Err(e) = sender.send(Ok(final_message.clone())) { + if let Err(e) = sender.send(Ok(final_message.clone())) { tracing::debug!( "Failed to send final assistant message (receiver likely dropped): {}", e @@ -906,8 +989,8 @@ impl Agent { .as_ref() .cloned() .ok_or_else(|| { - anyhow::anyhow!("Failed to get updated thread state after adding assistant message") - })?; + anyhow::anyhow!("Failed to get updated thread state after adding assistant message") + })?; // --- Tool Execution Logic --- // If the LLM wants to use tools, execute them @@ -966,7 +1049,7 @@ impl Agent { tool_call.function.name, e ); error!("{}", err_msg); - // Optionally log to Braintrust span here + // Return anyhow::Error as before return Err(anyhow::anyhow!(err_msg)); } }; @@ -979,59 +1062,90 @@ impl Agent { "id": tool_call.id }); - // Execute the tool using the executor from RegisteredTool - let result = match registered_tool - .executor - .execute(params, tool_call.id.clone()) - .await - { - Ok(r) => r, + // --- Tool Execution with Timeout --- + const TOOL_TIMEOUT_SECS: u64 = 60; // Timeout for tool execution + let tool_execution_result = tokio::time::timeout( + Duration::from_secs(TOOL_TIMEOUT_SECS), + registered_tool + .executor + .execute(params, tool_call.id.clone()), + ) + .await; + + // Process tool execution result (timeout or actual result/error) + let result: Result = match tool_execution_result { + Ok(Ok(r)) => Ok(r), // Tool executed successfully within timeout + Ok(Err(e)) => Err(e), // Tool returned an error within timeout + Err(_) => { + // Tool execution timed out + let timeout_msg = format!( + "Tool '{}' timed out after {} seconds.", + tool_call.function.name, TOOL_TIMEOUT_SECS + ); + warn!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, tool_name = %tool_call.function.name, "{}", timeout_msg); + // Return an error indicating timeout, wrapped in anyhow + Err(anyhow::anyhow!(format!( + "Tool '{}' timed out after {} seconds.", + tool_call.function.name, TOOL_TIMEOUT_SECS + ))) + } + }; + + // Handle the result (success, error, or timeout error) + let tool_message = match result { + Ok(r) => { + // Tool succeeded + let result_str = serde_json::to_string(&r)?; + AgentMessage::tool( + None, + result_str.clone(), + tool_call.id.clone(), + Some(tool_call.function.name.clone()), + MessageProgress::Complete, + ) + } Err(e) => { - // --- Updated Error Handling for Retries --- + // Tool failed (either execution error or timeout) + // Error `e` is already anyhow::Error here let error_message = format!( - "Tool execution error for {}: {:?}", + "Tool execution failed for {}: {:?}", tool_call.function.name, e ); - tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, tool_name = %tool_call.function.name, "{}", error_message); - // Log error in span + + // Log error differently for timeout vs execution error if needed + error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, tool_name = %tool_call.function.name, "{}", error_message); + + // Log error in Braintrust span if let Some(tool_span) = &tool_span { if let Some(client) = &*BRAINTRUST_CLIENT { let error_info = serde_json::json!({ - "error": format!("Tool execution error: {:?}", e) + "error": error_message // Generic failure message }); - - // Create a new span with the error output let error_span = tool_span.clone().with_output(error_info); - - // Log span non-blockingly (client handles the background processing) if let Err(log_err) = client.log_span(error_span).await { error!( - "Failed to log tool execution error span: {}", + "Failed to log tool execution failure span: {}", log_err ); } } } - // --- End Updated Error Handling --- - let error_message = format!( - "Tool execution error for {}: {:?}", - tool_call.function.name, e - ); - error!("{}", error_message); // Log locally - return Err(anyhow::anyhow!(error_message)); // Return immediately + // --- End Braintrust Logging --- + + // Create an error tool message to send back to the LLM + AgentMessage::tool( + None, + serde_json::json!({ "error": error_message }).to_string(), // Send descriptive error string + tool_call.id.clone(), + Some(tool_call.function.name.clone()), + MessageProgress::Complete, + ) + // Note: We are NOT returning the error here, instead we send + // the error back as a tool result message to the LLM. } }; - let result_str = serde_json::to_string(&result)?; - let tool_message = AgentMessage::tool( - None, - result_str.clone(), - tool_call.id.clone(), - Some(tool_call.function.name.clone()), - MessageProgress::Complete, - ); - - // Log the combined assistant+tool span with the tool result as output + // Log the combined assistant+tool span with the tool result/error as output if let Some(tool_span) = &tool_span { if let Some(client) = &*BRAINTRUST_CLIENT { // Only log completed messages @@ -1089,10 +1203,11 @@ impl Agent { tool_call.function.name ); error!("{}", err_msg); - // Create a fake tool result indicating the error + + // Create a fake tool result indicating the error (string based) let error_result = AgentMessage::tool( None, - serde_json::json!({"error": err_msg}).to_string(), + serde_json::json!({ "error": err_msg.clone() }).to_string(), // Use the string message tool_call.id.clone(), Some(tool_call.function.name.clone()), MessageProgress::Complete, @@ -1100,14 +1215,24 @@ impl Agent { // Broadcast the error message // Handle the Result from get_stream_sender if let Ok(sender) = agent.get_stream_sender().await { - if let Err(e) = sender.send(Ok(error_result.clone())) { + if let Err(e) = sender.send(Ok(error_result.clone())) { tracing::debug!( "Failed to send tool error message (receiver likely dropped): {}", e ); } + // Also send the specific error type over the channel + if let Err(e) = sender.send(Err(AgentError(err_msg))) { + // Send string error + tracing::warn!( + "Failed to send tool not found error over channel: {}", + e + ); + } } else { - tracing::debug!("Stream sender not available when sending tool error message."); + tracing::debug!( + "Stream sender not available when sending tool error message." + ); } // Update thread and push the error result for the next LLM call agent.update_current_thread(error_result.clone()).await?; @@ -1344,7 +1469,10 @@ mod tests { #[async_trait::async_trait] impl ModeProvider for MockModeProvider { - async fn get_configuration_for_state(&self, _state: &HashMap) -> Result { + async fn get_configuration_for_state( + &self, + _state: &HashMap, + ) -> Result { // Return a default/empty configuration for testing basic agent functions Ok(ModeConfiguration { prompt: "Test Prompt".to_string(), @@ -1455,7 +1583,7 @@ mod tests { "test_agent_no_tools".to_string(), env::var("LLM_API_KEY").ok(), env::var("LLM_BASE_URL").ok(), - mock_provider, + mock_provider, )); let thread = AgentThread::new( @@ -1574,7 +1702,7 @@ mod tests { "test_agent_disabled".to_string(), env::var("LLM_API_KEY").ok(), env::var("LLM_BASE_URL").ok(), - mock_provider, + mock_provider, )); // Create weather tool From e0f57274c2a9b6404597e80a03b2f222354bad93 Mon Sep 17 00:00:00 2001 From: dal Date: Mon, 28 Apr 2025 08:07:24 -0600 Subject: [PATCH 3/3] evals on gemini 2.5 pro --- api/libs/agents/src/agent.rs | 2 +- api/libs/agents/src/agents/buster_multi_agent.rs | 2 +- api/libs/agents/src/agents/modes/analysis.rs | 2 +- api/libs/agents/src/agents/modes/data_catalog_search.rs | 2 +- api/libs/agents/src/agents/modes/follow_up_initialization.rs | 2 +- api/libs/agents/src/agents/modes/initialization.rs | 4 ++-- api/libs/agents/src/agents/modes/mod.rs | 2 +- api/libs/agents/src/agents/modes/planning.rs | 2 +- .../src/tools/categories/file_tools/filter_dashboards.rs | 2 +- 9 files changed, 10 insertions(+), 10 deletions(-) diff --git a/api/libs/agents/src/agent.rs b/api/libs/agents/src/agent.rs index 2c904fddb..053834de0 100644 --- a/api/libs/agents/src/agent.rs +++ b/api/libs/agents/src/agent.rs @@ -1448,7 +1448,6 @@ pub trait AgentExt { (*self.get_agent_arc()).get_current_thread().await } } - #[cfg(test)] mod tests { use super::*; @@ -1826,3 +1825,4 @@ mod tests { assert_eq!(agent.get_state_bool("bool_key").await, None); } } + diff --git a/api/libs/agents/src/agents/buster_multi_agent.rs b/api/libs/agents/src/agents/buster_multi_agent.rs index f4f04d974..ffa2eae43 100644 --- a/api/libs/agents/src/agents/buster_multi_agent.rs +++ b/api/libs/agents/src/agents/buster_multi_agent.rs @@ -150,7 +150,7 @@ impl BusterMultiAgent { // Create agent, passing the provider let agent = Arc::new(Agent::new( - "o4-mini".to_string(), // Initial model (can be overridden by first mode) + "gemini-2.5-pro-exp-03-25".to_string(), // Initial model (can be overridden by first mode) user_id, session_id, "buster_multi_agent".to_string(), diff --git a/api/libs/agents/src/agents/modes/analysis.rs b/api/libs/agents/src/agents/modes/analysis.rs index 7d58f9a15..0f2b9e8fb 100644 --- a/api/libs/agents/src/agents/modes/analysis.rs +++ b/api/libs/agents/src/agents/modes/analysis.rs @@ -30,7 +30,7 @@ pub fn get_configuration(agent_data: &ModeAgentData) -> ModeConfiguration { // Note: This prompt doesn't use {DATASETS} // 2. Define the model for this mode (Using default based on original MODEL = None) - let model = "o4-mini".to_string(); + let model = "gemini-2.5-pro-exp-03-25".to_string(); // 3. Define the tool loader closure let tool_loader: Box< diff --git a/api/libs/agents/src/agents/modes/data_catalog_search.rs b/api/libs/agents/src/agents/modes/data_catalog_search.rs index 5c881553a..432ba3169 100644 --- a/api/libs/agents/src/agents/modes/data_catalog_search.rs +++ b/api/libs/agents/src/agents/modes/data_catalog_search.rs @@ -32,7 +32,7 @@ pub fn get_configuration(agent_data: &ModeAgentData) -> ModeConfiguration { // Note: This prompt doesn't use {TODAYS_DATE} // 2. Define the model for this mode - let model = "o4-mini".to_string(); // Use o4-mini as requested + let model = "gemini-2.5-pro-exp-03-25".to_string(); // Use gemini-2.5-pro-exp-03-25 as requested // 3. Define the tool loader closure let tool_loader: Box) -> Pin> + Send>> + Send + Sync> = diff --git a/api/libs/agents/src/agents/modes/follow_up_initialization.rs b/api/libs/agents/src/agents/modes/follow_up_initialization.rs index bf09711ff..39af16fc7 100644 --- a/api/libs/agents/src/agents/modes/follow_up_initialization.rs +++ b/api/libs/agents/src/agents/modes/follow_up_initialization.rs @@ -42,7 +42,7 @@ pub fn get_configuration(agent_data: &ModeAgentData) -> ModeConfiguration { .replace("{TODAYS_DATE}", &agent_data.todays_date); // 2. Define the model for this mode (Using a default, adjust if needed) - let model = "o4-mini".to_string(); // Assuming default based on original MODEL = None + let model = "gemini-2.5-pro-exp-03-25".to_string(); // Assuming default based on original MODEL = None // 3. Define the tool loader closure let tool_loader: Box) -> Pin> + Send>> + Send + Sync> = diff --git a/api/libs/agents/src/agents/modes/initialization.rs b/api/libs/agents/src/agents/modes/initialization.rs index 8ae993499..0c4ead2be 100644 --- a/api/libs/agents/src/agents/modes/initialization.rs +++ b/api/libs/agents/src/agents/modes/initialization.rs @@ -26,8 +26,8 @@ pub fn get_configuration(agent_data: &ModeAgentData) -> ModeConfiguration { // 2. Define the model for this mode (Using a default, adjust if needed) // Since the original MODEL was None, we might use the agent's default - // or specify a standard one like "o4-mini". Let's use "o4-mini". - let model = "o4-mini".to_string(); + // or specify a standard one like "gemini-2.5-pro-exp-03-25". Let's use "gemini-2.5-pro-exp-03-25". + let model = "gemini-2.5-pro-exp-03-25".to_string(); // 3. Define the tool loader closure let tool_loader: Box< diff --git a/api/libs/agents/src/agents/modes/mod.rs b/api/libs/agents/src/agents/modes/mod.rs index a4527b7cc..476915d5b 100644 --- a/api/libs/agents/src/agents/modes/mod.rs +++ b/api/libs/agents/src/agents/modes/mod.rs @@ -31,7 +31,7 @@ pub struct ModeAgentData { pub struct ModeConfiguration { /// The system prompt to use for the LLM call in this mode. pub prompt: String, - /// The specific LLM model identifier (e.g., "o4-mini") to use for this mode. + /// The specific LLM model identifier (e.g., "gemini-2.5-pro-exp-03-25") to use for this mode. pub model: String, /// An async function/closure responsible for clearing existing tools /// and loading the specific tools required for this mode onto the agent. diff --git a/api/libs/agents/src/agents/modes/planning.rs b/api/libs/agents/src/agents/modes/planning.rs index 0472b621f..52935e08b 100644 --- a/api/libs/agents/src/agents/modes/planning.rs +++ b/api/libs/agents/src/agents/modes/planning.rs @@ -28,7 +28,7 @@ pub fn get_configuration(agent_data: &ModeAgentData) -> ModeConfiguration { .replace("{DATASETS}", &agent_data.dataset_with_descriptions.join("\n\n")); // 2. Define the model for this mode (Using default based on original MODEL = None) - let model = "o4-mini".to_string(); + let model = "gemini-2.5-pro-exp-03-25".to_string(); // 3. Define the tool loader closure let tool_loader: Box< diff --git a/api/libs/agents/src/tools/categories/file_tools/filter_dashboards.rs b/api/libs/agents/src/tools/categories/file_tools/filter_dashboards.rs index 12530854a..8fd00805c 100644 --- a/api/libs/agents/src/tools/categories/file_tools/filter_dashboards.rs +++ b/api/libs/agents/src/tools/categories/file_tools/filter_dashboards.rs @@ -673,7 +673,7 @@ mod tests { fn test_tool_parameter_validation() { let tool = FilterDashboardsTool { agent: Arc::new(Agent::new( - "o4-mini".to_string(), + "gemini-2.5-pro-exp-03-25".to_string(), HashMap::new(), Uuid::new_v4(), Uuid::new_v4(),