mirror of https://github.com/buster-so/buster.git
refactor(tools): Enhance ToolExecutor trait and file-related tools
- Add generic Output type to ToolExecutor trait - Update file tools to use strongly-typed output structs - Modify agent and tool implementations to support generic output - Improve error handling and result reporting in file-related tools - Add more detailed status messages for file operations
This commit is contained in:
parent
4ec6e78648
commit
711bbe899a
|
@ -3,9 +3,9 @@ use crate::utils::{
|
|||
tools::ToolExecutor,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use serde_json::Value;
|
||||
use std::{collections::HashMap, env};
|
||||
use tokio::sync::mpsc;
|
||||
use async_trait::async_trait;
|
||||
|
||||
use super::types::AgentThread;
|
||||
|
||||
|
@ -16,14 +16,14 @@ pub struct Agent {
|
|||
/// Client for communicating with the LLM provider
|
||||
llm_client: LiteLLMClient,
|
||||
/// Registry of available tools, mapped by their names
|
||||
tools: HashMap<String, Box<dyn ToolExecutor>>,
|
||||
tools: HashMap<String, Box<dyn ToolExecutor<Output = Value>>>,
|
||||
/// The model identifier to use (e.g., "gpt-4")
|
||||
model: String,
|
||||
}
|
||||
|
||||
impl Agent {
|
||||
/// Create a new Agent instance with a specific LLM client and model
|
||||
pub fn new(model: String, tools: HashMap<String, Box<dyn ToolExecutor>>) -> Self {
|
||||
pub fn new(model: String, tools: HashMap<String, Box<dyn ToolExecutor<Output = Value>>>) -> Self {
|
||||
let llm_api_key = env::var("LLM_API_KEY").expect("LLM_API_KEY must be set");
|
||||
let llm_base_url = env::var("LLM_BASE_URL").expect("LLM_API_BASE must be set");
|
||||
|
||||
|
@ -41,7 +41,7 @@ impl Agent {
|
|||
/// # Arguments
|
||||
/// * `name` - The name of the tool, used to identify it in tool calls
|
||||
/// * `tool` - The tool implementation that will be executed
|
||||
pub fn add_tool<T: ToolExecutor + 'static>(&mut self, name: String, tool: T) {
|
||||
pub fn add_tool<T: ToolExecutor<Output = Value> + 'static>(&mut self, name: String, tool: T) {
|
||||
self.tools.insert(name, Box::new(tool));
|
||||
}
|
||||
|
||||
|
@ -49,7 +49,7 @@ impl Agent {
|
|||
///
|
||||
/// # Arguments
|
||||
/// * `tools` - HashMap of tool names and their implementations
|
||||
pub fn add_tools<T: ToolExecutor + 'static>(&mut self, tools: HashMap<String, T>) {
|
||||
pub fn add_tools<T: ToolExecutor<Output = Value> + 'static>(&mut self, tools: HashMap<String, T>) {
|
||||
for (name, tool) in tools {
|
||||
self.tools.insert(name, Box::new(tool));
|
||||
}
|
||||
|
@ -204,6 +204,7 @@ mod tests {
|
|||
use crate::utils::clients::ai::litellm::ToolCall;
|
||||
|
||||
use super::*;
|
||||
use axum::async_trait;
|
||||
use dotenv::dotenv;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
|
@ -215,7 +216,9 @@ mod tests {
|
|||
|
||||
#[async_trait]
|
||||
impl ToolExecutor for WeatherTool {
|
||||
async fn execute(&self, tool_call: &ToolCall) -> Result<Value> {
|
||||
type Output = Value;
|
||||
|
||||
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output> {
|
||||
Ok(json!({
|
||||
"temperature": 20,
|
||||
"unit": "fahrenheit"
|
||||
|
|
|
@ -22,18 +22,29 @@ struct BulkModifyFilesParams {
|
|||
files_with_modifications: Vec<FileModification>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct BulkModifyFilesOutput {
|
||||
success: bool,
|
||||
}
|
||||
|
||||
pub struct BulkModifyFilesTool;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolExecutor for BulkModifyFilesTool {
|
||||
type Output = BulkModifyFilesOutput;
|
||||
|
||||
fn get_name(&self) -> String {
|
||||
"bulk_modify_files".to_string()
|
||||
}
|
||||
|
||||
async fn execute(&self, tool_call: &ToolCall) -> Result<Value> {
|
||||
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output> {
|
||||
let params: BulkModifyFilesParams = serde_json::from_str(&tool_call.function.arguments.clone())?;
|
||||
// TODO: Implement actual file modification logic
|
||||
Ok(Value::Array(vec![]))
|
||||
let output = BulkModifyFilesOutput {
|
||||
success: true,
|
||||
};
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn get_schema(&self) -> Value {
|
||||
|
|
|
@ -17,9 +17,9 @@ use crate::{
|
|||
utils::{clients::ai::litellm::ToolCall, tools::ToolExecutor},
|
||||
};
|
||||
|
||||
use super::file_types::{dashboard_yml::DashboardYml, metric_yml::MetricYml};
|
||||
use super::file_types::{dashboard_yml::DashboardYml, file::FileEnum, metric_yml::MetricYml};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
struct FileParams {
|
||||
name: String,
|
||||
file_type: String,
|
||||
|
@ -31,15 +31,24 @@ struct CreateFilesParams {
|
|||
files: Vec<FileParams>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct CreateFilesOutput {
|
||||
success: bool,
|
||||
message: String,
|
||||
files: Vec<FileEnum>,
|
||||
}
|
||||
|
||||
pub struct CreateFilesTool;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolExecutor for CreateFilesTool {
|
||||
type Output = CreateFilesOutput;
|
||||
|
||||
fn get_name(&self) -> String {
|
||||
"create_files".to_string()
|
||||
}
|
||||
|
||||
async fn execute(&self, tool_call: &ToolCall) -> Result<Value> {
|
||||
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output> {
|
||||
let params: CreateFilesParams =
|
||||
match serde_json::from_str(&tool_call.function.arguments.clone()) {
|
||||
Ok(params) => params,
|
||||
|
@ -53,15 +62,64 @@ impl ToolExecutor for CreateFilesTool {
|
|||
|
||||
let files = params.files;
|
||||
|
||||
let mut created_files = vec![];
|
||||
|
||||
let mut failed_files = vec![];
|
||||
|
||||
for file in files {
|
||||
match file.file_type.as_str() {
|
||||
"metric" => create_metric_file(file).await?,
|
||||
"dashboard" => create_dashboard_file(file).await?,
|
||||
_ => return Err(anyhow::anyhow!("Invalid file type: {}. Currently only `metric` and `dashboard` types are supported.", file.file_type)),
|
||||
}
|
||||
let created_file = match file.file_type.as_str() {
|
||||
"metric" => match create_metric_file(file.clone()).await {
|
||||
Ok(f) => {
|
||||
created_files.push(f);
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
failed_files.push((file.name, e.to_string()));
|
||||
continue;
|
||||
}
|
||||
},
|
||||
"dashboard" => match create_dashboard_file(file.clone()).await {
|
||||
Ok(f) => {
|
||||
created_files.push(f);
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
failed_files.push((file.name, e.to_string()));
|
||||
continue;
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
failed_files.push((file.name, format!("Invalid file type: {}. Currently only `metric` and `dashboard` types are supported.", file.file_type)));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
Ok(Value::Array(vec![]))
|
||||
let message = if failed_files.is_empty() {
|
||||
format!("Successfully created {} files.", created_files.len())
|
||||
} else {
|
||||
let success_msg = if !created_files.is_empty() {
|
||||
format!("Successfully created {} files. ", created_files.len())
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let failures: Vec<String> = failed_files
|
||||
.iter()
|
||||
.map(|(name, error)| format!("Failed to create '{}': {}", name, error))
|
||||
.collect();
|
||||
|
||||
format!("{}Failed to create {} files: {}",
|
||||
success_msg,
|
||||
failed_files.len(),
|
||||
failures.join("; "))
|
||||
};
|
||||
|
||||
Ok(CreateFilesOutput {
|
||||
success: !created_files.is_empty(),
|
||||
message,
|
||||
files: created_files,
|
||||
})
|
||||
}
|
||||
|
||||
fn get_schema(&self) -> Value {
|
||||
|
@ -104,7 +162,7 @@ impl ToolExecutor for CreateFilesTool {
|
|||
}
|
||||
}
|
||||
|
||||
async fn create_metric_file(file: FileParams) -> Result<()> {
|
||||
async fn create_metric_file(file: FileParams) -> Result<FileEnum> {
|
||||
let metric_yml = match MetricYml::new(file.yml_content) {
|
||||
Ok(metric_file) => metric_file,
|
||||
Err(e) => return Err(e),
|
||||
|
@ -128,7 +186,7 @@ async fn create_metric_file(file: FileParams) -> Result<()> {
|
|||
id: metric_id.clone(),
|
||||
name: metric_yml.title.clone(),
|
||||
file_name: format!("{}.yml", file.name),
|
||||
content: serde_json::to_value(metric_yml).unwrap(),
|
||||
content: serde_json::to_value(metric_yml.clone()).unwrap(),
|
||||
created_by: Uuid::new_v4(),
|
||||
verification: Verification::NotRequested,
|
||||
evaluation_obj: None,
|
||||
|
@ -140,9 +198,8 @@ async fn create_metric_file(file: FileParams) -> Result<()> {
|
|||
deleted_at: None,
|
||||
};
|
||||
|
||||
let metric_file_record = match insert_into(metric_files::table)
|
||||
.values(metric_file_record)
|
||||
.returning(metric_files::all_columns)
|
||||
match insert_into(metric_files::table)
|
||||
.values(&metric_file_record)
|
||||
.execute(&mut conn)
|
||||
.await
|
||||
{
|
||||
|
@ -150,10 +207,10 @@ async fn create_metric_file(file: FileParams) -> Result<()> {
|
|||
Err(e) => return Err(anyhow::anyhow!("Failed to create metric file: {}", e)),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
Ok(FileEnum::Metric(metric_yml))
|
||||
}
|
||||
|
||||
async fn create_dashboard_file(file: FileParams) -> Result<()> {
|
||||
async fn create_dashboard_file(file: FileParams) -> Result<FileEnum> {
|
||||
let dashboard_yml = match DashboardYml::new(file.yml_content) {
|
||||
Ok(dashboard_file) => dashboard_file,
|
||||
Err(e) => return Err(e),
|
||||
|
@ -175,9 +232,12 @@ async fn create_dashboard_file(file: FileParams) -> Result<()> {
|
|||
|
||||
let dashboard_file_record = DashboardFile {
|
||||
id: dashboard_id.clone(),
|
||||
name: dashboard_yml.name.clone().unwrap_or_else(|| "New Dashboard".to_string()),
|
||||
name: dashboard_yml
|
||||
.name
|
||||
.clone()
|
||||
.unwrap_or_else(|| "New Dashboard".to_string()),
|
||||
file_name: format!("{}.yml", file.name),
|
||||
content: serde_json::to_value(dashboard_yml).unwrap(),
|
||||
content: serde_json::to_value(dashboard_yml.clone()).unwrap(),
|
||||
filter: None,
|
||||
organization_id: Uuid::new_v4(),
|
||||
created_by: Uuid::new_v4(),
|
||||
|
@ -187,12 +247,14 @@ async fn create_dashboard_file(file: FileParams) -> Result<()> {
|
|||
};
|
||||
|
||||
match insert_into(dashboard_files::table)
|
||||
.values(dashboard_file_record)
|
||||
.values(&dashboard_file_record)
|
||||
.returning(dashboard_files::all_columns)
|
||||
.execute(&mut conn)
|
||||
.await
|
||||
{
|
||||
Ok(_) => Ok(()),
|
||||
Err(e) => Err(anyhow::anyhow!(e)),
|
||||
}
|
||||
Ok(_) => (),
|
||||
Err(e) => return Err(anyhow::anyhow!(e)),
|
||||
};
|
||||
|
||||
Ok(FileEnum::Dashboard(dashboard_yml))
|
||||
}
|
||||
|
|
|
@ -1,22 +1,24 @@
|
|||
use anyhow::Result;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct DashboardYml {
|
||||
pub id: Option<Uuid>,
|
||||
pub updated_at: Option<DateTime<Utc>>,
|
||||
pub name: Option<String>,
|
||||
pub rows: Vec<Row>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct Row {
|
||||
items: Vec<RowItem>, // max number of items in a row is 4, min is 1
|
||||
row_height: u32, // max is 550, min is 320
|
||||
column_sizes: Vec<u32>, // max sum of elements is 12 min is 3
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct RowItem {
|
||||
// This id is the id of the metric or item reference that goes here in the dashboard.
|
||||
id: Uuid,
|
||||
|
@ -40,6 +42,8 @@ impl DashboardYml {
|
|||
file.name = Some(String::from("New Dashboard"));
|
||||
}
|
||||
|
||||
file.updated_at = Some(Utc::now());
|
||||
|
||||
// Validate the file
|
||||
match file.validate() {
|
||||
Ok(_) => Ok(file),
|
||||
|
@ -47,6 +51,7 @@ impl DashboardYml {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: The validate of the dashboard should also be whether metrics exist?
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
// Validate the id and name
|
||||
if self.id.is_none() {
|
||||
|
@ -57,6 +62,10 @@ impl DashboardYml {
|
|||
return Err(anyhow::anyhow!("Dashboard file name is required"));
|
||||
}
|
||||
|
||||
if self.updated_at.is_none() {
|
||||
return Err(anyhow::anyhow!("Dashboard file updated_at is required"));
|
||||
}
|
||||
|
||||
// Validate each row
|
||||
for row in &self.rows {
|
||||
// Check row height constraints
|
||||
|
|
|
@ -1,10 +1,61 @@
|
|||
use anyhow::Result;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::dashboard_yml::DashboardYml;
|
||||
use super::metric_yml::MetricYml;
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct File {
|
||||
pub name: String,
|
||||
pub file_type: String,
|
||||
pub yml_content: String,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum FileEnum {
|
||||
Metric(MetricYml),
|
||||
Dashboard(DashboardYml),
|
||||
}
|
||||
|
||||
impl FileEnum {
|
||||
pub fn name(&self) -> anyhow::Result<String> {
|
||||
match self {
|
||||
Self::Metric(metric) => Ok(metric.title.clone()),
|
||||
Self::Dashboard(dashboard) => match &dashboard.name {
|
||||
Some(name) => Ok(name.clone()),
|
||||
None => Err(anyhow::anyhow!("Dashboard name is required but not found")),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> anyhow::Result<Uuid> {
|
||||
match self {
|
||||
Self::Metric(metric) => match metric.id {
|
||||
Some(id) => Ok(id),
|
||||
None => Err(anyhow::anyhow!("Metric id is required but not found")),
|
||||
},
|
||||
Self::Dashboard(dashboard) => match dashboard.id {
|
||||
Some(id) => Ok(id),
|
||||
None => Err(anyhow::anyhow!("Dashboard id is required but not found")),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn updated_at(&self) -> anyhow::Result<DateTime<Utc>> {
|
||||
match self {
|
||||
Self::Metric(metric) => match metric.updated_at {
|
||||
Some(dt) => Ok(dt),
|
||||
None => Err(anyhow::anyhow!(
|
||||
"Metric updated_at is required but not found"
|
||||
)),
|
||||
},
|
||||
Self::Dashboard(dashboard) => match dashboard.updated_at {
|
||||
Some(dt) => Ok(dt),
|
||||
None => Err(anyhow::anyhow!(
|
||||
"Dashboard updated_at is required but not found"
|
||||
)),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
use anyhow::Result;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct MetricYml {
|
||||
pub id: Option<Uuid>,
|
||||
pub updated_at: Option<DateTime<Utc>>,
|
||||
pub title: String,
|
||||
pub description: Option<String>,
|
||||
pub sql: String,
|
||||
|
@ -12,13 +14,13 @@ pub struct MetricYml {
|
|||
pub data_metadata: Vec<DataMetadata>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct DataMetadata {
|
||||
pub name: String,
|
||||
pub data_type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
#[serde(tag = "selectedChartType")]
|
||||
pub enum ChartConfig {
|
||||
#[serde(rename = "bar")]
|
||||
|
@ -38,7 +40,7 @@ pub enum ChartConfig {
|
|||
}
|
||||
|
||||
// Base chart config shared by all chart types
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct BaseChartConfig {
|
||||
pub column_label_formats: std::collections::HashMap<String, ColumnLabelFormat>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
|
@ -67,7 +69,7 @@ pub struct BaseChartConfig {
|
|||
pub y2_axis_config: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ColumnLabelFormat {
|
||||
pub column_type: String,
|
||||
pub style: String,
|
||||
|
@ -101,7 +103,7 @@ pub struct ColumnLabelFormat {
|
|||
pub convert_number_to: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ColumnSettings {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub show_data_labels: Option<bool>,
|
||||
|
@ -123,7 +125,7 @@ pub struct ColumnSettings {
|
|||
pub line_symbol_size_dot: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct GoalLine {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub show: Option<bool>,
|
||||
|
@ -137,7 +139,7 @@ pub struct GoalLine {
|
|||
pub goal_line_color: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct Trendline {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub show: Option<bool>,
|
||||
|
@ -151,7 +153,7 @@ pub struct Trendline {
|
|||
pub trend_line_color: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct BarLineChartConfig {
|
||||
#[serde(flatten)]
|
||||
pub base: BaseChartConfig,
|
||||
|
@ -168,7 +170,7 @@ pub struct BarLineChartConfig {
|
|||
pub line_group_type: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct BarAndLineAxis {
|
||||
pub x: Vec<String>,
|
||||
pub y: Vec<String>,
|
||||
|
@ -177,7 +179,7 @@ pub struct BarAndLineAxis {
|
|||
pub tooltip: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ScatterChartConfig {
|
||||
#[serde(flatten)]
|
||||
pub base: BaseChartConfig,
|
||||
|
@ -186,7 +188,7 @@ pub struct ScatterChartConfig {
|
|||
pub scatter_dot_size: Option<Vec<f64>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ScatterAxis {
|
||||
pub x: Vec<String>,
|
||||
pub y: Vec<String>,
|
||||
|
@ -198,7 +200,7 @@ pub struct ScatterAxis {
|
|||
pub tooltip: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct PieChartConfig {
|
||||
#[serde(flatten)]
|
||||
pub base: BaseChartConfig,
|
||||
|
@ -219,7 +221,7 @@ pub struct PieChartConfig {
|
|||
pub pie_minimum_slice_percentage: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct PieChartAxis {
|
||||
pub x: Vec<String>,
|
||||
pub y: Vec<String>,
|
||||
|
@ -227,14 +229,14 @@ pub struct PieChartAxis {
|
|||
pub tooltip: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ComboChartConfig {
|
||||
#[serde(flatten)]
|
||||
pub base: BaseChartConfig,
|
||||
pub combo_chart_axis: ComboChartAxis,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ComboChartAxis {
|
||||
pub x: Vec<String>,
|
||||
pub y: Vec<String>,
|
||||
|
@ -246,7 +248,7 @@ pub struct ComboChartAxis {
|
|||
pub tooltip: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct MetricChartConfig {
|
||||
#[serde(flatten)]
|
||||
pub base: BaseChartConfig,
|
||||
|
@ -261,7 +263,7 @@ pub struct MetricChartConfig {
|
|||
pub metric_value_label: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct TableChartConfig {
|
||||
#[serde(flatten)]
|
||||
pub base: BaseChartConfig,
|
||||
|
@ -288,12 +290,15 @@ impl MetricYml {
|
|||
file.id = Some(Uuid::new_v4());
|
||||
}
|
||||
|
||||
file.updated_at = Some(Utc::now());
|
||||
|
||||
match file.validate() {
|
||||
Ok(_) => Ok(file),
|
||||
Err(e) => Err(anyhow::anyhow!("Error compiling file: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
//TODO: Need to validate a metric deeply.
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -1,49 +1,558 @@
|
|||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
use diesel::prelude::*;
|
||||
use diesel_async::RunQueryDsl;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use tracing::{debug, error, info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::utils::{clients::ai::litellm::ToolCall, tools::ToolExecutor};
|
||||
use crate::{
|
||||
database::{
|
||||
lib::get_pg_pool,
|
||||
models::{DashboardFile, MetricFile},
|
||||
schema::{dashboard_files, metric_files},
|
||||
},
|
||||
utils::{
|
||||
clients::ai::litellm::ToolCall,
|
||||
tools::file_tools::file_types::{
|
||||
dashboard_yml::DashboardYml, file::FileEnum, metric_yml::MetricYml,
|
||||
},
|
||||
tools::ToolExecutor,
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct FileRequest {
|
||||
id: String,
|
||||
file_type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct OpenFilesParams {
|
||||
file_names: Vec<String>,
|
||||
files: Vec<FileRequest>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct OpenFilesOutput {
|
||||
message: String,
|
||||
results: Vec<FileEnum>,
|
||||
}
|
||||
|
||||
pub struct OpenFilesTool;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolExecutor for OpenFilesTool {
|
||||
type Output = OpenFilesOutput;
|
||||
|
||||
fn get_name(&self) -> String {
|
||||
"open_files".to_string()
|
||||
}
|
||||
|
||||
async fn execute(&self, tool_call: &ToolCall) -> Result<Value> {
|
||||
let params: OpenFilesParams = serde_json::from_str(&tool_call.function.arguments.clone())?;
|
||||
// TODO: Implement actual file opening logic
|
||||
Ok(Value::Array(vec![]))
|
||||
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output> {
|
||||
debug!("Starting file open operation");
|
||||
let params: OpenFilesParams =
|
||||
serde_json::from_str(&tool_call.function.arguments.clone())
|
||||
.map_err(|e| {
|
||||
error!(error = %e, "Failed to parse tool parameters");
|
||||
anyhow::anyhow!("Failed to parse tool parameters: {}", e)
|
||||
})?;
|
||||
|
||||
let mut results = Vec::new();
|
||||
let mut error_messages = Vec::new();
|
||||
|
||||
// Track requested IDs by type for later comparison
|
||||
let mut requested_ids: HashMap<String, HashSet<Uuid>> = HashMap::new();
|
||||
let mut found_ids: HashMap<String, HashSet<Uuid>> = HashMap::new();
|
||||
|
||||
// Group requests by file type and track requested IDs
|
||||
let grouped_requests = params
|
||||
.files
|
||||
.into_iter()
|
||||
.filter_map(|req| match Uuid::parse_str(&req.id) {
|
||||
Ok(id) => {
|
||||
requested_ids
|
||||
.entry(req.file_type.clone())
|
||||
.or_default()
|
||||
.insert(id);
|
||||
Some((req.file_type, id))
|
||||
}
|
||||
Err(_) => {
|
||||
warn!(invalid_id = %req.id, "Invalid UUID format");
|
||||
error_messages.push(format!("Invalid UUID format for id: {}", req.id));
|
||||
None
|
||||
}
|
||||
})
|
||||
.fold(HashMap::new(), |mut acc, (file_type, id)| {
|
||||
acc.entry(file_type).or_insert_with(Vec::new).push(id);
|
||||
acc
|
||||
});
|
||||
|
||||
// Process dashboard files
|
||||
if let Some(dashboard_ids) = grouped_requests.get("dashboard") {
|
||||
match get_dashboard_files(dashboard_ids).await {
|
||||
Ok(dashboard_files) => {
|
||||
for (dashboard_yml, id, _) in dashboard_files {
|
||||
found_ids
|
||||
.entry("dashboard".to_string())
|
||||
.or_default()
|
||||
.insert(id);
|
||||
results.push(FileEnum::Dashboard(dashboard_yml));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = %e, "Failed to process dashboard files");
|
||||
error_messages.push(format!("Error processing dashboard files: {}", e));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process metric files
|
||||
if let Some(metric_ids) = grouped_requests.get("metric") {
|
||||
match get_metric_files(metric_ids).await {
|
||||
Ok(metric_files) => {
|
||||
for (metric_yml, id, _) in metric_files {
|
||||
found_ids
|
||||
.entry("metric".to_string())
|
||||
.or_default()
|
||||
.insert(id);
|
||||
results.push(FileEnum::Metric(metric_yml));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = %e, "Failed to process metric files");
|
||||
error_messages.push(format!("Error processing metric files: {}", e));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generate message about missing files
|
||||
let mut missing_files = Vec::new();
|
||||
for (file_type, ids) in requested_ids.iter() {
|
||||
let found = found_ids.get(file_type).cloned().unwrap_or_default();
|
||||
let missing: Vec<_> = ids.difference(&found).collect();
|
||||
if !missing.is_empty() {
|
||||
warn!(
|
||||
file_type = %file_type,
|
||||
missing_count = missing.len(),
|
||||
missing_ids = ?missing,
|
||||
"Files not found"
|
||||
);
|
||||
missing_files.push(format!(
|
||||
"{} {}s: {}",
|
||||
missing.len(),
|
||||
file_type,
|
||||
missing
|
||||
.iter()
|
||||
.map(|id| id.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let message = build_status_message(&results, &missing_files, &error_messages);
|
||||
info!(
|
||||
total_requested = requested_ids.values().map(|ids| ids.len()).sum::<usize>(),
|
||||
total_found = results.len(),
|
||||
error_count = error_messages.len(),
|
||||
"Completed file open operation"
|
||||
);
|
||||
|
||||
Ok(OpenFilesOutput { message, results })
|
||||
}
|
||||
|
||||
fn get_schema(&self) -> Value {
|
||||
serde_json::json!({
|
||||
"name": "open_files",
|
||||
"strict": true,
|
||||
"description": "Opens one or more dashboard or metric files and returns their YML contents",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": ["file_names"],
|
||||
"required": ["files"],
|
||||
"properties": {
|
||||
"file_names": {
|
||||
"files": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"description": "The name of a file to be opened"
|
||||
"type": "object",
|
||||
"required": ["id", "file_type"],
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"description": "The UUID of the file"
|
||||
},
|
||||
"file_type": {
|
||||
"type": "string",
|
||||
"enum": ["dashboard", "metric"],
|
||||
"description": "The type of file to read"
|
||||
}
|
||||
}
|
||||
},
|
||||
"description": "List of file names to be opened"
|
||||
"description": "List of files to be opened"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
},
|
||||
"description": "Opens one or more files in read mode and displays **their entire contents** to the user. If you use this, the user will actually see the metric/dashboard you open."
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_dashboard_files(ids: &[Uuid]) -> Result<Vec<(DashboardYml, Uuid, String)>> {
|
||||
debug!(dashboard_ids = ?ids, "Fetching dashboard files");
|
||||
let mut conn = get_pg_pool()
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!(error = %e, "Failed to get database connection");
|
||||
anyhow::anyhow!("Failed to get database connection: {}", e)
|
||||
})?;
|
||||
|
||||
let files = match dashboard_files::table
|
||||
.filter(dashboard_files::id.eq_any(ids))
|
||||
.filter(dashboard_files::deleted_at.is_null())
|
||||
.load::<DashboardFile>(&mut conn)
|
||||
.await
|
||||
{
|
||||
Ok(files) => {
|
||||
debug!(count = files.len(), "Successfully loaded dashboard files from database");
|
||||
files
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = %e, "Failed to load dashboard files from database");
|
||||
return Err(anyhow::anyhow!(
|
||||
"Error loading dashboard files from database: {}",
|
||||
e
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let mut results = Vec::new();
|
||||
for file in files {
|
||||
match serde_json::from_value(file.content.clone()) {
|
||||
Ok(dashboard_yml) => {
|
||||
debug!(dashboard_id = %file.id, "Successfully parsed dashboard YAML");
|
||||
results.push((dashboard_yml, file.id, file.updated_at.to_string()));
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
error = %e,
|
||||
dashboard_id = %file.id,
|
||||
"Failed to parse dashboard YAML"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
requested_count = ids.len(),
|
||||
found_count = results.len(),
|
||||
"Completed dashboard files retrieval"
|
||||
);
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
async fn get_metric_files(ids: &[Uuid]) -> Result<Vec<(MetricYml, Uuid, String)>> {
|
||||
debug!(metric_ids = ?ids, "Fetching metric files");
|
||||
let mut conn = get_pg_pool()
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!(error = %e, "Failed to get database connection");
|
||||
anyhow::anyhow!("Failed to get database connection: {}", e)
|
||||
})?;
|
||||
|
||||
let files = match metric_files::table
|
||||
.filter(metric_files::id.eq_any(ids))
|
||||
.filter(metric_files::deleted_at.is_null())
|
||||
.load::<MetricFile>(&mut conn)
|
||||
.await
|
||||
{
|
||||
Ok(files) => {
|
||||
debug!(count = files.len(), "Successfully loaded metric files from database");
|
||||
files
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = %e, "Failed to load metric files from database");
|
||||
return Err(anyhow::anyhow!(
|
||||
"Error loading metric files from database: {}",
|
||||
e
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let mut results = Vec::new();
|
||||
for file in files {
|
||||
match serde_json::from_value(file.content.clone()) {
|
||||
Ok(metric_yml) => {
|
||||
debug!(metric_id = %file.id, "Successfully parsed metric YAML");
|
||||
results.push((metric_yml, file.id, file.updated_at.to_string()));
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
error = %e,
|
||||
metric_id = %file.id,
|
||||
"Failed to parse metric YAML"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
requested_count = ids.len(),
|
||||
found_count = results.len(),
|
||||
"Completed metric files retrieval"
|
||||
);
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
fn build_status_message(
|
||||
results: &[FileEnum],
|
||||
missing_files: &[String],
|
||||
error_messages: &[String],
|
||||
) -> String {
|
||||
let mut parts = Vec::new();
|
||||
|
||||
// Add success message if any files were found
|
||||
if !results.is_empty() {
|
||||
parts.push(format!("Successfully opened {} files", results.len()));
|
||||
}
|
||||
|
||||
// Add missing files information
|
||||
if !missing_files.is_empty() {
|
||||
parts.push(format!(
|
||||
"Could not find the following files: {}",
|
||||
missing_files.join("; ")
|
||||
));
|
||||
}
|
||||
|
||||
// Add any error messages
|
||||
if !error_messages.is_empty() {
|
||||
parts.push(format!(
|
||||
"Encountered the following issues: {}",
|
||||
error_messages.join("; ")
|
||||
));
|
||||
}
|
||||
|
||||
// If everything is empty, provide a clear message
|
||||
if parts.is_empty() {
|
||||
"No files were processed due to invalid input".to_string()
|
||||
} else {
|
||||
parts.join(". ")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::utils::tools::file_tools::file_types::metric_yml::{BarLineChartConfig, BaseChartConfig, BarAndLineAxis, ChartConfig, DataMetadata};
|
||||
|
||||
use super::*;
|
||||
use chrono::Utc;
|
||||
use serde_json::json;
|
||||
|
||||
fn create_test_dashboard() -> DashboardYml {
|
||||
DashboardYml {
|
||||
id: Some(Uuid::new_v4()),
|
||||
updated_at: Some(Utc::now()),
|
||||
name: Some("Test Dashboard".to_string()),
|
||||
rows: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
fn create_test_metric() -> MetricYml {
|
||||
MetricYml {
|
||||
id: Some(Uuid::new_v4()),
|
||||
updated_at: Some(Utc::now()),
|
||||
title: "Test Metric".to_string(),
|
||||
description: Some("Test Description".to_string()),
|
||||
sql: "SELECT * FROM test_table".to_string(),
|
||||
chart_config: ChartConfig::Bar(BarLineChartConfig {
|
||||
base: BaseChartConfig {
|
||||
column_label_formats: HashMap::new(),
|
||||
column_settings: None,
|
||||
colors: None,
|
||||
show_legend: None,
|
||||
grid_lines: None,
|
||||
show_legend_headline: None,
|
||||
goal_lines: None,
|
||||
trendlines: None,
|
||||
disable_tooltip: None,
|
||||
y_axis_config: None,
|
||||
x_axis_config: None,
|
||||
category_axis_style_config: None,
|
||||
y2_axis_config: None,
|
||||
},
|
||||
bar_and_line_axis: BarAndLineAxis {
|
||||
x: vec![],
|
||||
y: vec![],
|
||||
category: vec![],
|
||||
tooltip: None,
|
||||
},
|
||||
bar_layout: None,
|
||||
bar_sort_by: None,
|
||||
bar_group_type: None,
|
||||
bar_show_total_at_top: None,
|
||||
line_group_type: None,
|
||||
}),
|
||||
data_metadata: vec![
|
||||
DataMetadata {
|
||||
name: "id".to_string(),
|
||||
data_type: "number".to_string(),
|
||||
},
|
||||
DataMetadata {
|
||||
name: "value".to_string(),
|
||||
data_type: "string".to_string(),
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_status_message_all_success() {
|
||||
let results = vec![
|
||||
FileEnum::Dashboard(create_test_dashboard()),
|
||||
FileEnum::Metric(create_test_metric()),
|
||||
];
|
||||
let missing_files = vec![];
|
||||
let error_messages = vec![];
|
||||
|
||||
let message = build_status_message(&results, &missing_files, &error_messages);
|
||||
assert_eq!(message, "Successfully opened 2 files");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_status_message_with_missing() {
|
||||
let results = vec![
|
||||
FileEnum::Dashboard(create_test_dashboard()),
|
||||
FileEnum::Metric(create_test_metric()),
|
||||
];
|
||||
let missing_files = vec![
|
||||
"1 dashboard: abc-123".to_string(),
|
||||
"2 metrics: def-456, ghi-789".to_string(),
|
||||
];
|
||||
let error_messages = vec![];
|
||||
|
||||
let message = build_status_message(&results, &missing_files, &error_messages);
|
||||
assert_eq!(
|
||||
message,
|
||||
"Successfully opened 2 files. Could not find the following files: 1 dashboard: abc-123; 2 metrics: def-456, ghi-789"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_status_message_with_errors() {
|
||||
let results = vec![];
|
||||
let missing_files = vec![];
|
||||
let error_messages = vec![
|
||||
"Invalid UUID format for id: xyz".to_string(),
|
||||
"Error processing metric files: connection failed".to_string(),
|
||||
];
|
||||
|
||||
let message = build_status_message(&results, &missing_files, &error_messages);
|
||||
assert_eq!(
|
||||
message,
|
||||
"Encountered the following issues: Invalid UUID format for id: xyz; Error processing metric files: connection failed"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_status_message_mixed_results() {
|
||||
let results = vec![FileEnum::Metric(create_test_metric())];
|
||||
let missing_files = vec!["1 dashboard: abc-123".to_string()];
|
||||
let error_messages = vec!["Invalid UUID format for id: xyz".to_string()];
|
||||
|
||||
let message = build_status_message(&results, &missing_files, &error_messages);
|
||||
assert_eq!(
|
||||
message,
|
||||
"Successfully opened 1 files. Could not find the following files: 1 dashboard: abc-123. Encountered the following issues: Invalid UUID format for id: xyz"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_valid_params() {
|
||||
let params_json = json!({
|
||||
"files": [
|
||||
{"id": "550e8400-e29b-41d4-a716-446655440000", "file_type": "dashboard"},
|
||||
{"id": "550e8400-e29b-41d4-a716-446655440001", "file_type": "metric"}
|
||||
]
|
||||
});
|
||||
|
||||
let params: OpenFilesParams = serde_json::from_value(params_json).unwrap();
|
||||
assert_eq!(params.files.len(), 2);
|
||||
assert_eq!(params.files[0].file_type, "dashboard");
|
||||
assert_eq!(params.files[1].file_type, "metric");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_invalid_uuid() {
|
||||
let params_json = json!({
|
||||
"files": [
|
||||
{"id": "not-a-uuid", "file_type": "dashboard"},
|
||||
{"id": "also-not-a-uuid", "file_type": "metric"}
|
||||
]
|
||||
});
|
||||
|
||||
let params: OpenFilesParams = serde_json::from_value(params_json).unwrap();
|
||||
for file in ¶ms.files {
|
||||
let uuid_result = Uuid::parse_str(&file.id);
|
||||
assert!(uuid_result.is_err());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_invalid_file_type() {
|
||||
let params_json = json!({
|
||||
"files": [
|
||||
{"id": "550e8400-e29b-41d4-a716-446655440000", "file_type": "invalid"},
|
||||
{"id": "550e8400-e29b-41d4-a716-446655440001", "file_type": "unknown"}
|
||||
]
|
||||
});
|
||||
|
||||
let params: OpenFilesParams = serde_json::from_value(params_json).unwrap();
|
||||
for file in ¶ms.files {
|
||||
assert!(file.file_type != "dashboard" && file.file_type != "metric");
|
||||
}
|
||||
}
|
||||
|
||||
// Mock tests for file retrieval
|
||||
#[tokio::test]
|
||||
async fn test_get_dashboard_files() {
|
||||
let test_id = Uuid::new_v4();
|
||||
let dashboard = create_test_dashboard();
|
||||
let test_files = vec![DashboardFile {
|
||||
id: test_id,
|
||||
name: dashboard.name.clone().unwrap_or_default(),
|
||||
file_name: "test.yml".to_string(),
|
||||
content: serde_json::to_value(&dashboard).unwrap(),
|
||||
filter: None,
|
||||
organization_id: Uuid::new_v4(),
|
||||
created_by: Uuid::new_v4(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
deleted_at: None,
|
||||
}];
|
||||
|
||||
// TODO: Mock database connection and return test_files
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_metric_files() {
|
||||
let test_id = Uuid::new_v4();
|
||||
let metric = create_test_metric();
|
||||
let test_files = vec![MetricFile {
|
||||
id: test_id,
|
||||
name: metric.title.clone(),
|
||||
file_name: "test.yml".to_string(),
|
||||
content: serde_json::to_value(&metric).unwrap(),
|
||||
verification: crate::database::enums::Verification::NotRequested,
|
||||
evaluation_obj: None,
|
||||
evaluation_summary: None,
|
||||
evaluation_score: None,
|
||||
organization_id: Uuid::new_v4(),
|
||||
created_by: Uuid::new_v4(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
deleted_at: None,
|
||||
}];
|
||||
|
||||
// TODO: Mock database connection and return test_files
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::utils::{clients::ai::litellm::ToolCall, tools::ToolExecutor};
|
||||
|
||||
|
@ -22,18 +22,30 @@ struct CatalogSearchResult {
|
|||
metadata: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct SearchDataCatalogOutput {
|
||||
success: bool,
|
||||
results: Vec<CatalogSearchResult>,
|
||||
}
|
||||
|
||||
pub struct SearchDataCatalogTool;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolExecutor for SearchDataCatalogTool {
|
||||
type Output = SearchDataCatalogOutput;
|
||||
|
||||
fn get_name(&self) -> String {
|
||||
"search_data_catalog".to_string()
|
||||
}
|
||||
|
||||
async fn execute(&self, tool_call: &ToolCall) -> Result<Value> {
|
||||
let params: SearchDataCatalogParams = serde_json::from_str(&tool_call.function.arguments.clone())?;
|
||||
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output> {
|
||||
let params: SearchDataCatalogParams =
|
||||
serde_json::from_str(&tool_call.function.arguments.clone())?;
|
||||
// TODO: Implement actual data catalog search logic
|
||||
Ok(Value::Array(vec![]))
|
||||
Ok(SearchDataCatalogOutput {
|
||||
success: true,
|
||||
results: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
fn get_schema(&self) -> Value {
|
||||
|
@ -67,4 +79,4 @@ impl ToolExecutor for SearchDataCatalogTool {
|
|||
"description": "Searches the data catalog for relevant items including datasets, metrics, business terms, and logic definitions. Returns structured results with relevance scores. Use this to find data assets and their documentation."
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::utils::{clients::ai::litellm::ToolCall, tools::ToolExecutor};
|
||||
|
||||
|
@ -10,18 +10,28 @@ struct SearchFilesParams {
|
|||
query_params: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct SearchFilesOutput {
|
||||
success: bool,
|
||||
}
|
||||
|
||||
pub struct SearchFilesTool;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolExecutor for SearchFilesTool {
|
||||
type Output = SearchFilesOutput;
|
||||
|
||||
fn get_name(&self) -> String {
|
||||
"search_files".to_string()
|
||||
}
|
||||
|
||||
async fn execute(&self, tool_call: &ToolCall) -> Result<Value> {
|
||||
let params: SearchFilesParams = serde_json::from_str(&tool_call.function.arguments.clone())?;
|
||||
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output> {
|
||||
let params: SearchFilesParams =
|
||||
serde_json::from_str(&tool_call.function.arguments.clone())?;
|
||||
// TODO: Implement actual file search logic
|
||||
Ok(Value::Array(vec![]))
|
||||
Ok(SearchFilesOutput {
|
||||
success: true,
|
||||
})
|
||||
}
|
||||
|
||||
fn get_schema(&self) -> Value {
|
||||
|
@ -46,4 +56,4 @@ impl ToolExecutor for SearchFilesTool {
|
|||
"description": "Searches for metric and dashboard files using natural-language queries. Typically used if you suspect there might already be a relevant metric or dashboard in the repository. If results are found, you can then decide whether to open them with `open_files`."
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,18 +10,25 @@ struct SendToUserParams {
|
|||
metric_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct SendToUserOutput {
|
||||
success: bool,
|
||||
}
|
||||
|
||||
pub struct SendToUserTool;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolExecutor for SendToUserTool {
|
||||
type Output = SendToUserOutput;
|
||||
|
||||
fn get_name(&self) -> String {
|
||||
"send_to_user".to_string()
|
||||
}
|
||||
|
||||
async fn execute(&self, tool_call: &ToolCall) -> Result<Value> {
|
||||
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output> {
|
||||
let params: SendToUserParams = serde_json::from_str(&tool_call.function.arguments.clone())?;
|
||||
// TODO: Implement actual send to user logic
|
||||
Ok(Value::Array(vec![]))
|
||||
Ok(SendToUserOutput { success: true })
|
||||
}
|
||||
|
||||
fn get_schema(&self) -> Value {
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::utils::clients::ai::litellm::ToolCall;
|
||||
|
@ -10,8 +11,11 @@ pub mod file_tools;
|
|||
/// Any struct that wants to be used as a tool must implement this trait.
|
||||
#[async_trait]
|
||||
pub trait ToolExecutor: Send + Sync {
|
||||
/// The type of the output of the tool
|
||||
type Output: Serialize + Send;
|
||||
|
||||
/// Execute the tool with given arguments and return a result
|
||||
async fn execute(&self, tool_call: &ToolCall) -> Result<Value>;
|
||||
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output>;
|
||||
|
||||
/// Return the JSON schema that describes this tool's interface
|
||||
fn get_schema(&self) -> serde_json::Value;
|
||||
|
@ -21,11 +25,11 @@ pub trait ToolExecutor: Send + Sync {
|
|||
}
|
||||
|
||||
trait IntoBoxedTool {
|
||||
fn boxed(self) -> Box<dyn ToolExecutor>;
|
||||
fn boxed(self) -> Box<dyn ToolExecutor<Output = Value>>;
|
||||
}
|
||||
|
||||
impl<T: ToolExecutor + 'static> IntoBoxedTool for T {
|
||||
fn boxed(self) -> Box<dyn ToolExecutor> {
|
||||
impl<T: ToolExecutor<Output = Value> + 'static> IntoBoxedTool for T {
|
||||
fn boxed(self) -> Box<dyn ToolExecutor<Output = Value>> {
|
||||
Box::new(self)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue