Merge pull request #239 from buster-so/dal/agent-resiliency

Dal/agent resiliency
This commit is contained in:
dal 2025-04-28 07:09:52 -07:00 committed by GitHub
commit d06d3deadb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 345 additions and 159 deletions

View File

@ -106,6 +106,7 @@ rayon = "1.10.0"
diesel_migrations = "2.0.0"
html-escape = "0.2.13"
tokio-cron-scheduler = "0.13.0"
tokio-retry = "0.3.0"
[profile.release]
debug = false

View File

@ -32,6 +32,8 @@ redis = { workspace = true }
reqwest = { workspace = true }
sqlx = { workspace = true }
stored_values = { path = "../stored_values" }
tokio-retry = { workspace = true }
thiserror = { workspace = true }
# Development dependencies
[dev-dependencies]

View File

@ -2,15 +2,16 @@ use crate::tools::{IntoToolCallExecutor, ToolExecutor};
use anyhow::Result;
use braintrust::{BraintrustClient, TraceBuilder};
use litellm::{
AgentMessage, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient,
MessageProgress, Metadata, Tool, ToolCall, ToolChoice,
AgentMessage, ChatCompletionChunk, ChatCompletionRequest, DeltaToolCall, FunctionCall,
LiteLLMClient, MessageProgress, Metadata, Tool, ToolCall, ToolChoice,
};
use once_cell::sync::Lazy;
use serde_json::Value;
use std::time::{Duration, Instant};
use std::{collections::HashMap, env, sync::Arc};
use tokio::sync::{broadcast, RwLock};
use tracing::error;
use tokio::sync::{broadcast, mpsc, RwLock};
use tokio_retry::{strategy::ExponentialBackoff, Retry};
use tracing::{error, warn};
use uuid::Uuid;
// Type definition for tool registry to simplify complex type
@ -37,6 +38,7 @@ static BRAINTRUST_CLIENT: Lazy<Option<Arc<BraintrustClient>>> = Lazy::new(|| {
}
});
// --- Reverted AgentError Struct ---
#[derive(Debug, Clone)]
pub struct AgentError(pub String);
@ -47,6 +49,7 @@ impl std::fmt::Display for AgentError {
write!(f, "{}", self.0)
}
}
// --- End Reverted AgentError Struct ---
type MessageResult = Result<AgentMessage, AgentError>;
@ -150,7 +153,10 @@ type ToolsMap = Arc<RwLock<HashMap<String, RegisteredTool>>>;
#[async_trait::async_trait]
pub trait ModeProvider {
// Fetches the complete configuration for a given agent state
async fn get_configuration_for_state(&self, state: &HashMap<String, Value>) -> Result<ModeConfiguration>;
async fn get_configuration_for_state(
&self,
state: &HashMap<String, Value>,
) -> Result<ModeConfiguration>;
}
// --- End ModeProvider Trait ---
@ -216,7 +222,7 @@ impl Agent {
shutdown_tx: Arc::new(RwLock::new(shutdown_tx)),
name,
terminating_tool_names: Arc::new(RwLock::new(Vec::new())), // Initialize empty list
mode_provider, // Store the provider
mode_provider, // Store the provider
}
}
@ -243,7 +249,7 @@ impl Agent {
shutdown_tx: Arc::clone(&existing_agent.shutdown_tx), // Shared shutdown
name,
terminating_tool_names: Arc::new(RwLock::new(Vec::new())), // Sub-agent starts with empty term tools?
mode_provider: Arc::clone(&mode_provider), // Share provider
mode_provider: Arc::clone(&mode_provider), // Share provider
}
}
@ -273,10 +279,14 @@ impl Agent {
/// Get a new receiver for the broadcast channel.
/// Returns an error if the stream channel has been closed or was not initialized.
pub async fn get_stream_receiver(&self) -> Result<broadcast::Receiver<MessageResult>, AgentError> {
pub async fn get_stream_receiver(
&self,
) -> Result<broadcast::Receiver<MessageResult>, AgentError> {
match self.stream_tx.read().await.as_ref() {
Some(tx) => Ok(tx.subscribe()),
None => Err(AgentError("Stream channel is closed or not initialized.".to_string()))
None => Err(AgentError(
"Stream channel is closed or not initialized.".to_string(),
)), // Use string error
}
}
@ -285,7 +295,9 @@ impl Agent {
pub async fn get_stream_sender(&self) -> Result<broadcast::Sender<MessageResult>, AgentError> {
match self.stream_tx.read().await.as_ref() {
Some(tx) => Ok(tx.clone()),
None => Err(AgentError("Stream channel is closed or not initialized.".to_string()))
None => Err(AgentError(
"Stream channel is closed or not initialized.".to_string(),
)), // Use string error
}
}
@ -477,12 +489,14 @@ impl Agent {
tokio::select! {
result = Agent::process_thread_with_depth(agent_arc_clone, thread_clone.clone(), &thread_clone, 0, None, None) => {
if let Err(e) = result {
// Log the error
let err_msg = format!("Error processing thread: {:?}", e);
error!("{}", err_msg); // Log the error
error!("{}", err_msg);
// Use the clone created before select!
// Handle the Result from get_stream_sender
let agent_error = AgentError(err_msg); // Use reverted struct
if let Ok(sender) = agent_clone_for_post_process.get_stream_sender().await {
if let Err(send_err) = sender.send(Err(AgentError(err_msg.clone()))) {
if let Err(send_err) = sender.send(Err(agent_error)) {
tracing::warn!("Failed to send error message to stream: {}", send_err);
}
} else {
@ -532,7 +546,11 @@ impl Agent {
});
// Handle the Result from get_stream_receiver
agent_for_ok.get_stream_receiver().await.map_err(|e| e.into())
// Add mapping back for the outer function signature
agent_for_ok
.get_stream_receiver()
.await
.map_err(anyhow::Error::from)
}
async fn process_thread_with_depth(
@ -603,9 +621,11 @@ impl Agent {
// Limit recursion to a maximum of 15 times
if recursion_depth >= 15 {
let max_depth_msg = format!("Maximum recursion depth ({}) reached.", recursion_depth);
warn!("{}", max_depth_msg);
let message = AgentMessage::assistant(
Some("max_recursion_depth_message".to_string()),
Some("I apologize, but I've reached the maximum number of actions (15). Please try breaking your request into smaller parts.".to_string()),
Some(max_depth_msg.clone()), // Send the message string
None,
MessageProgress::Complete,
None,
@ -613,29 +633,42 @@ impl Agent {
);
// Handle the Result from get_stream_sender
if let Ok(sender) = agent.get_stream_sender().await {
// Send the Ok message first
if let Err(e) = sender.send(Ok(message)) {
tracing::warn!(
"Channel send error when sending recursion limit message: {}",
warn!(
"Channel send error when sending max recursion depth message: {}",
e
);
}
// Send the error itself over the channel
if let Err(e) = sender.send(Err(AgentError(max_depth_msg))) {
// Send string error
warn!(
"Channel send error when sending max recursion depth error: {}",
e
);
}
} else {
tracing::warn!("Stream sender not available when sending recursion limit message.");
warn!("Stream sender not available when sending max recursion depth info.");
}
agent.close().await; // Ensure stream is closed
return Ok(()); // Don't return error, just stop processing
return Ok(()); // Stop processing gracefully, error sent via channel
}
// --- Fetch and Apply Mode Configuration ---
let state = agent.get_state().await;
let mode_config = agent.mode_provider.get_configuration_for_state(&state).await?;
let mode_config = agent
.mode_provider
.get_configuration_for_state(&state)
.await?;
// Apply Tool Loading via the closure provided by the mode
agent.clear_tools().await; // Clear previous mode's tools
(mode_config.tool_loader)(&agent).await?; // Explicitly cast self
// Apply Terminating Tools for this mode
{ // Scope for write lock
{
// Scope for write lock
let mut term_tools_lock = agent.terminating_tool_names.write().await;
term_tools_lock.clear();
term_tools_lock.extend(mode_config.terminating_tools);
@ -647,7 +680,8 @@ impl Agent {
let system_message = AgentMessage::developer(mode_config.prompt);
let mut llm_messages = vec![system_message];
llm_messages.extend(
agent.current_thread // Use self.current_thread which is updated
agent
.current_thread // Use self.current_thread which is updated
.read()
.await
.as_ref()
@ -676,7 +710,7 @@ impl Agent {
messages: llm_messages,
tools: if tools.is_empty() { None } else { Some(tools) },
tool_choice: Some(ToolChoice::Required), // Or adjust based on mode?
stream: Some(true), // Enable streaming
stream: Some(true), // Enable streaming
metadata: Some(Metadata {
generation_name: "agent".to_string(),
user_id: thread_ref.user_id.to_string(),
@ -687,111 +721,215 @@ impl Agent {
..Default::default()
};
// --- Retry Logic for Initial Stream Request ---
let retry_strategy = ExponentialBackoff::from_millis(100).take(3); // Retry 3 times, ~100ms, ~200ms, ~400ms
// Define a condition for retrying: only on network-related errors
let retry_condition = |e: &anyhow::Error| -> bool {
if let Some(req_err) = e.downcast_ref::<reqwest::Error>() {
// Retry on specific transient errors
req_err.is_timeout() || req_err.is_connect() || req_err.is_request()
} else {
false // Don't retry if it's not a reqwest network error
}
};
// The retry operation now wraps the actual result or a permanent error in an outer Ok
// Retriable errors are returned as the Err variant for Retry::spawn
let stream_rx_result = Retry::spawn(retry_strategy, || {
// Clone necessary data for the closure
let agent_clone = agent.clone();
let request_clone = request.clone();
let retry_condition_clone = retry_condition; // Clone the condition closure
async move {
match agent_clone
.llm_client
.stream_chat_completion(request_clone)
.await
{
Ok(rx) => Ok(Ok(rx)), // Outer Ok, Inner Ok: Success
Err(e) => {
if retry_condition_clone(&e) {
// Check if error is retriable
Err(e) // Outer Err: Signal retry
} else {
// Outer Ok, Inner Err: Permanent failure, stop retrying
Ok(Err(e))
}
}
}
}
})
.await;
// --- End Retry Logic ---
// Get the streaming response from the LLM
let mut stream_rx = match agent
.llm_client
.stream_chat_completion(request.clone())
.await
{
Ok(rx) => rx,
Err(e) => {
// --- Added Error Handling ---
let error_message = format!("Error starting LLM stream: {:?}", e);
// Handle the nested result from the retry logic
let mut stream_rx: mpsc::Receiver<Result<ChatCompletionChunk>> = match stream_rx_result {
Ok(Ok(rx)) => rx, // Success case
Ok(Err(permanent_error)) => {
// Permanent error case (non-retriable)
let error_message = format!(
"Error starting LLM stream (non-retriable): {:?}",
permanent_error
);
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message);
// Log error in span
// Log etc. as before...
if let Some(parent_span) = parent_span.clone() {
if let Some(client) = &*BRAINTRUST_CLIENT {
let error_span = parent_span.with_output(serde_json::json!({
"error": format!("Error starting stream: {:?}", e)
"error": error_message
}));
// Log span non-blockingly (client handles the background processing)
if let Err(log_err) = client.log_span(error_span).await {
error!("Failed to log error span: {}", log_err);
}
}
}
// --- End Added Error Handling ---
return Err(anyhow::anyhow!(error_message)); // Return immediately
return Err(permanent_error); // Return the permanent error
}
Err(last_retriable_error) => {
// Error after retries exhausted
let error_message = format!(
"Error starting LLM stream after multiple retries: {:?}",
last_retriable_error
);
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message);
// Log etc. as before...
if let Some(parent_span) = parent_span.clone() {
if let Some(client) = &*BRAINTRUST_CLIENT {
let error_span = parent_span.with_output(serde_json::json!({
"error": error_message
}));
if let Err(log_err) = client.log_span(error_span).await {
error!("Failed to log error span: {}", log_err);
}
}
}
return Err(last_retriable_error); // Return the last retriable error
}
};
// We store the parent span to use for creating individual tool spans
// This avoids creating a general assistant span that would never be completed
let parent_for_tool_spans = parent_span.clone();
// Process the streaming chunks
let mut buffer = MessageBuffer::new();
let mut _is_complete = false;
const STREAM_TIMEOUT_SECS: u64 = 120; // Timeout after 120 seconds of inactivity
while let Some(chunk_result) = stream_rx.recv().await {
match chunk_result {
Ok(chunk) => {
if chunk.choices.is_empty() {
continue;
}
loop {
// Replaced `while let` with `loop` and explicit timeout
match tokio::time::timeout(Duration::from_secs(STREAM_TIMEOUT_SECS), stream_rx.recv())
.await
{
Ok(Some(chunk_result)) => {
// Received a message within timeout
match chunk_result {
Ok(chunk) => {
if chunk.choices.is_empty() {
continue;
}
buffer.message_id = Some(chunk.id.clone());
let delta = &chunk.choices[0].delta;
buffer.message_id = Some(chunk.id.clone());
let delta = &chunk.choices[0].delta;
// Accumulate content if present
if let Some(content) = &delta.content {
buffer.content.push_str(content);
}
// Accumulate content if present
if let Some(content) = &delta.content {
buffer.content.push_str(content);
}
// Process tool calls if present
if let Some(tool_calls) = &delta.tool_calls {
for tool_call in tool_calls {
let id = tool_call.id.clone().unwrap_or_else(|| {
buffer
.tool_calls
.keys()
.next()
.cloned()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string())
});
// Process tool calls if present
if let Some(tool_calls) = &delta.tool_calls {
for tool_call in tool_calls {
let id = tool_call.id.clone().unwrap_or_else(|| {
buffer
.tool_calls
.keys()
.next()
.cloned()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string())
});
// Get or create the pending tool call
let pending_call = buffer.tool_calls.entry(id.clone()).or_default();
// Get or create the pending tool call
let pending_call =
buffer.tool_calls.entry(id.clone()).or_default();
// Update the pending call with the delta
pending_call.update_from_delta(tool_call);
}
}
// Update the pending call with the delta
pending_call.update_from_delta(tool_call);
}
}
// Check if we should flush the buffer
if buffer.should_flush() {
buffer.flush(&agent).await?;
}
// Check if we should flush the buffer
if buffer.should_flush() {
buffer.flush(&agent).await?;
}
// Check if this is the final chunk
if chunk.choices[0].finish_reason.is_some() {
_is_complete = true;
}
}
Err(e) => {
// --- Added Error Handling ---
let error_message = format!("Error receiving chunk from LLM stream: {:?}", e);
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message);
// Log error in parent span
if let Some(parent) = &parent_for_tool_spans {
if let Some(client) = &*BRAINTRUST_CLIENT {
// Create error info
let error_info = serde_json::json!({
"error": format!("Error in stream: {:?}", e)
});
// Log error as output to parent span
let error_span = parent.clone().with_output(error_info);
// Log span non-blockingly (client handles the background processing)
if let Err(log_err) = client.log_span(error_span).await {
error!("Failed to log stream error span: {}", log_err);
// Check if this is the final chunk
if chunk.choices[0].finish_reason.is_some() {
_is_complete = true;
// Don't break here yet, let the loop condition handle it
}
}
Err(e) => {
// Format the error string
let error_message =
format!("Error receiving chunk from LLM stream: {:?}", e);
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", error_message);
// Log error in span
if let Some(parent_span) = parent_span.clone() {
if let Some(client) = &*BRAINTRUST_CLIENT {
let error_span = parent_span.with_output(serde_json::json!({
"error": error_message // Use formatted string
}));
// Log span non-blockingly (client handles the background processing)
if let Err(log_err) = client.log_span(error_span).await {
error!("Failed to log stream error span: {}", log_err);
}
}
}
// --- End Updated Error Handling ---
// Send string error over broadcast channel before returning
let agent_error = AgentError(error_message.clone()); // Create string error
if let Ok(sender) = agent.get_stream_sender().await {
// clone() is now valid for AgentError(String)
if let Err(send_err) = sender.send(Err(agent_error)) {
warn!("Failed to send stream error over channel: {}", send_err);
}
} else {
warn!("Stream sender not available for sending error.");
}
// Return anyhow::Error as before
return Err(anyhow::anyhow!(error_message));
}
}
// --- End Added Error Handling ---
return Err(anyhow::anyhow!(error_message)); // Return immediately
}
Ok(None) => {
// Stream closed gracefully
break;
}
Err(_) => {
// Timeout occurred
// Format the timeout message
let timeout_msg = format!(
"LLM stream timed out after {} seconds of inactivity.",
STREAM_TIMEOUT_SECS
);
warn!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, "{}", timeout_msg);
// Send string timeout error over broadcast channel
let agent_error = AgentError(timeout_msg.clone()); // Create string error
if let Ok(sender) = agent.get_stream_sender().await {
if let Err(send_err) = sender.send(Err(agent_error)) {
warn!("Failed to send timeout error over channel: {}", send_err);
}
} else {
warn!("Stream sender not available for sending timeout error.");
}
// We could return an error here, or just break and let the agent finish
// For now, let's break and proceed with whatever was buffered
break;
}
}
}
@ -829,7 +967,7 @@ impl Agent {
// Ensure we don't block if the receiver dropped
// Handle the Result from get_stream_sender
if let Ok(sender) = agent.get_stream_sender().await {
if let Err(e) = sender.send(Ok(final_message.clone())) {
if let Err(e) = sender.send(Ok(final_message.clone())) {
tracing::debug!(
"Failed to send final assistant message (receiver likely dropped): {}",
e
@ -851,8 +989,8 @@ impl Agent {
.as_ref()
.cloned()
.ok_or_else(|| {
anyhow::anyhow!("Failed to get updated thread state after adding assistant message")
})?;
anyhow::anyhow!("Failed to get updated thread state after adding assistant message")
})?;
// --- Tool Execution Logic ---
// If the LLM wants to use tools, execute them
@ -911,7 +1049,7 @@ impl Agent {
tool_call.function.name, e
);
error!("{}", err_msg);
// Optionally log to Braintrust span here
// Return anyhow::Error as before
return Err(anyhow::anyhow!(err_msg));
}
};
@ -924,59 +1062,90 @@ impl Agent {
"id": tool_call.id
});
// Execute the tool using the executor from RegisteredTool
let result = match registered_tool
.executor
.execute(params, tool_call.id.clone())
.await
{
Ok(r) => r,
// --- Tool Execution with Timeout ---
const TOOL_TIMEOUT_SECS: u64 = 60; // Timeout for tool execution
let tool_execution_result = tokio::time::timeout(
Duration::from_secs(TOOL_TIMEOUT_SECS),
registered_tool
.executor
.execute(params, tool_call.id.clone()),
)
.await;
// Process tool execution result (timeout or actual result/error)
let result: Result<Value> = match tool_execution_result {
Ok(Ok(r)) => Ok(r), // Tool executed successfully within timeout
Ok(Err(e)) => Err(e), // Tool returned an error within timeout
Err(_) => {
// Tool execution timed out
let timeout_msg = format!(
"Tool '{}' timed out after {} seconds.",
tool_call.function.name, TOOL_TIMEOUT_SECS
);
warn!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, tool_name = %tool_call.function.name, "{}", timeout_msg);
// Return an error indicating timeout, wrapped in anyhow
Err(anyhow::anyhow!(format!(
"Tool '{}' timed out after {} seconds.",
tool_call.function.name, TOOL_TIMEOUT_SECS
)))
}
};
// Handle the result (success, error, or timeout error)
let tool_message = match result {
Ok(r) => {
// Tool succeeded
let result_str = serde_json::to_string(&r)?;
AgentMessage::tool(
None,
result_str.clone(),
tool_call.id.clone(),
Some(tool_call.function.name.clone()),
MessageProgress::Complete,
)
}
Err(e) => {
// --- Added Error Handling ---
// Tool failed (either execution error or timeout)
// Error `e` is already anyhow::Error here
let error_message = format!(
"Tool execution error for {}: {:?}",
"Tool execution failed for {}: {:?}",
tool_call.function.name, e
);
tracing::error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, tool_name = %tool_call.function.name, "{}", error_message);
// Log error in tool span
// Log error differently for timeout vs execution error if needed
error!(agent_name = %agent.name, chat_id = %agent.session_id, user_id = %agent.user_id, tool_name = %tool_call.function.name, "{}", error_message);
// Log error in Braintrust span
if let Some(tool_span) = &tool_span {
if let Some(client) = &*BRAINTRUST_CLIENT {
let error_info = serde_json::json!({
"error": format!("Tool execution error: {:?}", e)
"error": error_message // Generic failure message
});
// Create a new span with the error output
let error_span = tool_span.clone().with_output(error_info);
// Log span non-blockingly (client handles the background processing)
if let Err(log_err) = client.log_span(error_span).await {
error!(
"Failed to log tool execution error span: {}",
"Failed to log tool execution failure span: {}",
log_err
);
}
}
}
// --- End Added Error Handling ---
let error_message = format!(
"Tool execution error for {}: {:?}",
tool_call.function.name, e
);
error!("{}", error_message); // Log locally
return Err(anyhow::anyhow!(error_message)); // Return immediately
// --- End Braintrust Logging ---
// Create an error tool message to send back to the LLM
AgentMessage::tool(
None,
serde_json::json!({ "error": error_message }).to_string(), // Send descriptive error string
tool_call.id.clone(),
Some(tool_call.function.name.clone()),
MessageProgress::Complete,
)
// Note: We are NOT returning the error here, instead we send
// the error back as a tool result message to the LLM.
}
};
let result_str = serde_json::to_string(&result)?;
let tool_message = AgentMessage::tool(
None,
result_str.clone(),
tool_call.id.clone(),
Some(tool_call.function.name.clone()),
MessageProgress::Complete,
);
// Log the combined assistant+tool span with the tool result as output
// Log the combined assistant+tool span with the tool result/error as output
if let Some(tool_span) = &tool_span {
if let Some(client) = &*BRAINTRUST_CLIENT {
// Only log completed messages
@ -1034,10 +1203,11 @@ impl Agent {
tool_call.function.name
);
error!("{}", err_msg);
// Create a fake tool result indicating the error
// Create a fake tool result indicating the error (string based)
let error_result = AgentMessage::tool(
None,
serde_json::json!({"error": err_msg}).to_string(),
serde_json::json!({ "error": err_msg.clone() }).to_string(), // Use the string message
tool_call.id.clone(),
Some(tool_call.function.name.clone()),
MessageProgress::Complete,
@ -1045,14 +1215,24 @@ impl Agent {
// Broadcast the error message
// Handle the Result from get_stream_sender
if let Ok(sender) = agent.get_stream_sender().await {
if let Err(e) = sender.send(Ok(error_result.clone())) {
if let Err(e) = sender.send(Ok(error_result.clone())) {
tracing::debug!(
"Failed to send tool error message (receiver likely dropped): {}",
e
);
}
// Also send the specific error type over the channel
if let Err(e) = sender.send(Err(AgentError(err_msg))) {
// Send string error
tracing::warn!(
"Failed to send tool not found error over channel: {}",
e
);
}
} else {
tracing::debug!("Stream sender not available when sending tool error message.");
tracing::debug!(
"Stream sender not available when sending tool error message."
);
}
// Update thread and push the error result for the next LLM call
agent.update_current_thread(error_result.clone()).await?;
@ -1268,7 +1448,6 @@ pub trait AgentExt {
(*self.get_agent_arc()).get_current_thread().await
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -1289,7 +1468,10 @@ mod tests {
#[async_trait::async_trait]
impl ModeProvider for MockModeProvider {
async fn get_configuration_for_state(&self, _state: &HashMap<String, Value>) -> Result<ModeConfiguration> {
async fn get_configuration_for_state(
&self,
_state: &HashMap<String, Value>,
) -> Result<ModeConfiguration> {
// Return a default/empty configuration for testing basic agent functions
Ok(ModeConfiguration {
prompt: "Test Prompt".to_string(),
@ -1643,3 +1825,4 @@ mod tests {
assert_eq!(agent.get_state_bool("bool_key").await, None);
}
}

View File

@ -150,7 +150,7 @@ impl BusterMultiAgent {
// Create agent, passing the provider
let agent = Arc::new(Agent::new(
"o4-mini".to_string(), // Initial model (can be overridden by first mode)
"gemini-2.5-pro-exp-03-25".to_string(), // Initial model (can be overridden by first mode)
user_id,
session_id,
"buster_multi_agent".to_string(),

View File

@ -30,7 +30,7 @@ pub fn get_configuration(agent_data: &ModeAgentData) -> ModeConfiguration {
// Note: This prompt doesn't use {DATASETS}
// 2. Define the model for this mode (Using default based on original MODEL = None)
let model = "o4-mini".to_string();
let model = "gemini-2.5-pro-exp-03-25".to_string();
// 3. Define the tool loader closure
let tool_loader: Box<

View File

@ -32,7 +32,7 @@ pub fn get_configuration(agent_data: &ModeAgentData) -> ModeConfiguration {
// Note: This prompt doesn't use {TODAYS_DATE}
// 2. Define the model for this mode
let model = "o4-mini".to_string(); // Use o4-mini as requested
let model = "gemini-2.5-pro-exp-03-25".to_string(); // Use gemini-2.5-pro-exp-03-25 as requested
// 3. Define the tool loader closure
let tool_loader: Box<dyn Fn(&Arc<Agent>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync> =

View File

@ -42,7 +42,7 @@ pub fn get_configuration(agent_data: &ModeAgentData) -> ModeConfiguration {
.replace("{TODAYS_DATE}", &agent_data.todays_date);
// 2. Define the model for this mode (Using a default, adjust if needed)
let model = "o4-mini".to_string(); // Assuming default based on original MODEL = None
let model = "gemini-2.5-pro-exp-03-25".to_string(); // Assuming default based on original MODEL = None
// 3. Define the tool loader closure
let tool_loader: Box<dyn Fn(&Arc<Agent>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync> =

View File

@ -26,8 +26,8 @@ pub fn get_configuration(agent_data: &ModeAgentData) -> ModeConfiguration {
// 2. Define the model for this mode (Using a default, adjust if needed)
// Since the original MODEL was None, we might use the agent's default
// or specify a standard one like "o4-mini". Let's use "o4-mini".
let model = "o4-mini".to_string();
// or specify a standard one like "gemini-2.5-pro-exp-03-25". Let's use "gemini-2.5-pro-exp-03-25".
let model = "gemini-2.5-pro-exp-03-25".to_string();
// 3. Define the tool loader closure
let tool_loader: Box<

View File

@ -31,7 +31,7 @@ pub struct ModeAgentData {
pub struct ModeConfiguration {
/// The system prompt to use for the LLM call in this mode.
pub prompt: String,
/// The specific LLM model identifier (e.g., "o4-mini") to use for this mode.
/// The specific LLM model identifier (e.g., "gemini-2.5-pro-exp-03-25") to use for this mode.
pub model: String,
/// An async function/closure responsible for clearing existing tools
/// and loading the specific tools required for this mode onto the agent.

View File

@ -28,7 +28,7 @@ pub fn get_configuration(agent_data: &ModeAgentData) -> ModeConfiguration {
.replace("{DATASETS}", &agent_data.dataset_with_descriptions.join("\n\n"));
// 2. Define the model for this mode (Using default based on original MODEL = None)
let model = "o4-mini".to_string();
let model = "gemini-2.5-pro-exp-03-25".to_string();
// 3. Define the tool loader closure
let tool_loader: Box<

View File

@ -673,7 +673,7 @@ mod tests {
fn test_tool_parameter_validation() {
let tool = FilterDashboardsTool {
agent: Arc::new(Agent::new(
"o4-mini".to_string(),
"gemini-2.5-pro-exp-03-25".to_string(),
HashMap::new(),
Uuid::new_v4(),
Uuid::new_v4(),