agent retry syste a little more resilient

This commit is contained in:
dal 2025-04-26 10:47:58 -06:00
parent 1f6e757b1b
commit 8ad1801d9d
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
3 changed files with 133 additions and 76 deletions

View File

@ -106,6 +106,7 @@ rayon = "1.10.0"
diesel_migrations = "2.0.0" diesel_migrations = "2.0.0"
html-escape = "0.2.13" html-escape = "0.2.13"
tokio-cron-scheduler = "0.13.0" tokio-cron-scheduler = "0.13.0"
tokio-retry = "0.3.0"
[profile.release] [profile.release]
debug = false debug = false

View File

@ -32,6 +32,7 @@ redis = { workspace = true }
reqwest = { workspace = true } reqwest = { workspace = true }
sqlx = { workspace = true } sqlx = { workspace = true }
stored_values = { path = "../stored_values" } stored_values = { path = "../stored_values" }
tokio-retry = { workspace = true }
# Development dependencies # Development dependencies
[dev-dependencies] [dev-dependencies]

View File

@ -2,15 +2,16 @@ use crate::tools::{IntoToolCallExecutor, ToolExecutor};
use anyhow::Result; use anyhow::Result;
use braintrust::{BraintrustClient, TraceBuilder}; use braintrust::{BraintrustClient, TraceBuilder};
use litellm::{ use litellm::{
AgentMessage, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient, AgentMessage, ChatCompletionChunk, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient,
MessageProgress, Metadata, Tool, ToolCall, ToolChoice, MessageProgress, Metadata, Tool, ToolCall, ToolChoice,
}; };
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use serde_json::Value; use serde_json::Value;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use std::{collections::HashMap, env, sync::Arc}; use std::{collections::HashMap, env, sync::Arc};
use tokio::sync::{broadcast, RwLock}; use tokio::sync::{broadcast, mpsc, RwLock};
use tracing::error; use tokio_retry::{strategy::ExponentialBackoff, Retry, Condition};
use tracing::{error, info, warn};
use uuid::Uuid; use uuid::Uuid;
// Type definition for tool registry to simplify complex type // Type definition for tool registry to simplify complex type
@ -661,7 +662,7 @@ impl Agent {
// --- End Prepare LLM Messages --- // --- End Prepare LLM Messages ---
// Collect all enabled tools and their schemas // Collect all enabled tools and their schemas
let tools = agent.get_enabled_tools().await; let tools = agent.get_enabled_tools().await;
// Get user message for logging (unchanged) // Get user message for logging (unchanged)
let _user_message = thread_ref let _user_message = thread_ref
@ -673,7 +674,7 @@ impl Agent {
// Create the tool-enabled request // Create the tool-enabled request
let request = ChatCompletionRequest { let request = ChatCompletionRequest {
model: mode_config.model, // Use the model from mode config model: mode_config.model, // Use the model from mode config
messages: llm_messages, messages: llm_messages,
tools: if tools.is_empty() { None } else { Some(tools) }, tools: if tools.is_empty() { None } else { Some(tools) },
tool_choice: Some(ToolChoice::Required), // Or adjust based on mode? tool_choice: Some(ToolChoice::Required), // Or adjust based on mode?
stream: Some(true), // Enable streaming stream: Some(true), // Enable streaming
@ -687,22 +688,34 @@ impl Agent {
..Default::default() ..Default::default()
}; };
// --- Retry Logic for Initial Stream Request ---
let retry_strategy = ExponentialBackoff::from_millis(100).take(3); // Retry 3 times, ~100ms, ~200ms, ~400ms
let stream_rx_result = Retry::spawn(retry_strategy, || {
let agent_clone = agent.clone(); // Clone Arc for use in the closure
let request_clone = request.clone(); // Clone request for use in the closure
async move {
agent_clone
.llm_client
.stream_chat_completion(request_clone)
.await
}
})
.await;
// --- End Retry Logic ---
// Get the streaming response from the LLM // Get the streaming response from the LLM
let mut stream_rx = match agent let mut stream_rx: mpsc::Receiver<Result<ChatCompletionChunk>> = match stream_rx_result {
.llm_client
.stream_chat_completion(request.clone())
.await
{
Ok(rx) => rx, Ok(rx) => rx,
Err(e) => { Err(e) => {
// --- Added Error Handling --- // --- Updated Error Handling for Retries ---
let error_message = format!("Error starting LLM stream: {:?}", e); let error_message = format!("Error starting LLM stream after multiple retries: {:?}", e);
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message); tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message);
// Log error in span // Log error in span
if let Some(parent_span) = parent_span.clone() { if let Some(parent_span) = parent_span.clone() {
if let Some(client) = &*BRAINTRUST_CLIENT { if let Some(client) = &*BRAINTRUST_CLIENT {
let error_span = parent_span.with_output(serde_json::json!({ let error_span = parent_span.with_output(serde_json::json!({
"error": format!("Error starting stream: {:?}", e) "error": format!("Error starting stream after retries: {:?}", e)
})); }));
// Log span non-blockingly (client handles the background processing) // Log span non-blockingly (client handles the background processing)
@ -711,87 +724,129 @@ impl Agent {
} }
} }
} }
// --- End Added Error Handling --- // --- End Updated Error Handling ---
return Err(anyhow::anyhow!(error_message)); // Return immediately return Err(anyhow::anyhow!(error_message)); // Return immediately
} }
}; };
// We store the parent span to use for creating individual tool spans // We store the parent span to use for creating individual tool spans
// This avoids creating a general assistant span that would never be completed
let parent_for_tool_spans = parent_span.clone(); let parent_for_tool_spans = parent_span.clone();
// Process the streaming chunks // Process the streaming chunks
let mut buffer = MessageBuffer::new(); let mut buffer = MessageBuffer::new();
let mut _is_complete = false; let mut _is_complete = false;
const STREAM_TIMEOUT_SECS: u64 = 120; // Timeout after 120 seconds of inactivity
while let Some(chunk_result) = stream_rx.recv().await { loop { // Replaced `while let` with `loop` and explicit timeout
match chunk_result { match tokio::time::timeout(
Ok(chunk) => { Duration::from_secs(STREAM_TIMEOUT_SECS),
if chunk.choices.is_empty() { stream_rx.recv(),
continue; )
} .await
{
Ok(Some(chunk_result)) => { // Received a message within timeout
match chunk_result {
Ok(chunk) => {
if chunk.choices.is_empty() {
continue;
}
buffer.message_id = Some(chunk.id.clone()); buffer.message_id = Some(chunk.id.clone());
let delta = &chunk.choices[0].delta; let delta = &chunk.choices[0].delta;
// Accumulate content if present // Accumulate content if present
if let Some(content) = &delta.content { if let Some(content) = &delta.content {
buffer.content.push_str(content); buffer.content.push_str(content);
} }
// Process tool calls if present // Process tool calls if present
if let Some(tool_calls) = &delta.tool_calls { if let Some(tool_calls) = &delta.tool_calls {
for tool_call in tool_calls { for tool_call in tool_calls {
let id = tool_call.id.clone().unwrap_or_else(|| { let id = tool_call.id.clone().unwrap_or_else(|| {
buffer buffer
.tool_calls .tool_calls
.keys() .keys()
.next() .next()
.cloned() .cloned()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()) .unwrap_or_else(|| uuid::Uuid::new_v4().to_string())
}); });
// Get or create the pending tool call // Get or create the pending tool call
let pending_call = buffer.tool_calls.entry(id.clone()).or_default(); let pending_call =
buffer.tool_calls.entry(id.clone()).or_default();
// Update the pending call with the delta // Update the pending call with the delta
pending_call.update_from_delta(tool_call); pending_call.update_from_delta(tool_call);
} }
} }
// Check if we should flush the buffer // Check if we should flush the buffer
if buffer.should_flush() { if buffer.should_flush() {
buffer.flush(&agent).await?; buffer.flush(&agent).await?;
} }
// Check if this is the final chunk // Check if this is the final chunk
if chunk.choices[0].finish_reason.is_some() { if chunk.choices[0].finish_reason.is_some() {
_is_complete = true; _is_complete = true;
} // Don't break here yet, let the loop condition handle it
}
Err(e) => {
// --- Added Error Handling ---
let error_message = format!("Error receiving chunk from LLM stream: {:?}", e);
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message);
// Log error in parent span
if let Some(parent) = &parent_for_tool_spans {
if let Some(client) = &*BRAINTRUST_CLIENT {
// Create error info
let error_info = serde_json::json!({
"error": format!("Error in stream: {:?}", e)
});
// Log error as output to parent span
let error_span = parent.clone().with_output(error_info);
// Log span non-blockingly (client handles the background processing)
if let Err(log_err) = client.log_span(error_span).await {
error!("Failed to log stream error span: {}", log_err);
} }
} }
Err(e) => {
// --- Updated Error Handling for Retries ---
let error_message = format!(
"Error receiving chunk from LLM stream after multiple retries: {:?}",
e
);
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message);
// Log error in span
if let Some(parent_span) = parent_span.clone() {
if let Some(client) = &*BRAINTRUST_CLIENT {
let error_span = parent_span.with_output(serde_json::json!({
"error": format!("Error in stream after retries: {:?}", e)
}));
// Log span non-blockingly (client handles the background processing)
if let Err(log_err) = client.log_span(error_span).await {
error!("Failed to log stream error span: {}", log_err);
}
}
}
// --- End Updated Error Handling ---
// Send error over broadcast channel before returning
let agent_error = AgentError(error_message.clone());
if let Ok(sender) = agent.get_stream_sender().await {
if let Err(send_err) = sender.send(Err(agent_error)) {
warn!("Failed to send stream error over channel: {}", send_err);
}
} else {
warn!("Stream sender not available for sending error.");
}
return Err(anyhow::anyhow!(error_message)); // Return immediately
}
} }
// --- End Added Error Handling --- }
return Err(anyhow::anyhow!(error_message)); // Return immediately Ok(None) => { // Stream closed gracefully
break;
}
Err(_) => { // Timeout occurred
let timeout_msg = format!(
"LLM stream timed out after {} seconds of inactivity.",
STREAM_TIMEOUT_SECS
);
warn!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", timeout_msg);
// Send timeout error over broadcast channel
let agent_error = AgentError(timeout_msg.clone());
if let Ok(sender) = agent.get_stream_sender().await {
if let Err(send_err) = sender.send(Err(agent_error)) {
warn!("Failed to send timeout error over channel: {}", send_err);
}
} else {
warn!("Stream sender not available for sending timeout error.");
}
// We could return an error here, or just break and let the agent finish
// For now, let's break and proceed with whatever was buffered
break;
} }
} }
} }
@ -932,13 +987,13 @@ impl Agent {
{ {
Ok(r) => r, Ok(r) => r,
Err(e) => { Err(e) => {
// --- Added Error Handling --- // --- Updated Error Handling for Retries ---
let error_message = format!( let error_message = format!(
"Tool execution error for {}: {:?}", "Tool execution error for {}: {:?}",
tool_call.function.name, e tool_call.function.name, e
); );
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, tool_name = %tool_call.function.name, "{}", error_message); tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, tool_name = %tool_call.function.name, "{}", error_message);
// Log error in tool span // Log error in span
if let Some(tool_span) = &tool_span { if let Some(tool_span) = &tool_span {
if let Some(client) = &*BRAINTRUST_CLIENT { if let Some(client) = &*BRAINTRUST_CLIENT {
let error_info = serde_json::json!({ let error_info = serde_json::json!({
@ -957,7 +1012,7 @@ impl Agent {
} }
} }
} }
// --- End Added Error Handling --- // --- End Updated Error Handling ---
let error_message = format!( let error_message = format!(
"Tool execution error for {}: {:?}", "Tool execution error for {}: {:?}",
tool_call.function.name, e tool_call.function.name, e