use crate::tools::{IntoToolCallExecutor, ToolExecutor}; use anyhow::Result; use litellm::{ AgentMessage, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient, MessageProgress, Metadata, Tool, ToolCall, ToolChoice, }; use serde_json::Value; use std::{collections::HashMap, env, sync::Arc}; use tokio::sync::{broadcast, RwLock}; use uuid::Uuid; use std::time::{Duration, Instant}; use crate::models::AgentThread; #[derive(Debug, Clone)] pub struct AgentError(pub String); impl std::error::Error for AgentError {} impl std::fmt::Display for AgentError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) } } type MessageResult = Result; #[derive(Debug)] struct MessageBuffer { content: String, tool_calls: HashMap, last_flush: Instant, message_id: Option, first_message_sent: bool, } impl MessageBuffer { fn new() -> Self { Self { content: String::new(), tool_calls: HashMap::new(), last_flush: Instant::now(), message_id: None, first_message_sent: false, } } fn should_flush(&self) -> bool { self.last_flush.elapsed() >= Duration::from_millis(50) } fn has_changes(&self) -> bool { !self.content.is_empty() || !self.tool_calls.is_empty() } async fn flush(&mut self, agent: &Agent) -> Result<()> { if !self.has_changes() { return Ok(()); } // Create tool calls vector if we have any let tool_calls: Option> = if !self.tool_calls.is_empty() { Some( self.tool_calls .values() .filter_map(|p| { if p.function_name.is_some() { Some(p.clone().into_tool_call()) } else { None } }) .collect(), ) } else { None }; // Create and send the message let message = AgentMessage::assistant( self.message_id.clone(), if self.content.is_empty() { None } else { Some(self.content.clone()) }, tool_calls, MessageProgress::InProgress, Some(!self.first_message_sent), Some(agent.name.clone()), ); agent.get_stream_sender().await.send(Ok(message))?; // Update state self.first_message_sent = true; self.last_flush = Instant::now(); self.content.clear(); // Clear content but keep tool calls as they may still be accumulating Ok(()) } } #[derive(Clone)] /// The Agent struct is responsible for managing conversations with the LLM /// and coordinating tool executions. It maintains a registry of available tools /// and handles the recursive nature of tool calls. pub struct Agent { /// Client for communicating with the LLM provider llm_client: LiteLLMClient, /// Registry of available tools, mapped by their names tools: Arc< RwLock< HashMap + Send + Sync>>, >, >, /// The model identifier to use (e.g., "gpt-4") model: String, /// Flexible state storage for maintaining memory across interactions state: Arc>>, /// The current thread being processed, if any current_thread: Arc>>, /// Sender for streaming messages from this agent and sub-agents stream_tx: Arc>>>, /// The user ID for the current thread user_id: Uuid, /// The session ID for the current thread session_id: Uuid, /// Agent name name: String, /// Shutdown signal sender shutdown_tx: Arc>>, } impl Agent { /// Create a new Agent instance with a specific LLM client and model pub fn new( model: String, tools: HashMap + Send + Sync>>, user_id: Uuid, session_id: Uuid, name: String, ) -> Self { let llm_api_key = env::var("LLM_API_KEY").expect("LLM_API_KEY must be set"); let llm_base_url = env::var("LLM_BASE_URL").expect("LLM_API_BASE must be set"); let llm_client = LiteLLMClient::new(Some(llm_api_key), Some(llm_base_url)); let (tx, _rx) = broadcast::channel(1000); let (shutdown_tx, _) = broadcast::channel(1); Self { llm_client, tools: Arc::new(RwLock::new(tools)), model, state: Arc::new(RwLock::new(HashMap::new())), current_thread: Arc::new(RwLock::new(None)), stream_tx: Arc::new(RwLock::new(Some(tx))), user_id, session_id, shutdown_tx: Arc::new(RwLock::new(shutdown_tx)), name, } } /// Create a new Agent that shares state and stream with an existing agent pub fn from_existing(existing_agent: &Agent, name: String) -> Self { let llm_api_key = env::var("LLM_API_KEY").expect("LLM_API_KEY must be set"); let llm_base_url = env::var("LLM_BASE_URL").expect("LLM_API_BASE must be set"); let llm_client = LiteLLMClient::new(Some(llm_api_key), Some(llm_base_url)); Self { llm_client, tools: Arc::new(RwLock::new(HashMap::new())), model: existing_agent.model.clone(), state: Arc::clone(&existing_agent.state), current_thread: Arc::clone(&existing_agent.current_thread), stream_tx: Arc::clone(&existing_agent.stream_tx), user_id: existing_agent.user_id, session_id: existing_agent.session_id, shutdown_tx: Arc::clone(&existing_agent.shutdown_tx), name, } } pub async fn get_enabled_tools(&self) -> Vec { // Collect all registered tools and their schemas let tools = self.tools.read().await; let mut enabled_tools = Vec::new(); for (_, tool) in tools.iter() { if tool.is_enabled().await { enabled_tools.push(Tool { tool_type: "function".to_string(), function: tool.get_schema(), }); } } enabled_tools } /// Get a new receiver for the broadcast channel pub async fn get_stream_receiver(&self) -> broadcast::Receiver { self.stream_tx.read().await.as_ref().unwrap().subscribe() } /// Get a clone of the current stream sender pub async fn get_stream_sender(&self) -> broadcast::Sender { self.stream_tx.read().await.as_ref().unwrap().clone() } /// Get a value from the agent's state by key pub async fn get_state_value(&self, key: &str) -> Option { self.state.read().await.get(key).cloned() } /// Set a value in the agent's state pub async fn set_state_value(&self, key: String, value: Value) { self.state.write().await.insert(key, value); } /// Update multiple state values at once using a closure pub async fn update_state(&self, f: F) where F: FnOnce(&mut HashMap), { let mut state = self.state.write().await; f(&mut state); } /// Clear all state values pub async fn clear_state(&self) { self.state.write().await.clear(); } /// Get the current thread being processed, if any pub async fn get_current_thread(&self) -> Option { self.current_thread.read().await.clone() } pub fn get_user_id(&self) -> Uuid { self.user_id } pub fn get_session_id(&self) -> Uuid { self.session_id } pub fn get_model_name(&self) -> &str { &self.model } /// Get the complete conversation history of the current thread pub async fn get_conversation_history(&self) -> Option> { self.current_thread .read() .await .as_ref() .map(|thread| thread.messages.clone()) } /// Update the current thread with a new message async fn update_current_thread(&self, message: AgentMessage) -> Result<()> { let mut thread_lock = self.current_thread.write().await; if let Some(thread) = thread_lock.as_mut() { thread.messages.push(message); } Ok(()) } /// Add a new tool with the agent /// /// # Arguments /// * `name` - The name of the tool, used to identify it in tool calls /// * `tool` - The tool implementation that will be executed pub async fn add_tool(&self, name: String, tool: T) where T: ToolExecutor + 'static, T::Params: serde::de::DeserializeOwned, T::Output: serde::Serialize, { let mut tools = self.tools.write().await; // Convert the tool to a ToolCallExecutor let value_tool = tool.into_tool_call_executor(); tools.insert(name, Box::new(value_tool)); } /// Add multiple tools to the agent at once /// /// # Arguments /// * `tools` - HashMap of tool names and their implementations pub async fn add_tools(&self, tools: HashMap) where E: ToolExecutor + 'static, E::Params: serde::de::DeserializeOwned, E::Output: serde::Serialize, { let mut tools_map = self.tools.write().await; for (name, tool) in tools { // Convert each tool to a ToolCallExecutor let value_tool = tool.into_tool_call_executor(); tools_map.insert(name, Box::new(value_tool)); } } /// Process a thread of conversation, potentially executing tools and continuing /// the conversation recursively until a final response is reached. /// /// This is a convenience wrapper around process_thread_streaming that collects /// all streamed messages into a final response. /// /// # Arguments /// * `thread` - The conversation thread to process /// /// # Returns /// * A Result containing the final Message from the assistant pub async fn process_thread(&self, thread: &AgentThread) -> Result { let mut rx = self.process_thread_streaming(thread).await?; let mut final_message = None; while let Ok(msg) = rx.recv().await { final_message = Some(msg?); } final_message.ok_or_else(|| anyhow::anyhow!("No messages received from processing")) } /// Process a thread of conversation with streaming responses. This is the primary /// interface for processing conversations. /// /// # Arguments /// * `thread` - The conversation thread to process /// /// # Returns /// * A Result containing a receiver for streamed messages pub async fn process_thread_streaming( &self, thread: &AgentThread, ) -> Result> { // Spawn the processing task let agent_clone = self.clone(); let thread_clone = thread.clone(); // Get shutdown receiver let mut shutdown_rx = self.get_shutdown_receiver().await; tokio::spawn(async move { tokio::select! { result = agent_clone.process_thread_with_depth(&thread_clone, 0) => { if let Err(e) = result { let err_msg = format!("Error processing thread: {:?}", e); let _ = agent_clone.get_stream_sender().await.send(Err(AgentError(err_msg))); // Send Done message after error let _ = agent_clone.get_stream_sender().await.send(Ok(AgentMessage::Done)); } }, _ = shutdown_rx.recv() => { // Send shutdown notification let _ = agent_clone.get_stream_sender().await.send( Ok(AgentMessage::assistant( Some("shutdown_message".to_string()), Some("Processing interrupted due to shutdown signal".to_string()), None, MessageProgress::Complete, None, Some(agent_clone.name.clone()), )) ); // Send Done message after shutdown let _ = agent_clone.get_stream_sender().await.send(Ok(AgentMessage::Done)); } } }); Ok(self.get_stream_receiver().await) } async fn process_thread_with_depth( &self, thread: &AgentThread, recursion_depth: u32, ) -> Result<()> { // Set the initial thread { let mut current = self.current_thread.write().await; *current = Some(thread.clone()); } if recursion_depth >= 30 { let message = AgentMessage::assistant( Some("max_recursion_depth_message".to_string()), Some("I apologize, but I've reached the maximum number of actions (30). Please try breaking your request into smaller parts.".to_string()), None, MessageProgress::Complete, None, Some(self.name.clone()), ); self.get_stream_sender().await.send(Ok(message))?; self.close().await; return Ok(()); } // Collect all registered tools and their schemas let tools = self.get_enabled_tools().await; // Create the tool-enabled request let request = ChatCompletionRequest { model: self.model.clone(), messages: thread.messages.clone(), tools: if tools.is_empty() { None } else { Some(tools) }, tool_choice: Some(ToolChoice::Auto), stream: Some(true), // Enable streaming metadata: Some(Metadata { generation_name: "agent".to_string(), user_id: thread.user_id.to_string(), session_id: thread.id.to_string(), trace_id: None, }), ..Default::default() }; // Get the streaming response from the LLM let mut stream_rx = match self.llm_client.stream_chat_completion(request).await { Ok(rx) => rx, Err(e) => return Err(anyhow::anyhow!("Error starting stream: {:?}", e)), }; // Process the streaming chunks let mut buffer = MessageBuffer::new(); let mut is_complete = false; while let Some(chunk_result) = stream_rx.recv().await { match chunk_result { Ok(chunk) => { if chunk.choices.is_empty() { continue; } buffer.message_id = Some(chunk.id.clone()); let delta = &chunk.choices[0].delta; // Accumulate content if present if let Some(content) = &delta.content { buffer.content.push_str(content); } // Process tool calls if present if let Some(tool_calls) = &delta.tool_calls { for tool_call in tool_calls { let id = tool_call.id.clone().unwrap_or_else(|| { buffer.tool_calls .keys() .next() .map(|s| s.clone()) .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()) }); // Get or create the pending tool call let pending_call = buffer.tool_calls .entry(id.clone()) .or_insert_with(PendingToolCall::new); // Update the pending call with the delta pending_call.update_from_delta(tool_call); } } // Check if we should flush the buffer if buffer.should_flush() { buffer.flush(self).await?; } // Check if this is the final chunk if chunk.choices[0].finish_reason.is_some() { is_complete = true; } } Err(e) => return Err(anyhow::anyhow!("Error in stream: {:?}", e)), } } // Create and send the final message let final_tool_calls: Option> = if !buffer.tool_calls.is_empty() { Some( buffer.tool_calls .values() .map(|p| p.clone().into_tool_call()) .collect(), ) } else { None }; let final_message = AgentMessage::assistant( buffer.message_id, if buffer.content.is_empty() { None } else { Some(buffer.content) }, final_tool_calls.clone(), MessageProgress::Complete, Some(false), Some(self.name.clone()), ); // Broadcast the final assistant message self.get_stream_sender() .await .send(Ok(final_message.clone()))?; // Update thread with assistant message self.update_current_thread(final_message.clone()).await?; // If this is an auto response without tool calls, it means we're done if final_tool_calls.is_none() { // Send Done message and return self.get_stream_sender() .await .send(Ok(AgentMessage::Done))?; return Ok(()); } // If the LLM wants to use tools, execute them and continue if let Some(tool_calls) = final_tool_calls { let mut results = Vec::new(); // Execute each requested tool for tool_call in tool_calls { if let Some(tool) = self.tools.read().await.get(&tool_call.function.name) { let params: Value = serde_json::from_str(&tool_call.function.arguments)?; let result = tool.execute(params, tool_call.id.clone()).await?; let result_str = serde_json::to_string(&result)?; let tool_message = AgentMessage::tool( None, result_str, tool_call.id.clone(), Some(tool_call.function.name.clone()), MessageProgress::Complete, ); // Broadcast the tool message as soon as we receive it self.get_stream_sender() .await .send(Ok(tool_message.clone()))?; // Update thread with tool response self.update_current_thread(tool_message.clone()).await?; results.push(tool_message); } } // Create a new thread with the tool results and continue recursively let mut new_thread = thread.clone(); new_thread.messages.push(final_message); new_thread.messages.extend(results); Box::pin(self.process_thread_with_depth(&new_thread, recursion_depth + 1)).await } else { // Send Done message and return self.get_stream_sender() .await .send(Ok(AgentMessage::Done))?; Ok(()) } } /// Get a receiver for the shutdown signal pub async fn get_shutdown_receiver(&self) -> broadcast::Receiver<()> { self.shutdown_tx.read().await.subscribe() } /// Signal shutdown to all receivers pub async fn shutdown(&self) -> Result<()> { // Send shutdown signal self.shutdown_tx.read().await.send(())?; Ok(()) } /// Get a reference to the tools map pub async fn get_tools( &self, ) -> tokio::sync::RwLockReadGuard< '_, HashMap + Send + Sync>>, > { self.tools.read().await } // Add this new method alongside other channel-related methods pub async fn close(&self) { let mut tx = self.stream_tx.write().await; *tx = None; } } #[derive(Debug, Default, Clone)] struct PendingToolCall { id: Option, call_type: Option, function_name: Option, arguments: String, code_interpreter: Option, retrieval: Option, } impl PendingToolCall { fn new() -> Self { Self::default() } fn update_from_delta(&mut self, tool_call: &DeltaToolCall) { if let Some(id) = &tool_call.id { self.id = Some(id.clone()); } if let Some(call_type) = &tool_call.call_type { self.call_type = Some(call_type.clone()); } if let Some(function) = &tool_call.function { if let Some(name) = &function.name { self.function_name = Some(name.clone()); } if let Some(args) = &function.arguments { self.arguments.push_str(args); } } if let Some(_) = &tool_call.code_interpreter { self.code_interpreter = None; } if let Some(_) = &tool_call.retrieval { self.retrieval = None; } } fn into_tool_call(self) -> ToolCall { ToolCall { id: self.id.unwrap_or_default(), function: FunctionCall { name: self.function_name.unwrap_or_default(), arguments: self.arguments, }, call_type: self.call_type.unwrap_or_default(), code_interpreter: None, retrieval: None, } } } /// A trait that provides convenient access to Agent functionality /// when the agent is stored behind an Arc #[async_trait::async_trait] pub trait AgentExt { fn get_agent(&self) -> &Arc; async fn stream_process_thread( &self, thread: &AgentThread, ) -> Result> { (*self.get_agent()).process_thread_streaming(thread).await } async fn process_thread(&self, thread: &AgentThread) -> Result { (*self.get_agent()).process_thread(thread).await } async fn get_current_thread(&self) -> Option { (*self.get_agent()).get_current_thread().await } } #[cfg(test)] mod tests { use super::*; use crate::tools::ToolExecutor; use async_trait::async_trait; use litellm::MessageProgress; use serde_json::{json, Value}; use uuid::Uuid; fn setup() { dotenv::dotenv().ok(); } struct WeatherTool { agent: Arc, } impl WeatherTool { fn new(agent: Arc) -> Self { Self { agent } } } impl WeatherTool { async fn send_progress( &self, content: String, tool_id: String, progress: MessageProgress, ) -> Result<()> { let message = AgentMessage::tool( None, content, tool_id, Some(self.get_name()), progress, ); self.agent.get_stream_sender().await.send(Ok(message))?; Ok(()) } } #[async_trait] impl ToolExecutor for WeatherTool { type Output = Value; type Params = Value; async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result { self.send_progress( "Fetching weather data...".to_string(), "123".to_string(), MessageProgress::InProgress, ) .await?; // Simulate a delay tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; let result = json!({ "temperature": 20, "unit": "fahrenheit" }); self.send_progress( serde_json::to_string(&result)?, "123".to_string(), MessageProgress::Complete, ) .await?; Ok(result) } async fn is_enabled(&self) -> bool { true } fn get_schema(&self) -> Value { json!({ "name": "get_weather", "description": "Get current weather information for a specific location", "parameters": { "type": "object", "properties": { "location": { "type": "string", "description": "The city and state, e.g., San Francisco, CA" }, "unit": { "type": "string", "enum": ["celsius", "fahrenheit"], "description": "The temperature unit to use" } }, "required": ["location"] } }) } fn get_name(&self) -> String { "get_weather".to_string() } } #[tokio::test] async fn test_agent_convo_no_tools() { setup(); // Create LLM client and agent let agent = Agent::new( "o1".to_string(), HashMap::new(), Uuid::new_v4(), Uuid::new_v4(), "test_agent".to_string(), ); let thread = AgentThread::new( None, Uuid::new_v4(), vec![AgentMessage::user("Hello, world!".to_string())], ); let response = match agent.process_thread(&thread).await { Ok(response) => response, Err(e) => panic!("Error processing thread: {:?}", e), }; } #[tokio::test] async fn test_agent_convo_with_tools() { setup(); // Create agent first let mut agent = Agent::new( "o1".to_string(), HashMap::new(), Uuid::new_v4(), Uuid::new_v4(), "test_agent".to_string(), ); // Create weather tool with reference to agent let weather_tool = WeatherTool::new(Arc::new(agent.clone())); // Add tool to agent agent.add_tool(weather_tool.get_name(), weather_tool); let thread = AgentThread::new( None, Uuid::new_v4(), vec![AgentMessage::user( "What is the weather in vineyard ut?".to_string(), )], ); let response = match agent.process_thread(&thread).await { Ok(response) => response, Err(e) => panic!("Error processing thread: {:?}", e), }; } #[tokio::test] async fn test_agent_with_multiple_steps() { setup(); // Create LLM client and agent let mut agent = Agent::new( "o1".to_string(), HashMap::new(), Uuid::new_v4(), Uuid::new_v4(), "test_agent".to_string(), ); let weather_tool = WeatherTool::new(Arc::new(agent.clone())); agent.add_tool(weather_tool.get_name(), weather_tool); let thread = AgentThread::new( None, Uuid::new_v4(), vec![AgentMessage::user( "What is the weather in vineyard ut and san francisco?".to_string(), )], ); let response = match agent.process_thread(&thread).await { Ok(response) => response, Err(e) => panic!("Error processing thread: {:?}", e), }; } #[tokio::test] async fn test_agent_state_management() { setup(); // Create agent let agent = Agent::new( "o1".to_string(), HashMap::new(), Uuid::new_v4(), Uuid::new_v4(), "test_agent".to_string(), ); // Test setting single values agent .set_state_value("test_key".to_string(), json!("test_value")) .await; let value = agent.get_state_value("test_key").await; assert_eq!(value, Some(json!("test_value"))); // Test updating multiple values agent .update_state(|state| { state.insert("key1".to_string(), json!(1)); state.insert("key2".to_string(), json!({"nested": "value"})); }) .await; assert_eq!(agent.get_state_value("key1").await, Some(json!(1))); assert_eq!( agent.get_state_value("key2").await, Some(json!({"nested": "value"})) ); // Test clearing state agent.clear_state().await; assert_eq!(agent.get_state_value("test_key").await, None); assert_eq!(agent.get_state_value("key1").await, None); assert_eq!(agent.get_state_value("key2").await, None); } }