mirror of https://github.com/buster-so/buster.git
made it so tools can inherit the agent attributes
This commit is contained in:
parent
d452f4fb5f
commit
afba56b5e0
|
@ -280,9 +280,7 @@ impl Agent {
|
|||
// Execute each requested tool
|
||||
for tool_call in tool_calls {
|
||||
if let Some(tool) = self.tools.get(&tool_call.function.name) {
|
||||
let result = tool
|
||||
.execute(tool_call, &thread.user_id, &thread.id, None)
|
||||
.await?;
|
||||
let result = tool.execute(tool_call).await?;
|
||||
let result_str = serde_json::to_string(&result)?;
|
||||
let tool_message = Message::tool(
|
||||
None,
|
||||
|
@ -375,30 +373,29 @@ mod tests {
|
|||
dotenv().ok();
|
||||
}
|
||||
|
||||
struct WeatherTool;
|
||||
struct WeatherTool {
|
||||
agent: Arc<Agent>,
|
||||
}
|
||||
|
||||
impl WeatherTool {
|
||||
fn new(agent: Arc<Agent>) -> Self {
|
||||
Self { agent }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolExecutor for WeatherTool {
|
||||
type Output = Value;
|
||||
|
||||
async fn execute(
|
||||
&self,
|
||||
tool_call: &ToolCall,
|
||||
user_id: &Uuid,
|
||||
session_id: &Uuid,
|
||||
stream_tx: Option<mpsc::Sender<Result<Message>>>,
|
||||
) -> Result<Self::Output> {
|
||||
// Simulate some progress messages if streaming is enabled
|
||||
if let Some(tx) = &stream_tx {
|
||||
let progress = Message::tool(
|
||||
None,
|
||||
"Fetching weather data...".to_string(),
|
||||
tool_call.id.clone(),
|
||||
Some(tool_call.function.name.clone()),
|
||||
Some(MessageProgress::InProgress),
|
||||
);
|
||||
tx.send(Ok(progress)).await?;
|
||||
}
|
||||
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output> {
|
||||
// Send progress using agent's stream sender
|
||||
self.agent.get_stream_sender().await.send(Ok(Message::tool(
|
||||
None,
|
||||
"Fetching weather data...".to_string(),
|
||||
tool_call.id.clone(),
|
||||
Some(self.get_name()),
|
||||
Some(MessageProgress::InProgress),
|
||||
))).await?;
|
||||
|
||||
// Simulate a delay
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
|
@ -408,17 +405,14 @@ mod tests {
|
|||
"unit": "fahrenheit"
|
||||
});
|
||||
|
||||
// Send completion message if streaming
|
||||
if let Some(tx) = &stream_tx {
|
||||
let complete = Message::tool(
|
||||
None,
|
||||
serde_json::to_string(&result)?,
|
||||
tool_call.id.clone(),
|
||||
Some(tool_call.function.name.clone()),
|
||||
Some(MessageProgress::Complete),
|
||||
);
|
||||
tx.send(Ok(complete)).await?;
|
||||
}
|
||||
// Send completion message using agent's stream sender
|
||||
self.agent.get_stream_sender().await.send(Ok(Message::tool(
|
||||
None,
|
||||
serde_json::to_string(&result)?,
|
||||
tool_call.id.clone(),
|
||||
Some(self.get_name()),
|
||||
Some(MessageProgress::Complete),
|
||||
))).await?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
@ -475,11 +469,13 @@ mod tests {
|
|||
async fn test_agent_convo_with_tools() {
|
||||
setup();
|
||||
|
||||
// Create LLM client and agent
|
||||
// Create agent first
|
||||
let mut agent = Agent::new("o1".to_string(), HashMap::new());
|
||||
|
||||
let weather_tool = WeatherTool;
|
||||
|
||||
|
||||
// Create weather tool with reference to agent
|
||||
let weather_tool = WeatherTool::new(Arc::new(agent.clone()));
|
||||
|
||||
// Add tool to agent
|
||||
agent.add_tool(weather_tool.get_name(), weather_tool);
|
||||
|
||||
let thread = AgentThread::new(
|
||||
|
@ -505,7 +501,7 @@ mod tests {
|
|||
// Create LLM client and agent
|
||||
let mut agent = Agent::new("o1".to_string(), HashMap::new());
|
||||
|
||||
let weather_tool = WeatherTool;
|
||||
let weather_tool = WeatherTool::new(Arc::new(agent.clone()));
|
||||
|
||||
agent.add_tool(weather_tool.get_name(), weather_tool);
|
||||
|
||||
|
|
|
@ -3,9 +3,9 @@ use axum::async_trait;
|
|||
use litellm::{Message, ToolCall};
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use tokio::sync::mpsc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::utils::agent::Agent;
|
||||
|
||||
pub mod agents_as_tools;
|
||||
pub mod data_tools;
|
||||
|
@ -14,66 +14,49 @@ pub mod interaction_tools;
|
|||
|
||||
/// A trait that defines how tools should be implemented.
|
||||
/// Any struct that wants to be used as a tool must implement this trait.
|
||||
/// Tools are constructed with a reference to their agent and can access its capabilities.
|
||||
#[async_trait]
|
||||
pub trait ToolExecutor: Send + Sync {
|
||||
/// The type of the output of the tool
|
||||
type Output: Serialize + Send;
|
||||
|
||||
/// Execute the tool with the given parameters and optionally stream progress
|
||||
async fn execute(
|
||||
&self,
|
||||
tool_call: &ToolCall,
|
||||
user_id: &Uuid,
|
||||
session_id: &Uuid,
|
||||
stream_tx: Option<mpsc::Sender<Result<Message>>>,
|
||||
) -> Result<Self::Output>;
|
||||
/// Execute the tool with the given parameters.
|
||||
/// The tool has access to its agent's capabilities through its stored agent reference.
|
||||
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output>;
|
||||
|
||||
/// Get the JSON schema for this tool
|
||||
fn get_schema(&self) -> Value;
|
||||
|
||||
/// Get the name of this tool
|
||||
fn get_name(&self) -> String;
|
||||
|
||||
/// Helper method to send a progress message if streaming is enabled
|
||||
async fn send_progress(
|
||||
&self,
|
||||
stream_tx: &Option<mpsc::Sender<Result<Message>>>,
|
||||
message: Message,
|
||||
) -> Result<()> {
|
||||
if let Some(tx) = stream_tx {
|
||||
tx.send(Ok(message)).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper type that converts any ToolExecutor to one that outputs Value
|
||||
pub struct ValueToolExecutor<T: ToolExecutor>(T);
|
||||
pub struct ValueToolExecutor<T: ToolExecutor> {
|
||||
inner: T,
|
||||
}
|
||||
|
||||
impl<T: ToolExecutor> ValueToolExecutor<T> {
|
||||
pub fn new(inner: T) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T: ToolExecutor> ToolExecutor for ValueToolExecutor<T> {
|
||||
type Output = Value;
|
||||
|
||||
async fn execute(
|
||||
&self,
|
||||
tool_call: &ToolCall,
|
||||
user_id: &Uuid,
|
||||
session_id: &Uuid,
|
||||
stream_tx: Option<mpsc::Sender<Result<Message>>>,
|
||||
) -> Result<Self::Output> {
|
||||
let result = self
|
||||
.0
|
||||
.execute(tool_call, user_id, session_id, stream_tx)
|
||||
.await?;
|
||||
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output> {
|
||||
let result = self.inner.execute(tool_call).await?;
|
||||
Ok(serde_json::to_value(result)?)
|
||||
}
|
||||
|
||||
fn get_schema(&self) -> Value {
|
||||
self.0.get_schema()
|
||||
self.inner.get_schema()
|
||||
}
|
||||
|
||||
fn get_name(&self) -> String {
|
||||
self.0.get_name()
|
||||
self.inner.get_name()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -87,6 +70,6 @@ pub trait IntoValueTool {
|
|||
// Implement IntoValueTool for all types that implement ToolExecutor
|
||||
impl<T: ToolExecutor> IntoValueTool for T {
|
||||
fn into_value_tool(self) -> ValueToolExecutor<Self> {
|
||||
ValueToolExecutor(self)
|
||||
ValueToolExecutor::new(self)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue