refactor(tools): Implement ValueToolExecutor for generic tool output conversion

- Add `ValueToolExecutor` to convert tool outputs to `serde_json::Value`
- Introduce `IntoValueTool` trait for easy value type conversion
- Update agent tool addition methods to use new value conversion mechanism
- Simplify tool registration by automatically converting tool outputs
- Remove previous manual boxing and type conversion logic
This commit is contained in:
dal 2025-02-07 10:15:55 -07:00
parent 2090e0b7d7
commit a2f3433555
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
4 changed files with 62 additions and 72 deletions

View File

@ -13,12 +13,12 @@ use crate::{
ws_utils::send_ws_message, ws_utils::send_ws_message,
}, },
utils::{ utils::{
agent::Agent, agent::{Agent, AgentThread},
clients::ai::litellm::Message, clients::ai::litellm::Message,
tools::file_tools::{ tools::{file_tools::{
CreateFilesTool, ModifyFilesTool, OpenFilesTool, SearchDataCatalogTool, CreateFilesTool, ModifyFilesTool, OpenFilesTool, SearchDataCatalogTool,
SearchFilesTool, SendToUserTool, SearchFilesTool, SendToUserTool,
}, }, ToolExecutor, IntoValueTool},
}, },
}; };
@ -50,16 +50,12 @@ impl AgentThreadHandler {
let open_files_tool = OpenFilesTool; let open_files_tool = OpenFilesTool;
let send_to_user_tool = SendToUserTool; let send_to_user_tool = SendToUserTool;
// Add each tool individually agent.add_tool(search_data_catalog_tool.get_name(), search_data_catalog_tool.into_value_tool());
agent.add_tool( agent.add_tool(search_files_tool.get_name(), search_files_tool.into_value_tool());
search_data_catalog_tool.get_name(), agent.add_tool(modify_files_tool.get_name(), modify_files_tool.into_value_tool());
search_data_catalog_tool, agent.add_tool(create_files_tool.get_name(), create_files_tool.into_value_tool());
); agent.add_tool(open_files_tool.get_name(), open_files_tool.into_value_tool());
agent.add_tool(search_files_tool.get_name(), search_files_tool); agent.add_tool(send_to_user_tool.get_name(), send_to_user_tool.into_value_tool());
agent.add_tool(modify_files_tool.get_name(), modify_files_tool);
agent.add_tool(create_files_tool.get_name(), create_files_tool);
agent.add_tool(open_files_tool.get_name(), open_files_tool);
agent.add_tool(send_to_user_tool.get_name(), send_to_user_tool);
Ok(Self { agent }) Ok(Self { agent })
} }

View File

