made it so tools can inherit the agent attributes

This commit is contained in:
dal 2025-02-20 08:01:54 -07:00
parent d452f4fb5f
commit afba56b5e0
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
2 changed files with 53 additions and 74 deletions

View File

@ -280,9 +280,7 @@ impl Agent {
// Execute each requested tool
for tool_call in tool_calls {
if let Some(tool) = self.tools.get(&tool_call.function.name) {
let result = tool
.execute(tool_call, &thread.user_id, &thread.id, None)
.await?;
let result = tool.execute(tool_call).await?;
let result_str = serde_json::to_string(&result)?;
let tool_message = Message::tool(
None,
@ -375,30 +373,29 @@ mod tests {
dotenv().ok();
}
struct WeatherTool;
struct WeatherTool {
agent: Arc<Agent>,
}
impl WeatherTool {
fn new(agent: Arc<Agent>) -> Self {
Self { agent }
}
}
#[async_trait]
impl ToolExecutor for WeatherTool {
type Output = Value;
async fn execute(
&self,
tool_call: &ToolCall,
user_id: &Uuid,
session_id: &Uuid,
stream_tx: Option<mpsc::Sender<Result<Message>>>,
) -> Result<Self::Output> {
// Simulate some progress messages if streaming is enabled
if let Some(tx) = &stream_tx {
let progress = Message::tool(
None,
"Fetching weather data...".to_string(),
tool_call.id.clone(),
Some(tool_call.function.name.clone()),
Some(MessageProgress::InProgress),
);
tx.send(Ok(progress)).await?;
}
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output> {
// Send progress using agent's stream sender
self.agent.get_stream_sender().await.send(Ok(Message::tool(
None,
"Fetching weather data...".to_string(),
tool_call.id.clone(),
Some(self.get_name()),
Some(MessageProgress::InProgress),
))).await?;
// Simulate a delay
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
@ -408,17 +405,14 @@ mod tests {
"unit": "fahrenheit"
});
// Send completion message if streaming
if let Some(tx) = &stream_tx {
let complete = Message::tool(
None,
serde_json::to_string(&result)?,
tool_call.id.clone(),
Some(tool_call.function.name.clone()),
Some(MessageProgress::Complete),
);
tx.send(Ok(complete)).await?;
}
// Send completion message using agent's stream sender
self.agent.get_stream_sender().await.send(Ok(Message::tool(
None,
serde_json::to_string(&result)?,
tool_call.id.clone(),
Some(self.get_name()),
Some(MessageProgress::Complete),
))).await?;
Ok(result)
}
@ -475,11 +469,13 @@ mod tests {
async fn test_agent_convo_with_tools() {
setup();
// Create LLM client and agent
// Create agent first
let mut agent = Agent::new("o1".to_string(), HashMap::new());
let weather_tool = WeatherTool;
// Create weather tool with reference to agent
let weather_tool = WeatherTool::new(Arc::new(agent.clone()));
// Add tool to agent
agent.add_tool(weather_tool.get_name(), weather_tool);
let thread = AgentThread::new(
@ -505,7 +501,7 @@ mod tests {
// Create LLM client and agent
let mut agent = Agent::new("o1".to_string(), HashMap::new());
let weather_tool = WeatherTool;
let weather_tool = WeatherTool::new(Arc::new(agent.clone()));
agent.add_tool(weather_tool.get_name(), weather_tool);

View File

@ -3,9 +3,9 @@ use axum::async_trait;
use litellm::{Message, ToolCall};
use serde::Serialize;
use serde_json::Value;
use tokio::sync::mpsc;
use uuid::Uuid;
use crate::utils::agent::Agent;
pub mod agents_as_tools;
pub mod data_tools;
@ -14,66 +14,49 @@ pub mod interaction_tools;
/// A trait that defines how tools should be implemented.
/// Any struct that wants to be used as a tool must implement this trait.
/// Tools are constructed with a reference to their agent and can access its capabilities.
#[async_trait]
pub trait ToolExecutor: Send + Sync {
/// The type of the output of the tool
type Output: Serialize + Send;
/// Execute the tool with the given parameters and optionally stream progress
async fn execute(
&self,
tool_call: &ToolCall,
user_id: &Uuid,
session_id: &Uuid,
stream_tx: Option<mpsc::Sender<Result<Message>>>,
) -> Result<Self::Output>;
/// Execute the tool with the given parameters.
/// The tool has access to its agent's capabilities through its stored agent reference.
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output>;
/// Get the JSON schema for this tool
fn get_schema(&self) -> Value;
/// Get the name of this tool
fn get_name(&self) -> String;
/// Helper method to send a progress message if streaming is enabled
async fn send_progress(
&self,
stream_tx: &Option<mpsc::Sender<Result<Message>>>,
message: Message,
) -> Result<()> {
if let Some(tx) = stream_tx {
tx.send(Ok(message)).await?;
}
Ok(())
}
}
/// A wrapper type that converts any ToolExecutor to one that outputs Value
pub struct ValueToolExecutor<T: ToolExecutor>(T);
pub struct ValueToolExecutor<T: ToolExecutor> {
inner: T,
}
impl<T: ToolExecutor> ValueToolExecutor<T> {
pub fn new(inner: T) -> Self {
Self { inner }
}
}
#[async_trait]
impl<T: ToolExecutor> ToolExecutor for ValueToolExecutor<T> {
type Output = Value;
async fn execute(
&self,
tool_call: &ToolCall,
user_id: &Uuid,
session_id: &Uuid,
stream_tx: Option<mpsc::Sender<Result<Message>>>,
) -> Result<Self::Output> {
let result = self
.0
.execute(tool_call, user_id, session_id, stream_tx)
.await?;
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output> {
let result = self.inner.execute(tool_call).await?;
Ok(serde_json::to_value(result)?)
}
fn get_schema(&self) -> Value {
self.0.get_schema()
self.inner.get_schema()
}
fn get_name(&self) -> String {
self.0.get_name()
self.inner.get_name()
}
}
@ -87,6 +70,6 @@ pub trait IntoValueTool {
// Implement IntoValueTool for all types that implement ToolExecutor
impl<T: ToolExecutor> IntoValueTool for T {
fn into_value_tool(self) -> ValueToolExecutor<Self> {
ValueToolExecutor(self)
ValueToolExecutor::new(self)
}
}