mirror of https://github.com/buster-so/buster.git
Merge pull request #239 from buster-so/dal/agent-resiliency
Dal/agent resiliency
This commit is contained in:
commit
d06d3deadb
|
@ -106,6 +106,7 @@ rayon = "1.10.0"
|
|||
diesel_migrations = "2.0.0"
|
||||
html-escape = "0.2.13"
|
||||
tokio-cron-scheduler = "0.13.0"
|
||||
tokio-retry = "0.3.0"
|
||||
|
||||
[profile.release]
|
||||
debug = false
|
||||
|
|
|
@ -32,6 +32,8 @@ redis = { workspace = true }
|
|||
reqwest = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
stored_values = { path = "../stored_values" }
|
||||
tokio-retry = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
|
||||
# Development dependencies
|
||||
[dev-dependencies]
|
||||
|
|
|
@ -2,15 +2,16 @@ use crate::tools::{IntoToolCallExecutor, ToolExecutor};
|
|||
use anyhow::Result;
|
||||
use braintrust::{BraintrustClient, TraceBuilder};
|
||||
use litellm::{
|
||||
AgentMessage, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient,
|
||||
MessageProgress, Metadata, Tool, ToolCall, ToolChoice,
|
||||
AgentMessage, ChatCompletionChunk, ChatCompletionRequest, DeltaToolCall, FunctionCall,
|
||||
LiteLLMClient, MessageProgress, Metadata, Tool, ToolCall, ToolChoice,
|
||||
};
|
||||
use once_cell::sync::Lazy;
|
||||
use serde_json::Value;
|
||||
use std::time::{Duration, Instant};
|
||||
use std::{collections::HashMap, env, sync::Arc};
|
||||
use tokio::sync::{broadcast, RwLock};
|
||||
use tracing::error;
|
||||
use tokio::sync::{broadcast, mpsc, RwLock};
|
||||
use tokio_retry::{strategy::ExponentialBackoff, Retry};
|
||||
use tracing::{error, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
// Type definition for tool registry to simplify complex type
|
||||
|
@ -37,6 +38,7 @@ static BRAINTRUST_CLIENT: Lazy<Option<Arc<BraintrustClient>>> = Lazy::new(|| {
|
|||
}
|
||||
});
|
||||
|
||||
// --- Reverted AgentError Struct ---
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AgentError(pub String);
|
||||
|
||||
|
@ -47,6 +49,7 @@ impl std::fmt::Display for AgentError {
|
|||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
// --- End Reverted AgentError Struct ---
|
||||
|
||||
type MessageResult = Result<AgentMessage, AgentError>;
|
||||
|
||||
|
@ -146,13 +149,16 @@ struct RegisteredTool {
|
|||
// Update the ToolRegistry type alias is no longer needed, but we need the new type for the map
|
||||
type ToolsMap = Arc<RwLock<HashMap<String, RegisteredTool>>>;
|
||||
|
||||
// --- Define ModeProvider Trait ---
|
||||
// --- Define ModeProvider Trait ---
|
||||
#[async_trait::async_trait]
|
||||
pub trait ModeProvider {
|
||||
// Fetches the complete configuration for a given agent state
|
||||
async fn get_configuration_for_state(&self, state: &HashMap<String, Value>) -> Result<ModeConfiguration>;
|
||||
async fn get_configuration_for_state(
|
||||
&self,
|
||||
state: &HashMap<String, Value>,
|
||||
) -> Result<ModeConfiguration>;
|
||||
}
|
||||
// --- End ModeProvider Trait ---
|
||||
// --- End ModeProvider Trait ---
|
||||
|
||||
#[derive(Clone)]
|
||||
/// The Agent struct is responsible for managing conversations with the LLM
|
||||
|
@ -183,7 +189,7 @@ pub struct Agent {
|
|||
/// This will be managed by the ModeProvider now.
|
||||
terminating_tool_names: Arc<RwLock<Vec<String>>>,
|
||||
/// Provider for mode-specific logic (prompt, model, tools, termination)
|
||||
mode_provider: Arc<dyn ModeProvider + Send + Sync>,
|
||||
mode_provider: Arc<dyn ModeProvider + Send + Sync>,
|
||||
}
|
||||
|
||||
impl Agent {
|
||||
|
@ -216,7 +222,7 @@ impl Agent {
|
|||
shutdown_tx: Arc::new(RwLock::new(shutdown_tx)),
|
||||
name,
|
||||
terminating_tool_names: Arc::new(RwLock::new(Vec::new())), // Initialize empty list
|
||||
mode_provider, // Store the provider
|
||||
mode_provider, // Store the provider
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -243,7 +249,7 @@ impl Agent {
|
|||
shutdown_tx: Arc::clone(&existing_agent.shutdown_tx), // Shared shutdown
|
||||
name,
|
||||
terminating_tool_names: Arc::new(RwLock::new(Vec::new())), // Sub-agent starts with empty term tools?
|
||||
mode_provider: Arc::clone(&mode_provider), // Share provider
|
||||
mode_provider: Arc::clone(&mode_provider), // Share provider
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -273,10 +279,14 @@ impl Agent {
|
|||
|
||||
/// Get a new receiver for the broadcast channel.
|
||||
/// Returns an error if the stream channel has been closed or was not initialized.
|
||||
pub async fn get_stream_receiver(&self) -> Result<broadcast::Receiver<MessageResult>, AgentError> {
|
||||
pub async fn get_stream_receiver(
|
||||
&self,
|
||||
) -> Result<broadcast::Receiver<MessageResult>, AgentError> {
|
||||
match self.stream_tx.read().await.as_ref() {
|
||||
Some(tx) => Ok(tx.subscribe()),
|
||||
None => Err(AgentError("Stream channel is closed or not initialized.".to_string()))
|
||||
None => Err(AgentError(
|
||||
"Stream channel is closed or not initialized.".to_string(),
|
||||
)), // Use string error
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -285,7 +295,9 @@ impl Agent {
|
|||
pub async fn get_stream_sender(&self) -> Result<broadcast::Sender<MessageResult>, AgentError> {
|
||||
match self.stream_tx.read().await.as_ref() {
|
||||
Some(tx) => Ok(tx.clone()),
|
||||
None => Err(AgentError("Stream channel is closed or not initialized.".to_string()))
|
||||
None => Err(AgentError(
|
||||
"Stream channel is closed or not initialized.".to_string(),
|
||||
)), // Use string error
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -477,12 +489,14 @@ impl Agent {
|
|||
tokio::select! {
|
||||
result = Agent::process_thread_with_depth(agent_arc_clone, thread_clone.clone(), &thread_clone, 0, None, None) => {
|
||||
if let Err(e) = result {
|
||||
// Log the error
|
||||
let err_msg = format!("Error processing thread: {:?}", e);
|
||||
error!("{}", err_msg); // Log the error
|
||||
error!("{}", err_msg);
|
||||
// Use the clone created before select!
|
||||
// Handle the Result from get_stream_sender
|
||||
let agent_error = AgentError(err_msg); // Use reverted struct
|
||||
if let Ok(sender) = agent_clone_for_post_process.get_stream_sender().await {
|
||||
if let Err(send_err) = sender.send(Err(AgentError(err_msg.clone()))) {
|
||||
if let Err(send_err) = sender.send(Err(agent_error)) {
|
||||
tracing::warn!("Failed to send error message to stream: {}", send_err);
|
||||
}
|
||||
} else {
|
||||
|
@ -532,7 +546,11 @@ impl Agent {
|
|||
});
|
||||
|
||||
// Handle the Result from get_stream_receiver
|
||||
agent_for_ok.get_stream_receiver().await.map_err(|e| e.into())
|
||||
// Add mapping back for the outer function signature
|
||||
agent_for_ok
|
||||
.get_stream_receiver()
|
||||
.await
|
||||
.map_err(anyhow::Error::from)
|
||||
}
|
||||
|
||||
async fn process_thread_with_depth(
|
||||
|
@ -603,9 +621,11 @@ impl Agent {
|
|||
|
||||
// Limit recursion to a maximum of 15 times
|
||||
if recursion_depth >= 15 {
|
||||
let max_depth_msg = format!("Maximum recursion depth ({}) reached.", recursion_depth);
|
||||
warn!("{}", max_depth_msg);
|
||||
let message = AgentMessage::assistant(
|
||||
Some("max_recursion_depth_message".to_string()),
|
||||
Some("I apologize, but I've reached the maximum number of actions (15). Please try breaking your request into smaller parts.".to_string()),
|
||||
Some(max_depth_msg.clone()), // Send the message string
|
||||
None,
|
||||
MessageProgress::Complete,
|
||||
None,
|
||||
|
@ -613,41 +633,55 @@ impl Agent {
|
|||
);
|
||||
// Handle the Result from get_stream_sender
|
||||
if let Ok(sender) = agent.get_stream_sender().await {
|
||||
// Send the Ok message first
|
||||
if let Err(e) = sender.send(Ok(message)) {
|
||||
tracing::warn!(
|
||||
"Channel send error when sending recursion limit message: {}",
|
||||
warn!(
|
||||
"Channel send error when sending max recursion depth message: {}",
|
||||
e
|
||||
);
|
||||
}
|
||||
// Send the error itself over the channel
|
||||
if let Err(e) = sender.send(Err(AgentError(max_depth_msg))) {
|
||||
// Send string error
|
||||
warn!(
|
||||
"Channel send error when sending max recursion depth error: {}",
|
||||
e
|
||||
);
|
||||
}
|
||||
} else {
|
||||
tracing::warn!("Stream sender not available when sending recursion limit message.");
|
||||
warn!("Stream sender not available when sending max recursion depth info.");
|
||||
}
|
||||
agent.close().await; // Ensure stream is closed
|
||||
return Ok(()); // Don't return error, just stop processing
|
||||
return Ok(()); // Stop processing gracefully, error sent via channel
|
||||
}
|
||||
|
||||
// --- Fetch and Apply Mode Configuration ---
|
||||
// --- Fetch and Apply Mode Configuration ---
|
||||
let state = agent.get_state().await;
|
||||
let mode_config = agent.mode_provider.get_configuration_for_state(&state).await?;
|
||||
let mode_config = agent
|
||||
.mode_provider
|
||||
.get_configuration_for_state(&state)
|
||||
.await?;
|
||||
|
||||
// Apply Tool Loading via the closure provided by the mode
|
||||
agent.clear_tools().await; // Clear previous mode's tools
|
||||
(mode_config.tool_loader)(&agent).await?; // Explicitly cast self
|
||||
|
||||
// Apply Terminating Tools for this mode
|
||||
{ // Scope for write lock
|
||||
{
|
||||
// Scope for write lock
|
||||
let mut term_tools_lock = agent.terminating_tool_names.write().await;
|
||||
term_tools_lock.clear();
|
||||
term_tools_lock.extend(mode_config.terminating_tools);
|
||||
}
|
||||
// --- End Mode Configuration Application ---
|
||||
|
||||
// --- Prepare LLM Messages ---
|
||||
// --- Prepare LLM Messages ---
|
||||
// Use prompt from mode_config
|
||||
let system_message = AgentMessage::developer(mode_config.prompt);
|
||||
let mut llm_messages = vec![system_message];
|
||||
llm_messages.extend(
|
||||
agent.current_thread // Use self.current_thread which is updated
|
||||
agent
|
||||
.current_thread // Use self.current_thread which is updated
|
||||
.read()
|
||||
.await
|
||||
.as_ref()
|
||||
|
@ -655,13 +689,13 @@ impl Agent {
|
|||
.messages
|
||||
// Filter out previous Developer messages if desired, or keep history clean
|
||||
.iter()
|
||||
.filter(|msg| !matches!(msg, AgentMessage::Developer { .. }))
|
||||
.filter(|msg| !matches!(msg, AgentMessage::Developer { .. }))
|
||||
.cloned(),
|
||||
);
|
||||
// --- End Prepare LLM Messages ---
|
||||
|
||||
// Collect all enabled tools and their schemas
|
||||
let tools = agent.get_enabled_tools().await;
|
||||
let tools = agent.get_enabled_tools().await;
|
||||
|
||||
// Get user message for logging (unchanged)
|
||||
let _user_message = thread_ref
|
||||
|
@ -673,10 +707,10 @@ impl Agent {
|
|||
// Create the tool-enabled request
|
||||
let request = ChatCompletionRequest {
|
||||
model: mode_config.model, // Use the model from mode config
|
||||
messages: llm_messages,
|
||||
messages: llm_messages,
|
||||
tools: if tools.is_empty() { None } else { Some(tools) },
|
||||
tool_choice: Some(ToolChoice::Required), // Or adjust based on mode?
|
||||
stream: Some(true), // Enable streaming
|
||||
stream: Some(true), // Enable streaming
|
||||
metadata: Some(Metadata {
|
||||
generation_name: "agent".to_string(),
|
||||
user_id: thread_ref.user_id.to_string(),
|
||||
|
@ -687,111 +721,215 @@ impl Agent {
|
|||
..Default::default()
|
||||
};
|
||||
|
||||
// --- Retry Logic for Initial Stream Request ---
|
||||
let retry_strategy = ExponentialBackoff::from_millis(100).take(3); // Retry 3 times, ~100ms, ~200ms, ~400ms
|
||||
|
||||
// Define a condition for retrying: only on network-related errors
|
||||
let retry_condition = |e: &anyhow::Error| -> bool {
|
||||
if let Some(req_err) = e.downcast_ref::<reqwest::Error>() {
|
||||
// Retry on specific transient errors
|
||||
req_err.is_timeout() || req_err.is_connect() || req_err.is_request()
|
||||
} else {
|
||||
false // Don't retry if it's not a reqwest network error
|
||||
}
|
||||
};
|
||||
|
||||
// The retry operation now wraps the actual result or a permanent error in an outer Ok
|
||||
// Retriable errors are returned as the Err variant for Retry::spawn
|
||||
let stream_rx_result = Retry::spawn(retry_strategy, || {
|
||||
// Clone necessary data for the closure
|
||||
let agent_clone = agent.clone();
|
||||
let request_clone = request.clone();
|
||||
let retry_condition_clone = retry_condition; // Clone the condition closure
|
||||
async move {
|
||||
match agent_clone
|
||||
.llm_client
|
||||
.stream_chat_completion(request_clone)
|
||||
.await
|
||||
{
|
||||
Ok(rx) => Ok(Ok(rx)), // Outer Ok, Inner Ok: Success
|
||||
Err(e) => {
|
||||
if retry_condition_clone(&e) {
|
||||
// Check if error is retriable
|
||||
Err(e) // Outer Err: Signal retry
|
||||
} else {
|
||||
// Outer Ok, Inner Err: Permanent failure, stop retrying
|
||||
Ok(Err(e))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.await;
|
||||
// --- End Retry Logic ---
|
||||
|
||||
// Get the streaming response from the LLM
|
||||
let mut stream_rx = match agent
|
||||
.llm_client
|
||||
.stream_chat_completion(request.clone())
|
||||
.await
|
||||
{
|
||||
Ok(rx) => rx,
|
||||
Err(e) => {
|
||||
// --- Added Error Handling ---
|
||||
let error_message = format!("Error starting LLM stream: {:?}", e);
|
||||
// Handle the nested result from the retry logic
|
||||
let mut stream_rx: mpsc::Receiver<Result<ChatCompletionChunk>> = match stream_rx_result {
|
||||
Ok(Ok(rx)) => rx, // Success case
|
||||
Ok(Err(permanent_error)) => {
|
||||
// Permanent error case (non-retriable)
|
||||
let error_message = format!(
|
||||
"Error starting LLM stream (non-retriable): {:?}",
|
||||
permanent_error
|
||||
);
|
||||
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message);
|
||||
// Log error in span
|
||||
// Log etc. as before...
|
||||
if let Some(parent_span) = parent_span.clone() {
|
||||
if let Some(client) = &*BRAINTRUST_CLIENT {
|
||||
let error_span = parent_span.with_output(serde_json::json!({
|
||||
"error": format!("Error starting stream: {:?}", e)
|
||||
"error": error_message
|
||||
}));
|
||||
|
||||
// Log span non-blockingly (client handles the background processing)
|
||||
if let Err(log_err) = client.log_span(error_span).await {
|
||||
error!("Failed to log error span: {}", log_err);
|
||||
}
|
||||
}
|
||||
}
|
||||
// --- End Added Error Handling ---
|
||||
return Err(anyhow::anyhow!(error_message)); // Return immediately
|
||||
return Err(permanent_error); // Return the permanent error
|
||||
}
|
||||
Err(last_retriable_error) => {
|
||||
// Error after retries exhausted
|
||||
let error_message = format!(
|
||||
"Error starting LLM stream after multiple retries: {:?}",
|
||||
last_retriable_error
|
||||
);
|
||||
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message);
|
||||
// Log etc. as before...
|
||||
if let Some(parent_span) = parent_span.clone() {
|
||||
if let Some(client) = &*BRAINTRUST_CLIENT {
|
||||
let error_span = parent_span.with_output(serde_json::json!({
|
||||
"error": error_message
|
||||
}));
|
||||
if let Err(log_err) = client.log_span(error_span).await {
|
||||
error!("Failed to log error span: {}", log_err);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Err(last_retriable_error); // Return the last retriable error
|
||||
}
|
||||
};
|
||||
|
||||
// We store the parent span to use for creating individual tool spans
|
||||
// This avoids creating a general assistant span that would never be completed
|
||||
let parent_for_tool_spans = parent_span.clone();
|
||||
|
||||
// Process the streaming chunks
|
||||
let mut buffer = MessageBuffer::new();
|
||||
let mut _is_complete = false;
|
||||
const STREAM_TIMEOUT_SECS: u64 = 120; // Timeout after 120 seconds of inactivity
|
||||
|
||||
while let Some(chunk_result) = stream_rx.recv().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
if chunk.choices.is_empty() {
|
||||
continue;
|
||||
}
|
||||
loop {
|
||||
// Replaced `while let` with `loop` and explicit timeout
|
||||
match tokio::time::timeout(Duration::from_secs(STREAM_TIMEOUT_SECS), stream_rx.recv())
|
||||
.await
|
||||
{
|
||||
Ok(Some(chunk_result)) => {
|
||||
// Received a message within timeout
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
if chunk.choices.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
buffer.message_id = Some(chunk.id.clone());
|
||||
let delta = &chunk.choices[0].delta;
|
||||
buffer.message_id = Some(chunk.id.clone());
|
||||
let delta = &chunk.choices[0].delta;
|
||||
|
||||
// Accumulate content if present
|
||||
if let Some(content) = &delta.content {
|
||||
buffer.content.push_str(content);
|
||||
}
|
||||
// Accumulate content if present
|
||||
if let Some(content) = &delta.content {
|
||||
buffer.content.push_str(content);
|
||||
}
|
||||
|
||||
// Process tool calls if present
|
||||
if let Some(tool_calls) = &delta.tool_calls {
|
||||
for tool_call in tool_calls {
|
||||
let id = tool_call.id.clone().unwrap_or_else(|| {
|
||||
buffer
|
||||
.tool_calls
|
||||
.keys()
|
||||
.next()
|
||||
.cloned()
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string())
|
||||
});
|
||||
// Process tool calls if present
|
||||
if let Some(tool_calls) = &delta.tool_calls {
|
||||
for tool_call in tool_calls {
|
||||
let id = tool_call.id.clone().unwrap_or_else(|| {
|
||||
buffer
|
||||
.tool_calls
|
||||
.keys()
|
||||
.next()
|
||||
.cloned()
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string())
|
||||
});
|
||||
|
||||
// Get or create the pending tool call
|
||||
let pending_call = buffer.tool_calls.entry(id.clone()).or_default();
|
||||
// Get or create the pending tool call
|
||||
let pending_call =
|
||||
buffer.tool_calls.entry(id.clone()).or_default();
|
||||
|
||||
// Update the pending call with the delta
|
||||
pending_call.update_from_delta(tool_call);
|
||||
}
|
||||
}
|
||||
// Update the pending call with the delta
|
||||
pending_call.update_from_delta(tool_call);
|
||||
}
|
||||
}
|
||||
|
||||
// Check if we should flush the buffer
|
||||
if buffer.should_flush() {
|
||||
buffer.flush(&agent).await?;
|
||||
}
|
||||
// Check if we should flush the buffer
|
||||
if buffer.should_flush() {
|
||||
buffer.flush(&agent).await?;
|
||||
}
|
||||
|
||||
// Check if this is the final chunk
|
||||
if chunk.choices[0].finish_reason.is_some() {
|
||||
_is_complete = true;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
// --- Added Error Handling ---
|
||||
let error_message = format!("Error receiving chunk from LLM stream: {:?}", e);
|
||||
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message);
|
||||
// Log error in parent span
|
||||
if let Some(parent) = &parent_for_tool_spans {
|
||||
if let Some(client) = &*BRAINTRUST_CLIENT {
|
||||
// Create error info
|
||||
let error_info = serde_json::json!({
|
||||
"error": format!("Error in stream: {:?}", e)
|
||||
});
|
||||
|
||||
// Log error as output to parent span
|
||||
let error_span = parent.clone().with_output(error_info);
|
||||
|
||||
// Log span non-blockingly (client handles the background processing)
|
||||
if let Err(log_err) = client.log_span(error_span).await {
|
||||
error!("Failed to log stream error span: {}", log_err);
|
||||
// Check if this is the final chunk
|
||||
if chunk.choices[0].finish_reason.is_some() {
|
||||
_is_complete = true;
|
||||
// Don't break here yet, let the loop condition handle it
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
// Format the error string
|
||||
let error_message =
|
||||
format!("Error receiving chunk from LLM stream: {:?}", e);
|
||||
|
||||
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message);
|
||||
// Log error in span
|
||||
if let Some(parent_span) = parent_span.clone() {
|
||||
if let Some(client) = &*BRAINTRUST_CLIENT {
|
||||
let error_span = parent_span.with_output(serde_json::json!({
|
||||
"error": error_message // Use formatted string
|
||||
}));
|
||||
|
||||
// Log span non-blockingly (client handles the background processing)
|
||||
if let Err(log_err) = client.log_span(error_span).await {
|
||||
error!("Failed to log stream error span: {}", log_err);
|
||||
}
|
||||
}
|
||||
}
|
||||
// --- End Updated Error Handling ---
|
||||
// Send string error over broadcast channel before returning
|
||||
let agent_error = AgentError(error_message.clone()); // Create string error
|
||||
if let Ok(sender) = agent.get_stream_sender().await {
|
||||
// clone() is now valid for AgentError(String)
|
||||
if let Err(send_err) = sender.send(Err(agent_error)) {
|
||||
warn!("Failed to send stream error over channel: {}", send_err);
|
||||
}
|
||||
} else {
|
||||
warn!("Stream sender not available for sending error.");
|
||||
}
|
||||
// Return anyhow::Error as before
|
||||
return Err(anyhow::anyhow!(error_message));
|
||||
}
|
||||
}
|
||||
// --- End Added Error Handling ---
|
||||
return Err(anyhow::anyhow!(error_message)); // Return immediately
|
||||
}
|
||||
Ok(None) => {
|
||||
// Stream closed gracefully
|
||||
break;
|
||||
}
|
||||
Err(_) => {
|
||||
// Timeout occurred
|
||||
// Format the timeout message
|
||||
let timeout_msg = format!(
|
||||
"LLM stream timed out after {} seconds of inactivity.",
|
||||
STREAM_TIMEOUT_SECS
|
||||
);
|
||||
warn!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", timeout_msg);
|
||||
|
||||
// Send string timeout error over broadcast channel
|
||||
let agent_error = AgentError(timeout_msg.clone()); // Create string error
|
||||
if let Ok(sender) = agent.get_stream_sender().await {
|
||||
if let Err(send_err) = sender.send(Err(agent_error)) {
|
||||
warn!("Failed to send timeout error over channel: {}", send_err);
|
||||
}
|
||||
} else {
|
||||
warn!("Stream sender not available for sending timeout error.");
|
||||
}
|
||||
// We could return an error here, or just break and let the agent finish
|
||||
// For now, let's break and proceed with whatever was buffered
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -829,7 +967,7 @@ impl Agent {
|
|||
// Ensure we don't block if the receiver dropped
|
||||
// Handle the Result from get_stream_sender
|
||||
if let Ok(sender) = agent.get_stream_sender().await {
|
||||
if let Err(e) = sender.send(Ok(final_message.clone())) {
|
||||
if let Err(e) = sender.send(Ok(final_message.clone())) {
|
||||
tracing::debug!(
|
||||
"Failed to send final assistant message (receiver likely dropped): {}",
|
||||
e
|
||||
|
@ -851,8 +989,8 @@ impl Agent {
|
|||
.as_ref()
|
||||
.cloned()
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!("Failed to get updated thread state after adding assistant message")
|
||||
})?;
|
||||
anyhow::anyhow!("Failed to get updated thread state after adding assistant message")
|
||||
})?;
|
||||
|
||||
// --- Tool Execution Logic ---
|
||||
// If the LLM wants to use tools, execute them
|
||||
|
@ -911,7 +1049,7 @@ impl Agent {
|
|||
tool_call.function.name, e
|
||||
);
|
||||
error!("{}", err_msg);
|
||||
// Optionally log to Braintrust span here
|
||||
// Return anyhow::Error as before
|
||||
return Err(anyhow::anyhow!(err_msg));
|
||||
}
|
||||
};
|
||||
|
@ -924,59 +1062,90 @@ impl Agent {
|
|||
"id": tool_call.id
|
||||
});
|
||||
|
||||
// Execute the tool using the executor from RegisteredTool
|
||||
let result = match registered_tool
|
||||
.executor
|
||||
.execute(params, tool_call.id.clone())
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
// --- Tool Execution with Timeout ---
|
||||
const TOOL_TIMEOUT_SECS: u64 = 60; // Timeout for tool execution
|
||||
let tool_execution_result = tokio::time::timeout(
|
||||
Duration::from_secs(TOOL_TIMEOUT_SECS),
|
||||
registered_tool
|
||||
.executor
|
||||
.execute(params, tool_call.id.clone()),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Process tool execution result (timeout or actual result/error)
|
||||
let result: Result<Value> = match tool_execution_result {
|
||||
Ok(Ok(r)) => Ok(r), // Tool executed successfully within timeout
|
||||
Ok(Err(e)) => Err(e), // Tool returned an error within timeout
|
||||
Err(_) => {
|
||||
// Tool execution timed out
|
||||
let timeout_msg = format!(
|
||||
"Tool '{}' timed out after {} seconds.",
|
||||
tool_call.function.name, TOOL_TIMEOUT_SECS
|
||||
);
|
||||
warn!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, tool_name = %tool_call.function.name, "{}", timeout_msg);
|
||||
// Return an error indicating timeout, wrapped in anyhow
|
||||
Err(anyhow::anyhow!(format!(
|
||||
"Tool '{}' timed out after {} seconds.",
|
||||
tool_call.function.name, TOOL_TIMEOUT_SECS
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
// Handle the result (success, error, or timeout error)
|
||||
let tool_message = match result {
|
||||
Ok(r) => {
|
||||
// Tool succeeded
|
||||
let result_str = serde_json::to_string(&r)?;
|
||||
AgentMessage::tool(
|
||||
None,
|
||||
result_str.clone(),
|
||||
tool_call.id.clone(),
|
||||
Some(tool_call.function.name.clone()),
|
||||
MessageProgress::Complete,
|
||||
)
|
||||
}
|
||||
Err(e) => {
|
||||
// --- Added Error Handling ---
|
||||
// Tool failed (either execution error or timeout)
|
||||
// Error `e` is already anyhow::Error here
|
||||
let error_message = format!(
|
||||
"Tool execution error for {}: {:?}",
|
||||
"Tool execution failed for {}: {:?}",
|
||||
tool_call.function.name, e
|
||||
);
|
||||
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, tool_name = %tool_call.function.name, "{}", error_message);
|
||||
// Log error in tool span
|
||||
|
||||
// Log error differently for timeout vs execution error if needed
|
||||
error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, tool_name = %tool_call.function.name, "{}", error_message);
|
||||
|
||||
// Log error in Braintrust span
|
||||
if let Some(tool_span) = &tool_span {
|
||||
if let Some(client) = &*BRAINTRUST_CLIENT {
|
||||
let error_info = serde_json::json!({
|
||||
"error": format!("Tool execution error: {:?}", e)
|
||||
"error": error_message // Generic failure message
|
||||
});
|
||||
|
||||
// Create a new span with the error output
|
||||
let error_span = tool_span.clone().with_output(error_info);
|
||||
|
||||
// Log span non-blockingly (client handles the background processing)
|
||||
if let Err(log_err) = client.log_span(error_span).await {
|
||||
error!(
|
||||
"Failed to log tool execution error span: {}",
|
||||
"Failed to log tool execution failure span: {}",
|
||||
log_err
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
// --- End Added Error Handling ---
|
||||
let error_message = format!(
|
||||
"Tool execution error for {}: {:?}",
|
||||
tool_call.function.name, e
|
||||
);
|
||||
error!("{}", error_message); // Log locally
|
||||
return Err(anyhow::anyhow!(error_message)); // Return immediately
|
||||
// --- End Braintrust Logging ---
|
||||
|
||||
// Create an error tool message to send back to the LLM
|
||||
AgentMessage::tool(
|
||||
None,
|
||||
serde_json::json!({ "error": error_message }).to_string(), // Send descriptive error string
|
||||
tool_call.id.clone(),
|
||||
Some(tool_call.function.name.clone()),
|
||||
MessageProgress::Complete,
|
||||
)
|
||||
// Note: We are NOT returning the error here, instead we send
|
||||
// the error back as a tool result message to the LLM.
|
||||
}
|
||||
};
|
||||
|
||||
let result_str = serde_json::to_string(&result)?;
|
||||
let tool_message = AgentMessage::tool(
|
||||
None,
|
||||
result_str.clone(),
|
||||
tool_call.id.clone(),
|
||||
Some(tool_call.function.name.clone()),
|
||||
MessageProgress::Complete,
|
||||
);
|
||||
|
||||
// Log the combined assistant+tool span with the tool result as output
|
||||
// Log the combined assistant+tool span with the tool result/error as output
|
||||
if let Some(tool_span) = &tool_span {
|
||||
if let Some(client) = &*BRAINTRUST_CLIENT {
|
||||
// Only log completed messages
|
||||
|
@ -1034,10 +1203,11 @@ impl Agent {
|
|||
tool_call.function.name
|
||||
);
|
||||
error!("{}", err_msg);
|
||||
// Create a fake tool result indicating the error
|
||||
|
||||
// Create a fake tool result indicating the error (string based)
|
||||
let error_result = AgentMessage::tool(
|
||||
None,
|
||||
serde_json::json!({"error": err_msg}).to_string(),
|
||||
serde_json::json!({ "error": err_msg.clone() }).to_string(), // Use the string message
|
||||
tool_call.id.clone(),
|
||||
Some(tool_call.function.name.clone()),
|
||||
MessageProgress::Complete,
|
||||
|
@ -1045,14 +1215,24 @@ impl Agent {
|
|||
// Broadcast the error message
|
||||
// Handle the Result from get_stream_sender
|
||||
if let Ok(sender) = agent.get_stream_sender().await {
|
||||
if let Err(e) = sender.send(Ok(error_result.clone())) {
|
||||
if let Err(e) = sender.send(Ok(error_result.clone())) {
|
||||
tracing::debug!(
|
||||
"Failed to send tool error message (receiver likely dropped): {}",
|
||||
e
|
||||
);
|
||||
}
|
||||
// Also send the specific error type over the channel
|
||||
if let Err(e) = sender.send(Err(AgentError(err_msg))) {
|
||||
// Send string error
|
||||
tracing::warn!(
|
||||
"Failed to send tool not found error over channel: {}",
|
||||
e
|
||||
);
|
||||
}
|
||||
} else {
|
||||
tracing::debug!("Stream sender not available when sending tool error message.");
|
||||
tracing::debug!(
|
||||
"Stream sender not available when sending tool error message."
|
||||
);
|
||||
}
|
||||
// Update thread and push the error result for the next LLM call
|
||||
agent.update_current_thread(error_result.clone()).await?;
|
||||
|
@ -1268,7 +1448,6 @@ pub trait AgentExt {
|
|||
(*self.get_agent_arc()).get_current_thread().await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
@ -1289,7 +1468,10 @@ mod tests {
|
|||
|
||||
#[async_trait::async_trait]
|
||||
impl ModeProvider for MockModeProvider {
|
||||
async fn get_configuration_for_state(&self, _state: &HashMap<String, Value>) -> Result<ModeConfiguration> {
|
||||
async fn get_configuration_for_state(
|
||||
&self,
|
||||
_state: &HashMap<String, Value>,
|
||||
) -> Result<ModeConfiguration> {
|
||||
// Return a default/empty configuration for testing basic agent functions
|
||||
Ok(ModeConfiguration {
|
||||
prompt: "Test Prompt".to_string(),
|
||||
|
@ -1400,7 +1582,7 @@ mod tests {
|
|||
"test_agent_no_tools".to_string(),
|
||||
env::var("LLM_API_KEY").ok(),
|
||||
env::var("LLM_BASE_URL").ok(),
|
||||
mock_provider,
|
||||
mock_provider,
|
||||
));
|
||||
|
||||
let thread = AgentThread::new(
|
||||
|
@ -1519,7 +1701,7 @@ mod tests {
|
|||
"test_agent_disabled".to_string(),
|
||||
env::var("LLM_API_KEY").ok(),
|
||||
env::var("LLM_BASE_URL").ok(),
|
||||
mock_provider,
|
||||
mock_provider,
|
||||
));
|
||||
|
||||
// Create weather tool
|
||||
|
@ -1643,3 +1825,4 @@ mod tests {
|
|||
assert_eq!(agent.get_state_bool("bool_key").await, None);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -150,7 +150,7 @@ impl BusterMultiAgent {
|
|||
|
||||
// Create agent, passing the provider
|
||||
let agent = Arc::new(Agent::new(
|
||||
"o4-mini".to_string(), // Initial model (can be overridden by first mode)
|
||||
"gemini-2.5-pro-exp-03-25".to_string(), // Initial model (can be overridden by first mode)
|
||||
user_id,
|
||||
session_id,
|
||||
"buster_multi_agent".to_string(),
|
||||
|
|
|
@ -30,7 +30,7 @@ pub fn get_configuration(agent_data: &ModeAgentData) -> ModeConfiguration {
|
|||
// Note: This prompt doesn't use {DATASETS}
|
||||
|
||||
// 2. Define the model for this mode (Using default based on original MODEL = None)
|
||||
let model = "o4-mini".to_string();
|
||||
let model = "gemini-2.5-pro-exp-03-25".to_string();
|
||||
|
||||
// 3. Define the tool loader closure
|
||||
let tool_loader: Box<
|
||||
|
|
|
@ -32,7 +32,7 @@ pub fn get_configuration(agent_data: &ModeAgentData) -> ModeConfiguration {
|
|||
// Note: This prompt doesn't use {TODAYS_DATE}
|
||||
|
||||
// 2. Define the model for this mode
|
||||
let model = "o4-mini".to_string(); // Use o4-mini as requested
|
||||
let model = "gemini-2.5-pro-exp-03-25".to_string(); // Use gemini-2.5-pro-exp-03-25 as requested
|
||||
|
||||
// 3. Define the tool loader closure
|
||||
let tool_loader: Box<dyn Fn(&Arc<Agent>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync> =
|
||||
|
|
|
@ -42,7 +42,7 @@ pub fn get_configuration(agent_data: &ModeAgentData) -> ModeConfiguration {
|
|||
.replace("{TODAYS_DATE}", &agent_data.todays_date);
|
||||
|
||||
// 2. Define the model for this mode (Using a default, adjust if needed)
|
||||
let model = "o4-mini".to_string(); // Assuming default based on original MODEL = None
|
||||
let model = "gemini-2.5-pro-exp-03-25".to_string(); // Assuming default based on original MODEL = None
|
||||
|
||||
// 3. Define the tool loader closure
|
||||
let tool_loader: Box<dyn Fn(&Arc<Agent>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync> =
|
||||
|
|
|
@ -26,8 +26,8 @@ pub fn get_configuration(agent_data: &ModeAgentData) -> ModeConfiguration {
|
|||
|
||||
// 2. Define the model for this mode (Using a default, adjust if needed)
|
||||
// Since the original MODEL was None, we might use the agent's default
|
||||
// or specify a standard one like "o4-mini". Let's use "o4-mini".
|
||||
let model = "o4-mini".to_string();
|
||||
// or specify a standard one like "gemini-2.5-pro-exp-03-25". Let's use "gemini-2.5-pro-exp-03-25".
|
||||
let model = "gemini-2.5-pro-exp-03-25".to_string();
|
||||
|
||||
// 3. Define the tool loader closure
|
||||
let tool_loader: Box<
|
||||
|
|
|
@ -31,7 +31,7 @@ pub struct ModeAgentData {
|
|||
pub struct ModeConfiguration {
|
||||
/// The system prompt to use for the LLM call in this mode.
|
||||
pub prompt: String,
|
||||
/// The specific LLM model identifier (e.g., "o4-mini") to use for this mode.
|
||||
/// The specific LLM model identifier (e.g., "gemini-2.5-pro-exp-03-25") to use for this mode.
|
||||
pub model: String,
|
||||
/// An async function/closure responsible for clearing existing tools
|
||||
/// and loading the specific tools required for this mode onto the agent.
|
||||
|
|
|
@ -28,7 +28,7 @@ pub fn get_configuration(agent_data: &ModeAgentData) -> ModeConfiguration {
|
|||
.replace("{DATASETS}", &agent_data.dataset_with_descriptions.join("\n\n"));
|
||||
|
||||
// 2. Define the model for this mode (Using default based on original MODEL = None)
|
||||
let model = "o4-mini".to_string();
|
||||
let model = "gemini-2.5-pro-exp-03-25".to_string();
|
||||
|
||||
// 3. Define the tool loader closure
|
||||
let tool_loader: Box<
|
||||
|
|
|
@ -673,7 +673,7 @@ mod tests {
|
|||
fn test_tool_parameter_validation() {
|
||||
let tool = FilterDashboardsTool {
|
||||
agent: Arc::new(Agent::new(
|
||||
"o4-mini".to_string(),
|
||||
"gemini-2.5-pro-exp-03-25".to_string(),
|
||||
HashMap::new(),
|
||||
Uuid::new_v4(),
|
||||
Uuid::new_v4(),
|
||||
|
|
Loading…
Reference in New Issue