@ -6,6 +6,7 @@ use anyhow::Result;
use serde_json::Value; use serde_json::Value;
use std::{collections::HashMap, env}; use std::{collections::HashMap, env};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use serde::Serialize;
use super::types::AgentThread; use super::types::AgentThread;
@ -44,7 +45,7 @@ impl Agent {
/// # Arguments /// # Arguments
/// * `name` - The name of the tool, used to identify it in tool calls /// * `name` - The name of the tool, used to identify it in tool calls
/// * `tool` - The tool implementation that will be executed /// * `tool` - The tool implementation that will be executed
pub fn add_tool<T: ToolExecutor<Output = Value> + 'static>(&mut self, name: String, tool: T) { pub fn add_tool(&mut self, name: String, tool: impl ToolExecutor<Output = Value> + 'static) {
self.tools.insert(name, Box::new(tool)); self.tools.insert(name, Box::new(tool));
} }
@ -52,9 +53,9 @@ impl Agent {
/// ///
/// # Arguments /// # Arguments
/// * `tools` - HashMap of tool names and their implementations /// * `tools` - HashMap of tool names and their implementations
pub fn add_tools<T: ToolExecutor<Output = Value> + 'static>( pub fn add_tools<E: ToolExecutor<Output = Value> + 'static>(
&mut self, &mut self,
tools: HashMap<String, T>, tools: HashMap<String, E>,
) { ) {
for (name, tool) in tools { for (name, tool) in tools {
self.tools.insert(name, Box::new(tool)); self.tools.insert(name, Box::new(tool));

View File

@ -17,7 +17,10 @@ use crate::{
utils::{clients::ai::litellm::ToolCall, tools::ToolExecutor}, utils::{clients::ai::litellm::ToolCall, tools::ToolExecutor},
}; };
use super::{file_types::{dashboard_yml::DashboardYml, file::FileEnum, metric_yml::MetricYml}, FileModificationTool}; use super::{
file_types::{dashboard_yml::DashboardYml, file::FileEnum, metric_yml::MetricYml},
FileModificationTool,
};
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
struct FileParams { struct FileParams {
@ -116,7 +119,10 @@ impl ToolExecutor for CreateFilesTool {
if let Some(dashboard_id) = &dashboard_yml.id { if let Some(dashboard_id) = &dashboard_yml.id {
let dashboard_file = DashboardFile { let dashboard_file = DashboardFile {
id: dashboard_id.clone(), id: dashboard_id.clone(),
name: dashboard_yml.name.clone().unwrap_or_else(|| "New Dashboard".to_string()), name: dashboard_yml
.name
.clone()
.unwrap_or_else(|| "New Dashboard".to_string()),
file_name: format!("{}.yml", file.name), file_name: format!("{}.yml", file.name),
content: serde_json::to_value(dashboard_yml.clone()).unwrap(), content: serde_json::to_value(dashboard_yml.clone()).unwrap(),
filter: None, filter: None,
@ -160,11 +166,12 @@ impl ToolExecutor for CreateFilesTool {
created_files.extend(metric_ymls.into_iter().map(FileEnum::Metric)); created_files.extend(metric_ymls.into_iter().map(FileEnum::Metric));
} }
Err(e) => { Err(e) => {
failed_files.extend( failed_files.extend(metric_records.iter().map(|r| {
metric_records (
.iter() r.file_name.clone(),
.map(|r| (r.file_name.clone(), format!("Failed to create metric file: {}", e))), format!("Failed to create metric file: {}", e),
); )
}));
} }
} }
} }
@ -180,11 +187,12 @@ impl ToolExecutor for CreateFilesTool {
created_files.extend(dashboard_ymls.into_iter().map(FileEnum::Dashboard)); created_files.extend(dashboard_ymls.into_iter().map(FileEnum::Dashboard));
} }
Err(e) => { Err(e) => {
failed_files.extend( failed_files.extend(dashboard_records.iter().map(|r| {
dashboard_records (
.iter() r.file_name.clone(),
.map(|r| (r.file_name.clone(), format!("Failed to create dashboard file: {}", e))), format!("Failed to create dashboard file: {}", e),
); )
}));
} }
} }
} }
@ -256,41 +264,3 @@ impl ToolExecutor for CreateFilesTool {
}) })
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_files_serialization() {
let tool = CreateFilesTool;
let output = CreateFilesOutput {
message: "Test message".to_string(),
files: vec![],
};
// Use the custom serialization
let result = tool.serialize_output(&output);
assert!(result.is_ok());
}
#[test]
fn test_create_files_with_content() {
let tool = CreateFilesTool;
let yml_content = "name: test\ntype: metric\ndescription: A test metric";
let metric = MetricYml::new(yml_content.to_string()).unwrap();
let output = CreateFilesOutput {
message: "Test message".to_string(),
files: vec![FileEnum::Metric(metric)],
};
// Use the custom serialization
let result = tool.serialize_output(&output).unwrap();
// Verify line numbers were added
assert!(result.contains("1 | name: test"));
assert!(result.contains("2 | type: metric"));
assert!(result.contains("3 | description: A test metric"));
}
}

View File

@ -24,12 +24,35 @@ pub trait ToolExecutor: Send + Sync {
fn get_name(&self) -> String; fn get_name(&self) -> String;
} }
trait IntoBoxedTool { /// A wrapper type that converts any ToolExecutor to one that outputs Value
fn boxed(self) -> Box<dyn ToolExecutor<Output = Value>>; pub struct ValueToolExecutor<T: ToolExecutor>(T);
}
impl<T: ToolExecutor<Output = Value> + 'static> IntoBoxedTool for T { #[async_trait]
fn boxed(self) -> Box<dyn ToolExecutor<Output = Value>> { impl<T: ToolExecutor> ToolExecutor for ValueToolExecutor<T> {
Box::new(self) type Output = Value;
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output> {
let result = self.0.execute(tool_call).await?;
Ok(serde_json::to_value(result)?)
}
fn get_schema(&self) -> Value {
self.0.get_schema()
}
fn get_name(&self) -> String {
self.0.get_name()
}
}
/// Extension trait to add value conversion methods to ToolExecutor
pub trait IntoValueTool {
fn into_value_tool(self) -> ValueToolExecutor<Self> where Self: ToolExecutor + Sized;
}
// Implement IntoValueTool for all types that implement ToolExecutor
impl<T: ToolExecutor> IntoValueTool for T {
fn into_value_tool(self) -> ValueToolExecutor<Self> {
ValueToolExecutor(self)
} }
} }