mirror of https://github.com/buster-so/buster.git
agent retry syste a little more resilient
This commit is contained in:
parent
1f6e757b1b
commit
8ad1801d9d
|
@ -106,6 +106,7 @@ rayon = "1.10.0"
|
||||||
diesel_migrations = "2.0.0"
|
diesel_migrations = "2.0.0"
|
||||||
html-escape = "0.2.13"
|
html-escape = "0.2.13"
|
||||||
tokio-cron-scheduler = "0.13.0"
|
tokio-cron-scheduler = "0.13.0"
|
||||||
|
tokio-retry = "0.3.0"
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
debug = false
|
debug = false
|
||||||
|
|
|
@ -32,6 +32,7 @@ redis = { workspace = true }
|
||||||
reqwest = { workspace = true }
|
reqwest = { workspace = true }
|
||||||
sqlx = { workspace = true }
|
sqlx = { workspace = true }
|
||||||
stored_values = { path = "../stored_values" }
|
stored_values = { path = "../stored_values" }
|
||||||
|
tokio-retry = { workspace = true }
|
||||||
|
|
||||||
# Development dependencies
|
# Development dependencies
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
|
|
@ -2,15 +2,16 @@ use crate::tools::{IntoToolCallExecutor, ToolExecutor};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use braintrust::{BraintrustClient, TraceBuilder};
|
use braintrust::{BraintrustClient, TraceBuilder};
|
||||||
use litellm::{
|
use litellm::{
|
||||||
AgentMessage, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient,
|
AgentMessage, ChatCompletionChunk, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient,
|
||||||
MessageProgress, Metadata, Tool, ToolCall, ToolChoice,
|
MessageProgress, Metadata, Tool, ToolCall, ToolChoice,
|
||||||
};
|
};
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use std::{collections::HashMap, env, sync::Arc};
|
use std::{collections::HashMap, env, sync::Arc};
|
||||||
use tokio::sync::{broadcast, RwLock};
|
use tokio::sync::{broadcast, mpsc, RwLock};
|
||||||
use tracing::error;
|
use tokio_retry::{strategy::ExponentialBackoff, Retry, Condition};
|
||||||
|
use tracing::{error, info, warn};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
// Type definition for tool registry to simplify complex type
|
// Type definition for tool registry to simplify complex type
|
||||||
|
@ -687,22 +688,34 @@ impl Agent {
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// --- Retry Logic for Initial Stream Request ---
|
||||||
|
let retry_strategy = ExponentialBackoff::from_millis(100).take(3); // Retry 3 times, ~100ms, ~200ms, ~400ms
|
||||||
|
|
||||||
|
let stream_rx_result = Retry::spawn(retry_strategy, || {
|
||||||
|
let agent_clone = agent.clone(); // Clone Arc for use in the closure
|
||||||
|
let request_clone = request.clone(); // Clone request for use in the closure
|
||||||
|
async move {
|
||||||
|
agent_clone
|
||||||
|
.llm_client
|
||||||
|
.stream_chat_completion(request_clone)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
// --- End Retry Logic ---
|
||||||
|
|
||||||
// Get the streaming response from the LLM
|
// Get the streaming response from the LLM
|
||||||
let mut stream_rx = match agent
|
let mut stream_rx: mpsc::Receiver<Result<ChatCompletionChunk>> = match stream_rx_result {
|
||||||
.llm_client
|
|
||||||
.stream_chat_completion(request.clone())
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(rx) => rx,
|
Ok(rx) => rx,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
// --- Added Error Handling ---
|
// --- Updated Error Handling for Retries ---
|
||||||
let error_message = format!("Error starting LLM stream: {:?}", e);
|
let error_message = format!("Error starting LLM stream after multiple retries: {:?}", e);
|
||||||
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message);
|
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message);
|
||||||
// Log error in span
|
// Log error in span
|
||||||
if let Some(parent_span) = parent_span.clone() {
|
if let Some(parent_span) = parent_span.clone() {
|
||||||
if let Some(client) = &*BRAINTRUST_CLIENT {
|
if let Some(client) = &*BRAINTRUST_CLIENT {
|
||||||
let error_span = parent_span.with_output(serde_json::json!({
|
let error_span = parent_span.with_output(serde_json::json!({
|
||||||
"error": format!("Error starting stream: {:?}", e)
|
"error": format!("Error starting stream after retries: {:?}", e)
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Log span non-blockingly (client handles the background processing)
|
// Log span non-blockingly (client handles the background processing)
|
||||||
|
@ -711,87 +724,129 @@ impl Agent {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// --- End Added Error Handling ---
|
// --- End Updated Error Handling ---
|
||||||
return Err(anyhow::anyhow!(error_message)); // Return immediately
|
return Err(anyhow::anyhow!(error_message)); // Return immediately
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// We store the parent span to use for creating individual tool spans
|
// We store the parent span to use for creating individual tool spans
|
||||||
// This avoids creating a general assistant span that would never be completed
|
|
||||||
let parent_for_tool_spans = parent_span.clone();
|
let parent_for_tool_spans = parent_span.clone();
|
||||||
|
|
||||||
// Process the streaming chunks
|
// Process the streaming chunks
|
||||||
let mut buffer = MessageBuffer::new();
|
let mut buffer = MessageBuffer::new();
|
||||||
let mut _is_complete = false;
|
let mut _is_complete = false;
|
||||||
|
const STREAM_TIMEOUT_SECS: u64 = 120; // Timeout after 120 seconds of inactivity
|
||||||
|
|
||||||
while let Some(chunk_result) = stream_rx.recv().await {
|
loop { // Replaced `while let` with `loop` and explicit timeout
|
||||||
match chunk_result {
|
match tokio::time::timeout(
|
||||||
Ok(chunk) => {
|
Duration::from_secs(STREAM_TIMEOUT_SECS),
|
||||||
if chunk.choices.is_empty() {
|
stream_rx.recv(),
|
||||||
continue;
|
)
|
||||||
}
|
.await
|
||||||
|
{
|
||||||
|
Ok(Some(chunk_result)) => { // Received a message within timeout
|
||||||
|
match chunk_result {
|
||||||
|
Ok(chunk) => {
|
||||||
|
if chunk.choices.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
buffer.message_id = Some(chunk.id.clone());
|
buffer.message_id = Some(chunk.id.clone());
|
||||||
let delta = &chunk.choices[0].delta;
|
let delta = &chunk.choices[0].delta;
|
||||||
|
|
||||||
// Accumulate content if present
|
// Accumulate content if present
|
||||||
if let Some(content) = &delta.content {
|
if let Some(content) = &delta.content {
|
||||||
buffer.content.push_str(content);
|
buffer.content.push_str(content);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process tool calls if present
|
// Process tool calls if present
|
||||||
if let Some(tool_calls) = &delta.tool_calls {
|
if let Some(tool_calls) = &delta.tool_calls {
|
||||||
for tool_call in tool_calls {
|
for tool_call in tool_calls {
|
||||||
let id = tool_call.id.clone().unwrap_or_else(|| {
|
let id = tool_call.id.clone().unwrap_or_else(|| {
|
||||||
buffer
|
buffer
|
||||||
.tool_calls
|
.tool_calls
|
||||||
.keys()
|
.keys()
|
||||||
.next()
|
.next()
|
||||||
.cloned()
|
.cloned()
|
||||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string())
|
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string())
|
||||||
});
|
});
|
||||||
|
|
||||||
// Get or create the pending tool call
|
// Get or create the pending tool call
|
||||||
let pending_call = buffer.tool_calls.entry(id.clone()).or_default();
|
let pending_call =
|
||||||
|
buffer.tool_calls.entry(id.clone()).or_default();
|
||||||
|
|
||||||
// Update the pending call with the delta
|
// Update the pending call with the delta
|
||||||
pending_call.update_from_delta(tool_call);
|
pending_call.update_from_delta(tool_call);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if we should flush the buffer
|
// Check if we should flush the buffer
|
||||||
if buffer.should_flush() {
|
if buffer.should_flush() {
|
||||||
buffer.flush(&agent).await?;
|
buffer.flush(&agent).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if this is the final chunk
|
// Check if this is the final chunk
|
||||||
if chunk.choices[0].finish_reason.is_some() {
|
if chunk.choices[0].finish_reason.is_some() {
|
||||||
_is_complete = true;
|
_is_complete = true;
|
||||||
}
|
// Don't break here yet, let the loop condition handle it
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
// --- Added Error Handling ---
|
|
||||||
let error_message = format!("Error receiving chunk from LLM stream: {:?}", e);
|
|
||||||
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message);
|
|
||||||
// Log error in parent span
|
|
||||||
if let Some(parent) = &parent_for_tool_spans {
|
|
||||||
if let Some(client) = &*BRAINTRUST_CLIENT {
|
|
||||||
// Create error info
|
|
||||||
let error_info = serde_json::json!({
|
|
||||||
"error": format!("Error in stream: {:?}", e)
|
|
||||||
});
|
|
||||||
|
|
||||||
// Log error as output to parent span
|
|
||||||
let error_span = parent.clone().with_output(error_info);
|
|
||||||
|
|
||||||
// Log span non-blockingly (client handles the background processing)
|
|
||||||
if let Err(log_err) = client.log_span(error_span).await {
|
|
||||||
error!("Failed to log stream error span: {}", log_err);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Err(e) => {
|
||||||
|
// --- Updated Error Handling for Retries ---
|
||||||
|
let error_message = format!(
|
||||||
|
"Error receiving chunk from LLM stream after multiple retries: {:?}",
|
||||||
|
e
|
||||||
|
);
|
||||||
|
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message);
|
||||||
|
// Log error in span
|
||||||
|
if let Some(parent_span) = parent_span.clone() {
|
||||||
|
if let Some(client) = &*BRAINTRUST_CLIENT {
|
||||||
|
let error_span = parent_span.with_output(serde_json::json!({
|
||||||
|
"error": format!("Error in stream after retries: {:?}", e)
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Log span non-blockingly (client handles the background processing)
|
||||||
|
if let Err(log_err) = client.log_span(error_span).await {
|
||||||
|
error!("Failed to log stream error span: {}", log_err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// --- End Updated Error Handling ---
|
||||||
|
// Send error over broadcast channel before returning
|
||||||
|
let agent_error = AgentError(error_message.clone());
|
||||||
|
if let Ok(sender) = agent.get_stream_sender().await {
|
||||||
|
if let Err(send_err) = sender.send(Err(agent_error)) {
|
||||||
|
warn!("Failed to send stream error over channel: {}", send_err);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
warn!("Stream sender not available for sending error.");
|
||||||
|
}
|
||||||
|
return Err(anyhow::anyhow!(error_message)); // Return immediately
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// --- End Added Error Handling ---
|
}
|
||||||
return Err(anyhow::anyhow!(error_message)); // Return immediately
|
Ok(None) => { // Stream closed gracefully
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Err(_) => { // Timeout occurred
|
||||||
|
let timeout_msg = format!(
|
||||||
|
"LLM stream timed out after {} seconds of inactivity.",
|
||||||
|
STREAM_TIMEOUT_SECS
|
||||||
|
);
|
||||||
|
warn!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", timeout_msg);
|
||||||
|
|
||||||
|
// Send timeout error over broadcast channel
|
||||||
|
let agent_error = AgentError(timeout_msg.clone());
|
||||||
|
if let Ok(sender) = agent.get_stream_sender().await {
|
||||||
|
if let Err(send_err) = sender.send(Err(agent_error)) {
|
||||||
|
warn!("Failed to send timeout error over channel: {}", send_err);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
warn!("Stream sender not available for sending timeout error.");
|
||||||
|
}
|
||||||
|
// We could return an error here, or just break and let the agent finish
|
||||||
|
// For now, let's break and proceed with whatever was buffered
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -932,13 +987,13 @@ impl Agent {
|
||||||
{
|
{
|
||||||
Ok(r) => r,
|
Ok(r) => r,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
// --- Added Error Handling ---
|
// --- Updated Error Handling for Retries ---
|
||||||
let error_message = format!(
|
let error_message = format!(
|
||||||
"Tool execution error for {}: {:?}",
|
"Tool execution error for {}: {:?}",
|
||||||
tool_call.function.name, e
|
tool_call.function.name, e
|
||||||
);
|
);
|
||||||
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, tool_name = %tool_call.function.name, "{}", error_message);
|
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, tool_name = %tool_call.function.name, "{}", error_message);
|
||||||
// Log error in tool span
|
// Log error in span
|
||||||
if let Some(tool_span) = &tool_span {
|
if let Some(tool_span) = &tool_span {
|
||||||
if let Some(client) = &*BRAINTRUST_CLIENT {
|
if let Some(client) = &*BRAINTRUST_CLIENT {
|
||||||
let error_info = serde_json::json!({
|
let error_info = serde_json::json!({
|
||||||
|
@ -957,7 +1012,7 @@ impl Agent {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// --- End Added Error Handling ---
|
// --- End Updated Error Handling ---
|
||||||
let error_message = format!(
|
let error_message = format!(
|
||||||
"Tool execution error for {}: {:?}",
|
"Tool execution error for {}: {:?}",
|
||||||
tool_call.function.name, e
|
tool_call.function.name, e
|
||||||
|
|
Loading…
Reference in New Issue