mirror of https://github.com/buster-so/buster.git
feat(data_catalog): Implement AI-powered dataset search tool
- Add comprehensive dataset search functionality using LLM for intelligent dataset matching - Implement search across datasets with relevance ranking based on YML content - Create structured search result output with dataset metadata - Add robust error handling, logging, and parsing for search operations - Include test coverage for search result validation - Enhance tool with flexible query parameter support and detailed response messages
This commit is contained in:
parent
372694bf1f
commit
adba6d6954
|
@ -1,35 +1,179 @@
|
|||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use diesel::prelude::*;
|
||||
use diesel_async::RunQueryDsl;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use serde_json::{json, Value};
|
||||
use tracing::{debug, error, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::utils::{clients::ai::litellm::ToolCall, tools::ToolExecutor};
|
||||
use crate::{
|
||||
database::{
|
||||
lib::get_pg_pool,
|
||||
schema::datasets,
|
||||
},
|
||||
utils::{
|
||||
clients::ai::litellm::{
|
||||
ChatCompletionRequest, LiteLLMClient, Message, ResponseFormat,
|
||||
Tool, ToolCall,
|
||||
},
|
||||
tools::ToolExecutor,
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct SearchDataCatalogParams {
|
||||
search_terms: Vec<String>,
|
||||
#[serde(default)]
|
||||
item_types: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct CatalogSearchResult {
|
||||
id: String,
|
||||
name: String,
|
||||
description: String,
|
||||
item_type: String,
|
||||
relevance_score: f32,
|
||||
metadata: Value,
|
||||
query_params: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct SearchDataCatalogOutput {
|
||||
success: bool,
|
||||
results: Vec<CatalogSearchResult>,
|
||||
message: String,
|
||||
results: Vec<DatasetSearchResult>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct DatasetSearchResult {
|
||||
id: Uuid,
|
||||
name: String,
|
||||
yml_content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct RawLLMResponse {
|
||||
results: Vec<Value>,
|
||||
}
|
||||
|
||||
const CATALOG_SEARCH_PROMPT: &str = r#"
|
||||
You are a dataset search assistant. You have access to a collection of datasets with their YML content.
|
||||
Your task is to identify all relevant datasets based on the following search queries:
|
||||
|
||||
{queries_joined_with_newlines}
|
||||
|
||||
Consider all queries collectively to determine relevance. These queries describe different aspects of the problem or question that needs to be answered.
|
||||
The YML content contains important information about the dataset including its schema, description, and other metadata.
|
||||
Use this information to determine if the dataset would be relevant to answering the queries.
|
||||
|
||||
You must return your response as a JSON object with a 'results' array. Each result should have:
|
||||
- id: string (UUID)
|
||||
- name: string
|
||||
|
||||
Available datasets:
|
||||
{datasets_array_as_json}
|
||||
|
||||
Requirements:
|
||||
1. Return all relevant datasets (no limit)
|
||||
2. Order results from most to least relevant
|
||||
3. Only include id and name fields
|
||||
4. Ensure all field types match the specified formats
|
||||
5. If no datasets are relevant, return an empty results array
|
||||
"#;
|
||||
|
||||
pub struct SearchDataCatalogTool;
|
||||
|
||||
impl SearchDataCatalogTool {
|
||||
fn format_search_prompt(query_params: &[String], datasets: &[DatasetRecord]) -> Result<String> {
|
||||
let datasets_json = datasets
|
||||
.iter()
|
||||
.map(|d| d.to_llm_format())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Ok(CATALOG_SEARCH_PROMPT
|
||||
.replace("{queries_joined_with_newlines}", &query_params.join("\n"))
|
||||
.replace("{datasets_array_as_json}", &serde_json::to_string_pretty(&datasets_json)?))
|
||||
}
|
||||
|
||||
async fn perform_llm_search(prompt: String) -> Result<Vec<DatasetSearchResult>> {
|
||||
debug!("Performing LLM search");
|
||||
|
||||
// Setup LiteLLM client
|
||||
let llm_client = LiteLLMClient::new(None, None);
|
||||
let request = ChatCompletionRequest {
|
||||
model: "o3-mini".to_string(),
|
||||
messages: vec![Message::User {
|
||||
content: prompt,
|
||||
name: None,
|
||||
}],
|
||||
temperature: Some(0.0),
|
||||
response_format: Some(ResponseFormat {
|
||||
type_: "json_object".to_string(),
|
||||
json_schema: None,
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Get response from LLM
|
||||
let response = llm_client.chat_completion(request).await.map_err(|e| {
|
||||
error!(error = %e, "Failed to get response from LLM");
|
||||
anyhow::anyhow!("Failed to get response from LLM: {}", e)
|
||||
})?;
|
||||
|
||||
// Parse LLM response
|
||||
let content = match &response.choices[0].message {
|
||||
Message::Assistant {
|
||||
content: Some(content),
|
||||
..
|
||||
} => content,
|
||||
_ => {
|
||||
error!("LLM response missing content");
|
||||
return Err(anyhow::anyhow!("LLM response missing content"));
|
||||
}
|
||||
};
|
||||
|
||||
// Parse into raw response first
|
||||
let raw_response: RawLLMResponse = serde_json::from_str(content).map_err(|e| {
|
||||
warn!(error = %e, "Failed to parse LLM response as JSON");
|
||||
anyhow::anyhow!("Failed to parse search results: {}", e)
|
||||
})?;
|
||||
|
||||
// Process each result, logging any invalid ones
|
||||
let mut valid_results = Vec::new();
|
||||
let mut invalid_count = 0;
|
||||
|
||||
for result in raw_response.results {
|
||||
match parse_search_result(&result) {
|
||||
Ok(result) => valid_results.push(result),
|
||||
Err(e) => {
|
||||
warn!(error = %e, "Invalid search result from LLM");
|
||||
invalid_count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if invalid_count > 0 {
|
||||
warn!(count = invalid_count, "Found invalid search results");
|
||||
}
|
||||
|
||||
Ok(valid_results)
|
||||
}
|
||||
|
||||
async fn get_datasets() -> Result<Vec<DatasetRecord>> {
|
||||
debug!("Fetching datasets");
|
||||
let mut conn = get_pg_pool().get().await?;
|
||||
|
||||
let datasets = datasets::table
|
||||
.select((
|
||||
datasets::id,
|
||||
datasets::name,
|
||||
datasets::yml_file,
|
||||
datasets::created_at,
|
||||
datasets::updated_at,
|
||||
datasets::deleted_at,
|
||||
))
|
||||
.filter(datasets::deleted_at.is_null())
|
||||
.load::<Dataset>(&mut conn)
|
||||
.await?;
|
||||
|
||||
debug!(count = datasets.len(), "Successfully loaded datasets");
|
||||
|
||||
// Convert to DatasetRecord format
|
||||
datasets.into_iter()
|
||||
.map(DatasetRecord::from_dataset)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolExecutor for SearchDataCatalogTool {
|
||||
type Output = SearchDataCatalogOutput;
|
||||
|
@ -39,12 +183,43 @@ impl ToolExecutor for SearchDataCatalogTool {
|
|||
}
|
||||
|
||||
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output> {
|
||||
let params: SearchDataCatalogParams =
|
||||
serde_json::from_str(&tool_call.function.arguments.clone())?;
|
||||
// TODO: Implement actual data catalog search logic
|
||||
debug!("Starting dataset search operation");
|
||||
let params: SearchDataCatalogParams = serde_json::from_str(&tool_call.function.arguments)?;
|
||||
|
||||
// Fetch all non-deleted datasets
|
||||
let datasets = Self::get_datasets().await?;
|
||||
if datasets.is_empty() {
|
||||
return Ok(SearchDataCatalogOutput {
|
||||
message: "No datasets available to search".to_string(),
|
||||
results: vec![],
|
||||
});
|
||||
}
|
||||
|
||||
// Format prompt and perform search
|
||||
let prompt = Self::format_search_prompt(¶ms.query_params, &datasets)?;
|
||||
let search_results = match Self::perform_llm_search(prompt).await {
|
||||
Ok(results) => results,
|
||||
Err(e) => {
|
||||
return Ok(SearchDataCatalogOutput {
|
||||
message: format!("Search failed: {}", e),
|
||||
results: vec![],
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let message = if search_results.is_empty() {
|
||||
"No relevant datasets found".to_string()
|
||||
} else {
|
||||
format!(
|
||||
"Found {} relevant datasets for {} queries",
|
||||
search_results.len(),
|
||||
params.query_params.len()
|
||||
)
|
||||
};
|
||||
|
||||
Ok(SearchDataCatalogOutput {
|
||||
success: true,
|
||||
results: vec![],
|
||||
message,
|
||||
results: search_results,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -54,29 +229,119 @@ impl ToolExecutor for SearchDataCatalogTool {
|
|||
"strict": true,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": ["search_terms"],
|
||||
"required": ["query_params"],
|
||||
"properties": {
|
||||
"search_terms": {
|
||||
"query_params": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"description": "A search term for finding relevant data catalog entries"
|
||||
"description": "A descriptive search query representing an aspect of the problem or question to be answered"
|
||||
},
|
||||
"description": "Array of strings representing the terms to search for in the data catalog"
|
||||
},
|
||||
"item_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"enum": ["dataset", "metric", "business_term", "logic"],
|
||||
"description": "Type of catalog item to search for"
|
||||
},
|
||||
"description": "Optional filter to limit search to specific types of catalog items"
|
||||
"description": "Array of natural language queries that collectively describe the problem or question that needs to be answered"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
},
|
||||
"description": "Searches the data catalog for relevant items including datasets, metrics, business terms, and logic definitions. Returns structured results with relevance scores. Use this to find data assets and their documentation."
|
||||
"description": "Searches for datasets using multiple natural language queries that describe different aspects of the problem/question. Analyzes YML content for relevance and returns all relevant datasets ordered by relevance."
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper types and functions
|
||||
|
||||
#[derive(Queryable, Selectable)]
|
||||
#[diesel(table_name = datasets)]
|
||||
#[diesel(check_for_backend(diesel::pg::Pg))]
|
||||
struct Dataset {
|
||||
id: Uuid,
|
||||
name: String,
|
||||
yml_file: Option<String>,
|
||||
created_at: DateTime<Utc>,
|
||||
updated_at: DateTime<Utc>,
|
||||
deleted_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
struct DatasetRecord {
|
||||
id: Uuid,
|
||||
name: String,
|
||||
yml_content: String,
|
||||
}
|
||||
|
||||
impl DatasetRecord {
|
||||
fn to_llm_format(&self) -> Value {
|
||||
json!({
|
||||
"id": self.id.to_string(),
|
||||
"name": self.name,
|
||||
"content": self.yml_content,
|
||||
})
|
||||
}
|
||||
|
||||
fn from_dataset(dataset: Dataset) -> Result<Self> {
|
||||
Ok(Self {
|
||||
id: dataset.id,
|
||||
name: dataset.name,
|
||||
yml_content: dataset.yml_file.unwrap_or_default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_search_result(result: &Value) -> Result<DatasetSearchResult> {
|
||||
Ok(DatasetSearchResult {
|
||||
id: Uuid::parse_str(
|
||||
result
|
||||
.get("id")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing id"))?,
|
||||
)?,
|
||||
name: result
|
||||
.get("name")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing name"))?
|
||||
.to_string(),
|
||||
yml_content: result
|
||||
.get("yml_content")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing yml_content"))?
|
||||
.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::TimeZone;
|
||||
|
||||
#[test]
|
||||
fn test_parse_valid_search_result() {
|
||||
let result = json!({
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"name": "Test Dataset",
|
||||
"yml_content": "description: Test dataset\nschema:\n - name: id\n type: uuid"
|
||||
});
|
||||
|
||||
let parsed = parse_search_result(&result).unwrap();
|
||||
assert_eq!(parsed.name, "Test Dataset");
|
||||
assert!(parsed.yml_content.contains("description: Test dataset"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_invalid_search_result() {
|
||||
let result = json!({
|
||||
"id": "invalid-uuid",
|
||||
"name": "Test Dataset",
|
||||
"yml_content": "test content"
|
||||
});
|
||||
|
||||
assert!(parse_search_result(&result).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_missing_fields() {
|
||||
let result = json!({
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"name": "Test Dataset"
|
||||
});
|
||||
|
||||
assert!(parse_search_result(&result).is_err());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -93,6 +93,7 @@ impl SearchFilesTool {
|
|||
content: prompt,
|
||||
name: None,
|
||||
}],
|
||||
temperature: Some(0.0),
|
||||
response_format: Some(ResponseFormat {
|
||||
type_: "json_object".to_string(),
|
||||
json_schema: None,
|
||||
|
@ -249,6 +250,34 @@ mod tests {
|
|||
use super::*;
|
||||
use chrono::TimeZone;
|
||||
|
||||
fn parse_search_result(result: &Value) -> Result<FileSearchResult> {
|
||||
Ok(FileSearchResult {
|
||||
id: Uuid::parse_str(
|
||||
result
|
||||
.get("id")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing id"))?,
|
||||
)?,
|
||||
name: result
|
||||
.get("name")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing name"))?
|
||||
.to_string(),
|
||||
file_type: result
|
||||
.get("file_type")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing file_type"))?
|
||||
.to_string(),
|
||||
updated_at: DateTime::parse_from_rfc3339(
|
||||
result
|
||||
.get("updated_at")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing updated_at"))?,
|
||||
)?
|
||||
.with_timezone(&Utc),
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_valid_search_result() {
|
||||
let result = json!({
|
||||
|
|
Loading…
Reference in New Issue