feat(file_search): Implement advanced AI-powered file search tool

- Add comprehensive file search functionality using LLM for intelligent file matching
- Implement search across metric and dashboard files with relevance ranking
- Create structured search result output with file metadata
- Add robust error handling and logging for search operations
- Include test coverage for search result parsing
This commit is contained in:
dal 2025-02-07 07:47:35 -07:00
parent 94c1635a34
commit 372694bf1f
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
1 changed files with 238 additions and 6 deletions

View File

@ -1,9 +1,26 @@
use anyhow::Result;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use diesel::prelude::*;
use diesel_async::RunQueryDsl;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_json::{json, Value};
use tracing::{debug, error, info, warn};
use uuid::Uuid;
use crate::utils::{clients::ai::litellm::ToolCall, tools::ToolExecutor};
use crate::{
database::{
lib::get_pg_pool,
models::{DashboardFile, MetricFile},
schema::{dashboard_files, metric_files},
},
utils::{
clients::ai::litellm::{
ChatCompletionRequest, LiteLLMClient, Message, ResponseFormat, Tool, ToolCall,
},
tools::ToolExecutor,
},
};
#[derive(Debug, Serialize, Deserialize)]
struct SearchFilesParams {
@ -12,11 +29,103 @@ struct SearchFilesParams {
#[derive(Debug, Serialize)]
pub struct SearchFilesOutput {
success: bool,
message: String,
files: Vec<FileSearchResult>,
}
#[derive(Debug, Serialize, Deserialize)]
struct FileSearchResult {
id: Uuid,
name: String,
file_type: String,
updated_at: DateTime<Utc>,
}
const FILE_SEARCH_PROMPT: &str = r#"
You are a file search assistant. You have access to a collection of metric and dashboard files.
Your task is to identify up to 10 most relevant files based on the following search queries:
{queries_joined_with_newlines}
Consider all queries collectively to determine relevance.
You must return your response as a JSON array of objects, where each object has these exact fields:
- id: string (UUID)
- name: string
- file_type: string (either "metric" or "dashboard")
- updated_at: string (ISO timestamp)
Available files:
{files_array_as_json}
Requirements:
1. Return up to 10 most relevant files
2. Order results from most to least relevant
3. Only include the specified fields
4. Ensure all field types match the specified formats
"#;
#[derive(Debug, Deserialize)]
struct LLMSearchResponse {
results: Vec<FileSearchResult>,
}
pub struct SearchFilesTool;
impl SearchFilesTool {
fn format_search_prompt(query_params: &[String], files_array: &[Value]) -> Result<String> {
let queries_joined = query_params.join("\n");
let files_json = serde_json::to_string_pretty(&files_array)?;
Ok(FILE_SEARCH_PROMPT
.replace("{queries_joined_with_newlines}", &queries_joined)
.replace("{files_array_as_json}", &files_json))
}
async fn perform_llm_search(prompt: String) -> Result<LLMSearchResponse> {
debug!("Performing LLM search");
// Setup LiteLLM client
let llm_client = LiteLLMClient::new(None, None);
let request = ChatCompletionRequest {
model: "o3-mini".to_string(),
messages: vec![Message::User {
content: prompt,
name: None,
}],
response_format: Some(ResponseFormat {
type_: "json_object".to_string(),
json_schema: None,
}),
..Default::default()
};
// Get response from LLM
let response = llm_client.chat_completion(request).await.map_err(|e| {
error!(error = %e, "Failed to get response from LLM");
anyhow::anyhow!("Failed to get response from LLM: {}", e)
})?;
// Parse LLM response
let content = match &response.choices[0].message {
Message::Assistant {
content: Some(content),
..
} => content,
_ => {
error!("LLM response missing content");
return Err(anyhow::anyhow!("LLM response missing content"));
}
};
// Parse into structured response
serde_json::from_str(content).map_err(|e| {
warn!(error = %e, "Failed to parse LLM response as JSON");
anyhow::anyhow!("Failed to parse search results: {}", e)
})
}
}
#[async_trait]
impl ToolExecutor for SearchFilesTool {
type Output = SearchFilesOutput;
@ -26,11 +135,62 @@ impl ToolExecutor for SearchFilesTool {
}
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output> {
debug!("Starting file search operation");
let params: SearchFilesParams =
serde_json::from_str(&tool_call.function.arguments.clone())?;
// TODO: Implement actual file search logic
// Fetch all non-deleted records from both tables concurrently
let (metric_files, dashboard_files) =
tokio::try_join!(get_metric_files(), get_dashboard_files())?;
// Format files for LLM
let files_array: Vec<Value> = metric_files
.iter()
.map(|f| {
json!({
"id": f.id.to_string(),
"name": f.name,
"file_type": "metric",
"updated_at": f.updated_at.to_rfc3339(),
})
})
.chain(dashboard_files.iter().map(|f| {
json!({
"id": f.id.to_string(),
"name": f.name,
"file_type": "dashboard",
"updated_at": f.updated_at.to_rfc3339(),
})
}))
.collect();
// Format prompt and perform search
let prompt = Self::format_search_prompt(&params.query_params, &files_array)?;
let search_response = match Self::perform_llm_search(prompt).await {
Ok(response) => response,
Err(e) => {
return Ok(SearchFilesOutput {
message: format!("Search failed: {}", e),
files: vec![],
});
}
};
let message = if search_response.results.is_empty() {
"No relevant files found".to_string()
} else {
format!("Found {} relevant files", search_response.results.len())
};
info!(
query_count = params.query_params.len(),
result_count = search_response.results.len(),
"Completed file search operation"
);
Ok(SearchFilesOutput {
success: true,
message,
files: search_response.results,
})
}
@ -53,7 +213,79 @@ impl ToolExecutor for SearchFilesTool {
},
"additionalProperties": false
},
"description": "Searches for metric and dashboard files using natural-language queries. Typically used if you suspect there might already be a relevant metric or dashboard in the repository. If results are found, you can then decide whether to open them with `open_files`."
"description": "Searches for metric and dashboard files using natural-language queries. Returns up to 10 most relevant files ordered by relevance."
})
}
}
async fn get_metric_files() -> Result<Vec<MetricFile>> {
debug!("Fetching metric files");
let mut conn = get_pg_pool().get().await?;
let files = metric_files::table
.filter(metric_files::deleted_at.is_null())
.load::<MetricFile>(&mut conn)
.await?;
debug!(count = files.len(), "Successfully loaded metric files");
Ok(files)
}
async fn get_dashboard_files() -> Result<Vec<DashboardFile>> {
debug!("Fetching dashboard files");
let mut conn = get_pg_pool().get().await?;
let files = dashboard_files::table
.filter(dashboard_files::deleted_at.is_null())
.load::<DashboardFile>(&mut conn)
.await?;
debug!(count = files.len(), "Successfully loaded dashboard files");
Ok(files)
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::TimeZone;
#[test]
fn test_parse_valid_search_result() {
let result = json!({
"id": "550e8400-e29b-41d4-a716-446655440000",
"name": "Test File",
"file_type": "metric",
"updated_at": "2024-02-07T00:00:00Z"
});
let parsed = parse_search_result(&result).unwrap();
assert_eq!(parsed.name, "Test File");
assert_eq!(parsed.file_type, "metric");
assert_eq!(
parsed.updated_at,
Utc.with_ymd_and_hms(2024, 2, 7, 0, 0, 0).unwrap()
);
}
#[test]
fn test_parse_invalid_search_result() {
let result = json!({
"id": "invalid-uuid",
"name": "Test File",
"file_type": "metric",
"updated_at": "2024-02-07T00:00:00Z"
});
assert!(parse_search_result(&result).is_err());
}
#[test]
fn test_parse_missing_fields() {
let result = json!({
"id": "550e8400-e29b-41d4-a716-446655440000",
"name": "Test File"
});
assert!(parse_search_result(&result).is_err());
}
}