diff --git a/api/src/utils/agent/agent.rs b/api/src/utils/agent/agent.rs index 0c11950e7..1099f1bf9 100644 --- a/api/src/utils/agent/agent.rs +++ b/api/src/utils/agent/agent.rs @@ -4,7 +4,7 @@ use crate::utils::{ }; use anyhow::Result; use serde_json::Value; -use std::{collections::HashMap, env}; +use std::{collections::HashMap, env, sync::Arc}; use tokio::sync::mpsc; use serde::Serialize; @@ -17,7 +17,7 @@ pub struct Agent { /// Client for communicating with the LLM provider llm_client: LiteLLMClient, /// Registry of available tools, mapped by their names - tools: HashMap>>, + tools: Arc>>>, /// The model identifier to use (e.g., "gpt-4") model: String, } @@ -35,7 +35,7 @@ impl Agent { Self { llm_client, - tools, + tools: Arc::new(tools), model, } } @@ -46,7 +46,10 @@ impl Agent { /// * `name` - The name of the tool, used to identify it in tool calls /// * `tool` - The tool implementation that will be executed pub fn add_tool(&mut self, name: String, tool: impl ToolExecutor + 'static) { - self.tools.insert(name, Box::new(tool)); + // Get a mutable reference to the HashMap inside the Arc + Arc::get_mut(&mut self.tools) + .expect("Failed to get mutable reference to tools") + .insert(name, Box::new(tool)); } /// Add multiple tools to the agent at once @@ -57,8 +60,10 @@ impl Agent { &mut self, tools: HashMap, ) { + let tools_map = Arc::get_mut(&mut self.tools) + .expect("Failed to get mutable reference to tools"); for (name, tool) in tools { - self.tools.insert(name, Box::new(tool)); + tools_map.insert(name, Box::new(tool)); } } @@ -175,8 +180,8 @@ impl Agent { let (tx, rx) = mpsc::channel(100); let mut pending_tool_calls = HashMap::new(); - // Clone the tools map for use in the spawned task - let tools_for_execution = self.tools.clone(); + // Clone the Arc for use in the spawned task + let tools_ref = self.tools.clone(); println!("DEBUG: Stream initialized, starting processing task"); @@ -220,7 +225,7 @@ impl Agent { // Check if this tool call is complete and ready for execution if let Some(complete_tool_call) = pending_tool_calls.get(&tool_call.id) { - if let Some(tool) = tools_for_execution.get(&complete_tool_call.function.name) { + if let Some(tool) = tools_ref.get(&complete_tool_call.function.name) { println!("DEBUG: Executing tool: {}", complete_tool_call.function.name); // Execute the tool