mirror of https://github.com/buster-so/buster.git
refactor(agent): Improve thread-safe tool management with Arc
- Wrap agent tools in an Arc for safe concurrent access - Modify tool addition methods to work with Arc-wrapped HashMap - Ensure thread-safe tool registration and retrieval - Update stream processing to use Arc-cloned tools reference
This commit is contained in:
parent
2bf27a9eda
commit
864257bc24
|
@ -4,7 +4,7 @@ use crate::utils::{
|
|||
};
|
||||
use anyhow::Result;
|
||||
use serde_json::Value;
|
||||
use std::{collections::HashMap, env};
|
||||
use std::{collections::HashMap, env, sync::Arc};
|
||||
use tokio::sync::mpsc;
|
||||
use serde::Serialize;
|
||||
|
||||
|
@ -17,7 +17,7 @@ pub struct Agent {
|
|||
/// Client for communicating with the LLM provider
|
||||
llm_client: LiteLLMClient,
|
||||
/// Registry of available tools, mapped by their names
|
||||
tools: HashMap<String, Box<dyn ToolExecutor<Output = Value>>>,
|
||||
tools: Arc<HashMap<String, Box<dyn ToolExecutor<Output = Value>>>>,
|
||||
/// The model identifier to use (e.g., "gpt-4")
|
||||
model: String,
|
||||
}
|
||||
|
@ -35,7 +35,7 @@ impl Agent {
|
|||
|
||||
Self {
|
||||
llm_client,
|
||||
tools,
|
||||
tools: Arc::new(tools),
|
||||
model,
|
||||
}
|
||||
}
|
||||
|
@ -46,7 +46,10 @@ impl Agent {
|
|||
/// * `name` - The name of the tool, used to identify it in tool calls
|
||||
/// * `tool` - The tool implementation that will be executed
|
||||
pub fn add_tool(&mut self, name: String, tool: impl ToolExecutor<Output = Value> + 'static) {
|
||||
self.tools.insert(name, Box::new(tool));
|
||||
// Get a mutable reference to the HashMap inside the Arc
|
||||
Arc::get_mut(&mut self.tools)
|
||||
.expect("Failed to get mutable reference to tools")
|
||||
.insert(name, Box::new(tool));
|
||||
}
|
||||
|
||||
/// Add multiple tools to the agent at once
|
||||
|
@ -57,8 +60,10 @@ impl Agent {
|
|||
&mut self,
|
||||
tools: HashMap<String, E>,
|
||||
) {
|
||||
let tools_map = Arc::get_mut(&mut self.tools)
|
||||
.expect("Failed to get mutable reference to tools");
|
||||
for (name, tool) in tools {
|
||||
self.tools.insert(name, Box::new(tool));
|
||||
tools_map.insert(name, Box::new(tool));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -175,8 +180,8 @@ impl Agent {
|
|||
let (tx, rx) = mpsc::channel(100);
|
||||
let mut pending_tool_calls = HashMap::new();
|
||||
|
||||
// Clone the tools map for use in the spawned task
|
||||
let tools_for_execution = self.tools.clone();
|
||||
// Clone the Arc for use in the spawned task
|
||||
let tools_ref = self.tools.clone();
|
||||
|
||||
println!("DEBUG: Stream initialized, starting processing task");
|
||||
|
||||
|
@ -220,7 +225,7 @@ impl Agent {
|
|||
|
||||
// Check if this tool call is complete and ready for execution
|
||||
if let Some(complete_tool_call) = pending_tool_calls.get(&tool_call.id) {
|
||||
if let Some(tool) = tools_for_execution.get(&complete_tool_call.function.name) {
|
||||
if let Some(tool) = tools_ref.get(&complete_tool_call.function.name) {
|
||||
println!("DEBUG: Executing tool: {}", complete_tool_call.function.name);
|
||||
|
||||
// Execute the tool
|
||||
|
|
Loading…
Reference in New Issue