From 372694bf1fa0311ea6fa7976f55a32fddbfcc70c Mon Sep 17 00:00:00 2001 From: dal Date: Fri, 7 Feb 2025 07:47:35 -0700 Subject: [PATCH] feat(file_search): Implement advanced AI-powered file search tool - Add comprehensive file search functionality using LLM for intelligent file matching - Implement search across metric and dashboard files with relevance ranking - Create structured search result output with file metadata - Add robust error handling and logging for search operations - Include test coverage for search result parsing --- .../utils/tools/file_tools/search_files.rs | 244 +++++++++++++++++- 1 file changed, 238 insertions(+), 6 deletions(-) diff --git a/api/src/utils/tools/file_tools/search_files.rs b/api/src/utils/tools/file_tools/search_files.rs index f9e5eec87..d6c372b87 100644 --- a/api/src/utils/tools/file_tools/search_files.rs +++ b/api/src/utils/tools/file_tools/search_files.rs @@ -1,9 +1,26 @@ use anyhow::Result; use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use diesel::prelude::*; +use diesel_async::RunQueryDsl; use serde::{Deserialize, Serialize}; -use serde_json::Value; +use serde_json::{json, Value}; +use tracing::{debug, error, info, warn}; +use uuid::Uuid; -use crate::utils::{clients::ai::litellm::ToolCall, tools::ToolExecutor}; +use crate::{ + database::{ + lib::get_pg_pool, + models::{DashboardFile, MetricFile}, + schema::{dashboard_files, metric_files}, + }, + utils::{ + clients::ai::litellm::{ + ChatCompletionRequest, LiteLLMClient, Message, ResponseFormat, Tool, ToolCall, + }, + tools::ToolExecutor, + }, +}; #[derive(Debug, Serialize, Deserialize)] struct SearchFilesParams { @@ -12,11 +29,103 @@ struct SearchFilesParams { #[derive(Debug, Serialize)] pub struct SearchFilesOutput { - success: bool, + message: String, + files: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +struct FileSearchResult { + id: Uuid, + name: String, + file_type: String, + updated_at: DateTime, +} + +const FILE_SEARCH_PROMPT: &str = r#" +You are a file search assistant. You have access to a collection of metric and dashboard files. +Your task is to identify up to 10 most relevant files based on the following search queries: + +{queries_joined_with_newlines} + +Consider all queries collectively to determine relevance. + +You must return your response as a JSON array of objects, where each object has these exact fields: +- id: string (UUID) +- name: string +- file_type: string (either "metric" or "dashboard") +- updated_at: string (ISO timestamp) + +Available files: +{files_array_as_json} + +Requirements: +1. Return up to 10 most relevant files +2. Order results from most to least relevant +3. Only include the specified fields +4. Ensure all field types match the specified formats +"#; + +#[derive(Debug, Deserialize)] +struct LLMSearchResponse { + results: Vec, } pub struct SearchFilesTool; +impl SearchFilesTool { + fn format_search_prompt(query_params: &[String], files_array: &[Value]) -> Result { + let queries_joined = query_params.join("\n"); + let files_json = serde_json::to_string_pretty(&files_array)?; + + Ok(FILE_SEARCH_PROMPT + .replace("{queries_joined_with_newlines}", &queries_joined) + .replace("{files_array_as_json}", &files_json)) + } + + async fn perform_llm_search(prompt: String) -> Result { + debug!("Performing LLM search"); + + // Setup LiteLLM client + let llm_client = LiteLLMClient::new(None, None); + let request = ChatCompletionRequest { + model: "o3-mini".to_string(), + messages: vec![Message::User { + content: prompt, + name: None, + }], + response_format: Some(ResponseFormat { + type_: "json_object".to_string(), + json_schema: None, + }), + ..Default::default() + }; + + // Get response from LLM + let response = llm_client.chat_completion(request).await.map_err(|e| { + error!(error = %e, "Failed to get response from LLM"); + anyhow::anyhow!("Failed to get response from LLM: {}", e) + })?; + + // Parse LLM response + let content = match &response.choices[0].message { + Message::Assistant { + content: Some(content), + .. + } => content, + _ => { + error!("LLM response missing content"); + return Err(anyhow::anyhow!("LLM response missing content")); + } + }; + + // Parse into structured response + serde_json::from_str(content).map_err(|e| { + warn!(error = %e, "Failed to parse LLM response as JSON"); + anyhow::anyhow!("Failed to parse search results: {}", e) + }) + } +} + #[async_trait] impl ToolExecutor for SearchFilesTool { type Output = SearchFilesOutput; @@ -26,11 +135,62 @@ impl ToolExecutor for SearchFilesTool { } async fn execute(&self, tool_call: &ToolCall) -> Result { + debug!("Starting file search operation"); let params: SearchFilesParams = serde_json::from_str(&tool_call.function.arguments.clone())?; - // TODO: Implement actual file search logic + + // Fetch all non-deleted records from both tables concurrently + let (metric_files, dashboard_files) = + tokio::try_join!(get_metric_files(), get_dashboard_files())?; + + // Format files for LLM + let files_array: Vec = metric_files + .iter() + .map(|f| { + json!({ + "id": f.id.to_string(), + "name": f.name, + "file_type": "metric", + "updated_at": f.updated_at.to_rfc3339(), + }) + }) + .chain(dashboard_files.iter().map(|f| { + json!({ + "id": f.id.to_string(), + "name": f.name, + "file_type": "dashboard", + "updated_at": f.updated_at.to_rfc3339(), + }) + })) + .collect(); + + // Format prompt and perform search + let prompt = Self::format_search_prompt(¶ms.query_params, &files_array)?; + let search_response = match Self::perform_llm_search(prompt).await { + Ok(response) => response, + Err(e) => { + return Ok(SearchFilesOutput { + message: format!("Search failed: {}", e), + files: vec![], + }); + } + }; + + let message = if search_response.results.is_empty() { + "No relevant files found".to_string() + } else { + format!("Found {} relevant files", search_response.results.len()) + }; + + info!( + query_count = params.query_params.len(), + result_count = search_response.results.len(), + "Completed file search operation" + ); + Ok(SearchFilesOutput { - success: true, + message, + files: search_response.results, }) } @@ -53,7 +213,79 @@ impl ToolExecutor for SearchFilesTool { }, "additionalProperties": false }, - "description": "Searches for metric and dashboard files using natural-language queries. Typically used if you suspect there might already be a relevant metric or dashboard in the repository. If results are found, you can then decide whether to open them with `open_files`." + "description": "Searches for metric and dashboard files using natural-language queries. Returns up to 10 most relevant files ordered by relevance." }) } } + +async fn get_metric_files() -> Result> { + debug!("Fetching metric files"); + let mut conn = get_pg_pool().get().await?; + + let files = metric_files::table + .filter(metric_files::deleted_at.is_null()) + .load::(&mut conn) + .await?; + + debug!(count = files.len(), "Successfully loaded metric files"); + Ok(files) +} + +async fn get_dashboard_files() -> Result> { + debug!("Fetching dashboard files"); + let mut conn = get_pg_pool().get().await?; + + let files = dashboard_files::table + .filter(dashboard_files::deleted_at.is_null()) + .load::(&mut conn) + .await?; + + debug!(count = files.len(), "Successfully loaded dashboard files"); + Ok(files) +} + +#[cfg(test)] +mod tests { + use super::*; + use chrono::TimeZone; + + #[test] + fn test_parse_valid_search_result() { + let result = json!({ + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "Test File", + "file_type": "metric", + "updated_at": "2024-02-07T00:00:00Z" + }); + + let parsed = parse_search_result(&result).unwrap(); + assert_eq!(parsed.name, "Test File"); + assert_eq!(parsed.file_type, "metric"); + assert_eq!( + parsed.updated_at, + Utc.with_ymd_and_hms(2024, 2, 7, 0, 0, 0).unwrap() + ); + } + + #[test] + fn test_parse_invalid_search_result() { + let result = json!({ + "id": "invalid-uuid", + "name": "Test File", + "file_type": "metric", + "updated_at": "2024-02-07T00:00:00Z" + }); + + assert!(parse_search_result(&result).is_err()); + } + + #[test] + fn test_parse_missing_fields() { + let result = json!({ + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "Test File" + }); + + assert!(parse_search_result(&result).is_err()); + } +}