From 711bbe899af8c79261f5f30764880d124c95d38d Mon Sep 17 00:00:00 2001 From: dal Date: Thu, 6 Feb 2025 23:45:48 -0700 Subject: [PATCH] refactor(tools): Enhance ToolExecutor trait and file-related tools - Add generic Output type to ToolExecutor trait - Update file tools to use strongly-typed output structs - Modify agent and tool implementations to support generic output - Improve error handling and result reporting in file-related tools - Add more detailed status messages for file operations --- api/src/utils/agent/agent.rs | 15 +- .../tools/file_tools/bulk_modify_files.rs | 15 +- .../utils/tools/file_tools/create_files.rs | 106 +++- .../file_tools/file_types/dashboard_yml.rs | 15 +- .../utils/tools/file_tools/file_types/file.rs | 55 +- .../tools/file_tools/file_types/metric_yml.rs | 41 +- api/src/utils/tools/file_tools/open_files.rs | 545 +++++++++++++++++- .../tools/file_tools/search_data_catalog.rs | 22 +- .../utils/tools/file_tools/search_files.rs | 20 +- .../utils/tools/file_tools/send_to_user.rs | 11 +- api/src/utils/tools/mod.rs | 12 +- 11 files changed, 770 insertions(+), 87 deletions(-) diff --git a/api/src/utils/agent/agent.rs b/api/src/utils/agent/agent.rs index 69b664aa5..68e86df45 100644 --- a/api/src/utils/agent/agent.rs +++ b/api/src/utils/agent/agent.rs @@ -3,9 +3,9 @@ use crate::utils::{ tools::ToolExecutor, }; use anyhow::Result; +use serde_json::Value; use std::{collections::HashMap, env}; use tokio::sync::mpsc; -use async_trait::async_trait; use super::types::AgentThread; @@ -16,14 +16,14 @@ pub struct Agent { /// Client for communicating with the LLM provider llm_client: LiteLLMClient, /// Registry of available tools, mapped by their names - tools: HashMap>, + tools: HashMap>>, /// The model identifier to use (e.g., "gpt-4") model: String, } impl Agent { /// Create a new Agent instance with a specific LLM client and model - pub fn new(model: String, tools: HashMap>) -> Self { + pub fn new(model: String, tools: HashMap>>) -> Self { let llm_api_key = env::var("LLM_API_KEY").expect("LLM_API_KEY must be set"); let llm_base_url = env::var("LLM_BASE_URL").expect("LLM_API_BASE must be set"); @@ -41,7 +41,7 @@ impl Agent { /// # Arguments /// * `name` - The name of the tool, used to identify it in tool calls /// * `tool` - The tool implementation that will be executed - pub fn add_tool(&mut self, name: String, tool: T) { + pub fn add_tool + 'static>(&mut self, name: String, tool: T) { self.tools.insert(name, Box::new(tool)); } @@ -49,7 +49,7 @@ impl Agent { /// /// # Arguments /// * `tools` - HashMap of tool names and their implementations - pub fn add_tools(&mut self, tools: HashMap) { + pub fn add_tools + 'static>(&mut self, tools: HashMap) { for (name, tool) in tools { self.tools.insert(name, Box::new(tool)); } @@ -204,6 +204,7 @@ mod tests { use crate::utils::clients::ai::litellm::ToolCall; use super::*; + use axum::async_trait; use dotenv::dotenv; use serde_json::{json, Value}; @@ -215,7 +216,9 @@ mod tests { #[async_trait] impl ToolExecutor for WeatherTool { - async fn execute(&self, tool_call: &ToolCall) -> Result { + type Output = Value; + + async fn execute(&self, tool_call: &ToolCall) -> Result { Ok(json!({ "temperature": 20, "unit": "fahrenheit" diff --git a/api/src/utils/tools/file_tools/bulk_modify_files.rs b/api/src/utils/tools/file_tools/bulk_modify_files.rs index be1cc9f2c..a939f10fb 100644 --- a/api/src/utils/tools/file_tools/bulk_modify_files.rs +++ b/api/src/utils/tools/file_tools/bulk_modify_files.rs @@ -22,18 +22,29 @@ struct BulkModifyFilesParams { files_with_modifications: Vec, } +#[derive(Debug, Serialize)] +pub struct BulkModifyFilesOutput { + success: bool, +} + pub struct BulkModifyFilesTool; #[async_trait] impl ToolExecutor for BulkModifyFilesTool { + type Output = BulkModifyFilesOutput; + fn get_name(&self) -> String { "bulk_modify_files".to_string() } - async fn execute(&self, tool_call: &ToolCall) -> Result { + async fn execute(&self, tool_call: &ToolCall) -> Result { let params: BulkModifyFilesParams = serde_json::from_str(&tool_call.function.arguments.clone())?; // TODO: Implement actual file modification logic - Ok(Value::Array(vec![])) + let output = BulkModifyFilesOutput { + success: true, + }; + + Ok(output) } fn get_schema(&self) -> Value { diff --git a/api/src/utils/tools/file_tools/create_files.rs b/api/src/utils/tools/file_tools/create_files.rs index 21e77a7d8..df9d82e13 100644 --- a/api/src/utils/tools/file_tools/create_files.rs +++ b/api/src/utils/tools/file_tools/create_files.rs @@ -17,9 +17,9 @@ use crate::{ utils::{clients::ai::litellm::ToolCall, tools::ToolExecutor}, }; -use super::file_types::{dashboard_yml::DashboardYml, metric_yml::MetricYml}; +use super::file_types::{dashboard_yml::DashboardYml, file::FileEnum, metric_yml::MetricYml}; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] struct FileParams { name: String, file_type: String, @@ -31,15 +31,24 @@ struct CreateFilesParams { files: Vec, } +#[derive(Debug, Serialize)] +pub struct CreateFilesOutput { + success: bool, + message: String, + files: Vec, +} + pub struct CreateFilesTool; #[async_trait] impl ToolExecutor for CreateFilesTool { + type Output = CreateFilesOutput; + fn get_name(&self) -> String { "create_files".to_string() } - async fn execute(&self, tool_call: &ToolCall) -> Result { + async fn execute(&self, tool_call: &ToolCall) -> Result { let params: CreateFilesParams = match serde_json::from_str(&tool_call.function.arguments.clone()) { Ok(params) => params, @@ -53,15 +62,64 @@ impl ToolExecutor for CreateFilesTool { let files = params.files; + let mut created_files = vec![]; + + let mut failed_files = vec![]; + for file in files { - match file.file_type.as_str() { - "metric" => create_metric_file(file).await?, - "dashboard" => create_dashboard_file(file).await?, - _ => return Err(anyhow::anyhow!("Invalid file type: {}. Currently only `metric` and `dashboard` types are supported.", file.file_type)), - } + let created_file = match file.file_type.as_str() { + "metric" => match create_metric_file(file.clone()).await { + Ok(f) => { + created_files.push(f); + continue; + } + Err(e) => { + failed_files.push((file.name, e.to_string())); + continue; + } + }, + "dashboard" => match create_dashboard_file(file.clone()).await { + Ok(f) => { + created_files.push(f); + continue; + } + Err(e) => { + failed_files.push((file.name, e.to_string())); + continue; + } + }, + _ => { + failed_files.push((file.name, format!("Invalid file type: {}. Currently only `metric` and `dashboard` types are supported.", file.file_type))); + continue; + } + }; } - Ok(Value::Array(vec![])) + let message = if failed_files.is_empty() { + format!("Successfully created {} files.", created_files.len()) + } else { + let success_msg = if !created_files.is_empty() { + format!("Successfully created {} files. ", created_files.len()) + } else { + String::new() + }; + + let failures: Vec = failed_files + .iter() + .map(|(name, error)| format!("Failed to create '{}': {}", name, error)) + .collect(); + + format!("{}Failed to create {} files: {}", + success_msg, + failed_files.len(), + failures.join("; ")) + }; + + Ok(CreateFilesOutput { + success: !created_files.is_empty(), + message, + files: created_files, + }) } fn get_schema(&self) -> Value { @@ -104,7 +162,7 @@ impl ToolExecutor for CreateFilesTool { } } -async fn create_metric_file(file: FileParams) -> Result<()> { +async fn create_metric_file(file: FileParams) -> Result { let metric_yml = match MetricYml::new(file.yml_content) { Ok(metric_file) => metric_file, Err(e) => return Err(e), @@ -128,7 +186,7 @@ async fn create_metric_file(file: FileParams) -> Result<()> { id: metric_id.clone(), name: metric_yml.title.clone(), file_name: format!("{}.yml", file.name), - content: serde_json::to_value(metric_yml).unwrap(), + content: serde_json::to_value(metric_yml.clone()).unwrap(), created_by: Uuid::new_v4(), verification: Verification::NotRequested, evaluation_obj: None, @@ -140,9 +198,8 @@ async fn create_metric_file(file: FileParams) -> Result<()> { deleted_at: None, }; - let metric_file_record = match insert_into(metric_files::table) - .values(metric_file_record) - .returning(metric_files::all_columns) + match insert_into(metric_files::table) + .values(&metric_file_record) .execute(&mut conn) .await { @@ -150,10 +207,10 @@ async fn create_metric_file(file: FileParams) -> Result<()> { Err(e) => return Err(anyhow::anyhow!("Failed to create metric file: {}", e)), }; - Ok(()) + Ok(FileEnum::Metric(metric_yml)) } -async fn create_dashboard_file(file: FileParams) -> Result<()> { +async fn create_dashboard_file(file: FileParams) -> Result { let dashboard_yml = match DashboardYml::new(file.yml_content) { Ok(dashboard_file) => dashboard_file, Err(e) => return Err(e), @@ -175,9 +232,12 @@ async fn create_dashboard_file(file: FileParams) -> Result<()> { let dashboard_file_record = DashboardFile { id: dashboard_id.clone(), - name: dashboard_yml.name.clone().unwrap_or_else(|| "New Dashboard".to_string()), + name: dashboard_yml + .name + .clone() + .unwrap_or_else(|| "New Dashboard".to_string()), file_name: format!("{}.yml", file.name), - content: serde_json::to_value(dashboard_yml).unwrap(), + content: serde_json::to_value(dashboard_yml.clone()).unwrap(), filter: None, organization_id: Uuid::new_v4(), created_by: Uuid::new_v4(), @@ -187,12 +247,14 @@ async fn create_dashboard_file(file: FileParams) -> Result<()> { }; match insert_into(dashboard_files::table) - .values(dashboard_file_record) + .values(&dashboard_file_record) .returning(dashboard_files::all_columns) .execute(&mut conn) .await { - Ok(_) => Ok(()), - Err(e) => Err(anyhow::anyhow!(e)), - } + Ok(_) => (), + Err(e) => return Err(anyhow::anyhow!(e)), + }; + + Ok(FileEnum::Dashboard(dashboard_yml)) } diff --git a/api/src/utils/tools/file_tools/file_types/dashboard_yml.rs b/api/src/utils/tools/file_tools/file_types/dashboard_yml.rs index 7b549bd77..8d19d70d6 100644 --- a/api/src/utils/tools/file_tools/file_types/dashboard_yml.rs +++ b/api/src/utils/tools/file_tools/file_types/dashboard_yml.rs @@ -1,22 +1,24 @@ use anyhow::Result; +use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use uuid::Uuid; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct DashboardYml { pub id: Option, + pub updated_at: Option>, pub name: Option, pub rows: Vec, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct Row { items: Vec, // max number of items in a row is 4, min is 1 row_height: u32, // max is 550, min is 320 column_sizes: Vec, // max sum of elements is 12 min is 3 } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct RowItem { // This id is the id of the metric or item reference that goes here in the dashboard. id: Uuid, @@ -40,6 +42,8 @@ impl DashboardYml { file.name = Some(String::from("New Dashboard")); } + file.updated_at = Some(Utc::now()); + // Validate the file match file.validate() { Ok(_) => Ok(file), @@ -47,6 +51,7 @@ impl DashboardYml { } } + // TODO: The validate of the dashboard should also be whether metrics exist? pub fn validate(&self) -> Result<()> { // Validate the id and name if self.id.is_none() { @@ -57,6 +62,10 @@ impl DashboardYml { return Err(anyhow::anyhow!("Dashboard file name is required")); } + if self.updated_at.is_none() { + return Err(anyhow::anyhow!("Dashboard file updated_at is required")); + } + // Validate each row for row in &self.rows { // Check row height constraints diff --git a/api/src/utils/tools/file_tools/file_types/file.rs b/api/src/utils/tools/file_tools/file_types/file.rs index 31027173e..458cefc0e 100644 --- a/api/src/utils/tools/file_tools/file_types/file.rs +++ b/api/src/utils/tools/file_tools/file_types/file.rs @@ -1,10 +1,61 @@ -use anyhow::Result; +use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use uuid::Uuid; +use super::dashboard_yml::DashboardYml; +use super::metric_yml::MetricYml; #[derive(Debug, Serialize, Deserialize)] pub struct File { pub name: String, pub file_type: String, pub yml_content: String, -} \ No newline at end of file +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum FileEnum { + Metric(MetricYml), + Dashboard(DashboardYml), +} + +impl FileEnum { + pub fn name(&self) -> anyhow::Result { + match self { + Self::Metric(metric) => Ok(metric.title.clone()), + Self::Dashboard(dashboard) => match &dashboard.name { + Some(name) => Ok(name.clone()), + None => Err(anyhow::anyhow!("Dashboard name is required but not found")), + }, + } + } + + pub fn id(&self) -> anyhow::Result { + match self { + Self::Metric(metric) => match metric.id { + Some(id) => Ok(id), + None => Err(anyhow::anyhow!("Metric id is required but not found")), + }, + Self::Dashboard(dashboard) => match dashboard.id { + Some(id) => Ok(id), + None => Err(anyhow::anyhow!("Dashboard id is required but not found")), + }, + } + } + + pub fn updated_at(&self) -> anyhow::Result> { + match self { + Self::Metric(metric) => match metric.updated_at { + Some(dt) => Ok(dt), + None => Err(anyhow::anyhow!( + "Metric updated_at is required but not found" + )), + }, + Self::Dashboard(dashboard) => match dashboard.updated_at { + Some(dt) => Ok(dt), + None => Err(anyhow::anyhow!( + "Dashboard updated_at is required but not found" + )), + }, + } + } +} diff --git a/api/src/utils/tools/file_tools/file_types/metric_yml.rs b/api/src/utils/tools/file_tools/file_types/metric_yml.rs index 560935724..7841616a1 100644 --- a/api/src/utils/tools/file_tools/file_types/metric_yml.rs +++ b/api/src/utils/tools/file_tools/file_types/metric_yml.rs @@ -1,10 +1,12 @@ use anyhow::Result; +use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use uuid::Uuid; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct MetricYml { pub id: Option, + pub updated_at: Option>, pub title: String, pub description: Option, pub sql: String, @@ -12,13 +14,13 @@ pub struct MetricYml { pub data_metadata: Vec, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct DataMetadata { pub name: String, pub data_type: String, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] #[serde(tag = "selectedChartType")] pub enum ChartConfig { #[serde(rename = "bar")] @@ -38,7 +40,7 @@ pub enum ChartConfig { } // Base chart config shared by all chart types -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct BaseChartConfig { pub column_label_formats: std::collections::HashMap, #[serde(skip_serializing_if = "Option::is_none")] @@ -67,7 +69,7 @@ pub struct BaseChartConfig { pub y2_axis_config: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct ColumnLabelFormat { pub column_type: String, pub style: String, @@ -101,7 +103,7 @@ pub struct ColumnLabelFormat { pub convert_number_to: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct ColumnSettings { #[serde(skip_serializing_if = "Option::is_none")] pub show_data_labels: Option, @@ -123,7 +125,7 @@ pub struct ColumnSettings { pub line_symbol_size_dot: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct GoalLine { #[serde(skip_serializing_if = "Option::is_none")] pub show: Option, @@ -137,7 +139,7 @@ pub struct GoalLine { pub goal_line_color: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct Trendline { #[serde(skip_serializing_if = "Option::is_none")] pub show: Option, @@ -151,7 +153,7 @@ pub struct Trendline { pub trend_line_color: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct BarLineChartConfig { #[serde(flatten)] pub base: BaseChartConfig, @@ -168,7 +170,7 @@ pub struct BarLineChartConfig { pub line_group_type: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct BarAndLineAxis { pub x: Vec, pub y: Vec, @@ -177,7 +179,7 @@ pub struct BarAndLineAxis { pub tooltip: Option>, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct ScatterChartConfig { #[serde(flatten)] pub base: BaseChartConfig, @@ -186,7 +188,7 @@ pub struct ScatterChartConfig { pub scatter_dot_size: Option>, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct ScatterAxis { pub x: Vec, pub y: Vec, @@ -198,7 +200,7 @@ pub struct ScatterAxis { pub tooltip: Option>, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct PieChartConfig { #[serde(flatten)] pub base: BaseChartConfig, @@ -219,7 +221,7 @@ pub struct PieChartConfig { pub pie_minimum_slice_percentage: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct PieChartAxis { pub x: Vec, pub y: Vec, @@ -227,14 +229,14 @@ pub struct PieChartAxis { pub tooltip: Option>, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct ComboChartConfig { #[serde(flatten)] pub base: BaseChartConfig, pub combo_chart_axis: ComboChartAxis, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct ComboChartAxis { pub x: Vec, pub y: Vec, @@ -246,7 +248,7 @@ pub struct ComboChartAxis { pub tooltip: Option>, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct MetricChartConfig { #[serde(flatten)] pub base: BaseChartConfig, @@ -261,7 +263,7 @@ pub struct MetricChartConfig { pub metric_value_label: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct TableChartConfig { #[serde(flatten)] pub base: BaseChartConfig, @@ -288,12 +290,15 @@ impl MetricYml { file.id = Some(Uuid::new_v4()); } + file.updated_at = Some(Utc::now()); + match file.validate() { Ok(_) => Ok(file), Err(e) => Err(anyhow::anyhow!("Error compiling file: {}", e)), } } + //TODO: Need to validate a metric deeply. pub fn validate(&self) -> Result<()> { Ok(()) } diff --git a/api/src/utils/tools/file_tools/open_files.rs b/api/src/utils/tools/file_tools/open_files.rs index b416d0248..5b1f098f5 100644 --- a/api/src/utils/tools/file_tools/open_files.rs +++ b/api/src/utils/tools/file_tools/open_files.rs @@ -1,49 +1,558 @@ use anyhow::Result; use async_trait::async_trait; -use serde_json::Value; +use diesel::prelude::*; +use diesel_async::RunQueryDsl; use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::{HashMap, HashSet}; +use tracing::{debug, error, info, warn}; +use uuid::Uuid; -use crate::utils::{clients::ai::litellm::ToolCall, tools::ToolExecutor}; +use crate::{ + database::{ + lib::get_pg_pool, + models::{DashboardFile, MetricFile}, + schema::{dashboard_files, metric_files}, + }, + utils::{ + clients::ai::litellm::ToolCall, + tools::file_tools::file_types::{ + dashboard_yml::DashboardYml, file::FileEnum, metric_yml::MetricYml, + }, + tools::ToolExecutor, + }, +}; + +#[derive(Debug, Serialize, Deserialize)] +struct FileRequest { + id: String, + file_type: String, +} #[derive(Debug, Serialize, Deserialize)] struct OpenFilesParams { - file_names: Vec, + files: Vec, +} + +#[derive(Debug, Serialize)] +pub struct OpenFilesOutput { + message: String, + results: Vec, } pub struct OpenFilesTool; #[async_trait] impl ToolExecutor for OpenFilesTool { + type Output = OpenFilesOutput; + fn get_name(&self) -> String { "open_files".to_string() } - async fn execute(&self, tool_call: &ToolCall) -> Result { - let params: OpenFilesParams = serde_json::from_str(&tool_call.function.arguments.clone())?; - // TODO: Implement actual file opening logic - Ok(Value::Array(vec![])) + async fn execute(&self, tool_call: &ToolCall) -> Result { + debug!("Starting file open operation"); + let params: OpenFilesParams = + serde_json::from_str(&tool_call.function.arguments.clone()) + .map_err(|e| { + error!(error = %e, "Failed to parse tool parameters"); + anyhow::anyhow!("Failed to parse tool parameters: {}", e) + })?; + + let mut results = Vec::new(); + let mut error_messages = Vec::new(); + + // Track requested IDs by type for later comparison + let mut requested_ids: HashMap> = HashMap::new(); + let mut found_ids: HashMap> = HashMap::new(); + + // Group requests by file type and track requested IDs + let grouped_requests = params + .files + .into_iter() + .filter_map(|req| match Uuid::parse_str(&req.id) { + Ok(id) => { + requested_ids + .entry(req.file_type.clone()) + .or_default() + .insert(id); + Some((req.file_type, id)) + } + Err(_) => { + warn!(invalid_id = %req.id, "Invalid UUID format"); + error_messages.push(format!("Invalid UUID format for id: {}", req.id)); + None + } + }) + .fold(HashMap::new(), |mut acc, (file_type, id)| { + acc.entry(file_type).or_insert_with(Vec::new).push(id); + acc + }); + + // Process dashboard files + if let Some(dashboard_ids) = grouped_requests.get("dashboard") { + match get_dashboard_files(dashboard_ids).await { + Ok(dashboard_files) => { + for (dashboard_yml, id, _) in dashboard_files { + found_ids + .entry("dashboard".to_string()) + .or_default() + .insert(id); + results.push(FileEnum::Dashboard(dashboard_yml)); + } + } + Err(e) => { + error!(error = %e, "Failed to process dashboard files"); + error_messages.push(format!("Error processing dashboard files: {}", e)); + } + } + } + + // Process metric files + if let Some(metric_ids) = grouped_requests.get("metric") { + match get_metric_files(metric_ids).await { + Ok(metric_files) => { + for (metric_yml, id, _) in metric_files { + found_ids + .entry("metric".to_string()) + .or_default() + .insert(id); + results.push(FileEnum::Metric(metric_yml)); + } + } + Err(e) => { + error!(error = %e, "Failed to process metric files"); + error_messages.push(format!("Error processing metric files: {}", e)); + } + } + } + + // Generate message about missing files + let mut missing_files = Vec::new(); + for (file_type, ids) in requested_ids.iter() { + let found = found_ids.get(file_type).cloned().unwrap_or_default(); + let missing: Vec<_> = ids.difference(&found).collect(); + if !missing.is_empty() { + warn!( + file_type = %file_type, + missing_count = missing.len(), + missing_ids = ?missing, + "Files not found" + ); + missing_files.push(format!( + "{} {}s: {}", + missing.len(), + file_type, + missing + .iter() + .map(|id| id.to_string()) + .collect::>() + .join(", ") + )); + } + } + + let message = build_status_message(&results, &missing_files, &error_messages); + info!( + total_requested = requested_ids.values().map(|ids| ids.len()).sum::(), + total_found = results.len(), + error_count = error_messages.len(), + "Completed file open operation" + ); + + Ok(OpenFilesOutput { message, results }) } fn get_schema(&self) -> Value { serde_json::json!({ "name": "open_files", - "strict": true, + "description": "Opens one or more dashboard or metric files and returns their YML contents", "parameters": { "type": "object", - "required": ["file_names"], + "required": ["files"], "properties": { - "file_names": { + "files": { "type": "array", "items": { - "type": "string", - "description": "The name of a file to be opened" + "type": "object", + "required": ["id", "file_type"], + "properties": { + "id": { + "type": "string", + "description": "The UUID of the file" + }, + "file_type": { + "type": "string", + "enum": ["dashboard", "metric"], + "description": "The type of file to read" + } + } }, - "description": "List of file names to be opened" + "description": "List of files to be opened" } - }, - "additionalProperties": false - }, - "description": "Opens one or more files in read mode and displays **their entire contents** to the user. If you use this, the user will actually see the metric/dashboard you open." + } + } }) } -} \ No newline at end of file +} + +async fn get_dashboard_files(ids: &[Uuid]) -> Result> { + debug!(dashboard_ids = ?ids, "Fetching dashboard files"); + let mut conn = get_pg_pool() + .get() + .await + .map_err(|e| { + error!(error = %e, "Failed to get database connection"); + anyhow::anyhow!("Failed to get database connection: {}", e) + })?; + + let files = match dashboard_files::table + .filter(dashboard_files::id.eq_any(ids)) + .filter(dashboard_files::deleted_at.is_null()) + .load::(&mut conn) + .await + { + Ok(files) => { + debug!(count = files.len(), "Successfully loaded dashboard files from database"); + files + } + Err(e) => { + error!(error = %e, "Failed to load dashboard files from database"); + return Err(anyhow::anyhow!( + "Error loading dashboard files from database: {}", + e + )); + } + }; + + let mut results = Vec::new(); + for file in files { + match serde_json::from_value(file.content.clone()) { + Ok(dashboard_yml) => { + debug!(dashboard_id = %file.id, "Successfully parsed dashboard YAML"); + results.push((dashboard_yml, file.id, file.updated_at.to_string())); + } + Err(e) => { + warn!( + error = %e, + dashboard_id = %file.id, + "Failed to parse dashboard YAML" + ); + } + } + } + + info!( + requested_count = ids.len(), + found_count = results.len(), + "Completed dashboard files retrieval" + ); + Ok(results) +} + +async fn get_metric_files(ids: &[Uuid]) -> Result> { + debug!(metric_ids = ?ids, "Fetching metric files"); + let mut conn = get_pg_pool() + .get() + .await + .map_err(|e| { + error!(error = %e, "Failed to get database connection"); + anyhow::anyhow!("Failed to get database connection: {}", e) + })?; + + let files = match metric_files::table + .filter(metric_files::id.eq_any(ids)) + .filter(metric_files::deleted_at.is_null()) + .load::(&mut conn) + .await + { + Ok(files) => { + debug!(count = files.len(), "Successfully loaded metric files from database"); + files + } + Err(e) => { + error!(error = %e, "Failed to load metric files from database"); + return Err(anyhow::anyhow!( + "Error loading metric files from database: {}", + e + )); + } + }; + + let mut results = Vec::new(); + for file in files { + match serde_json::from_value(file.content.clone()) { + Ok(metric_yml) => { + debug!(metric_id = %file.id, "Successfully parsed metric YAML"); + results.push((metric_yml, file.id, file.updated_at.to_string())); + } + Err(e) => { + warn!( + error = %e, + metric_id = %file.id, + "Failed to parse metric YAML" + ); + } + } + } + + info!( + requested_count = ids.len(), + found_count = results.len(), + "Completed metric files retrieval" + ); + Ok(results) +} + +fn build_status_message( + results: &[FileEnum], + missing_files: &[String], + error_messages: &[String], +) -> String { + let mut parts = Vec::new(); + + // Add success message if any files were found + if !results.is_empty() { + parts.push(format!("Successfully opened {} files", results.len())); + } + + // Add missing files information + if !missing_files.is_empty() { + parts.push(format!( + "Could not find the following files: {}", + missing_files.join("; ") + )); + } + + // Add any error messages + if !error_messages.is_empty() { + parts.push(format!( + "Encountered the following issues: {}", + error_messages.join("; ") + )); + } + + // If everything is empty, provide a clear message + if parts.is_empty() { + "No files were processed due to invalid input".to_string() + } else { + parts.join(". ") + } +} + +#[cfg(test)] +mod tests { + use crate::utils::tools::file_tools::file_types::metric_yml::{BarLineChartConfig, BaseChartConfig, BarAndLineAxis, ChartConfig, DataMetadata}; + + use super::*; + use chrono::Utc; + use serde_json::json; + + fn create_test_dashboard() -> DashboardYml { + DashboardYml { + id: Some(Uuid::new_v4()), + updated_at: Some(Utc::now()), + name: Some("Test Dashboard".to_string()), + rows: vec![], + } + } + + fn create_test_metric() -> MetricYml { + MetricYml { + id: Some(Uuid::new_v4()), + updated_at: Some(Utc::now()), + title: "Test Metric".to_string(), + description: Some("Test Description".to_string()), + sql: "SELECT * FROM test_table".to_string(), + chart_config: ChartConfig::Bar(BarLineChartConfig { + base: BaseChartConfig { + column_label_formats: HashMap::new(), + column_settings: None, + colors: None, + show_legend: None, + grid_lines: None, + show_legend_headline: None, + goal_lines: None, + trendlines: None, + disable_tooltip: None, + y_axis_config: None, + x_axis_config: None, + category_axis_style_config: None, + y2_axis_config: None, + }, + bar_and_line_axis: BarAndLineAxis { + x: vec![], + y: vec![], + category: vec![], + tooltip: None, + }, + bar_layout: None, + bar_sort_by: None, + bar_group_type: None, + bar_show_total_at_top: None, + line_group_type: None, + }), + data_metadata: vec![ + DataMetadata { + name: "id".to_string(), + data_type: "number".to_string(), + }, + DataMetadata { + name: "value".to_string(), + data_type: "string".to_string(), + } + ], + } + } + + #[test] + fn test_build_status_message_all_success() { + let results = vec![ + FileEnum::Dashboard(create_test_dashboard()), + FileEnum::Metric(create_test_metric()), + ]; + let missing_files = vec![]; + let error_messages = vec![]; + + let message = build_status_message(&results, &missing_files, &error_messages); + assert_eq!(message, "Successfully opened 2 files"); + } + + #[test] + fn test_build_status_message_with_missing() { + let results = vec![ + FileEnum::Dashboard(create_test_dashboard()), + FileEnum::Metric(create_test_metric()), + ]; + let missing_files = vec![ + "1 dashboard: abc-123".to_string(), + "2 metrics: def-456, ghi-789".to_string(), + ]; + let error_messages = vec![]; + + let message = build_status_message(&results, &missing_files, &error_messages); + assert_eq!( + message, + "Successfully opened 2 files. Could not find the following files: 1 dashboard: abc-123; 2 metrics: def-456, ghi-789" + ); + } + + #[test] + fn test_build_status_message_with_errors() { + let results = vec![]; + let missing_files = vec![]; + let error_messages = vec![ + "Invalid UUID format for id: xyz".to_string(), + "Error processing metric files: connection failed".to_string(), + ]; + + let message = build_status_message(&results, &missing_files, &error_messages); + assert_eq!( + message, + "Encountered the following issues: Invalid UUID format for id: xyz; Error processing metric files: connection failed" + ); + } + + #[test] + fn test_build_status_message_mixed_results() { + let results = vec![FileEnum::Metric(create_test_metric())]; + let missing_files = vec!["1 dashboard: abc-123".to_string()]; + let error_messages = vec!["Invalid UUID format for id: xyz".to_string()]; + + let message = build_status_message(&results, &missing_files, &error_messages); + assert_eq!( + message, + "Successfully opened 1 files. Could not find the following files: 1 dashboard: abc-123. Encountered the following issues: Invalid UUID format for id: xyz" + ); + } + + #[test] + fn test_parse_valid_params() { + let params_json = json!({ + "files": [ + {"id": "550e8400-e29b-41d4-a716-446655440000", "file_type": "dashboard"}, + {"id": "550e8400-e29b-41d4-a716-446655440001", "file_type": "metric"} + ] + }); + + let params: OpenFilesParams = serde_json::from_value(params_json).unwrap(); + assert_eq!(params.files.len(), 2); + assert_eq!(params.files[0].file_type, "dashboard"); + assert_eq!(params.files[1].file_type, "metric"); + } + + #[test] + fn test_parse_invalid_uuid() { + let params_json = json!({ + "files": [ + {"id": "not-a-uuid", "file_type": "dashboard"}, + {"id": "also-not-a-uuid", "file_type": "metric"} + ] + }); + + let params: OpenFilesParams = serde_json::from_value(params_json).unwrap(); + for file in ¶ms.files { + let uuid_result = Uuid::parse_str(&file.id); + assert!(uuid_result.is_err()); + } + } + + #[test] + fn test_parse_invalid_file_type() { + let params_json = json!({ + "files": [ + {"id": "550e8400-e29b-41d4-a716-446655440000", "file_type": "invalid"}, + {"id": "550e8400-e29b-41d4-a716-446655440001", "file_type": "unknown"} + ] + }); + + let params: OpenFilesParams = serde_json::from_value(params_json).unwrap(); + for file in ¶ms.files { + assert!(file.file_type != "dashboard" && file.file_type != "metric"); + } + } + + // Mock tests for file retrieval + #[tokio::test] + async fn test_get_dashboard_files() { + let test_id = Uuid::new_v4(); + let dashboard = create_test_dashboard(); + let test_files = vec![DashboardFile { + id: test_id, + name: dashboard.name.clone().unwrap_or_default(), + file_name: "test.yml".to_string(), + content: serde_json::to_value(&dashboard).unwrap(), + filter: None, + organization_id: Uuid::new_v4(), + created_by: Uuid::new_v4(), + created_at: Utc::now(), + updated_at: Utc::now(), + deleted_at: None, + }]; + + // TODO: Mock database connection and return test_files + } + + #[tokio::test] + async fn test_get_metric_files() { + let test_id = Uuid::new_v4(); + let metric = create_test_metric(); + let test_files = vec![MetricFile { + id: test_id, + name: metric.title.clone(), + file_name: "test.yml".to_string(), + content: serde_json::to_value(&metric).unwrap(), + verification: crate::database::enums::Verification::NotRequested, + evaluation_obj: None, + evaluation_summary: None, + evaluation_score: None, + organization_id: Uuid::new_v4(), + created_by: Uuid::new_v4(), + created_at: Utc::now(), + updated_at: Utc::now(), + deleted_at: None, + }]; + + // TODO: Mock database connection and return test_files + } +} diff --git a/api/src/utils/tools/file_tools/search_data_catalog.rs b/api/src/utils/tools/file_tools/search_data_catalog.rs index 54d5a1991..668e7ef35 100644 --- a/api/src/utils/tools/file_tools/search_data_catalog.rs +++ b/api/src/utils/tools/file_tools/search_data_catalog.rs @@ -1,7 +1,7 @@ use anyhow::Result; use async_trait::async_trait; -use serde_json::Value; use serde::{Deserialize, Serialize}; +use serde_json::Value; use crate::utils::{clients::ai::litellm::ToolCall, tools::ToolExecutor}; @@ -22,18 +22,30 @@ struct CatalogSearchResult { metadata: Value, } +#[derive(Debug, Serialize)] +pub struct SearchDataCatalogOutput { + success: bool, + results: Vec, +} + pub struct SearchDataCatalogTool; #[async_trait] impl ToolExecutor for SearchDataCatalogTool { + type Output = SearchDataCatalogOutput; + fn get_name(&self) -> String { "search_data_catalog".to_string() } - async fn execute(&self, tool_call: &ToolCall) -> Result { - let params: SearchDataCatalogParams = serde_json::from_str(&tool_call.function.arguments.clone())?; + async fn execute(&self, tool_call: &ToolCall) -> Result { + let params: SearchDataCatalogParams = + serde_json::from_str(&tool_call.function.arguments.clone())?; // TODO: Implement actual data catalog search logic - Ok(Value::Array(vec![])) + Ok(SearchDataCatalogOutput { + success: true, + results: vec![], + }) } fn get_schema(&self) -> Value { @@ -67,4 +79,4 @@ impl ToolExecutor for SearchDataCatalogTool { "description": "Searches the data catalog for relevant items including datasets, metrics, business terms, and logic definitions. Returns structured results with relevance scores. Use this to find data assets and their documentation." }) } -} \ No newline at end of file +} diff --git a/api/src/utils/tools/file_tools/search_files.rs b/api/src/utils/tools/file_tools/search_files.rs index 6a8852bd0..f9e5eec87 100644 --- a/api/src/utils/tools/file_tools/search_files.rs +++ b/api/src/utils/tools/file_tools/search_files.rs @@ -1,7 +1,7 @@ use anyhow::Result; use async_trait::async_trait; -use serde_json::Value; use serde::{Deserialize, Serialize}; +use serde_json::Value; use crate::utils::{clients::ai::litellm::ToolCall, tools::ToolExecutor}; @@ -10,18 +10,28 @@ struct SearchFilesParams { query_params: Vec, } +#[derive(Debug, Serialize)] +pub struct SearchFilesOutput { + success: bool, +} + pub struct SearchFilesTool; #[async_trait] impl ToolExecutor for SearchFilesTool { + type Output = SearchFilesOutput; + fn get_name(&self) -> String { "search_files".to_string() } - async fn execute(&self, tool_call: &ToolCall) -> Result { - let params: SearchFilesParams = serde_json::from_str(&tool_call.function.arguments.clone())?; + async fn execute(&self, tool_call: &ToolCall) -> Result { + let params: SearchFilesParams = + serde_json::from_str(&tool_call.function.arguments.clone())?; // TODO: Implement actual file search logic - Ok(Value::Array(vec![])) + Ok(SearchFilesOutput { + success: true, + }) } fn get_schema(&self) -> Value { @@ -46,4 +56,4 @@ impl ToolExecutor for SearchFilesTool { "description": "Searches for metric and dashboard files using natural-language queries. Typically used if you suspect there might already be a relevant metric or dashboard in the repository. If results are found, you can then decide whether to open them with `open_files`." }) } -} \ No newline at end of file +} diff --git a/api/src/utils/tools/file_tools/send_to_user.rs b/api/src/utils/tools/file_tools/send_to_user.rs index 34979a56e..7eb556a6e 100644 --- a/api/src/utils/tools/file_tools/send_to_user.rs +++ b/api/src/utils/tools/file_tools/send_to_user.rs @@ -10,18 +10,25 @@ struct SendToUserParams { metric_id: String, } +#[derive(Debug, Serialize)] +pub struct SendToUserOutput { + success: bool, +} + pub struct SendToUserTool; #[async_trait] impl ToolExecutor for SendToUserTool { + type Output = SendToUserOutput; + fn get_name(&self) -> String { "send_to_user".to_string() } - async fn execute(&self, tool_call: &ToolCall) -> Result { + async fn execute(&self, tool_call: &ToolCall) -> Result { let params: SendToUserParams = serde_json::from_str(&tool_call.function.arguments.clone())?; // TODO: Implement actual send to user logic - Ok(Value::Array(vec![])) + Ok(SendToUserOutput { success: true }) } fn get_schema(&self) -> Value { diff --git a/api/src/utils/tools/mod.rs b/api/src/utils/tools/mod.rs index 6294617af..5ebdadd1e 100644 --- a/api/src/utils/tools/mod.rs +++ b/api/src/utils/tools/mod.rs @@ -1,5 +1,6 @@ use anyhow::Result; use async_trait::async_trait; +use serde::Serialize; use serde_json::Value; use crate::utils::clients::ai::litellm::ToolCall; @@ -10,8 +11,11 @@ pub mod file_tools; /// Any struct that wants to be used as a tool must implement this trait. #[async_trait] pub trait ToolExecutor: Send + Sync { + /// The type of the output of the tool + type Output: Serialize + Send; + /// Execute the tool with given arguments and return a result - async fn execute(&self, tool_call: &ToolCall) -> Result; + async fn execute(&self, tool_call: &ToolCall) -> Result; /// Return the JSON schema that describes this tool's interface fn get_schema(&self) -> serde_json::Value; @@ -21,11 +25,11 @@ pub trait ToolExecutor: Send + Sync { } trait IntoBoxedTool { - fn boxed(self) -> Box; + fn boxed(self) -> Box>; } -impl IntoBoxedTool for T { - fn boxed(self) -> Box { +impl + 'static> IntoBoxedTool for T { + fn boxed(self) -> Box> { Box::new(self) } }