agent retry syste a little more resilient

This commit is contained in:
dal 2025-04-26 10:47:58 -06:00
parent 1f6e757b1b
commit 8ad1801d9d
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
3 changed files with 133 additions and 76 deletions

View File

@ -106,6 +106,7 @@ rayon = "1.10.0"
diesel_migrations = "2.0.0" diesel_migrations = "2.0.0"
html-escape = "0.2.13" html-escape = "0.2.13"
tokio-cron-scheduler = "0.13.0" tokio-cron-scheduler = "0.13.0"
tokio-retry = "0.3.0"
[profile.release] [profile.release]
debug = false debug = false

View File

@ -32,6 +32,7 @@ redis = { workspace = true }
reqwest = { workspace = true } reqwest = { workspace = true }
sqlx = { workspace = true } sqlx = { workspace = true }
stored_values = { path = "../stored_values" } stored_values = { path = "../stored_values" }
tokio-retry = { workspace = true }
# Development dependencies # Development dependencies
[dev-dependencies] [dev-dependencies]

View File

@ -2,15 +2,16 @@ use crate::tools::{IntoToolCallExecutor, ToolExecutor};
use anyhow::Result; use anyhow::Result;
use braintrust::{BraintrustClient, TraceBuilder}; use braintrust::{BraintrustClient, TraceBuilder};
use litellm::{ use litellm::{
AgentMessage, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient, AgentMessage, ChatCompletionChunk, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient,
MessageProgress, Metadata, Tool, ToolCall, ToolChoice, MessageProgress, Metadata, Tool, ToolCall, ToolChoice,
}; };
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use serde_json::Value; use serde_json::Value;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use std::{collections::HashMap, env, sync::Arc}; use std::{collections::HashMap, env, sync::Arc};
use tokio::sync::{broadcast, RwLock}; use tokio::sync::{broadcast, mpsc, RwLock};
use tracing::error; use tokio_retry::{strategy::ExponentialBackoff, Retry, Condition};
use tracing::{error, info, warn};
use uuid::Uuid; use uuid::Uuid;
// Type definition for tool registry to simplify complex type // Type definition for tool registry to simplify complex type
@ -687,22 +688,34 @@ impl Agent {
..Default::default() ..Default::default()
}; };
// Get the streaming response from the LLM // --- Retry Logic for Initial Stream Request ---
let mut stream_rx = match agent let retry_strategy = ExponentialBackoff::from_millis(100).take(3); // Retry 3 times, ~100ms, ~200ms, ~400ms
let stream_rx_result = Retry::spawn(retry_strategy, || {
let agent_clone = agent.clone(); // Clone Arc for use in the closure
let request_clone = request.clone(); // Clone request for use in the closure
async move {
agent_clone
.llm_client .llm_client
.stream_chat_completion(request.clone()) .stream_chat_completion(request_clone)
.await .await
{ }
})
.await;
// --- End Retry Logic ---
// Get the streaming response from the LLM
let mut stream_rx: mpsc::Receiver<Result<ChatCompletionChunk>> = match stream_rx_result {
Ok(rx) => rx, Ok(rx) => rx,
Err(e) => { Err(e) => {
// --- Added Error Handling --- // --- Updated Error Handling for Retries ---
let error_message = format!("Error starting LLM stream: {:?}", e); let error_message = format!("Error starting LLM stream after multiple retries: {:?}", e);
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message); tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message);
// Log error in span // Log error in span
if let Some(parent_span) = parent_span.clone() { if let Some(parent_span) = parent_span.clone() {
if let Some(client) = &*BRAINTRUST_CLIENT { if let Some(client) = &*BRAINTRUST_CLIENT {
let error_span = parent_span.with_output(serde_json::json!({ let error_span = parent_span.with_output(serde_json::json!({
"error": format!("Error starting stream: {:?}", e) "error": format!("Error starting stream after retries: {:?}", e)
})); }));
// Log span non-blockingly (client handles the background processing) // Log span non-blockingly (client handles the background processing)
@ -711,20 +724,27 @@ impl Agent {
} }
} }
} }
// --- End Added Error Handling --- // --- End Updated Error Handling ---
return Err(anyhow::anyhow!(error_message)); // Return immediately return Err(anyhow::anyhow!(error_message)); // Return immediately
} }
}; };
// We store the parent span to use for creating individual tool spans // We store the parent span to use for creating individual tool spans
// This avoids creating a general assistant span that would never be completed
let parent_for_tool_spans = parent_span.clone(); let parent_for_tool_spans = parent_span.clone();
// Process the streaming chunks // Process the streaming chunks
let mut buffer = MessageBuffer::new(); let mut buffer = MessageBuffer::new();
let mut _is_complete = false; let mut _is_complete = false;
const STREAM_TIMEOUT_SECS: u64 = 120; // Timeout after 120 seconds of inactivity
while let Some(chunk_result) = stream_rx.recv().await { loop { // Replaced `while let` with `loop` and explicit timeout
match tokio::time::timeout(
Duration::from_secs(STREAM_TIMEOUT_SECS),
stream_rx.recv(),
)
.await
{
Ok(Some(chunk_result)) => { // Received a message within timeout
match chunk_result { match chunk_result {
Ok(chunk) => { Ok(chunk) => {
if chunk.choices.is_empty() { if chunk.choices.is_empty() {
@ -752,7 +772,8 @@ impl Agent {
}); });
// Get or create the pending tool call // Get or create the pending tool call
let pending_call = buffer.tool_calls.entry(id.clone()).or_default(); let pending_call =
buffer.tool_calls.entry(id.clone()).or_default();
// Update the pending call with the delta // Update the pending call with the delta
pending_call.update_from_delta(tool_call); pending_call.update_from_delta(tool_call);
@ -767,22 +788,22 @@ impl Agent {
// Check if this is the final chunk // Check if this is the final chunk
if chunk.choices[0].finish_reason.is_some() { if chunk.choices[0].finish_reason.is_some() {
_is_complete = true; _is_complete = true;
// Don't break here yet, let the loop condition handle it
} }
} }
Err(e) => { Err(e) => {
// --- Added Error Handling --- // --- Updated Error Handling for Retries ---
let error_message = format!("Error receiving chunk from LLM stream: {:?}", e); let error_message = format!(
"Error receiving chunk from LLM stream after multiple retries: {:?}",
e
);
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message); tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message);
// Log error in parent span // Log error in span
if let Some(parent) = &parent_for_tool_spans { if let Some(parent_span) = parent_span.clone() {
if let Some(client) = &*BRAINTRUST_CLIENT { if let Some(client) = &*BRAINTRUST_CLIENT {
// Create error info let error_span = parent_span.with_output(serde_json::json!({
let error_info = serde_json::json!({ "error": format!("Error in stream after retries: {:?}", e)
"error": format!("Error in stream: {:?}", e) }));
});
// Log error as output to parent span
let error_span = parent.clone().with_output(error_info);
// Log span non-blockingly (client handles the background processing) // Log span non-blockingly (client handles the background processing)
if let Err(log_err) = client.log_span(error_span).await { if let Err(log_err) = client.log_span(error_span).await {
@ -790,11 +811,45 @@ impl Agent {
} }
} }
} }
// --- End Added Error Handling --- // --- End Updated Error Handling ---
// Send error over broadcast channel before returning
let agent_error = AgentError(error_message.clone());
if let Ok(sender) = agent.get_stream_sender().await {
if let Err(send_err) = sender.send(Err(agent_error)) {
warn!("Failed to send stream error over channel: {}", send_err);
}
} else {
warn!("Stream sender not available for sending error.");
}
return Err(anyhow::anyhow!(error_message)); // Return immediately return Err(anyhow::anyhow!(error_message)); // Return immediately
} }
} }
} }
Ok(None) => { // Stream closed gracefully
break;
}
Err(_) => { // Timeout occurred
let timeout_msg = format!(
"LLM stream timed out after {} seconds of inactivity.",
STREAM_TIMEOUT_SECS
);
warn!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", timeout_msg);
// Send timeout error over broadcast channel
let agent_error = AgentError(timeout_msg.clone());
if let Ok(sender) = agent.get_stream_sender().await {
if let Err(send_err) = sender.send(Err(agent_error)) {
warn!("Failed to send timeout error over channel: {}", send_err);
}
} else {
warn!("Stream sender not available for sending timeout error.");
}
// We could return an error here, or just break and let the agent finish
// For now, let's break and proceed with whatever was buffered
break;
}
}
}
// Flush any remaining buffered content or tool calls before creating final message // Flush any remaining buffered content or tool calls before creating final message
buffer.flush(&agent).await?; buffer.flush(&agent).await?;
@ -932,13 +987,13 @@ impl Agent {
{ {
Ok(r) => r, Ok(r) => r,
Err(e) => { Err(e) => {
// --- Added Error Handling --- // --- Updated Error Handling for Retries ---
let error_message = format!( let error_message = format!(
"Tool execution error for {}: {:?}", "Tool execution error for {}: {:?}",
tool_call.function.name, e tool_call.function.name, e
); );
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, tool_name = %tool_call.function.name, "{}", error_message); tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, tool_name = %tool_call.function.name, "{}", error_message);
// Log error in tool span // Log error in span
if let Some(tool_span) = &tool_span { if let Some(tool_span) = &tool_span {
if let Some(client) = &*BRAINTRUST_CLIENT { if let Some(client) = &*BRAINTRUST_CLIENT {
let error_info = serde_json::json!({ let error_info = serde_json::json!({
@ -957,7 +1012,7 @@ impl Agent {
} }
} }
} }
// --- End Added Error Handling --- // --- End Updated Error Handling ---
let error_message = format!( let error_message = format!(
"Tool execution error for {}: {:?}", "Tool execution error for {}: {:?}",
tool_call.function.name, e tool_call.function.name, e