From 8ad1801d9d2e17b64060498560c8b91b97dad8d7 Mon Sep 17 00:00:00 2001 From: dal Date: Sat, 26 Apr 2025 10:47:58 -0600 Subject: [PATCH] agent retry syste a little more resilient --- api/Cargo.toml | 1 + api/libs/agents/Cargo.toml | 1 + api/libs/agents/src/agent.rs | 207 ++++++++++++++++++++++------------- 3 files changed, 133 insertions(+), 76 deletions(-) diff --git a/api/Cargo.toml b/api/Cargo.toml index a8b3e26d1..3593b0f12 100644 --- a/api/Cargo.toml +++ b/api/Cargo.toml @@ -106,6 +106,7 @@ rayon = "1.10.0" diesel_migrations = "2.0.0" html-escape = "0.2.13" tokio-cron-scheduler = "0.13.0" +tokio-retry = "0.3.0" [profile.release] debug = false diff --git a/api/libs/agents/Cargo.toml b/api/libs/agents/Cargo.toml index b7775b0f0..397166cf8 100644 --- a/api/libs/agents/Cargo.toml +++ b/api/libs/agents/Cargo.toml @@ -32,6 +32,7 @@ redis = { workspace = true } reqwest = { workspace = true } sqlx = { workspace = true } stored_values = { path = "../stored_values" } +tokio-retry = { workspace = true } # Development dependencies [dev-dependencies] diff --git a/api/libs/agents/src/agent.rs b/api/libs/agents/src/agent.rs index 9c3fdb98f..36bdea697 100644 --- a/api/libs/agents/src/agent.rs +++ b/api/libs/agents/src/agent.rs @@ -2,15 +2,16 @@ use crate::tools::{IntoToolCallExecutor, ToolExecutor}; use anyhow::Result; use braintrust::{BraintrustClient, TraceBuilder}; use litellm::{ - AgentMessage, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient, + AgentMessage, ChatCompletionChunk, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient, MessageProgress, Metadata, Tool, ToolCall, ToolChoice, }; use once_cell::sync::Lazy; use serde_json::Value; use std::time::{Duration, Instant}; use std::{collections::HashMap, env, sync::Arc}; -use tokio::sync::{broadcast, RwLock}; -use tracing::error; +use tokio::sync::{broadcast, mpsc, RwLock}; +use tokio_retry::{strategy::ExponentialBackoff, Retry, Condition}; +use tracing::{error, info, warn}; use uuid::Uuid; // Type definition for tool registry to simplify complex type @@ -661,7 +662,7 @@ impl Agent { // --- End Prepare LLM Messages --- // Collect all enabled tools and their schemas - let tools = agent.get_enabled_tools().await; + let tools = agent.get_enabled_tools().await; // Get user message for logging (unchanged) let _user_message = thread_ref @@ -673,7 +674,7 @@ impl Agent { // Create the tool-enabled request let request = ChatCompletionRequest { model: mode_config.model, // Use the model from mode config - messages: llm_messages, + messages: llm_messages, tools: if tools.is_empty() { None } else { Some(tools) }, tool_choice: Some(ToolChoice::Required), // Or adjust based on mode? stream: Some(true), // Enable streaming @@ -687,22 +688,34 @@ impl Agent { ..Default::default() }; + // --- Retry Logic for Initial Stream Request --- + let retry_strategy = ExponentialBackoff::from_millis(100).take(3); // Retry 3 times, ~100ms, ~200ms, ~400ms + + let stream_rx_result = Retry::spawn(retry_strategy, || { + let agent_clone = agent.clone(); // Clone Arc for use in the closure + let request_clone = request.clone(); // Clone request for use in the closure + async move { + agent_clone + .llm_client + .stream_chat_completion(request_clone) + .await + } + }) + .await; + // --- End Retry Logic --- + // Get the streaming response from the LLM - let mut stream_rx = match agent - .llm_client - .stream_chat_completion(request.clone()) - .await - { + let mut stream_rx: mpsc::Receiver> = match stream_rx_result { Ok(rx) => rx, Err(e) => { - // --- Added Error Handling --- - let error_message = format!("Error starting LLM stream: {:?}", e); + // --- Updated Error Handling for Retries --- + let error_message = format!("Error starting LLM stream after multiple retries: {:?}", e); tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message); // Log error in span if let Some(parent_span) = parent_span.clone() { if let Some(client) = &*BRAINTRUST_CLIENT { let error_span = parent_span.with_output(serde_json::json!({ - "error": format!("Error starting stream: {:?}", e) + "error": format!("Error starting stream after retries: {:?}", e) })); // Log span non-blockingly (client handles the background processing) @@ -711,87 +724,129 @@ impl Agent { } } } - // --- End Added Error Handling --- + // --- End Updated Error Handling --- return Err(anyhow::anyhow!(error_message)); // Return immediately } }; // We store the parent span to use for creating individual tool spans - // This avoids creating a general assistant span that would never be completed let parent_for_tool_spans = parent_span.clone(); // Process the streaming chunks let mut buffer = MessageBuffer::new(); let mut _is_complete = false; + const STREAM_TIMEOUT_SECS: u64 = 120; // Timeout after 120 seconds of inactivity - while let Some(chunk_result) = stream_rx.recv().await { - match chunk_result { - Ok(chunk) => { - if chunk.choices.is_empty() { - continue; - } + loop { // Replaced `while let` with `loop` and explicit timeout + match tokio::time::timeout( + Duration::from_secs(STREAM_TIMEOUT_SECS), + stream_rx.recv(), + ) + .await + { + Ok(Some(chunk_result)) => { // Received a message within timeout + match chunk_result { + Ok(chunk) => { + if chunk.choices.is_empty() { + continue; + } - buffer.message_id = Some(chunk.id.clone()); - let delta = &chunk.choices[0].delta; + buffer.message_id = Some(chunk.id.clone()); + let delta = &chunk.choices[0].delta; - // Accumulate content if present - if let Some(content) = &delta.content { - buffer.content.push_str(content); - } + // Accumulate content if present + if let Some(content) = &delta.content { + buffer.content.push_str(content); + } - // Process tool calls if present - if let Some(tool_calls) = &delta.tool_calls { - for tool_call in tool_calls { - let id = tool_call.id.clone().unwrap_or_else(|| { - buffer - .tool_calls - .keys() - .next() - .cloned() - .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()) - }); + // Process tool calls if present + if let Some(tool_calls) = &delta.tool_calls { + for tool_call in tool_calls { + let id = tool_call.id.clone().unwrap_or_else(|| { + buffer + .tool_calls + .keys() + .next() + .cloned() + .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()) + }); - // Get or create the pending tool call - let pending_call = buffer.tool_calls.entry(id.clone()).or_default(); + // Get or create the pending tool call + let pending_call = + buffer.tool_calls.entry(id.clone()).or_default(); - // Update the pending call with the delta - pending_call.update_from_delta(tool_call); - } - } + // Update the pending call with the delta + pending_call.update_from_delta(tool_call); + } + } - // Check if we should flush the buffer - if buffer.should_flush() { - buffer.flush(&agent).await?; - } + // Check if we should flush the buffer + if buffer.should_flush() { + buffer.flush(&agent).await?; + } - // Check if this is the final chunk - if chunk.choices[0].finish_reason.is_some() { - _is_complete = true; - } - } - Err(e) => { - // --- Added Error Handling --- - let error_message = format!("Error receiving chunk from LLM stream: {:?}", e); - tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message); - // Log error in parent span - if let Some(parent) = &parent_for_tool_spans { - if let Some(client) = &*BRAINTRUST_CLIENT { - // Create error info - let error_info = serde_json::json!({ - "error": format!("Error in stream: {:?}", e) - }); - - // Log error as output to parent span - let error_span = parent.clone().with_output(error_info); - - // Log span non-blockingly (client handles the background processing) - if let Err(log_err) = client.log_span(error_span).await { - error!("Failed to log stream error span: {}", log_err); + // Check if this is the final chunk + if chunk.choices[0].finish_reason.is_some() { + _is_complete = true; + // Don't break here yet, let the loop condition handle it } } + Err(e) => { + // --- Updated Error Handling for Retries --- + let error_message = format!( + "Error receiving chunk from LLM stream after multiple retries: {:?}", + e + ); + tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message); + // Log error in span + if let Some(parent_span) = parent_span.clone() { + if let Some(client) = &*BRAINTRUST_CLIENT { + let error_span = parent_span.with_output(serde_json::json!({ + "error": format!("Error in stream after retries: {:?}", e) + })); + + // Log span non-blockingly (client handles the background processing) + if let Err(log_err) = client.log_span(error_span).await { + error!("Failed to log stream error span: {}", log_err); + } + } + } + // --- End Updated Error Handling --- + // Send error over broadcast channel before returning + let agent_error = AgentError(error_message.clone()); + if let Ok(sender) = agent.get_stream_sender().await { + if let Err(send_err) = sender.send(Err(agent_error)) { + warn!("Failed to send stream error over channel: {}", send_err); + } + } else { + warn!("Stream sender not available for sending error."); + } + return Err(anyhow::anyhow!(error_message)); // Return immediately + } } - // --- End Added Error Handling --- - return Err(anyhow::anyhow!(error_message)); // Return immediately + } + Ok(None) => { // Stream closed gracefully + break; + } + Err(_) => { // Timeout occurred + let timeout_msg = format!( + "LLM stream timed out after {} seconds of inactivity.", + STREAM_TIMEOUT_SECS + ); + warn!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", timeout_msg); + + // Send timeout error over broadcast channel + let agent_error = AgentError(timeout_msg.clone()); + if let Ok(sender) = agent.get_stream_sender().await { + if let Err(send_err) = sender.send(Err(agent_error)) { + warn!("Failed to send timeout error over channel: {}", send_err); + } + } else { + warn!("Stream sender not available for sending timeout error."); + } + // We could return an error here, or just break and let the agent finish + // For now, let's break and proceed with whatever was buffered + break; } } } @@ -932,13 +987,13 @@ impl Agent { { Ok(r) => r, Err(e) => { - // --- Added Error Handling --- + // --- Updated Error Handling for Retries --- let error_message = format!( "Tool execution error for {}: {:?}", tool_call.function.name, e ); tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, tool_name = %tool_call.function.name, "{}", error_message); - // Log error in tool span + // Log error in span if let Some(tool_span) = &tool_span { if let Some(client) = &*BRAINTRUST_CLIENT { let error_info = serde_json::json!({ @@ -957,7 +1012,7 @@ impl Agent { } } } - // --- End Added Error Handling --- + // --- End Updated Error Handling --- let error_message = format!( "Tool execution error for {}: {:?}", tool_call.function.name, e