From a2f3433555e753c1ae37d6f43d752e0952bb5cd1 Mon Sep 17 00:00:00 2001 From: dal Date: Fri, 7 Feb 2025 10:15:55 -0700 Subject: [PATCH] refactor(tools): Implement ValueToolExecutor for generic tool output conversion - Add `ValueToolExecutor` to convert tool outputs to `serde_json::Value` - Introduce `IntoValueTool` trait for easy value type conversion - Update agent tool addition methods to use new value conversion mechanism - Simplify tool registration by automatically converting tool outputs - Remove previous manual boxing and type conversion logic --- .../post_thread/agent_thread.rs | 22 +++--- api/src/utils/agent/agent.rs | 7 +- .../utils/tools/file_tools/create_files.rs | 70 ++++++------------- api/src/utils/tools/mod.rs | 35 ++++++++-- 4 files changed, 62 insertions(+), 72 deletions(-) diff --git a/api/src/routes/ws/threads_and_messages/post_thread/agent_thread.rs b/api/src/routes/ws/threads_and_messages/post_thread/agent_thread.rs index 824f86a6f..334e353c4 100644 --- a/api/src/routes/ws/threads_and_messages/post_thread/agent_thread.rs +++ b/api/src/routes/ws/threads_and_messages/post_thread/agent_thread.rs @@ -13,12 +13,12 @@ use crate::{ ws_utils::send_ws_message, }, utils::{ - agent::Agent, + agent::{Agent, AgentThread}, clients::ai::litellm::Message, - tools::file_tools::{ + tools::{file_tools::{ CreateFilesTool, ModifyFilesTool, OpenFilesTool, SearchDataCatalogTool, SearchFilesTool, SendToUserTool, - }, + }, ToolExecutor, IntoValueTool}, }, }; @@ -50,16 +50,12 @@ impl AgentThreadHandler { let open_files_tool = OpenFilesTool; let send_to_user_tool = SendToUserTool; - // Add each tool individually - agent.add_tool( - search_data_catalog_tool.get_name(), - search_data_catalog_tool, - ); - agent.add_tool(search_files_tool.get_name(), search_files_tool); - agent.add_tool(modify_files_tool.get_name(), modify_files_tool); - agent.add_tool(create_files_tool.get_name(), create_files_tool); - agent.add_tool(open_files_tool.get_name(), open_files_tool); - agent.add_tool(send_to_user_tool.get_name(), send_to_user_tool); + agent.add_tool(search_data_catalog_tool.get_name(), search_data_catalog_tool.into_value_tool()); + agent.add_tool(search_files_tool.get_name(), search_files_tool.into_value_tool()); + agent.add_tool(modify_files_tool.get_name(), modify_files_tool.into_value_tool()); + agent.add_tool(create_files_tool.get_name(), create_files_tool.into_value_tool()); + agent.add_tool(open_files_tool.get_name(), open_files_tool.into_value_tool()); + agent.add_tool(send_to_user_tool.get_name(), send_to_user_tool.into_value_tool()); Ok(Self { agent }) } diff --git a/api/src/utils/agent/agent.rs b/api/src/utils/agent/agent.rs index ad8ac037e..63764121e 100644 --- a/api/src/utils/agent/agent.rs +++ b/api/src/utils/agent/agent.rs @@ -6,6 +6,7 @@ use anyhow::Result; use serde_json::Value; use std::{collections::HashMap, env}; use tokio::sync::mpsc; +use serde::Serialize; use super::types::AgentThread; @@ -44,7 +45,7 @@ impl Agent { /// # Arguments /// * `name` - The name of the tool, used to identify it in tool calls /// * `tool` - The tool implementation that will be executed - pub fn add_tool + 'static>(&mut self, name: String, tool: T) { + pub fn add_tool(&mut self, name: String, tool: impl ToolExecutor + 'static) { self.tools.insert(name, Box::new(tool)); } @@ -52,9 +53,9 @@ impl Agent { /// /// # Arguments /// * `tools` - HashMap of tool names and their implementations - pub fn add_tools + 'static>( + pub fn add_tools + 'static>( &mut self, - tools: HashMap, + tools: HashMap, ) { for (name, tool) in tools { self.tools.insert(name, Box::new(tool)); diff --git a/api/src/utils/tools/file_tools/create_files.rs b/api/src/utils/tools/file_tools/create_files.rs index d3f05d247..0131804d4 100644 --- a/api/src/utils/tools/file_tools/create_files.rs +++ b/api/src/utils/tools/file_tools/create_files.rs @@ -17,7 +17,10 @@ use crate::{ utils::{clients::ai::litellm::ToolCall, tools::ToolExecutor}, }; -use super::{file_types::{dashboard_yml::DashboardYml, file::FileEnum, metric_yml::MetricYml}, FileModificationTool}; +use super::{ + file_types::{dashboard_yml::DashboardYml, file::FileEnum, metric_yml::MetricYml}, + FileModificationTool, +}; #[derive(Debug, Serialize, Deserialize, Clone)] struct FileParams { @@ -116,7 +119,10 @@ impl ToolExecutor for CreateFilesTool { if let Some(dashboard_id) = &dashboard_yml.id { let dashboard_file = DashboardFile { id: dashboard_id.clone(), - name: dashboard_yml.name.clone().unwrap_or_else(|| "New Dashboard".to_string()), + name: dashboard_yml + .name + .clone() + .unwrap_or_else(|| "New Dashboard".to_string()), file_name: format!("{}.yml", file.name), content: serde_json::to_value(dashboard_yml.clone()).unwrap(), filter: None, @@ -160,11 +166,12 @@ impl ToolExecutor for CreateFilesTool { created_files.extend(metric_ymls.into_iter().map(FileEnum::Metric)); } Err(e) => { - failed_files.extend( - metric_records - .iter() - .map(|r| (r.file_name.clone(), format!("Failed to create metric file: {}", e))), - ); + failed_files.extend(metric_records.iter().map(|r| { + ( + r.file_name.clone(), + format!("Failed to create metric file: {}", e), + ) + })); } } } @@ -180,11 +187,12 @@ impl ToolExecutor for CreateFilesTool { created_files.extend(dashboard_ymls.into_iter().map(FileEnum::Dashboard)); } Err(e) => { - failed_files.extend( - dashboard_records - .iter() - .map(|r| (r.file_name.clone(), format!("Failed to create dashboard file: {}", e))), - ); + failed_files.extend(dashboard_records.iter().map(|r| { + ( + r.file_name.clone(), + format!("Failed to create dashboard file: {}", e), + ) + })); } } } @@ -256,41 +264,3 @@ impl ToolExecutor for CreateFilesTool { }) } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_create_files_serialization() { - let tool = CreateFilesTool; - let output = CreateFilesOutput { - message: "Test message".to_string(), - files: vec![], - }; - - // Use the custom serialization - let result = tool.serialize_output(&output); - assert!(result.is_ok()); - } - - #[test] - fn test_create_files_with_content() { - let tool = CreateFilesTool; - let yml_content = "name: test\ntype: metric\ndescription: A test metric"; - let metric = MetricYml::new(yml_content.to_string()).unwrap(); - - let output = CreateFilesOutput { - message: "Test message".to_string(), - files: vec![FileEnum::Metric(metric)], - }; - - // Use the custom serialization - let result = tool.serialize_output(&output).unwrap(); - - // Verify line numbers were added - assert!(result.contains("1 | name: test")); - assert!(result.contains("2 | type: metric")); - assert!(result.contains("3 | description: A test metric")); - } -} diff --git a/api/src/utils/tools/mod.rs b/api/src/utils/tools/mod.rs index 5ebdadd1e..5b1092726 100644 --- a/api/src/utils/tools/mod.rs +++ b/api/src/utils/tools/mod.rs @@ -24,12 +24,35 @@ pub trait ToolExecutor: Send + Sync { fn get_name(&self) -> String; } -trait IntoBoxedTool { - fn boxed(self) -> Box>; -} +/// A wrapper type that converts any ToolExecutor to one that outputs Value +pub struct ValueToolExecutor(T); -impl + 'static> IntoBoxedTool for T { - fn boxed(self) -> Box> { - Box::new(self) +#[async_trait] +impl ToolExecutor for ValueToolExecutor { + type Output = Value; + + async fn execute(&self, tool_call: &ToolCall) -> Result { + let result = self.0.execute(tool_call).await?; + Ok(serde_json::to_value(result)?) + } + + fn get_schema(&self) -> Value { + self.0.get_schema() + } + + fn get_name(&self) -> String { + self.0.get_name() + } +} + +/// Extension trait to add value conversion methods to ToolExecutor +pub trait IntoValueTool { + fn into_value_tool(self) -> ValueToolExecutor where Self: ToolExecutor + Sized; +} + +// Implement IntoValueTool for all types that implement ToolExecutor +impl IntoValueTool for T { + fn into_value_tool(self) -> ValueToolExecutor { + ValueToolExecutor(self) } }