mirror of https://github.com/buster-so/buster.git
refactor(tools): Implement ValueToolExecutor for generic tool output conversion
- Add `ValueToolExecutor` to convert tool outputs to `serde_json::Value` - Introduce `IntoValueTool` trait for easy value type conversion - Update agent tool addition methods to use new value conversion mechanism - Simplify tool registration by automatically converting tool outputs - Remove previous manual boxing and type conversion logic
This commit is contained in:
parent
2090e0b7d7
commit
a2f3433555
|
@ -13,12 +13,12 @@ use crate::{
|
||||||
ws_utils::send_ws_message,
|
ws_utils::send_ws_message,
|
||||||
},
|
},
|
||||||
utils::{
|
utils::{
|
||||||
agent::Agent,
|
agent::{Agent, AgentThread},
|
||||||
clients::ai::litellm::Message,
|
clients::ai::litellm::Message,
|
||||||
tools::file_tools::{
|
tools::{file_tools::{
|
||||||
CreateFilesTool, ModifyFilesTool, OpenFilesTool, SearchDataCatalogTool,
|
CreateFilesTool, ModifyFilesTool, OpenFilesTool, SearchDataCatalogTool,
|
||||||
SearchFilesTool, SendToUserTool,
|
SearchFilesTool, SendToUserTool,
|
||||||
},
|
}, ToolExecutor, IntoValueTool},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -50,16 +50,12 @@ impl AgentThreadHandler {
|
||||||
let open_files_tool = OpenFilesTool;
|
let open_files_tool = OpenFilesTool;
|
||||||
let send_to_user_tool = SendToUserTool;
|
let send_to_user_tool = SendToUserTool;
|
||||||
|
|
||||||
// Add each tool individually
|
agent.add_tool(search_data_catalog_tool.get_name(), search_data_catalog_tool.into_value_tool());
|
||||||
agent.add_tool(
|
agent.add_tool(search_files_tool.get_name(), search_files_tool.into_value_tool());
|
||||||
search_data_catalog_tool.get_name(),
|
agent.add_tool(modify_files_tool.get_name(), modify_files_tool.into_value_tool());
|
||||||
search_data_catalog_tool,
|
agent.add_tool(create_files_tool.get_name(), create_files_tool.into_value_tool());
|
||||||
);
|
agent.add_tool(open_files_tool.get_name(), open_files_tool.into_value_tool());
|
||||||
agent.add_tool(search_files_tool.get_name(), search_files_tool);
|
agent.add_tool(send_to_user_tool.get_name(), send_to_user_tool.into_value_tool());
|
||||||
agent.add_tool(modify_files_tool.get_name(), modify_files_tool);
|
|
||||||
agent.add_tool(create_files_tool.get_name(), create_files_tool);
|
|
||||||
agent.add_tool(open_files_tool.get_name(), open_files_tool);
|
|
||||||
agent.add_tool(send_to_user_tool.get_name(), send_to_user_tool);
|
|
||||||
|
|
||||||
Ok(Self { agent })
|
Ok(Self { agent })
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ use anyhow::Result;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::{collections::HashMap, env};
|
use std::{collections::HashMap, env};
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
use super::types::AgentThread;
|
use super::types::AgentThread;
|
||||||
|
|
||||||
|
@ -44,7 +45,7 @@ impl Agent {
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `name` - The name of the tool, used to identify it in tool calls
|
/// * `name` - The name of the tool, used to identify it in tool calls
|
||||||
/// * `tool` - The tool implementation that will be executed
|
/// * `tool` - The tool implementation that will be executed
|
||||||
pub fn add_tool<T: ToolExecutor<Output = Value> + 'static>(&mut self, name: String, tool: T) {
|
pub fn add_tool(&mut self, name: String, tool: impl ToolExecutor<Output = Value> + 'static) {
|
||||||
self.tools.insert(name, Box::new(tool));
|
self.tools.insert(name, Box::new(tool));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -52,9 +53,9 @@ impl Agent {
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `tools` - HashMap of tool names and their implementations
|
/// * `tools` - HashMap of tool names and their implementations
|
||||||
pub fn add_tools<T: ToolExecutor<Output = Value> + 'static>(
|
pub fn add_tools<E: ToolExecutor<Output = Value> + 'static>(
|
||||||
&mut self,
|
&mut self,
|
||||||
tools: HashMap<String, T>,
|
tools: HashMap<String, E>,
|
||||||
) {
|
) {
|
||||||
for (name, tool) in tools {
|
for (name, tool) in tools {
|
||||||
self.tools.insert(name, Box::new(tool));
|
self.tools.insert(name, Box::new(tool));
|
||||||
|
|
|
@ -17,7 +17,10 @@ use crate::{
|
||||||
utils::{clients::ai::litellm::ToolCall, tools::ToolExecutor},
|
utils::{clients::ai::litellm::ToolCall, tools::ToolExecutor},
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{file_types::{dashboard_yml::DashboardYml, file::FileEnum, metric_yml::MetricYml}, FileModificationTool};
|
use super::{
|
||||||
|
file_types::{dashboard_yml::DashboardYml, file::FileEnum, metric_yml::MetricYml},
|
||||||
|
FileModificationTool,
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
struct FileParams {
|
struct FileParams {
|
||||||
|
@ -116,7 +119,10 @@ impl ToolExecutor for CreateFilesTool {
|
||||||
if let Some(dashboard_id) = &dashboard_yml.id {
|
if let Some(dashboard_id) = &dashboard_yml.id {
|
||||||
let dashboard_file = DashboardFile {
|
let dashboard_file = DashboardFile {
|
||||||
id: dashboard_id.clone(),
|
id: dashboard_id.clone(),
|
||||||
name: dashboard_yml.name.clone().unwrap_or_else(|| "New Dashboard".to_string()),
|
name: dashboard_yml
|
||||||
|
.name
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| "New Dashboard".to_string()),
|
||||||
file_name: format!("{}.yml", file.name),
|
file_name: format!("{}.yml", file.name),
|
||||||
content: serde_json::to_value(dashboard_yml.clone()).unwrap(),
|
content: serde_json::to_value(dashboard_yml.clone()).unwrap(),
|
||||||
filter: None,
|
filter: None,
|
||||||
|
@ -160,11 +166,12 @@ impl ToolExecutor for CreateFilesTool {
|
||||||
created_files.extend(metric_ymls.into_iter().map(FileEnum::Metric));
|
created_files.extend(metric_ymls.into_iter().map(FileEnum::Metric));
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
failed_files.extend(
|
failed_files.extend(metric_records.iter().map(|r| {
|
||||||
metric_records
|
(
|
||||||
.iter()
|
r.file_name.clone(),
|
||||||
.map(|r| (r.file_name.clone(), format!("Failed to create metric file: {}", e))),
|
format!("Failed to create metric file: {}", e),
|
||||||
);
|
)
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -180,11 +187,12 @@ impl ToolExecutor for CreateFilesTool {
|
||||||
created_files.extend(dashboard_ymls.into_iter().map(FileEnum::Dashboard));
|
created_files.extend(dashboard_ymls.into_iter().map(FileEnum::Dashboard));
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
failed_files.extend(
|
failed_files.extend(dashboard_records.iter().map(|r| {
|
||||||
dashboard_records
|
(
|
||||||
.iter()
|
r.file_name.clone(),
|
||||||
.map(|r| (r.file_name.clone(), format!("Failed to create dashboard file: {}", e))),
|
format!("Failed to create dashboard file: {}", e),
|
||||||
);
|
)
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -256,41 +264,3 @@ impl ToolExecutor for CreateFilesTool {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_create_files_serialization() {
|
|
||||||
let tool = CreateFilesTool;
|
|
||||||
let output = CreateFilesOutput {
|
|
||||||
message: "Test message".to_string(),
|
|
||||||
files: vec![],
|
|
||||||
};
|
|
||||||
|
|
||||||
// Use the custom serialization
|
|
||||||
let result = tool.serialize_output(&output);
|
|
||||||
assert!(result.is_ok());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_create_files_with_content() {
|
|
||||||
let tool = CreateFilesTool;
|
|
||||||
let yml_content = "name: test\ntype: metric\ndescription: A test metric";
|
|
||||||
let metric = MetricYml::new(yml_content.to_string()).unwrap();
|
|
||||||
|
|
||||||
let output = CreateFilesOutput {
|
|
||||||
message: "Test message".to_string(),
|
|
||||||
files: vec![FileEnum::Metric(metric)],
|
|
||||||
};
|
|
||||||
|
|
||||||
// Use the custom serialization
|
|
||||||
let result = tool.serialize_output(&output).unwrap();
|
|
||||||
|
|
||||||
// Verify line numbers were added
|
|
||||||
assert!(result.contains("1 | name: test"));
|
|
||||||
assert!(result.contains("2 | type: metric"));
|
|
||||||
assert!(result.contains("3 | description: A test metric"));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -24,12 +24,35 @@ pub trait ToolExecutor: Send + Sync {
|
||||||
fn get_name(&self) -> String;
|
fn get_name(&self) -> String;
|
||||||
}
|
}
|
||||||
|
|
||||||
trait IntoBoxedTool {
|
/// A wrapper type that converts any ToolExecutor to one that outputs Value
|
||||||
fn boxed(self) -> Box<dyn ToolExecutor<Output = Value>>;
|
pub struct ValueToolExecutor<T: ToolExecutor>(T);
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: ToolExecutor<Output = Value> + 'static> IntoBoxedTool for T {
|
#[async_trait]
|
||||||
fn boxed(self) -> Box<dyn ToolExecutor<Output = Value>> {
|
impl<T: ToolExecutor> ToolExecutor for ValueToolExecutor<T> {
|
||||||
Box::new(self)
|
type Output = Value;
|
||||||
|
|
||||||
|
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output> {
|
||||||
|
let result = self.0.execute(tool_call).await?;
|
||||||
|
Ok(serde_json::to_value(result)?)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_schema(&self) -> Value {
|
||||||
|
self.0.get_schema()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_name(&self) -> String {
|
||||||
|
self.0.get_name()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extension trait to add value conversion methods to ToolExecutor
|
||||||
|
pub trait IntoValueTool {
|
||||||
|
fn into_value_tool(self) -> ValueToolExecutor<Self> where Self: ToolExecutor + Sized;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement IntoValueTool for all types that implement ToolExecutor
|
||||||
|
impl<T: ToolExecutor> IntoValueTool for T {
|
||||||
|
fn into_value_tool(self) -> ValueToolExecutor<Self> {
|
||||||
|
ValueToolExecutor(self)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue