From 864257bc248e2bb8140b671f1b85cf9206ce728e Mon Sep 17 00:00:00 2001 From: dal Date: Fri, 7 Feb 2025 10:53:26 -0700 Subject: [PATCH] refactor(agent): Improve thread-safe tool management with Arc - Wrap agent tools in an Arc for safe concurrent access - Modify tool addition methods to work with Arc-wrapped HashMap - Ensure thread-safe tool registration and retrieval - Update stream processing to use Arc-cloned tools reference --- api/src/utils/agent/agent.rs | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/api/src/utils/agent/agent.rs b/api/src/utils/agent/agent.rs index 0c11950e7..1099f1bf9 100644 --- a/api/src/utils/agent/agent.rs +++ b/api/src/utils/agent/agent.rs @@ -4,7 +4,7 @@ use crate::utils::{ }; use anyhow::Result; use serde_json::Value; -use std::{collections::HashMap, env}; +use std::{collections::HashMap, env, sync::Arc}; use tokio::sync::mpsc; use serde::Serialize; @@ -17,7 +17,7 @@ pub struct Agent { /// Client for communicating with the LLM provider llm_client: LiteLLMClient, /// Registry of available tools, mapped by their names - tools: HashMap>>, + tools: Arc>>>, /// The model identifier to use (e.g., "gpt-4") model: String, } @@ -35,7 +35,7 @@ impl Agent { Self { llm_client, - tools, + tools: Arc::new(tools), model, } } @@ -46,7 +46,10 @@ impl Agent { /// * `name` - The name of the tool, used to identify it in tool calls /// * `tool` - The tool implementation that will be executed pub fn add_tool(&mut self, name: String, tool: impl ToolExecutor + 'static) { - self.tools.insert(name, Box::new(tool)); + // Get a mutable reference to the HashMap inside the Arc + Arc::get_mut(&mut self.tools) + .expect("Failed to get mutable reference to tools") + .insert(name, Box::new(tool)); } /// Add multiple tools to the agent at once @@ -57,8 +60,10 @@ impl Agent { &mut self, tools: HashMap, ) { + let tools_map = Arc::get_mut(&mut self.tools) + .expect("Failed to get mutable reference to tools"); for (name, tool) in tools { - self.tools.insert(name, Box::new(tool)); + tools_map.insert(name, Box::new(tool)); } } @@ -175,8 +180,8 @@ impl Agent { let (tx, rx) = mpsc::channel(100); let mut pending_tool_calls = HashMap::new(); - // Clone the tools map for use in the spawned task - let tools_for_execution = self.tools.clone(); + // Clone the Arc for use in the spawned task + let tools_ref = self.tools.clone(); println!("DEBUG: Stream initialized, starting processing task"); @@ -220,7 +225,7 @@ impl Agent { // Check if this tool call is complete and ready for execution if let Some(complete_tool_call) = pending_tool_calls.get(&tool_call.id) { - if let Some(tool) = tools_for_execution.get(&complete_tool_call.function.name) { + if let Some(tool) = tools_ref.get(&complete_tool_call.function.name) { println!("DEBUG: Executing tool: {}", complete_tool_call.function.name); // Execute the tool