refactor(tools): Implement ValueToolExecutor for generic tool output conversion

- Add `ValueToolExecutor` to convert tool outputs to `serde_json::Value`
- Introduce `IntoValueTool` trait for easy value type conversion
- Update agent tool addition methods to use new value conversion mechanism
- Simplify tool registration by automatically converting tool outputs
- Remove previous manual boxing and type conversion logic
This commit is contained in:
dal 2025-02-07 10:15:55 -07:00
parent 2090e0b7d7
commit a2f3433555
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
4 changed files with 62 additions and 72 deletions

View File

@ -13,12 +13,12 @@ use crate::{
ws_utils::send_ws_message,
},
utils::{
agent::Agent,
agent::{Agent, AgentThread},
clients::ai::litellm::Message,
tools::file_tools::{
tools::{file_tools::{
CreateFilesTool, ModifyFilesTool, OpenFilesTool, SearchDataCatalogTool,
SearchFilesTool, SendToUserTool,
},
}, ToolExecutor, IntoValueTool},
},
};
@ -50,16 +50,12 @@ impl AgentThreadHandler {
let open_files_tool = OpenFilesTool;
let send_to_user_tool = SendToUserTool;
// Add each tool individually
agent.add_tool(
search_data_catalog_tool.get_name(),
search_data_catalog_tool,
);
agent.add_tool(search_files_tool.get_name(), search_files_tool);
agent.add_tool(modify_files_tool.get_name(), modify_files_tool);
agent.add_tool(create_files_tool.get_name(), create_files_tool);
agent.add_tool(open_files_tool.get_name(), open_files_tool);
agent.add_tool(send_to_user_tool.get_name(), send_to_user_tool);
agent.add_tool(search_data_catalog_tool.get_name(), search_data_catalog_tool.into_value_tool());
agent.add_tool(search_files_tool.get_name(), search_files_tool.into_value_tool());
agent.add_tool(modify_files_tool.get_name(), modify_files_tool.into_value_tool());
agent.add_tool(create_files_tool.get_name(), create_files_tool.into_value_tool());
agent.add_tool(open_files_tool.get_name(), open_files_tool.into_value_tool());
agent.add_tool(send_to_user_tool.get_name(), send_to_user_tool.into_value_tool());
Ok(Self { agent })
}

View File

