use crate::utils::{ clients::ai::litellm::{ ChatCompletionRequest, DeltaFunctionCall, DeltaToolCall, FunctionCall, LiteLLMClient, Message, MessageProgress, Tool, ToolCall, ToolChoice, }, tools::ToolExecutor, }; use anyhow::Result; use serde::Serialize; use serde_json::Value; use std::{collections::HashMap, env, sync::Arc}; use tokio::sync::mpsc; use super::types::AgentThread; #[derive(Clone)] /// The Agent struct is responsible for managing conversations with the LLM /// and coordinating tool executions. It maintains a registry of available tools /// and handles the recursive nature of tool calls. pub struct Agent { /// Client for communicating with the LLM provider llm_client: LiteLLMClient, /// Registry of available tools, mapped by their names tools: Arc>>>, /// The model identifier to use (e.g., "gpt-4") model: String, } impl Agent { /// Create a new Agent instance with a specific LLM client and model pub fn new( model: String, tools: HashMap>>, ) -> Self { let llm_api_key = env::var("LLM_API_KEY").expect("LLM_API_KEY must be set"); let llm_base_url = env::var("LLM_BASE_URL").expect("LLM_API_BASE must be set"); let llm_client = LiteLLMClient::new(Some(llm_api_key), Some(llm_base_url)); Self { llm_client, tools: Arc::new(tools), model, } } /// Add a new tool with the agent /// /// # Arguments /// * `name` - The name of the tool, used to identify it in tool calls /// * `tool` - The tool implementation that will be executed pub fn add_tool(&mut self, name: String, tool: impl ToolExecutor + 'static) { // Get a mutable reference to the HashMap inside the Arc Arc::get_mut(&mut self.tools) .expect("Failed to get mutable reference to tools") .insert(name, Box::new(tool)); } /// Add multiple tools to the agent at once /// /// # Arguments /// * `tools` - HashMap of tool names and their implementations pub fn add_tools + 'static>( &mut self, tools: HashMap, ) { let tools_map = Arc::get_mut(&mut self.tools).expect("Failed to get mutable reference to tools"); for (name, tool) in tools { tools_map.insert(name, Box::new(tool)); } } /// Process a thread of conversation, potentially executing tools and continuing /// the conversation recursively until a final response is reached. /// /// # Arguments /// * `thread` - The conversation thread to process /// /// # Returns /// * A Result containing the final Message from the assistant pub async fn process_thread(&self, thread: &AgentThread) -> Result { self.process_thread_with_depth(thread, 0).await } async fn process_thread_with_depth( &self, thread: &AgentThread, recursion_depth: u32, ) -> Result { if recursion_depth >= 30 { return Ok(Message::assistant( Some("I apologize, but I've reached the maximum number of actions (30). Please try breaking your request into smaller parts.".to_string()), None, None, )); } // Collect all registered tools and their schemas let tools: Vec = self .tools .iter() .map(|(name, tool)| Tool { tool_type: "function".to_string(), function: tool.get_schema(), }) .collect(); // First, make request with tool_choice set to none let initial_request = ChatCompletionRequest { model: self.model.clone(), messages: thread.messages.clone(), tools: if tools.is_empty() { None } else { Some(tools.clone()) }, tool_choice: Some(ToolChoice::None("none".to_string())), ..Default::default() }; // Get initial response let initial_response = self.llm_client.chat_completion(initial_request).await?; let initial_message = &initial_response.choices[0].message; // Ensure we have content from the initial message let initial_content = match initial_message { Message::Assistant { content, .. } => content.clone().unwrap_or_default(), _ => return Err(anyhow::anyhow!("Expected assistant message from LLM")), }; // Create a new thread with the initial response (ensuring content is present) let mut tool_thread = thread.clone(); tool_thread .messages .push(Message::assistant(Some(initial_content), None, None)); // Create the tool-enabled request let request = ChatCompletionRequest { model: self.model.clone(), messages: tool_thread.messages.clone(), tools: if tools.is_empty() { None } else { Some(tools) }, tool_choice: Some(ToolChoice::Auto("auto".to_string())), ..Default::default() }; // Get the response from the LLM let response = match self.llm_client.chat_completion(request).await { Ok(response) => response, Err(e) => return Err(anyhow::anyhow!("Error processing thread: {:?}", e)), }; let llm_message = &response.choices[0].message; // Create the assistant message let message = match llm_message { Message::Assistant { content, tool_calls, .. } => Message::assistant(content.clone(), tool_calls.clone(), None), _ => return Err(anyhow::anyhow!("Expected assistant message from LLM")), }; // If this is an auto response without tool calls, it means we're done if let Message::Assistant { tool_calls: None, .. } = &llm_message { return Ok(message); } // If the LLM wants to use tools, execute them and continue if let Message::Assistant { tool_calls: Some(tool_calls), .. } = &llm_message { let mut results = Vec::new(); // Execute each requested tool for tool_call in tool_calls { if let Some(tool) = self.tools.get(&tool_call.function.name) { let result = tool.execute(tool_call).await?; let result_str = serde_json::to_string(&result)?; results.push(Message::tool(result_str, tool_call.id.clone(), None)); } } // Create a new thread with the tool results and continue recursively let mut new_thread = thread.clone(); new_thread.messages.push(message); new_thread.messages.extend(results); Box::pin(self.process_thread_with_depth(&new_thread, recursion_depth + 1)).await } else { Ok(message) } } /// Process a thread of conversation with streaming responses /// /// # Arguments /// * `thread` - The conversation thread to process /// /// # Returns /// * A Result containing a receiver for streamed messages pub async fn stream_process_thread( &self, thread: &AgentThread, ) -> Result>> { let (tx, rx) = mpsc::channel(100); let tools_ref = self.tools.clone(); let model = self.model.clone(); let llm_client = self.llm_client.clone(); // Clone thread for task ownership let thread = thread.clone(); tokio::spawn(async move { async fn process_stream_recursive( llm_client: &LiteLLMClient, model: &str, tools_ref: &Arc>>>, thread: &AgentThread, tx: &mpsc::Sender>, recursion_depth: u32, ) -> Result<()> { if recursion_depth >= 30 { let limit_message = Message::assistant( Some("I apologize, but I've reached the maximum number of actions (30). Please try breaking your request into smaller parts.".to_string()), None, None, ); let _ = tx.send(Ok(limit_message)).await; return Ok(()); } // Collect all registered tools and their schemas let tools: Vec = tools_ref .iter() .map(|(name, tool)| Tool { tool_type: "function".to_string(), function: tool.get_schema(), }) .collect(); // First, make request with tool_choice set to none let initial_request = ChatCompletionRequest { model: model.to_string(), messages: thread.messages.clone(), tools: if tools.is_empty() { None } else { Some(tools.clone()) }, tool_choice: Some(ToolChoice::None("none".to_string())), stream: Some(true), ..Default::default() }; // Get streaming response for initial thoughts let mut initial_stream = llm_client.stream_chat_completion(initial_request).await?; let mut initial_message = Message::assistant(Some(String::new()), None, None); let mut has_started = false; // Process initial stream chunks while let Some(chunk_result) = initial_stream.recv().await { match chunk_result { Ok(chunk) => { let delta = &chunk.choices[0].delta; // Handle content updates - send delta directly if let Some(content) = &delta.content { // Send the delta chunk immediately with InProgress let _ = tx .send(Ok(Message::assistant( Some(content.clone()), None, Some(MessageProgress::InProgress), ))) .await; // Also accumulate for our thread history if let Message::Assistant { content: msg_content, .. } = &mut initial_message { if let Some(existing) = msg_content { existing.push_str(content); } } } } Err(e) => { let _ = tx.send(Err(anyhow::Error::from(e))).await; return Ok(()); } } } // Ensure we have content in the initial message let initial_content = match &initial_message { Message::Assistant { content, .. } => content.clone().unwrap_or_default(), _ => String::new(), }; // Create new thread with initial response (ensuring content is present) let mut tool_thread = thread.clone(); tool_thread .messages .push(Message::assistant(Some(initial_content), None, None)); // Create the tool-enabled request let request = ChatCompletionRequest { model: model.to_string(), messages: tool_thread.messages.clone(), tools: if tools.is_empty() { None } else { Some(tools) }, tool_choice: Some(ToolChoice::Auto("auto".to_string())), stream: Some(true), ..Default::default() }; // Get streaming response let mut stream = llm_client.stream_chat_completion(request).await?; let mut current_message = Message::assistant(Some(String::new()), None, None); let mut current_pending_tool: Option = None; let mut has_tool_calls = false; let mut tool_results = Vec::new(); // Process stream chunks while let Some(chunk_result) = stream.recv().await { match chunk_result { Ok(chunk) => { let delta = &chunk.choices[0].delta; // Check for tool call completion if let Some(finish_reason) = &chunk.choices[0].finish_reason { if finish_reason == "tool_calls" { has_tool_calls = true; // Tool call is complete - execute it if let Some(pending) = current_pending_tool.take() { let tool_call = pending.into_tool_call(); // Create and preserve the assistant message with the tool call let assistant_tool_message = Message::assistant( None, Some(vec![tool_call.clone()]), Some(MessageProgress::Complete), ); let _ = tx.send(Ok(assistant_tool_message.clone())).await; // Execute the tool if let Some(tool) = tools_ref.get(&tool_call.function.name) { match tool.execute(&tool_call).await { Ok(result) => { let result_str = serde_json::to_string(&result)?; let tool_result = Message::tool( result_str, tool_call.id.clone(), Some(MessageProgress::Complete), ); let _ = tx.send(Ok(tool_result.clone())).await; // Store both the assistant tool message and the tool result tool_results.push(assistant_tool_message); tool_results.push(tool_result); } Err(e) => { let error_msg = format!("Tool execution failed: {:?}", e); let tool_error = Message::tool( error_msg, tool_call.id.clone(), Some(MessageProgress::Complete), ); let _ = tx.send(Ok(tool_error.clone())).await; // Store both the assistant tool message and the error tool_results.push(assistant_tool_message); tool_results.push(tool_error); } } } } continue; } } // Handle content updates - only send if we have actual content if let Some(content) = &delta.content { if !content.trim().is_empty() { if let Message::Assistant { content: msg_content, .. } = &mut current_message { if let Some(existing) = msg_content { existing.push_str(content); } } let _ = tx .send(Ok(Message::assistant( Some(content.clone()), None, None, ))) .await; } } // Handle tool calls - only send when we have meaningful tool call data if let Some(tool_calls) = &delta.tool_calls { has_tool_calls = true; if current_pending_tool.is_none() { current_pending_tool = Some(PendingToolCall::new()); } if let Some(pending) = &mut current_pending_tool { for tool_call in tool_calls { pending.update_from_delta(tool_call); // Send an update if we have a name, regardless of arguments if let Some(name) = &pending.function_name { let temp_tool_call = ToolCall { id: pending.id.clone().unwrap_or_default(), function: FunctionCall { name: name.clone(), arguments: pending.arguments.clone(), }, call_type: pending.call_type.clone().unwrap_or_default(), code_interpreter: None, retrieval: None, }; let _ = tx .send(Ok(Message::assistant( None, Some(vec![temp_tool_call]), Some(MessageProgress::InProgress), ))) .await; } } } } } Err(e) => { let _ = tx.send(Err(anyhow::Error::from(e))).await; return Ok(()); } } } // If we didn't get any tool calls in the auto response, we're done if !has_tool_calls { // Only include current_message in the thread if it has content if let Message::Assistant { content: Some(content), .. } = ¤t_message { if !content.trim().is_empty() { // Send the complete message let complete_message = Message::assistant( Some(content.clone()), None, Some(MessageProgress::Complete), ); let _ = tx.send(Ok(complete_message.clone())).await; let mut new_thread = thread.clone(); new_thread.messages.push(current_message); return Ok(()); } } return Ok(()); } // Create new thread with tool results and recurse let mut new_thread = thread.clone(); // Only include current_message if it has content if let Message::Assistant { content: Some(content), .. } = ¤t_message { if !content.trim().is_empty() { new_thread.messages.push(current_message); } } new_thread.messages.extend(tool_results); // Recurse with new thread Box::pin(process_stream_recursive( llm_client, model, tools_ref, &new_thread, tx, recursion_depth + 1, )) .await?; Ok(()) } // Start recursive processing if let Err(e) = process_stream_recursive(&llm_client, &model, &tools_ref, &thread, &tx, 0).await { let _ = tx.send(Err(e)).await; } }); Ok(rx) } } #[derive(Debug, Default)] struct PendingToolCall { id: Option, call_type: Option, function_name: Option, arguments: String, code_interpreter: Option, retrieval: Option, } impl PendingToolCall { fn new() -> Self { Self::default() } fn update_from_delta(&mut self, tool_call: &DeltaToolCall) { if let Some(id) = &tool_call.id { self.id = Some(id.clone()); } if let Some(call_type) = &tool_call.call_type { self.call_type = Some(call_type.clone()); } if let Some(function) = &tool_call.function { if let Some(name) = &function.name { self.function_name = Some(name.clone()); } if let Some(args) = &function.arguments { self.arguments.push_str(args); } } if let Some(code_interpreter) = &tool_call.code_interpreter { self.code_interpreter = None; } if let Some(retrieval) = &tool_call.retrieval { self.retrieval = None; } } fn into_tool_call(self) -> ToolCall { ToolCall { id: self.id.unwrap_or_default(), function: FunctionCall { name: self.function_name.unwrap_or_default(), arguments: self.arguments, }, call_type: self.call_type.unwrap_or_default(), code_interpreter: None, retrieval: None, } } } #[cfg(test)] mod tests { use crate::utils::clients::ai::litellm::ToolCall; use super::*; use axum::async_trait; use dotenv::dotenv; use serde_json::{json, Value}; fn setup() { dotenv().ok(); } struct WeatherTool; #[async_trait] impl ToolExecutor for WeatherTool { type Output = Value; async fn execute(&self, tool_call: &ToolCall) -> Result { Ok(json!({ "temperature": 20, "unit": "fahrenheit" })) } fn get_schema(&self) -> Value { json!({ "name": "get_weather", "description": "Get current weather information for a specific location", "parameters": { "type": "object", "properties": { "location": { "type": "string", "description": "The city and state, e.g., San Francisco, CA" }, "unit": { "type": "string", "enum": ["celsius", "fahrenheit"], "description": "The temperature unit to use" } }, "required": ["location"] } }) } fn get_name(&self) -> String { "get_weather".to_string() } } #[tokio::test] async fn test_agent_convo_no_tools() { setup(); // Create LLM client and agent let agent = Agent::new("o1".to_string(), HashMap::new()); let thread = AgentThread::new(None, vec![Message::user("Hello, world!".to_string())]); let response = match agent.process_thread(&thread).await { Ok(response) => response, Err(e) => panic!("Error processing thread: {:?}", e), }; println!("Response: {:?}", response); } #[tokio::test] async fn test_agent_convo_with_tools() { setup(); // Create LLM client and agent let mut agent = Agent::new("o1".to_string(), HashMap::new()); let weather_tool = WeatherTool; agent.add_tool(weather_tool.get_name(), weather_tool); let thread = AgentThread::new( None, vec![Message::user( "What is the weather in vineyard ut?".to_string(), )], ); let response = match agent.process_thread(&thread).await { Ok(response) => response, Err(e) => panic!("Error processing thread: {:?}", e), }; println!("Response: {:?}", response); } #[tokio::test] async fn test_agent_with_multiple_steps() { setup(); // Create LLM client and agent let mut agent = Agent::new("o1".to_string(), HashMap::new()); let weather_tool = WeatherTool; agent.add_tool(weather_tool.get_name(), weather_tool); let thread = AgentThread::new( None, vec![Message::user( "What is the weather in vineyard ut and san francisco?".to_string(), )], ); let response = match agent.process_thread(&thread).await { Ok(response) => response, Err(e) => panic!("Error processing thread: {:?}", e), }; println!("Response: {:?}", response); } }