mirror of https://github.com/buster-so/buster.git
refactor(agent): Improve thread-safe tool management with Arc
- Wrap agent tools in an Arc for safe concurrent access - Modify tool addition methods to work with Arc-wrapped HashMap - Ensure thread-safe tool registration and retrieval - Update stream processing to use Arc-cloned tools reference
This commit is contained in:
parent
2bf27a9eda
commit
864257bc24
|
@ -4,7 +4,7 @@ use crate::utils::{
|
||||||
};
|
};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::{collections::HashMap, env};
|
use std::{collections::HashMap, env, sync::Arc};
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ pub struct Agent {
|
||||||
/// Client for communicating with the LLM provider
|
/// Client for communicating with the LLM provider
|
||||||
llm_client: LiteLLMClient,
|
llm_client: LiteLLMClient,
|
||||||
/// Registry of available tools, mapped by their names
|
/// Registry of available tools, mapped by their names
|
||||||
tools: HashMap<String, Box<dyn ToolExecutor<Output = Value>>>,
|
tools: Arc<HashMap<String, Box<dyn ToolExecutor<Output = Value>>>>,
|
||||||
/// The model identifier to use (e.g., "gpt-4")
|
/// The model identifier to use (e.g., "gpt-4")
|
||||||
model: String,
|
model: String,
|
||||||
}
|
}
|
||||||
|
@ -35,7 +35,7 @@ impl Agent {
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
llm_client,
|
llm_client,
|
||||||
tools,
|
tools: Arc::new(tools),
|
||||||
model,
|
model,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -46,7 +46,10 @@ impl Agent {
|
||||||
/// * `name` - The name of the tool, used to identify it in tool calls
|
/// * `name` - The name of the tool, used to identify it in tool calls
|
||||||
/// * `tool` - The tool implementation that will be executed
|
/// * `tool` - The tool implementation that will be executed
|
||||||
pub fn add_tool(&mut self, name: String, tool: impl ToolExecutor<Output = Value> + 'static) {
|
pub fn add_tool(&mut self, name: String, tool: impl ToolExecutor<Output = Value> + 'static) {
|
||||||
self.tools.insert(name, Box::new(tool));
|
// Get a mutable reference to the HashMap inside the Arc
|
||||||
|
Arc::get_mut(&mut self.tools)
|
||||||
|
.expect("Failed to get mutable reference to tools")
|
||||||
|
.insert(name, Box::new(tool));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add multiple tools to the agent at once
|
/// Add multiple tools to the agent at once
|
||||||
|
@ -57,8 +60,10 @@ impl Agent {
|
||||||
&mut self,
|
&mut self,
|
||||||
tools: HashMap<String, E>,
|
tools: HashMap<String, E>,
|
||||||
) {
|
) {
|
||||||
|
let tools_map = Arc::get_mut(&mut self.tools)
|
||||||
|
.expect("Failed to get mutable reference to tools");
|
||||||
for (name, tool) in tools {
|
for (name, tool) in tools {
|
||||||
self.tools.insert(name, Box::new(tool));
|
tools_map.insert(name, Box::new(tool));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -175,8 +180,8 @@ impl Agent {
|
||||||
let (tx, rx) = mpsc::channel(100);
|
let (tx, rx) = mpsc::channel(100);
|
||||||
let mut pending_tool_calls = HashMap::new();
|
let mut pending_tool_calls = HashMap::new();
|
||||||
|
|
||||||
// Clone the tools map for use in the spawned task
|
// Clone the Arc for use in the spawned task
|
||||||
let tools_for_execution = self.tools.clone();
|
let tools_ref = self.tools.clone();
|
||||||
|
|
||||||
println!("DEBUG: Stream initialized, starting processing task");
|
println!("DEBUG: Stream initialized, starting processing task");
|
||||||
|
|
||||||
|
@ -220,7 +225,7 @@ impl Agent {
|
||||||
|
|
||||||
// Check if this tool call is complete and ready for execution
|
// Check if this tool call is complete and ready for execution
|
||||||
if let Some(complete_tool_call) = pending_tool_calls.get(&tool_call.id) {
|
if let Some(complete_tool_call) = pending_tool_calls.get(&tool_call.id) {
|
||||||
if let Some(tool) = tools_for_execution.get(&complete_tool_call.function.name) {
|
if let Some(tool) = tools_ref.get(&complete_tool_call.function.name) {
|
||||||
println!("DEBUG: Executing tool: {}", complete_tool_call.function.name);
|
println!("DEBUG: Executing tool: {}", complete_tool_call.function.name);
|
||||||
|
|
||||||
// Execute the tool
|
// Execute the tool
|
||||||
|
|
Loading…
Reference in New Issue