refactor(agent): Improve thread-safe tool management with Arc

- Wrap agent tools in an Arc for safe concurrent access
- Modify tool addition methods to work with Arc-wrapped HashMap
- Ensure thread-safe tool registration and retrieval
- Update stream processing to use Arc-cloned tools reference
This commit is contained in:
dal 2025-02-07 10:53:26 -07:00
parent 2bf27a9eda
commit 864257bc24
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
1 changed files with 13 additions and 8 deletions

View File

@ -4,7 +4,7 @@ use crate::utils::{
};
use anyhow::Result;
use serde_json::Value;
use std::{collections::HashMap, env};
use std::{collections::HashMap, env, sync::Arc};
use tokio::sync::mpsc;
use serde::Serialize;
@ -17,7 +17,7 @@ pub struct Agent {
/// Client for communicating with the LLM provider
llm_client: LiteLLMClient,
/// Registry of available tools, mapped by their names
tools: HashMap<String, Box<dyn ToolExecutor<Output = Value>>>,
tools: Arc<HashMap<String, Box<dyn ToolExecutor<Output = Value>>>>,
/// The model identifier to use (e.g., "gpt-4")
model: String,
}
@ -35,7 +35,7 @@ impl Agent {
Self {
llm_client,
tools,
tools: Arc::new(tools),
model,
}
}
@ -46,7 +46,10 @@ impl Agent {
/// * `name` - The name of the tool, used to identify it in tool calls
/// * `tool` - The tool implementation that will be executed
pub fn add_tool(&mut self, name: String, tool: impl ToolExecutor<Output = Value> + 'static) {
self.tools.insert(name, Box::new(tool));
// Get a mutable reference to the HashMap inside the Arc
Arc::get_mut(&mut self.tools)
.expect("Failed to get mutable reference to tools")
.insert(name, Box::new(tool));
}
/// Add multiple tools to the agent at once
@ -57,8 +60,10 @@ impl Agent {
&mut self,
tools: HashMap<String, E>,
) {
let tools_map = Arc::get_mut(&mut self.tools)
.expect("Failed to get mutable reference to tools");
for (name, tool) in tools {
self.tools.insert(name, Box::new(tool));
tools_map.insert(name, Box::new(tool));
}
}
@ -175,8 +180,8 @@ impl Agent {
let (tx, rx) = mpsc::channel(100);
let mut pending_tool_calls = HashMap::new();
// Clone the tools map for use in the spawned task
let tools_for_execution = self.tools.clone();
// Clone the Arc for use in the spawned task
let tools_ref = self.tools.clone();
println!("DEBUG: Stream initialized, starting processing task");
@ -220,7 +225,7 @@ impl Agent {
// Check if this tool call is complete and ready for execution
if let Some(complete_tool_call) = pending_tool_calls.get(&tool_call.id) {
if let Some(tool) = tools_for_execution.get(&complete_tool_call.function.name) {
if let Some(tool) = tools_ref.get(&complete_tool_call.function.name) {
println!("DEBUG: Executing tool: {}", complete_tool_call.function.name);
// Execute the tool