mirror of https://github.com/buster-so/buster.git
agent retry syste a little more resilient
This commit is contained in:
parent
1f6e757b1b
commit
8ad1801d9d
|
@ -106,6 +106,7 @@ rayon = "1.10.0"
|
|||
diesel_migrations = "2.0.0"
|
||||
html-escape = "0.2.13"
|
||||
tokio-cron-scheduler = "0.13.0"
|
||||
tokio-retry = "0.3.0"
|
||||
|
||||
[profile.release]
|
||||
debug = false
|
||||
|
|
|
@ -32,6 +32,7 @@ redis = { workspace = true }
|
|||
reqwest = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
stored_values = { path = "../stored_values" }
|
||||
tokio-retry = { workspace = true }
|
||||
|
||||
# Development dependencies
|
||||
[dev-dependencies]
|
||||
|
|
|
@ -2,15 +2,16 @@ use crate::tools::{IntoToolCallExecutor, ToolExecutor};
|
|||
use anyhow::Result;
|
||||
use braintrust::{BraintrustClient, TraceBuilder};
|
||||
use litellm::{
|
||||
AgentMessage, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient,
|
||||
AgentMessage, ChatCompletionChunk, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient,
|
||||
MessageProgress, Metadata, Tool, ToolCall, ToolChoice,
|
||||
};
|
||||
use once_cell::sync::Lazy;
|
||||
use serde_json::Value;
|
||||
use std::time::{Duration, Instant};
|
||||
use std::{collections::HashMap, env, sync::Arc};
|
||||
use tokio::sync::{broadcast, RwLock};
|
||||
use tracing::error;
|
||||
use tokio::sync::{broadcast, mpsc, RwLock};
|
||||
use tokio_retry::{strategy::ExponentialBackoff, Retry, Condition};
|
||||
use tracing::{error, info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
// Type definition for tool registry to simplify complex type
|
||||
|
@ -661,7 +662,7 @@ impl Agent {
|
|||
// --- End Prepare LLM Messages ---
|
||||
|
||||
// Collect all enabled tools and their schemas
|
||||
let tools = agent.get_enabled_tools().await;
|
||||
let tools = agent.get_enabled_tools().await;
|
||||
|
||||
// Get user message for logging (unchanged)
|
||||
let _user_message = thread_ref
|
||||
|
@ -673,7 +674,7 @@ impl Agent {
|
|||
// Create the tool-enabled request
|
||||
let request = ChatCompletionRequest {
|
||||
model: mode_config.model, // Use the model from mode config
|
||||
messages: llm_messages,
|
||||
messages: llm_messages,
|
||||
tools: if tools.is_empty() { None } else { Some(tools) },
|
||||
tool_choice: Some(ToolChoice::Required), // Or adjust based on mode?
|
||||
stream: Some(true), // Enable streaming
|
||||
|
@ -687,22 +688,34 @@ impl Agent {
|
|||
..Default::default()
|
||||
};
|
||||
|
||||
// --- Retry Logic for Initial Stream Request ---
|
||||
let retry_strategy = ExponentialBackoff::from_millis(100).take(3); // Retry 3 times, ~100ms, ~200ms, ~400ms
|
||||
|
||||
let stream_rx_result = Retry::spawn(retry_strategy, || {
|
||||
let agent_clone = agent.clone(); // Clone Arc for use in the closure
|
||||
let request_clone = request.clone(); // Clone request for use in the closure
|
||||
async move {
|
||||
agent_clone
|
||||
.llm_client
|
||||
.stream_chat_completion(request_clone)
|
||||
.await
|
||||
}
|
||||
})
|
||||
.await;
|
||||
// --- End Retry Logic ---
|
||||
|
||||
// Get the streaming response from the LLM
|
||||
let mut stream_rx = match agent
|
||||
.llm_client
|
||||
.stream_chat_completion(request.clone())
|
||||
.await
|
||||
{
|
||||
let mut stream_rx: mpsc::Receiver<Result<ChatCompletionChunk>> = match stream_rx_result {
|
||||
Ok(rx) => rx,
|
||||
Err(e) => {
|
||||
// --- Added Error Handling ---
|
||||
let error_message = format!("Error starting LLM stream: {:?}", e);
|
||||
// --- Updated Error Handling for Retries ---
|
||||
let error_message = format!("Error starting LLM stream after multiple retries: {:?}", e);
|
||||
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message);
|
||||
// Log error in span
|
||||
if let Some(parent_span) = parent_span.clone() {
|
||||
if let Some(client) = &*BRAINTRUST_CLIENT {
|
||||
let error_span = parent_span.with_output(serde_json::json!({
|
||||
"error": format!("Error starting stream: {:?}", e)
|
||||
"error": format!("Error starting stream after retries: {:?}", e)
|
||||
}));
|
||||
|
||||
// Log span non-blockingly (client handles the background processing)
|
||||
|
@ -711,87 +724,129 @@ impl Agent {
|
|||
}
|
||||
}
|
||||
}
|
||||
// --- End Added Error Handling ---
|
||||
// --- End Updated Error Handling ---
|
||||
return Err(anyhow::anyhow!(error_message)); // Return immediately
|
||||
}
|
||||
};
|
||||
|
||||
// We store the parent span to use for creating individual tool spans
|
||||
// This avoids creating a general assistant span that would never be completed
|
||||
let parent_for_tool_spans = parent_span.clone();
|
||||
|
||||
// Process the streaming chunks
|
||||
let mut buffer = MessageBuffer::new();
|
||||
let mut _is_complete = false;
|
||||
const STREAM_TIMEOUT_SECS: u64 = 120; // Timeout after 120 seconds of inactivity
|
||||
|
||||
while let Some(chunk_result) = stream_rx.recv().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
if chunk.choices.is_empty() {
|
||||
continue;
|
||||
}
|
||||
loop { // Replaced `while let` with `loop` and explicit timeout
|
||||
match tokio::time::timeout(
|
||||
Duration::from_secs(STREAM_TIMEOUT_SECS),
|
||||
stream_rx.recv(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Some(chunk_result)) => { // Received a message within timeout
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
if chunk.choices.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
buffer.message_id = Some(chunk.id.clone());
|
||||
let delta = &chunk.choices[0].delta;
|
||||
buffer.message_id = Some(chunk.id.clone());
|
||||
let delta = &chunk.choices[0].delta;
|
||||
|
||||
// Accumulate content if present
|
||||
if let Some(content) = &delta.content {
|
||||
buffer.content.push_str(content);
|
||||
}
|
||||
// Accumulate content if present
|
||||
if let Some(content) = &delta.content {
|
||||
buffer.content.push_str(content);
|
||||
}
|
||||
|
||||
// Process tool calls if present
|
||||
if let Some(tool_calls) = &delta.tool_calls {
|
||||
for tool_call in tool_calls {
|
||||
let id = tool_call.id.clone().unwrap_or_else(|| {
|
||||
buffer
|
||||
.tool_calls
|
||||
.keys()
|
||||
.next()
|
||||
.cloned()
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string())
|
||||
});
|
||||
// Process tool calls if present
|
||||
if let Some(tool_calls) = &delta.tool_calls {
|
||||
for tool_call in tool_calls {
|
||||
let id = tool_call.id.clone().unwrap_or_else(|| {
|
||||
buffer
|
||||
.tool_calls
|
||||
.keys()
|
||||
.next()
|
||||
.cloned()
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string())
|
||||
});
|
||||
|
||||
// Get or create the pending tool call
|
||||
let pending_call = buffer.tool_calls.entry(id.clone()).or_default();
|
||||
// Get or create the pending tool call
|
||||
let pending_call =
|
||||
buffer.tool_calls.entry(id.clone()).or_default();
|
||||
|
||||
// Update the pending call with the delta
|
||||
pending_call.update_from_delta(tool_call);
|
||||
}
|
||||
}
|
||||
// Update the pending call with the delta
|
||||
pending_call.update_from_delta(tool_call);
|
||||
}
|
||||
}
|
||||
|
||||
// Check if we should flush the buffer
|
||||
if buffer.should_flush() {
|
||||
buffer.flush(&agent).await?;
|
||||
}
|
||||
// Check if we should flush the buffer
|
||||
if buffer.should_flush() {
|
||||
buffer.flush(&agent).await?;
|
||||
}
|
||||
|
||||
// Check if this is the final chunk
|
||||
if chunk.choices[0].finish_reason.is_some() {
|
||||
_is_complete = true;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
// --- Added Error Handling ---
|
||||
let error_message = format!("Error receiving chunk from LLM stream: {:?}", e);
|
||||
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message);
|
||||
// Log error in parent span
|
||||
if let Some(parent) = &parent_for_tool_spans {
|
||||
if let Some(client) = &*BRAINTRUST_CLIENT {
|
||||
// Create error info
|
||||
let error_info = serde_json::json!({
|
||||
"error": format!("Error in stream: {:?}", e)
|
||||
});
|
||||
|
||||
// Log error as output to parent span
|
||||
let error_span = parent.clone().with_output(error_info);
|
||||
|
||||
// Log span non-blockingly (client handles the background processing)
|
||||
if let Err(log_err) = client.log_span(error_span).await {
|
||||
error!("Failed to log stream error span: {}", log_err);
|
||||
// Check if this is the final chunk
|
||||
if chunk.choices[0].finish_reason.is_some() {
|
||||
_is_complete = true;
|
||||
// Don't break here yet, let the loop condition handle it
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
// --- Updated Error Handling for Retries ---
|
||||
let error_message = format!(
|
||||
"Error receiving chunk from LLM stream after multiple retries: {:?}",
|
||||
e
|
||||
);
|
||||
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message);
|
||||
// Log error in span
|
||||
if let Some(parent_span) = parent_span.clone() {
|
||||
if let Some(client) = &*BRAINTRUST_CLIENT {
|
||||
let error_span = parent_span.with_output(serde_json::json!({
|
||||
"error": format!("Error in stream after retries: {:?}", e)
|
||||
}));
|
||||
|
||||
// Log span non-blockingly (client handles the background processing)
|
||||
if let Err(log_err) = client.log_span(error_span).await {
|
||||
error!("Failed to log stream error span: {}", log_err);
|
||||
}
|
||||
}
|
||||
}
|
||||
// --- End Updated Error Handling ---
|
||||
// Send error over broadcast channel before returning
|
||||
let agent_error = AgentError(error_message.clone());
|
||||
if let Ok(sender) = agent.get_stream_sender().await {
|
||||
if let Err(send_err) = sender.send(Err(agent_error)) {
|
||||
warn!("Failed to send stream error over channel: {}", send_err);
|
||||
}
|
||||
} else {
|
||||
warn!("Stream sender not available for sending error.");
|
||||
}
|
||||
return Err(anyhow::anyhow!(error_message)); // Return immediately
|
||||
}
|
||||
}
|
||||
// --- End Added Error Handling ---
|
||||
return Err(anyhow::anyhow!(error_message)); // Return immediately
|
||||
}
|
||||
Ok(None) => { // Stream closed gracefully
|
||||
break;
|
||||
}
|
||||
Err(_) => { // Timeout occurred
|
||||
let timeout_msg = format!(
|
||||
"LLM stream timed out after {} seconds of inactivity.",
|
||||
STREAM_TIMEOUT_SECS
|
||||
);
|
||||
warn!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", timeout_msg);
|
||||
|
||||
// Send timeout error over broadcast channel
|
||||
let agent_error = AgentError(timeout_msg.clone());
|
||||
if let Ok(sender) = agent.get_stream_sender().await {
|
||||
if let Err(send_err) = sender.send(Err(agent_error)) {
|
||||
warn!("Failed to send timeout error over channel: {}", send_err);
|
||||
}
|
||||
} else {
|
||||
warn!("Stream sender not available for sending timeout error.");
|
||||
}
|
||||
// We could return an error here, or just break and let the agent finish
|
||||
// For now, let's break and proceed with whatever was buffered
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -932,13 +987,13 @@ impl Agent {
|
|||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
// --- Added Error Handling ---
|
||||
// --- Updated Error Handling for Retries ---
|
||||
let error_message = format!(
|
||||
"Tool execution error for {}: {:?}",
|
||||
tool_call.function.name, e
|
||||
);
|
||||
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, tool_name = %tool_call.function.name, "{}", error_message);
|
||||
// Log error in tool span
|
||||
// Log error in span
|
||||
if let Some(tool_span) = &tool_span {
|
||||
if let Some(client) = &*BRAINTRUST_CLIENT {
|
||||
let error_info = serde_json::json!({
|
||||
|
@ -957,7 +1012,7 @@ impl Agent {
|
|||
}
|
||||
}
|
||||
}
|
||||
// --- End Added Error Handling ---
|
||||
// --- End Updated Error Handling ---
|
||||
let error_message = format!(
|
||||
"Tool execution error for {}: {:?}",
|
||||
tool_call.function.name, e
|
||||
|
|
Loading…
Reference in New Issue