agent retry syste a little more resilient

This commit is contained in:
dal 2025-04-26 10:47:58 -06:00
parent 1f6e757b1b
commit 8ad1801d9d
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
3 changed files with 133 additions and 76 deletions

View File

@ -106,6 +106,7 @@ rayon = "1.10.0"
diesel_migrations = "2.0.0"
html-escape = "0.2.13"
tokio-cron-scheduler = "0.13.0"
tokio-retry = "0.3.0"
[profile.release]
debug = false

View File

@ -32,6 +32,7 @@ redis = { workspace = true }
reqwest = { workspace = true }
sqlx = { workspace = true }
stored_values = { path = "../stored_values" }
tokio-retry = { workspace = true }
# Development dependencies
[dev-dependencies]

View File

@ -2,15 +2,16 @@ use crate::tools::{IntoToolCallExecutor, ToolExecutor};
use anyhow::Result;
use braintrust::{BraintrustClient, TraceBuilder};
use litellm::{
AgentMessage, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient,
AgentMessage, ChatCompletionChunk, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient,
MessageProgress, Metadata, Tool, ToolCall, ToolChoice,
};
use once_cell::sync::Lazy;
use serde_json::Value;
use std::time::{Duration, Instant};
use std::{collections::HashMap, env, sync::Arc};
use tokio::sync::{broadcast, RwLock};
use tracing::error;
use tokio::sync::{broadcast, mpsc, RwLock};
use tokio_retry::{strategy::ExponentialBackoff, Retry, Condition};
use tracing::{error, info, warn};
use uuid::Uuid;
// Type definition for tool registry to simplify complex type
@ -687,22 +688,34 @@ impl Agent {
..Default::default()
};
// Get the streaming response from the LLM
let mut stream_rx = match agent
// --- Retry Logic for Initial Stream Request ---
let retry_strategy = ExponentialBackoff::from_millis(100).take(3); // Retry 3 times, ~100ms, ~200ms, ~400ms
let stream_rx_result = Retry::spawn(retry_strategy, || {
let agent_clone = agent.clone(); // Clone Arc for use in the closure
let request_clone = request.clone(); // Clone request for use in the closure
async move {
agent_clone
.llm_client
.stream_chat_completion(request.clone())
.stream_chat_completion(request_clone)
.await
{
}
})
.await;
// --- End Retry Logic ---
// Get the streaming response from the LLM
let mut stream_rx: mpsc::Receiver<Result<ChatCompletionChunk>> = match stream_rx_result {
Ok(rx) => rx,
Err(e) => {
// --- Added Error Handling ---
let error_message = format!("Error starting LLM stream: {:?}", e);
// --- Updated Error Handling for Retries ---
let error_message = format!("Error starting LLM stream after multiple retries: {:?}", e);
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message);
// Log error in span
if let Some(parent_span) = parent_span.clone() {
if let Some(client) = &*BRAINTRUST_CLIENT {
let error_span = parent_span.with_output(serde_json::json!({
"error": format!("Error starting stream: {:?}", e)
"error": format!("Error starting stream after retries: {:?}", e)
}));
// Log span non-blockingly (client handles the background processing)
@ -711,20 +724,27 @@ impl Agent {
}
}
}
// --- End Added Error Handling ---
// --- End Updated Error Handling ---
return Err(anyhow::anyhow!(error_message)); // Return immediately
}
};
// We store the parent span to use for creating individual tool spans
// This avoids creating a general assistant span that would never be completed
let parent_for_tool_spans = parent_span.clone();
// Process the streaming chunks
let mut buffer = MessageBuffer::new();
let mut _is_complete = false;
const STREAM_TIMEOUT_SECS: u64 = 120; // Timeout after 120 seconds of inactivity
while let Some(chunk_result) = stream_rx.recv().await {
loop { // Replaced `while let` with `loop` and explicit timeout
match tokio::time::timeout(
Duration::from_secs(STREAM_TIMEOUT_SECS),
stream_rx.recv(),
)
.await
{
Ok(Some(chunk_result)) => { // Received a message within timeout
match chunk_result {
Ok(chunk) => {
if chunk.choices.is_empty() {
@ -752,7 +772,8 @@ impl Agent {
});
// Get or create the pending tool call
let pending_call = buffer.tool_calls.entry(id.clone()).or_default();
let pending_call =
buffer.tool_calls.entry(id.clone()).or_default();
// Update the pending call with the delta
pending_call.update_from_delta(tool_call);
@ -767,22 +788,22 @@ impl Agent {
// Check if this is the final chunk
if chunk.choices[0].finish_reason.is_some() {
_is_complete = true;
// Don't break here yet, let the loop condition handle it
}
}
Err(e) => {
// --- Added Error Handling ---
let error_message = format!("Error receiving chunk from LLM stream: {:?}", e);
// --- Updated Error Handling for Retries ---
let error_message = format!(
"Error receiving chunk from LLM stream after multiple retries: {:?}",
e
);
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message);
// Log error in parent span
if let Some(parent) = &parent_for_tool_spans {
// Log error in span
if let Some(parent_span) = parent_span.clone() {
if let Some(client) = &*BRAINTRUST_CLIENT {
// Create error info
let error_info = serde_json::json!({
"error": format!("Error in stream: {:?}", e)
});
// Log error as output to parent span
let error_span = parent.clone().with_output(error_info);
let error_span = parent_span.with_output(serde_json::json!({
"error": format!("Error in stream after retries: {:?}", e)
}));
// Log span non-blockingly (client handles the background processing)
if let Err(log_err) = client.log_span(error_span).await {
@ -790,11 +811,45 @@ impl Agent {
}
}
}
// --- End Added Error Handling ---
// --- End Updated Error Handling ---
// Send error over broadcast channel before returning
let agent_error = AgentError(error_message.clone());
if let Ok(sender) = agent.get_stream_sender().await {
if let Err(send_err) = sender.send(Err(agent_error)) {
warn!("Failed to send stream error over channel: {}", send_err);
}
} else {
warn!("Stream sender not available for sending error.");
}
return Err(anyhow::anyhow!(error_message)); // Return immediately
}
}
}
Ok(None) => { // Stream closed gracefully
break;
}
Err(_) => { // Timeout occurred
let timeout_msg = format!(
"LLM stream timed out after {} seconds of inactivity.",
STREAM_TIMEOUT_SECS
);
warn!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", timeout_msg);
// Send timeout error over broadcast channel
let agent_error = AgentError(timeout_msg.clone());
if let Ok(sender) = agent.get_stream_sender().await {
if let Err(send_err) = sender.send(Err(agent_error)) {
warn!("Failed to send timeout error over channel: {}", send_err);
}
} else {
warn!("Stream sender not available for sending timeout error.");
}
// We could return an error here, or just break and let the agent finish
// For now, let's break and proceed with whatever was buffered
break;
}
}
}
// Flush any remaining buffered content or tool calls before creating final message
buffer.flush(&agent).await?;
@ -932,13 +987,13 @@ impl Agent {
{
Ok(r) => r,
Err(e) => {
// --- Added Error Handling ---
// --- Updated Error Handling for Retries ---
let error_message = format!(
"Tool execution error for {}: {:?}",
tool_call.function.name, e
);
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, tool_name = %tool_call.function.name, "{}", error_message);
// Log error in tool span
// Log error in span
if let Some(tool_span) = &tool_span {
if let Some(client) = &*BRAINTRUST_CLIENT {
let error_info = serde_json::json!({
@ -957,7 +1012,7 @@ impl Agent {
}
}
}
// --- End Added Error Handling ---
// --- End Updated Error Handling ---
let error_message = format!(
"Tool execution error for {}: {:?}",
tool_call.function.name, e