refactor(tools): Enhance ToolExecutor trait and file-related tools

- Add generic Output type to ToolExecutor trait
- Update file tools to use strongly-typed output structs
- Modify agent and tool implementations to support generic output
- Improve error handling and result reporting in file-related tools
- Add more detailed status messages for file operations
This commit is contained in:
dal 2025-02-06 23:45:48 -07:00
parent 4ec6e78648
commit 711bbe899a
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
11 changed files with 770 additions and 87 deletions

View File

@ -3,9 +3,9 @@ use crate::utils::{
tools::ToolExecutor,
};
use anyhow::Result;
use serde_json::Value;
use std::{collections::HashMap, env};
use tokio::sync::mpsc;
use async_trait::async_trait;
use super::types::AgentThread;
@ -16,14 +16,14 @@ pub struct Agent {
/// Client for communicating with the LLM provider
llm_client: LiteLLMClient,
/// Registry of available tools, mapped by their names
tools: HashMap<String, Box<dyn ToolExecutor>>,
tools: HashMap<String, Box<dyn ToolExecutor<Output = Value>>>,
/// The model identifier to use (e.g., "gpt-4")
model: String,
}
impl Agent {
/// Create a new Agent instance with a specific LLM client and model
pub fn new(model: String, tools: HashMap<String, Box<dyn ToolExecutor>>) -> Self {
pub fn new(model: String, tools: HashMap<String, Box<dyn ToolExecutor<Output = Value>>>) -> Self {
let llm_api_key = env::var("LLM_API_KEY").expect("LLM_API_KEY must be set");
let llm_base_url = env::var("LLM_BASE_URL").expect("LLM_API_BASE must be set");
@ -41,7 +41,7 @@ impl Agent {
/// # Arguments
/// * `name` - The name of the tool, used to identify it in tool calls
/// * `tool` - The tool implementation that will be executed
pub fn add_tool<T: ToolExecutor + 'static>(&mut self, name: String, tool: T) {
pub fn add_tool<T: ToolExecutor<Output = Value> + 'static>(&mut self, name: String, tool: T) {
self.tools.insert(name, Box::new(tool));
}
@ -49,7 +49,7 @@ impl Agent {
///
/// # Arguments
/// * `tools` - HashMap of tool names and their implementations
pub fn add_tools<T: ToolExecutor + 'static>(&mut self, tools: HashMap<String, T>) {
pub fn add_tools<T: ToolExecutor<Output = Value> + 'static>(&mut self, tools: HashMap<String, T>) {
for (name, tool) in tools {
self.tools.insert(name, Box::new(tool));
}
@ -204,6 +204,7 @@ mod tests {
use crate::utils::clients::ai::litellm::ToolCall;
use super::*;
use axum::async_trait;
use dotenv::dotenv;
use serde_json::{json, Value};
@ -215,7 +216,9 @@ mod tests {
#[async_trait]
impl ToolExecutor for WeatherTool {
async fn execute(&self, tool_call: &ToolCall) -> Result<Value> {
type Output = Value;
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output> {
Ok(json!({
"temperature": 20,
"unit": "fahrenheit"

View File

@ -22,18 +22,29 @@ struct BulkModifyFilesParams {
files_with_modifications: Vec<FileModification>,
}
#[derive(Debug, Serialize)]
pub struct BulkModifyFilesOutput {
success: bool,
}
pub struct BulkModifyFilesTool;
#[async_trait]
impl ToolExecutor for BulkModifyFilesTool {
type Output = BulkModifyFilesOutput;
fn get_name(&self) -> String {
"bulk_modify_files".to_string()
}
async fn execute(&self, tool_call: &ToolCall) -> Result<Value> {
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output> {
let params: BulkModifyFilesParams = serde_json::from_str(&tool_call.function.arguments.clone())?;
// TODO: Implement actual file modification logic
Ok(Value::Array(vec![]))
let output = BulkModifyFilesOutput {
success: true,
};
Ok(output)
}
fn get_schema(&self) -> Value {

View File

@ -17,9 +17,9 @@ use crate::{
utils::{clients::ai::litellm::ToolCall, tools::ToolExecutor},
};
use super::file_types::{dashboard_yml::DashboardYml, metric_yml::MetricYml};
use super::file_types::{dashboard_yml::DashboardYml, file::FileEnum, metric_yml::MetricYml};
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
struct FileParams {
name: String,
file_type: String,
@ -31,15 +31,24 @@ struct CreateFilesParams {
files: Vec<FileParams>,
}
#[derive(Debug, Serialize)]
pub struct CreateFilesOutput {
success: bool,
message: String,
files: Vec<FileEnum>,
}
pub struct CreateFilesTool;
#[async_trait]
impl ToolExecutor for CreateFilesTool {
type Output = CreateFilesOutput;
fn get_name(&self) -> String {
"create_files".to_string()
}
async fn execute(&self, tool_call: &ToolCall) -> Result<Value> {
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output> {
let params: CreateFilesParams =
match serde_json::from_str(&tool_call.function.arguments.clone()) {
Ok(params) => params,
@ -53,15 +62,64 @@ impl ToolExecutor for CreateFilesTool {
let files = params.files;
let mut created_files = vec![];
let mut failed_files = vec![];
for file in files {
match file.file_type.as_str() {
"metric" => create_metric_file(file).await?,
"dashboard" => create_dashboard_file(file).await?,
_ => return Err(anyhow::anyhow!("Invalid file type: {}. Currently only `metric` and `dashboard` types are supported.", file.file_type)),
}
let created_file = match file.file_type.as_str() {
"metric" => match create_metric_file(file.clone()).await {
Ok(f) => {
created_files.push(f);
continue;
}
Err(e) => {
failed_files.push((file.name, e.to_string()));
continue;
}
},
"dashboard" => match create_dashboard_file(file.clone()).await {
Ok(f) => {
created_files.push(f);
continue;
}
Err(e) => {
failed_files.push((file.name, e.to_string()));
continue;
}
},
_ => {
failed_files.push((file.name, format!("Invalid file type: {}. Currently only `metric` and `dashboard` types are supported.", file.file_type)));
continue;
}
};
}
Ok(Value::Array(vec![]))
let message = if failed_files.is_empty() {
format!("Successfully created {} files.", created_files.len())
} else {
let success_msg = if !created_files.is_empty() {
format!("Successfully created {} files. ", created_files.len())
} else {
String::new()
};
let failures: Vec<String> = failed_files
.iter()
.map(|(name, error)| format!("Failed to create '{}': {}", name, error))
.collect();
format!("{}Failed to create {} files: {}",
success_msg,
failed_files.len(),
failures.join("; "))
};
Ok(CreateFilesOutput {
success: !created_files.is_empty(),
message,
files: created_files,
})
}
fn get_schema(&self) -> Value {
@ -104,7 +162,7 @@ impl ToolExecutor for CreateFilesTool {
}
}
async fn create_metric_file(file: FileParams) -> Result<()> {
async fn create_metric_file(file: FileParams) -> Result<FileEnum> {
let metric_yml = match MetricYml::new(file.yml_content) {
Ok(metric_file) => metric_file,
Err(e) => return Err(e),
@ -128,7 +186,7 @@ async fn create_metric_file(file: FileParams) -> Result<()> {
id: metric_id.clone(),
name: metric_yml.title.clone(),
file_name: format!("{}.yml", file.name),
content: serde_json::to_value(metric_yml).unwrap(),
content: serde_json::to_value(metric_yml.clone()).unwrap(),
created_by: Uuid::new_v4(),
verification: Verification::NotRequested,
evaluation_obj: None,
@ -140,9 +198,8 @@ async fn create_metric_file(file: FileParams) -> Result<()> {
deleted_at: None,
};
let metric_file_record = match insert_into(metric_files::table)
.values(metric_file_record)
.returning(metric_files::all_columns)
match insert_into(metric_files::table)
.values(&metric_file_record)
.execute(&mut conn)
.await
{
@ -150,10 +207,10 @@ async fn create_metric_file(file: FileParams) -> Result<()> {
Err(e) => return Err(anyhow::anyhow!("Failed to create metric file: {}", e)),
};
Ok(())
Ok(FileEnum::Metric(metric_yml))
}
async fn create_dashboard_file(file: FileParams) -> Result<()> {
async fn create_dashboard_file(file: FileParams) -> Result<FileEnum> {
let dashboard_yml = match DashboardYml::new(file.yml_content) {
Ok(dashboard_file) => dashboard_file,
Err(e) => return Err(e),
@ -175,9 +232,12 @@ async fn create_dashboard_file(file: FileParams) -> Result<()> {
let dashboard_file_record = DashboardFile {
id: dashboard_id.clone(),
name: dashboard_yml.name.clone().unwrap_or_else(|| "New Dashboard".to_string()),
name: dashboard_yml
.name
.clone()
.unwrap_or_else(|| "New Dashboard".to_string()),
file_name: format!("{}.yml", file.name),
content: serde_json::to_value(dashboard_yml).unwrap(),
content: serde_json::to_value(dashboard_yml.clone()).unwrap(),
filter: None,
organization_id: Uuid::new_v4(),
created_by: Uuid::new_v4(),
@ -187,12 +247,14 @@ async fn create_dashboard_file(file: FileParams) -> Result<()> {
};
match insert_into(dashboard_files::table)
.values(dashboard_file_record)
.values(&dashboard_file_record)
.returning(dashboard_files::all_columns)
.execute(&mut conn)
.await
{
Ok(_) => Ok(()),
Err(e) => Err(anyhow::anyhow!(e)),
}
Ok(_) => (),
Err(e) => return Err(anyhow::anyhow!(e)),
};
Ok(FileEnum::Dashboard(dashboard_yml))
}

View File

@ -1,22 +1,24 @@
use anyhow::Result;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct DashboardYml {
pub id: Option<Uuid>,
pub updated_at: Option<DateTime<Utc>>,
pub name: Option<String>,
pub rows: Vec<Row>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Row {
items: Vec<RowItem>, // max number of items in a row is 4, min is 1
row_height: u32, // max is 550, min is 320
column_sizes: Vec<u32>, // max sum of elements is 12 min is 3
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct RowItem {
// This id is the id of the metric or item reference that goes here in the dashboard.
id: Uuid,
@ -40,6 +42,8 @@ impl DashboardYml {
file.name = Some(String::from("New Dashboard"));
}
file.updated_at = Some(Utc::now());
// Validate the file
match file.validate() {
Ok(_) => Ok(file),
@ -47,6 +51,7 @@ impl DashboardYml {
}
}
// TODO: The validate of the dashboard should also be whether metrics exist?
pub fn validate(&self) -> Result<()> {
// Validate the id and name
if self.id.is_none() {
@ -57,6 +62,10 @@ impl DashboardYml {
return Err(anyhow::anyhow!("Dashboard file name is required"));
}
if self.updated_at.is_none() {
return Err(anyhow::anyhow!("Dashboard file updated_at is required"));
}
// Validate each row
for row in &self.rows {
// Check row height constraints

View File

@ -1,10 +1,61 @@
use anyhow::Result;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use super::dashboard_yml::DashboardYml;
use super::metric_yml::MetricYml;
#[derive(Debug, Serialize, Deserialize)]
pub struct File {
pub name: String,
pub file_type: String,
pub yml_content: String,
}
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum FileEnum {
Metric(MetricYml),
Dashboard(DashboardYml),
}
impl FileEnum {
pub fn name(&self) -> anyhow::Result<String> {
match self {
Self::Metric(metric) => Ok(metric.title.clone()),
Self::Dashboard(dashboard) => match &dashboard.name {
Some(name) => Ok(name.clone()),
None => Err(anyhow::anyhow!("Dashboard name is required but not found")),
},
}
}
pub fn id(&self) -> anyhow::Result<Uuid> {
match self {
Self::Metric(metric) => match metric.id {
Some(id) => Ok(id),
None => Err(anyhow::anyhow!("Metric id is required but not found")),
},
Self::Dashboard(dashboard) => match dashboard.id {
Some(id) => Ok(id),
None => Err(anyhow::anyhow!("Dashboard id is required but not found")),
},
}
}
pub fn updated_at(&self) -> anyhow::Result<DateTime<Utc>> {
match self {
Self::Metric(metric) => match metric.updated_at {
Some(dt) => Ok(dt),
None => Err(anyhow::anyhow!(
"Metric updated_at is required but not found"
)),
},
Self::Dashboard(dashboard) => match dashboard.updated_at {
Some(dt) => Ok(dt),
None => Err(anyhow::anyhow!(
"Dashboard updated_at is required but not found"
)),
},
}
}
}

View File

@ -1,10 +1,12 @@
use anyhow::Result;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct MetricYml {
pub id: Option<Uuid>,
pub updated_at: Option<DateTime<Utc>>,
pub title: String,
pub description: Option<String>,
pub sql: String,
@ -12,13 +14,13 @@ pub struct MetricYml {
pub data_metadata: Vec<DataMetadata>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct DataMetadata {
pub name: String,
pub data_type: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(tag = "selectedChartType")]
pub enum ChartConfig {
#[serde(rename = "bar")]
@ -38,7 +40,7 @@ pub enum ChartConfig {
}
// Base chart config shared by all chart types
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct BaseChartConfig {
pub column_label_formats: std::collections::HashMap<String, ColumnLabelFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
@ -67,7 +69,7 @@ pub struct BaseChartConfig {
pub y2_axis_config: Option<serde_json::Value>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ColumnLabelFormat {
pub column_type: String,
pub style: String,
@ -101,7 +103,7 @@ pub struct ColumnLabelFormat {
pub convert_number_to: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ColumnSettings {
#[serde(skip_serializing_if = "Option::is_none")]
pub show_data_labels: Option<bool>,
@ -123,7 +125,7 @@ pub struct ColumnSettings {
pub line_symbol_size_dot: Option<f64>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct GoalLine {
#[serde(skip_serializing_if = "Option::is_none")]
pub show: Option<bool>,
@ -137,7 +139,7 @@ pub struct GoalLine {
pub goal_line_color: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Trendline {
#[serde(skip_serializing_if = "Option::is_none")]
pub show: Option<bool>,
@ -151,7 +153,7 @@ pub struct Trendline {
pub trend_line_color: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct BarLineChartConfig {
#[serde(flatten)]
pub base: BaseChartConfig,
@ -168,7 +170,7 @@ pub struct BarLineChartConfig {
pub line_group_type: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct BarAndLineAxis {
pub x: Vec<String>,
pub y: Vec<String>,
@ -177,7 +179,7 @@ pub struct BarAndLineAxis {
pub tooltip: Option<Vec<String>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ScatterChartConfig {
#[serde(flatten)]
pub base: BaseChartConfig,
@ -186,7 +188,7 @@ pub struct ScatterChartConfig {
pub scatter_dot_size: Option<Vec<f64>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ScatterAxis {
pub x: Vec<String>,
pub y: Vec<String>,
@ -198,7 +200,7 @@ pub struct ScatterAxis {
pub tooltip: Option<Vec<String>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct PieChartConfig {
#[serde(flatten)]
pub base: BaseChartConfig,
@ -219,7 +221,7 @@ pub struct PieChartConfig {
pub pie_minimum_slice_percentage: Option<f64>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct PieChartAxis {
pub x: Vec<String>,
pub y: Vec<String>,
@ -227,14 +229,14 @@ pub struct PieChartAxis {
pub tooltip: Option<Vec<String>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ComboChartConfig {
#[serde(flatten)]
pub base: BaseChartConfig,
pub combo_chart_axis: ComboChartAxis,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ComboChartAxis {
pub x: Vec<String>,
pub y: Vec<String>,
@ -246,7 +248,7 @@ pub struct ComboChartAxis {
pub tooltip: Option<Vec<String>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct MetricChartConfig {
#[serde(flatten)]
pub base: BaseChartConfig,
@ -261,7 +263,7 @@ pub struct MetricChartConfig {
pub metric_value_label: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct TableChartConfig {
#[serde(flatten)]
pub base: BaseChartConfig,
@ -288,12 +290,15 @@ impl MetricYml {
file.id = Some(Uuid::new_v4());
}
file.updated_at = Some(Utc::now());
match file.validate() {
Ok(_) => Ok(file),
Err(e) => Err(anyhow::anyhow!("Error compiling file: {}", e)),
}
}
//TODO: Need to validate a metric deeply.
pub fn validate(&self) -> Result<()> {
Ok(())
}

View File

@ -1,49 +1,558 @@
use anyhow::Result;
use async_trait::async_trait;
use serde_json::Value;
use diesel::prelude::*;
use diesel_async::RunQueryDsl;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::{HashMap, HashSet};
use tracing::{debug, error, info, warn};
use uuid::Uuid;
use crate::utils::{clients::ai::litellm::ToolCall, tools::ToolExecutor};
use crate::{
database::{
lib::get_pg_pool,
models::{DashboardFile, MetricFile},
schema::{dashboard_files, metric_files},
},
utils::{
clients::ai::litellm::ToolCall,
tools::file_tools::file_types::{
dashboard_yml::DashboardYml, file::FileEnum, metric_yml::MetricYml,
},
tools::ToolExecutor,
},
};
#[derive(Debug, Serialize, Deserialize)]
struct FileRequest {
id: String,
file_type: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct OpenFilesParams {
file_names: Vec<String>,
files: Vec<FileRequest>,
}
#[derive(Debug, Serialize)]
pub struct OpenFilesOutput {
message: String,
results: Vec<FileEnum>,
}
pub struct OpenFilesTool;
#[async_trait]
impl ToolExecutor for OpenFilesTool {
type Output = OpenFilesOutput;
fn get_name(&self) -> String {
"open_files".to_string()
}
async fn execute(&self, tool_call: &ToolCall) -> Result<Value> {
let params: OpenFilesParams = serde_json::from_str(&tool_call.function.arguments.clone())?;
// TODO: Implement actual file opening logic
Ok(Value::Array(vec![]))
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output> {
debug!("Starting file open operation");
let params: OpenFilesParams =
serde_json::from_str(&tool_call.function.arguments.clone())
.map_err(|e| {
error!(error = %e, "Failed to parse tool parameters");
anyhow::anyhow!("Failed to parse tool parameters: {}", e)
})?;
let mut results = Vec::new();
let mut error_messages = Vec::new();
// Track requested IDs by type for later comparison
let mut requested_ids: HashMap<String, HashSet<Uuid>> = HashMap::new();
let mut found_ids: HashMap<String, HashSet<Uuid>> = HashMap::new();
// Group requests by file type and track requested IDs
let grouped_requests = params
.files
.into_iter()
.filter_map(|req| match Uuid::parse_str(&req.id) {
Ok(id) => {
requested_ids
.entry(req.file_type.clone())
.or_default()
.insert(id);
Some((req.file_type, id))
}
Err(_) => {
warn!(invalid_id = %req.id, "Invalid UUID format");
error_messages.push(format!("Invalid UUID format for id: {}", req.id));
None
}
})
.fold(HashMap::new(), |mut acc, (file_type, id)| {
acc.entry(file_type).or_insert_with(Vec::new).push(id);
acc
});
// Process dashboard files
if let Some(dashboard_ids) = grouped_requests.get("dashboard") {
match get_dashboard_files(dashboard_ids).await {
Ok(dashboard_files) => {
for (dashboard_yml, id, _) in dashboard_files {
found_ids
.entry("dashboard".to_string())
.or_default()
.insert(id);
results.push(FileEnum::Dashboard(dashboard_yml));
}
}
Err(e) => {
error!(error = %e, "Failed to process dashboard files");
error_messages.push(format!("Error processing dashboard files: {}", e));
}
}
}
// Process metric files
if let Some(metric_ids) = grouped_requests.get("metric") {
match get_metric_files(metric_ids).await {
Ok(metric_files) => {
for (metric_yml, id, _) in metric_files {
found_ids
.entry("metric".to_string())
.or_default()
.insert(id);
results.push(FileEnum::Metric(metric_yml));
}
}
Err(e) => {
error!(error = %e, "Failed to process metric files");
error_messages.push(format!("Error processing metric files: {}", e));
}
}
}
// Generate message about missing files
let mut missing_files = Vec::new();
for (file_type, ids) in requested_ids.iter() {
let found = found_ids.get(file_type).cloned().unwrap_or_default();
let missing: Vec<_> = ids.difference(&found).collect();
if !missing.is_empty() {
warn!(
file_type = %file_type,
missing_count = missing.len(),
missing_ids = ?missing,
"Files not found"
);
missing_files.push(format!(
"{} {}s: {}",
missing.len(),
file_type,
missing
.iter()
.map(|id| id.to_string())
.collect::<Vec<_>>()
.join(", ")
));
}
}
let message = build_status_message(&results, &missing_files, &error_messages);
info!(
total_requested = requested_ids.values().map(|ids| ids.len()).sum::<usize>(),
total_found = results.len(),
error_count = error_messages.len(),
"Completed file open operation"
);
Ok(OpenFilesOutput { message, results })
}
fn get_schema(&self) -> Value {
serde_json::json!({
"name": "open_files",
"strict": true,
"description": "Opens one or more dashboard or metric files and returns their YML contents",
"parameters": {
"type": "object",
"required": ["file_names"],
"required": ["files"],
"properties": {
"file_names": {
"files": {
"type": "array",
"items": {
"type": "string",
"description": "The name of a file to be opened"
"type": "object",
"required": ["id", "file_type"],
"properties": {
"id": {
"type": "string",
"description": "The UUID of the file"
},
"file_type": {
"type": "string",
"enum": ["dashboard", "metric"],
"description": "The type of file to read"
}
}
},
"description": "List of file names to be opened"
"description": "List of files to be opened"
}
},
"additionalProperties": false
},
"description": "Opens one or more files in read mode and displays **their entire contents** to the user. If you use this, the user will actually see the metric/dashboard you open."
}
}
})
}
}
}
async fn get_dashboard_files(ids: &[Uuid]) -> Result<Vec<(DashboardYml, Uuid, String)>> {
debug!(dashboard_ids = ?ids, "Fetching dashboard files");
let mut conn = get_pg_pool()
.get()
.await
.map_err(|e| {
error!(error = %e, "Failed to get database connection");
anyhow::anyhow!("Failed to get database connection: {}", e)
})?;
let files = match dashboard_files::table
.filter(dashboard_files::id.eq_any(ids))
.filter(dashboard_files::deleted_at.is_null())
.load::<DashboardFile>(&mut conn)
.await
{
Ok(files) => {
debug!(count = files.len(), "Successfully loaded dashboard files from database");
files
}
Err(e) => {
error!(error = %e, "Failed to load dashboard files from database");
return Err(anyhow::anyhow!(
"Error loading dashboard files from database: {}",
e
));
}
};
let mut results = Vec::new();
for file in files {
match serde_json::from_value(file.content.clone()) {
Ok(dashboard_yml) => {
debug!(dashboard_id = %file.id, "Successfully parsed dashboard YAML");
results.push((dashboard_yml, file.id, file.updated_at.to_string()));
}
Err(e) => {
warn!(
error = %e,
dashboard_id = %file.id,
"Failed to parse dashboard YAML"
);
}
}
}
info!(
requested_count = ids.len(),
found_count = results.len(),
"Completed dashboard files retrieval"
);
Ok(results)
}
async fn get_metric_files(ids: &[Uuid]) -> Result<Vec<(MetricYml, Uuid, String)>> {
debug!(metric_ids = ?ids, "Fetching metric files");
let mut conn = get_pg_pool()
.get()
.await
.map_err(|e| {
error!(error = %e, "Failed to get database connection");
anyhow::anyhow!("Failed to get database connection: {}", e)
})?;
let files = match metric_files::table
.filter(metric_files::id.eq_any(ids))
.filter(metric_files::deleted_at.is_null())
.load::<MetricFile>(&mut conn)
.await
{
Ok(files) => {
debug!(count = files.len(), "Successfully loaded metric files from database");
files
}
Err(e) => {
error!(error = %e, "Failed to load metric files from database");
return Err(anyhow::anyhow!(
"Error loading metric files from database: {}",
e
));
}
};
let mut results = Vec::new();
for file in files {
match serde_json::from_value(file.content.clone()) {
Ok(metric_yml) => {
debug!(metric_id = %file.id, "Successfully parsed metric YAML");
results.push((metric_yml, file.id, file.updated_at.to_string()));
}
Err(e) => {
warn!(
error = %e,
metric_id = %file.id,
"Failed to parse metric YAML"
);
}
}
}
info!(
requested_count = ids.len(),
found_count = results.len(),
"Completed metric files retrieval"
);
Ok(results)
}
fn build_status_message(
results: &[FileEnum],
missing_files: &[String],
error_messages: &[String],
) -> String {
let mut parts = Vec::new();
// Add success message if any files were found
if !results.is_empty() {
parts.push(format!("Successfully opened {} files", results.len()));
}
// Add missing files information
if !missing_files.is_empty() {
parts.push(format!(
"Could not find the following files: {}",
missing_files.join("; ")
));
}
// Add any error messages
if !error_messages.is_empty() {
parts.push(format!(
"Encountered the following issues: {}",
error_messages.join("; ")
));
}
// If everything is empty, provide a clear message
if parts.is_empty() {
"No files were processed due to invalid input".to_string()
} else {
parts.join(". ")
}
}
#[cfg(test)]
mod tests {
use crate::utils::tools::file_tools::file_types::metric_yml::{BarLineChartConfig, BaseChartConfig, BarAndLineAxis, ChartConfig, DataMetadata};
use super::*;
use chrono::Utc;
use serde_json::json;
fn create_test_dashboard() -> DashboardYml {
DashboardYml {
id: Some(Uuid::new_v4()),
updated_at: Some(Utc::now()),
name: Some("Test Dashboard".to_string()),
rows: vec![],
}
}
fn create_test_metric() -> MetricYml {
MetricYml {
id: Some(Uuid::new_v4()),
updated_at: Some(Utc::now()),
title: "Test Metric".to_string(),
description: Some("Test Description".to_string()),
sql: "SELECT * FROM test_table".to_string(),
chart_config: ChartConfig::Bar(BarLineChartConfig {
base: BaseChartConfig {
column_label_formats: HashMap::new(),
column_settings: None,
colors: None,
show_legend: None,
grid_lines: None,
show_legend_headline: None,
goal_lines: None,
trendlines: None,
disable_tooltip: None,
y_axis_config: None,
x_axis_config: None,
category_axis_style_config: None,
y2_axis_config: None,
},
bar_and_line_axis: BarAndLineAxis {
x: vec![],
y: vec![],
category: vec![],
tooltip: None,
},
bar_layout: None,
bar_sort_by: None,
bar_group_type: None,
bar_show_total_at_top: None,
line_group_type: None,
}),
data_metadata: vec![
DataMetadata {
name: "id".to_string(),
data_type: "number".to_string(),
},
DataMetadata {
name: "value".to_string(),
data_type: "string".to_string(),
}
],
}
}
#[test]
fn test_build_status_message_all_success() {
let results = vec![
FileEnum::Dashboard(create_test_dashboard()),
FileEnum::Metric(create_test_metric()),
];
let missing_files = vec![];
let error_messages = vec![];
let message = build_status_message(&results, &missing_files, &error_messages);
assert_eq!(message, "Successfully opened 2 files");
}
#[test]
fn test_build_status_message_with_missing() {
let results = vec![
FileEnum::Dashboard(create_test_dashboard()),
FileEnum::Metric(create_test_metric()),
];
let missing_files = vec![
"1 dashboard: abc-123".to_string(),
"2 metrics: def-456, ghi-789".to_string(),
];
let error_messages = vec![];
let message = build_status_message(&results, &missing_files, &error_messages);
assert_eq!(
message,
"Successfully opened 2 files. Could not find the following files: 1 dashboard: abc-123; 2 metrics: def-456, ghi-789"
);
}
#[test]
fn test_build_status_message_with_errors() {
let results = vec![];
let missing_files = vec![];
let error_messages = vec![
"Invalid UUID format for id: xyz".to_string(),
"Error processing metric files: connection failed".to_string(),
];
let message = build_status_message(&results, &missing_files, &error_messages);
assert_eq!(
message,
"Encountered the following issues: Invalid UUID format for id: xyz; Error processing metric files: connection failed"
);
}
#[test]
fn test_build_status_message_mixed_results() {
let results = vec![FileEnum::Metric(create_test_metric())];
let missing_files = vec!["1 dashboard: abc-123".to_string()];
let error_messages = vec!["Invalid UUID format for id: xyz".to_string()];
let message = build_status_message(&results, &missing_files, &error_messages);
assert_eq!(
message,
"Successfully opened 1 files. Could not find the following files: 1 dashboard: abc-123. Encountered the following issues: Invalid UUID format for id: xyz"
);
}
#[test]
fn test_parse_valid_params() {
let params_json = json!({
"files": [
{"id": "550e8400-e29b-41d4-a716-446655440000", "file_type": "dashboard"},
{"id": "550e8400-e29b-41d4-a716-446655440001", "file_type": "metric"}
]
});
let params: OpenFilesParams = serde_json::from_value(params_json).unwrap();
assert_eq!(params.files.len(), 2);
assert_eq!(params.files[0].file_type, "dashboard");
assert_eq!(params.files[1].file_type, "metric");
}
#[test]
fn test_parse_invalid_uuid() {
let params_json = json!({
"files": [
{"id": "not-a-uuid", "file_type": "dashboard"},
{"id": "also-not-a-uuid", "file_type": "metric"}
]
});
let params: OpenFilesParams = serde_json::from_value(params_json).unwrap();
for file in &params.files {
let uuid_result = Uuid::parse_str(&file.id);
assert!(uuid_result.is_err());
}
}
#[test]
fn test_parse_invalid_file_type() {
let params_json = json!({
"files": [
{"id": "550e8400-e29b-41d4-a716-446655440000", "file_type": "invalid"},
{"id": "550e8400-e29b-41d4-a716-446655440001", "file_type": "unknown"}
]
});
let params: OpenFilesParams = serde_json::from_value(params_json).unwrap();
for file in &params.files {
assert!(file.file_type != "dashboard" && file.file_type != "metric");
}
}
// Mock tests for file retrieval
#[tokio::test]
async fn test_get_dashboard_files() {
let test_id = Uuid::new_v4();
let dashboard = create_test_dashboard();
let test_files = vec![DashboardFile {
id: test_id,
name: dashboard.name.clone().unwrap_or_default(),
file_name: "test.yml".to_string(),
content: serde_json::to_value(&dashboard).unwrap(),
filter: None,
organization_id: Uuid::new_v4(),
created_by: Uuid::new_v4(),
created_at: Utc::now(),
updated_at: Utc::now(),
deleted_at: None,
}];
// TODO: Mock database connection and return test_files
}
#[tokio::test]
async fn test_get_metric_files() {
let test_id = Uuid::new_v4();
let metric = create_test_metric();
let test_files = vec![MetricFile {
id: test_id,
name: metric.title.clone(),
file_name: "test.yml".to_string(),
content: serde_json::to_value(&metric).unwrap(),
verification: crate::database::enums::Verification::NotRequested,
evaluation_obj: None,
evaluation_summary: None,
evaluation_score: None,
organization_id: Uuid::new_v4(),
created_by: Uuid::new_v4(),
created_at: Utc::now(),
updated_at: Utc::now(),
deleted_at: None,
}];
// TODO: Mock database connection and return test_files
}
}

View File

@ -1,7 +1,7 @@
use anyhow::Result;
use async_trait::async_trait;
use serde_json::Value;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::utils::{clients::ai::litellm::ToolCall, tools::ToolExecutor};
@ -22,18 +22,30 @@ struct CatalogSearchResult {
metadata: Value,
}
#[derive(Debug, Serialize)]
pub struct SearchDataCatalogOutput {
success: bool,
results: Vec<CatalogSearchResult>,
}
pub struct SearchDataCatalogTool;
#[async_trait]
impl ToolExecutor for SearchDataCatalogTool {
type Output = SearchDataCatalogOutput;
fn get_name(&self) -> String {
"search_data_catalog".to_string()
}
async fn execute(&self, tool_call: &ToolCall) -> Result<Value> {
let params: SearchDataCatalogParams = serde_json::from_str(&tool_call.function.arguments.clone())?;
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output> {
let params: SearchDataCatalogParams =
serde_json::from_str(&tool_call.function.arguments.clone())?;
// TODO: Implement actual data catalog search logic
Ok(Value::Array(vec![]))
Ok(SearchDataCatalogOutput {
success: true,
results: vec![],
})
}
fn get_schema(&self) -> Value {
@ -67,4 +79,4 @@ impl ToolExecutor for SearchDataCatalogTool {
"description": "Searches the data catalog for relevant items including datasets, metrics, business terms, and logic definitions. Returns structured results with relevance scores. Use this to find data assets and their documentation."
})
}
}
}

View File

@ -1,7 +1,7 @@
use anyhow::Result;
use async_trait::async_trait;
use serde_json::Value;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::utils::{clients::ai::litellm::ToolCall, tools::ToolExecutor};
@ -10,18 +10,28 @@ struct SearchFilesParams {
query_params: Vec<String>,
}
#[derive(Debug, Serialize)]
pub struct SearchFilesOutput {
success: bool,
}
pub struct SearchFilesTool;
#[async_trait]
impl ToolExecutor for SearchFilesTool {
type Output = SearchFilesOutput;
fn get_name(&self) -> String {
"search_files".to_string()
}
async fn execute(&self, tool_call: &ToolCall) -> Result<Value> {
let params: SearchFilesParams = serde_json::from_str(&tool_call.function.arguments.clone())?;
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output> {
let params: SearchFilesParams =
serde_json::from_str(&tool_call.function.arguments.clone())?;
// TODO: Implement actual file search logic
Ok(Value::Array(vec![]))
Ok(SearchFilesOutput {
success: true,
})
}
fn get_schema(&self) -> Value {
@ -46,4 +56,4 @@ impl ToolExecutor for SearchFilesTool {
"description": "Searches for metric and dashboard files using natural-language queries. Typically used if you suspect there might already be a relevant metric or dashboard in the repository. If results are found, you can then decide whether to open them with `open_files`."
})
}
}
}

View File

@ -10,18 +10,25 @@ struct SendToUserParams {
metric_id: String,
}
#[derive(Debug, Serialize)]
pub struct SendToUserOutput {
success: bool,
}
pub struct SendToUserTool;
#[async_trait]
impl ToolExecutor for SendToUserTool {
type Output = SendToUserOutput;
fn get_name(&self) -> String {
"send_to_user".to_string()
}
async fn execute(&self, tool_call: &ToolCall) -> Result<Value> {
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output> {
let params: SendToUserParams = serde_json::from_str(&tool_call.function.arguments.clone())?;
// TODO: Implement actual send to user logic
Ok(Value::Array(vec![]))
Ok(SendToUserOutput { success: true })
}
fn get_schema(&self) -> Value {

View File

@ -1,5 +1,6 @@
use anyhow::Result;
use async_trait::async_trait;
use serde::Serialize;
use serde_json::Value;
use crate::utils::clients::ai::litellm::ToolCall;
@ -10,8 +11,11 @@ pub mod file_tools;
/// Any struct that wants to be used as a tool must implement this trait.
#[async_trait]
pub trait ToolExecutor: Send + Sync {
/// The type of the output of the tool
type Output: Serialize + Send;
/// Execute the tool with given arguments and return a result
async fn execute(&self, tool_call: &ToolCall) -> Result<Value>;
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output>;
/// Return the JSON schema that describes this tool's interface
fn get_schema(&self) -> serde_json::Value;
@ -21,11 +25,11 @@ pub trait ToolExecutor: Send + Sync {
}
trait IntoBoxedTool {
fn boxed(self) -> Box<dyn ToolExecutor>;
fn boxed(self) -> Box<dyn ToolExecutor<Output = Value>>;
}
impl<T: ToolExecutor + 'static> IntoBoxedTool for T {
fn boxed(self) -> Box<dyn ToolExecutor> {
impl<T: ToolExecutor<Output = Value> + 'static> IntoBoxedTool for T {
fn boxed(self) -> Box<dyn ToolExecutor<Output = Value>> {
Box::new(self)
}
}