@ -6,6 +6,7 @@ use anyhow::Result;
use serde_json::Value;
use std::{collections::HashMap, env};
use tokio::sync::mpsc;
use serde::Serialize;
use super::types::AgentThread;
@ -44,7 +45,7 @@ impl Agent {
/// # Arguments
/// * `name` - The name of the tool, used to identify it in tool calls
/// * `tool` - The tool implementation that will be executed
pub fn add_tool<T: ToolExecutor<Output = Value> + 'static>(&mut self, name: String, tool: T) {
pub fn add_tool(&mut self, name: String, tool: impl ToolExecutor<Output = Value> + 'static) {
self.tools.insert(name, Box::new(tool));
}
@ -52,9 +53,9 @@ impl Agent {
///
/// # Arguments
/// * `tools` - HashMap of tool names and their implementations
pub fn add_tools<T: ToolExecutor<Output = Value> + 'static>(
pub fn add_tools<E: ToolExecutor<Output = Value> + 'static>(
&mut self,
tools: HashMap<String, T>,
tools: HashMap<String, E>,
) {
for (name, tool) in tools {
self.tools.insert(name, Box::new(tool));

View File

@ -17,7 +17,10 @@ use crate::{
utils::{clients::ai::litellm::ToolCall, tools::ToolExecutor},
};
use super::{file_types::{dashboard_yml::DashboardYml, file::FileEnum, metric_yml::MetricYml}, FileModificationTool};
use super::{
file_types::{dashboard_yml::DashboardYml, file::FileEnum, metric_yml::MetricYml},
FileModificationTool,
};
#[derive(Debug, Serialize, Deserialize, Clone)]
struct FileParams {
@ -116,7 +119,10 @@ impl ToolExecutor for CreateFilesTool {
if let Some(dashboard_id) = &dashboard_yml.id {
let dashboard_file = DashboardFile {
id: dashboard_id.clone(),
name: dashboard_yml.name.clone().unwrap_or_else(|| "New Dashboard".to_string()),
name: dashboard_yml
.name
.clone()
.unwrap_or_else(|| "New Dashboard".to_string()),
file_name: format!("{}.yml", file.name),
content: serde_json::to_value(dashboard_yml.clone()).unwrap(),
filter: None,
@ -160,11 +166,12 @@ impl ToolExecutor for CreateFilesTool {
created_files.extend(metric_ymls.into_iter().map(FileEnum::Metric));
}
Err(e) => {
failed_files.extend(
metric_records
.iter()
.map(|r| (r.file_name.clone(), format!("Failed to create metric file: {}", e))),
);
failed_files.extend(metric_records.iter().map(|r| {
(
r.file_name.clone(),
format!("Failed to create metric file: {}", e),
)
}));
}
}
}
@ -180,11 +187,12 @@ impl ToolExecutor for CreateFilesTool {
created_files.extend(dashboard_ymls.into_iter().map(FileEnum::Dashboard));
}
Err(e) => {
failed_files.extend(
dashboard_records
.iter()
.map(|r| (r.file_name.clone(), format!("Failed to create dashboard file: {}", e))),
);
failed_files.extend(dashboard_records.iter().map(|r| {
(
r.file_name.clone(),
format!("Failed to create dashboard file: {}", e),
)
}));
}
}
}
@ -256,41 +264,3 @@ impl ToolExecutor for CreateFilesTool {
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_files_serialization() {
let tool = CreateFilesTool;
let output = CreateFilesOutput {
message: "Test message".to_string(),
files: vec![],
};
// Use the custom serialization
let result = tool.serialize_output(&output);
assert!(result.is_ok());
}
#[test]
fn test_create_files_with_content() {
let tool = CreateFilesTool;
let yml_content = "name: test\ntype: metric\ndescription: A test metric";
let metric = MetricYml::new(yml_content.to_string()).unwrap();
let output = CreateFilesOutput {
message: "Test message".to_string(),
files: vec![FileEnum::Metric(metric)],
};
// Use the custom serialization
let result = tool.serialize_output(&output).unwrap();
// Verify line numbers were added
assert!(result.contains("1 | name: test"));
assert!(result.contains("2 | type: metric"));
assert!(result.contains("3 | description: A test metric"));
}
}

View File

@ -24,12 +24,35 @@ pub trait ToolExecutor: Send + Sync {
fn get_name(&self) -> String;
}
trait IntoBoxedTool {
fn boxed(self) -> Box<dyn ToolExecutor<Output = Value>>;
}
/// A wrapper type that converts any ToolExecutor to one that outputs Value
pub struct ValueToolExecutor<T: ToolExecutor>(T);
impl<T: ToolExecutor<Output = Value> + 'static> IntoBoxedTool for T {
fn boxed(self) -> Box<dyn ToolExecutor<Output = Value>> {
Box::new(self)
#[async_trait]
impl<T: ToolExecutor> ToolExecutor for ValueToolExecutor<T> {
type Output = Value;
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output> {
let result = self.0.execute(tool_call).await?;
Ok(serde_json::to_value(result)?)
}
fn get_schema(&self) -> Value {
self.0.get_schema()
}
fn get_name(&self) -> String {
self.0.get_name()
}
}
/// Extension trait to add value conversion methods to ToolExecutor
pub trait IntoValueTool {
fn into_value_tool(self) -> ValueToolExecutor<Self> where Self: ToolExecutor + Sized;
}
// Implement IntoValueTool for all types that implement ToolExecutor
impl<T: ToolExecutor> IntoValueTool for T {
fn into_value_tool(self) -> ValueToolExecutor<Self> {
ValueToolExecutor(self)
}
}