mirror of https://github.com/buster-so/buster.git
feat(file_search): Implement advanced AI-powered file search tool
- Add comprehensive file search functionality using LLM for intelligent file matching - Implement search across metric and dashboard files with relevance ranking - Create structured search result output with file metadata - Add robust error handling and logging for search operations - Include test coverage for search result parsing
This commit is contained in:
parent
94c1635a34
commit
372694bf1f
|
@ -1,9 +1,26 @@
|
|||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use diesel::prelude::*;
|
||||
use diesel_async::RunQueryDsl;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use serde_json::{json, Value};
|
||||
use tracing::{debug, error, info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::utils::{clients::ai::litellm::ToolCall, tools::ToolExecutor};
|
||||
use crate::{
|
||||
database::{
|
||||
lib::get_pg_pool,
|
||||
models::{DashboardFile, MetricFile},
|
||||
schema::{dashboard_files, metric_files},
|
||||
},
|
||||
utils::{
|
||||
clients::ai::litellm::{
|
||||
ChatCompletionRequest, LiteLLMClient, Message, ResponseFormat, Tool, ToolCall,
|
||||
},
|
||||
tools::ToolExecutor,
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct SearchFilesParams {
|
||||
|
@ -12,11 +29,103 @@ struct SearchFilesParams {
|
|||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct SearchFilesOutput {
|
||||
success: bool,
|
||||
message: String,
|
||||
files: Vec<FileSearchResult>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct FileSearchResult {
|
||||
id: Uuid,
|
||||
name: String,
|
||||
file_type: String,
|
||||
updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
const FILE_SEARCH_PROMPT: &str = r#"
|
||||
You are a file search assistant. You have access to a collection of metric and dashboard files.
|
||||
Your task is to identify up to 10 most relevant files based on the following search queries:
|
||||
|
||||
{queries_joined_with_newlines}
|
||||
|
||||
Consider all queries collectively to determine relevance.
|
||||
|
||||
You must return your response as a JSON array of objects, where each object has these exact fields:
|
||||
- id: string (UUID)
|
||||
- name: string
|
||||
- file_type: string (either "metric" or "dashboard")
|
||||
- updated_at: string (ISO timestamp)
|
||||
|
||||
Available files:
|
||||
{files_array_as_json}
|
||||
|
||||
Requirements:
|
||||
1. Return up to 10 most relevant files
|
||||
2. Order results from most to least relevant
|
||||
3. Only include the specified fields
|
||||
4. Ensure all field types match the specified formats
|
||||
"#;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct LLMSearchResponse {
|
||||
results: Vec<FileSearchResult>,
|
||||
}
|
||||
|
||||
pub struct SearchFilesTool;
|
||||
|
||||
impl SearchFilesTool {
|
||||
fn format_search_prompt(query_params: &[String], files_array: &[Value]) -> Result<String> {
|
||||
let queries_joined = query_params.join("\n");
|
||||
let files_json = serde_json::to_string_pretty(&files_array)?;
|
||||
|
||||
Ok(FILE_SEARCH_PROMPT
|
||||
.replace("{queries_joined_with_newlines}", &queries_joined)
|
||||
.replace("{files_array_as_json}", &files_json))
|
||||
}
|
||||
|
||||
async fn perform_llm_search(prompt: String) -> Result<LLMSearchResponse> {
|
||||
debug!("Performing LLM search");
|
||||
|
||||
// Setup LiteLLM client
|
||||
let llm_client = LiteLLMClient::new(None, None);
|
||||
let request = ChatCompletionRequest {
|
||||
model: "o3-mini".to_string(),
|
||||
messages: vec![Message::User {
|
||||
content: prompt,
|
||||
name: None,
|
||||
}],
|
||||
response_format: Some(ResponseFormat {
|
||||
type_: "json_object".to_string(),
|
||||
json_schema: None,
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Get response from LLM
|
||||
let response = llm_client.chat_completion(request).await.map_err(|e| {
|
||||
error!(error = %e, "Failed to get response from LLM");
|
||||
anyhow::anyhow!("Failed to get response from LLM: {}", e)
|
||||
})?;
|
||||
|
||||
// Parse LLM response
|
||||
let content = match &response.choices[0].message {
|
||||
Message::Assistant {
|
||||
content: Some(content),
|
||||
..
|
||||
} => content,
|
||||
_ => {
|
||||
error!("LLM response missing content");
|
||||
return Err(anyhow::anyhow!("LLM response missing content"));
|
||||
}
|
||||
};
|
||||
|
||||
// Parse into structured response
|
||||
serde_json::from_str(content).map_err(|e| {
|
||||
warn!(error = %e, "Failed to parse LLM response as JSON");
|
||||
anyhow::anyhow!("Failed to parse search results: {}", e)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolExecutor for SearchFilesTool {
|
||||
type Output = SearchFilesOutput;
|
||||
|
@ -26,11 +135,62 @@ impl ToolExecutor for SearchFilesTool {
|
|||
}
|
||||
|
||||
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output> {
|
||||
debug!("Starting file search operation");
|
||||
let params: SearchFilesParams =
|
||||
serde_json::from_str(&tool_call.function.arguments.clone())?;
|
||||
// TODO: Implement actual file search logic
|
||||
|
||||
// Fetch all non-deleted records from both tables concurrently
|
||||
let (metric_files, dashboard_files) =
|
||||
tokio::try_join!(get_metric_files(), get_dashboard_files())?;
|
||||
|
||||
// Format files for LLM
|
||||
let files_array: Vec<Value> = metric_files
|
||||
.iter()
|
||||
.map(|f| {
|
||||
json!({
|
||||
"id": f.id.to_string(),
|
||||
"name": f.name,
|
||||
"file_type": "metric",
|
||||
"updated_at": f.updated_at.to_rfc3339(),
|
||||
})
|
||||
})
|
||||
.chain(dashboard_files.iter().map(|f| {
|
||||
json!({
|
||||
"id": f.id.to_string(),
|
||||
"name": f.name,
|
||||
"file_type": "dashboard",
|
||||
"updated_at": f.updated_at.to_rfc3339(),
|
||||
})
|
||||
}))
|
||||
.collect();
|
||||
|
||||
// Format prompt and perform search
|
||||
let prompt = Self::format_search_prompt(¶ms.query_params, &files_array)?;
|
||||
let search_response = match Self::perform_llm_search(prompt).await {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
return Ok(SearchFilesOutput {
|
||||
message: format!("Search failed: {}", e),
|
||||
files: vec![],
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let message = if search_response.results.is_empty() {
|
||||
"No relevant files found".to_string()
|
||||
} else {
|
||||
format!("Found {} relevant files", search_response.results.len())
|
||||
};
|
||||
|
||||
info!(
|
||||
query_count = params.query_params.len(),
|
||||
result_count = search_response.results.len(),
|
||||
"Completed file search operation"
|
||||
);
|
||||
|
||||
Ok(SearchFilesOutput {
|
||||
success: true,
|
||||
message,
|
||||
files: search_response.results,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -53,7 +213,79 @@ impl ToolExecutor for SearchFilesTool {
|
|||
},
|
||||
"additionalProperties": false
|
||||
},
|
||||
"description": "Searches for metric and dashboard files using natural-language queries. Typically used if you suspect there might already be a relevant metric or dashboard in the repository. If results are found, you can then decide whether to open them with `open_files`."
|
||||
"description": "Searches for metric and dashboard files using natural-language queries. Returns up to 10 most relevant files ordered by relevance."
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_metric_files() -> Result<Vec<MetricFile>> {
|
||||
debug!("Fetching metric files");
|
||||
let mut conn = get_pg_pool().get().await?;
|
||||
|
||||
let files = metric_files::table
|
||||
.filter(metric_files::deleted_at.is_null())
|
||||
.load::<MetricFile>(&mut conn)
|
||||
.await?;
|
||||
|
||||
debug!(count = files.len(), "Successfully loaded metric files");
|
||||
Ok(files)
|
||||
}
|
||||
|
||||
async fn get_dashboard_files() -> Result<Vec<DashboardFile>> {
|
||||
debug!("Fetching dashboard files");
|
||||
let mut conn = get_pg_pool().get().await?;
|
||||
|
||||
let files = dashboard_files::table
|
||||
.filter(dashboard_files::deleted_at.is_null())
|
||||
.load::<DashboardFile>(&mut conn)
|
||||
.await?;
|
||||
|
||||
debug!(count = files.len(), "Successfully loaded dashboard files");
|
||||
Ok(files)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::TimeZone;
|
||||
|
||||
#[test]
|
||||
fn test_parse_valid_search_result() {
|
||||
let result = json!({
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"name": "Test File",
|
||||
"file_type": "metric",
|
||||
"updated_at": "2024-02-07T00:00:00Z"
|
||||
});
|
||||
|
||||
let parsed = parse_search_result(&result).unwrap();
|
||||
assert_eq!(parsed.name, "Test File");
|
||||
assert_eq!(parsed.file_type, "metric");
|
||||
assert_eq!(
|
||||
parsed.updated_at,
|
||||
Utc.with_ymd_and_hms(2024, 2, 7, 0, 0, 0).unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_invalid_search_result() {
|
||||
let result = json!({
|
||||
"id": "invalid-uuid",
|
||||
"name": "Test File",
|
||||
"file_type": "metric",
|
||||
"updated_at": "2024-02-07T00:00:00Z"
|
||||
});
|
||||
|
||||
assert!(parse_search_result(&result).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_missing_fields() {
|
||||
let result = json!({
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"name": "Test File"
|
||||
});
|
||||
|
||||
assert!(parse_search_result(&result).is_err());
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue