From 400dae6e582cad9c1c7a176bb251f4a0bbe12a78 Mon Sep 17 00:00:00 2001 From: dal Date: Mon, 31 Mar 2025 08:25:51 -0600 Subject: [PATCH] search data catalog endpoint --- api/src/routes/rest/routes/helpers/mod.rs | 8 + .../routes/helpers/search_data_catalog.rs | 392 ++++++++++++++++++ api/src/routes/rest/routes/mod.rs | 2 + 3 files changed, 402 insertions(+) create mode 100644 api/src/routes/rest/routes/helpers/mod.rs create mode 100644 api/src/routes/rest/routes/helpers/search_data_catalog.rs diff --git a/api/src/routes/rest/routes/helpers/mod.rs b/api/src/routes/rest/routes/helpers/mod.rs new file mode 100644 index 000000000..1ec9eb2b0 --- /dev/null +++ b/api/src/routes/rest/routes/helpers/mod.rs @@ -0,0 +1,8 @@ +mod search_data_catalog; + +use axum::Router; + +pub fn router() -> Router { + Router::new() + .nest("/search_data_catalog", search_data_catalog::router()) +} \ No newline at end of file diff --git a/api/src/routes/rest/routes/helpers/search_data_catalog.rs b/api/src/routes/rest/routes/helpers/search_data_catalog.rs new file mode 100644 index 000000000..255bc3b4a --- /dev/null +++ b/api/src/routes/rest/routes/helpers/search_data_catalog.rs @@ -0,0 +1,392 @@ +use axum::{routing::post, Extension, Json, Router}; +use cohere_rust::{ + api::rerank::{ReRankModel, ReRankRequest}, + Cohere, +}; +use database::{pool::get_pg_pool, schema::datasets}; +use diesel::prelude::*; +use diesel_async::RunQueryDsl; +use futures::{ + future::join_all, + stream::{self, StreamExt}, +}; +use litellm::{AgentMessage, ChatCompletionRequest, LiteLLMClient, Metadata, ResponseFormat}; +use middleware::types::AuthenticatedUser; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::{HashMap, HashSet}; +use tracing::{debug, error, info, warn}; +use uuid::Uuid; + +use crate::routes::rest::ApiResponse; + +#[derive(Debug, Deserialize)] +pub struct SearchDataCatalogRequest { + queries: Vec, + user_request: Option, +} + +#[derive(Debug, Serialize)] +pub struct SearchDataCatalogResponse { + results: Vec, +} + +#[derive(Debug, Serialize, Clone, PartialEq, Eq, Hash)] +pub struct DatasetResult { + id: Uuid, + name: Option, + yml_content: Option, +} + +// Model representing a dataset from the database +#[derive(Debug, Queryable, Selectable, Clone)] +#[diesel(table_name = datasets)] +#[diesel(check_for_backend(diesel::pg::Pg))] +struct Dataset { + id: Uuid, + name: String, + #[diesel(column_name = "yml_file")] + yml_content: Option, + // Other fields omitted for query +} + +#[derive(Debug, Clone)] +struct RankedDataset { + dataset: Dataset, + relevance_score: f64, +} + +#[derive(Debug, Deserialize)] +struct LLMFilterResponse { + results: Vec, +} + +#[derive(Debug, Deserialize)] +struct FilteredDataset { + id: String, + reason: String, +} + +const LLM_FILTER_PROMPT: &str = r#" +You are a dataset relevance evaluator. Your task is to determine which datasets might contain information relevant to the user's query based on their structure and metadata. + +USER REQUEST: {user_request} +SEARCH QUERY: {query} + +Below is a list of datasets that were identified as potentially relevant by an initial semantic ranking system. +For each dataset, review its description in the YAML format and determine if its structure is suitable for the user's query. +ONLY include datasets that you determine are relevant in your response. + +DATASETS: +{datasets_json} + +Return a JSON response with the following structure: +```json +{ + "results": [ + { + "id": "dataset-uuid-here", + "reason": "Brief explanation of why this dataset's structure is relevant" + }, + // ... more relevant datasets only + ] +} +``` + +IMPORTANT GUIDELINES: +1. DO NOT make assumptions about what specific values exist in the datasets +2. Focus EXCLUSIVELY on identifying datasets with STRUCTURES that could reasonably contain the type of information requested +3. For example, if a user asks about "red bull sales", consider datasets about products, sales, inventory, etc. as potentially relevant - even if "red bull" is not explicitly mentioned +4. Evaluate based on whether the dataset's schema, fields, or description indicates it COULD contain the relevant information +5. Look for structural compatibility rather than exact matches in the content +6. ONLY include datasets you find relevant in your response - omit any that aren't relevant +7. Ensure the "id" field exactly matches the dataset's UUID +8. Use both the USER REQUEST and SEARCH QUERY to understand the user's information needs - the USER REQUEST provides broader context while the SEARCH QUERY represents specific search intent +9. Restrict your evaluation strictly to the defined elements in the dataset metadata: + - Column names and their data types + - Entity relationships + - Predefined metrics + - Table schemas + - Dimension hierarchies +10. Do NOT make assumptions about what data might exist beyond what is explicitly defined in the metadata +11. A dataset is relevant ONLY if its documented structure supports answering the query, not because you assume it might contain certain data +"#; + +pub fn router() -> Router { + Router::new().route("/", post(handle_search_data_catalog)) +} + +async fn handle_search_data_catalog( + Extension(user): Extension, + Json(request): Json, +) -> ApiResponse { + // Basic validation + if request.queries.is_empty() { + return ApiResponse::JsonData(SearchDataCatalogResponse { results: vec![] }); + } + + // Get the user's organization ID (using the first organization) + let org_id = match user.organizations.get(0) { + Some(org) => org.id, + None => { + error!("User has no organizations"); + return ApiResponse::JsonData(SearchDataCatalogResponse { results: vec![] }); + } + }; + + // Retrieve datasets for the organization + let datasets = match get_datasets_for_organization(org_id).await { + Ok(datasets) => datasets, + Err(e) => { + error!("Failed to retrieve datasets: {}", e); + return ApiResponse::JsonData(SearchDataCatalogResponse { results: vec![] }); + } + }; + + if datasets.is_empty() { + return ApiResponse::JsonData(SearchDataCatalogResponse { results: vec![] }); + } + + // Extract YML content for reranking + let documents: Vec = datasets + .iter() + .filter_map(|dataset| dataset.yml_content.clone()) + .collect(); + + if documents.is_empty() { + warn!("No datasets with YML content found"); + return ApiResponse::JsonData(SearchDataCatalogResponse { results: vec![] }); + } + + // Store user_request for passing to process_query + let user_request = request.user_request.clone(); + + // Process all queries concurrently using Cohere reranking + let ranked_datasets_futures = stream::iter(request.queries) + .map(|query| process_query(query, datasets.clone(), documents.clone(), &user, user_request.clone())) + .buffer_unordered(5) // Process up to 5 queries concurrently + .collect::>() + .await; + + // Combine and deduplicate results + let mut unique_datasets = HashSet::new(); + let results = ranked_datasets_futures + .into_iter() + .flat_map(|result| match result { + Ok(datasets) => datasets, + Err(e) => { + error!("Failed to process query: {}", e); + vec![] + } + }) + .filter(|result| unique_datasets.insert(result.clone())) + .collect(); + + ApiResponse::JsonData(SearchDataCatalogResponse { results }) +} + +async fn get_datasets_for_organization(org_id: Uuid) -> Result, anyhow::Error> { + use database::schema::datasets::dsl::*; + + let mut conn = get_pg_pool().get().await?; + + let results = datasets + .filter(organization_id.eq(org_id)) + .filter(deleted_at.is_null()) + .filter(yml_file.is_not_null()) + .select((id, name, yml_file)) + .load::(&mut conn) + .await?; + + Ok(results) +} + +async fn process_query( + query: String, + all_datasets: Vec, + documents: Vec, + user: &AuthenticatedUser, + user_request: Option, +) -> Result, anyhow::Error> { + // Step 1: Rerank datasets using Cohere + let ranked_datasets = rerank_datasets(&query, &all_datasets, &documents).await?; + + if ranked_datasets.is_empty() { + info!( + "No datasets were relevant after reranking for query: '{}'", + query + ); + return Ok(vec![]); + } + + // Step 2: Filter with LLM for true relevance + let filtered_datasets = filter_datasets_with_llm(&query, ranked_datasets, user, user_request).await?; + + Ok(filtered_datasets) +} + +async fn rerank_datasets( + query: &str, + all_datasets: &[Dataset], + documents: &[String], +) -> Result, anyhow::Error> { + // Initialize Cohere client + let co = Cohere::default(); + + // Create rerank request + let request = ReRankRequest { + query, + documents, + model: ReRankModel::EnglishV3, + top_n: Some(20), // Get top 20 results per query + ..Default::default() + }; + + // Get reranked results + let rerank_results = co.rerank(&request).await?; + + // Map results back to datasets + let mut ranked_datasets = Vec::new(); + for result in rerank_results { + if let Some(dataset) = all_datasets.get(result.index as usize) { + ranked_datasets.push(RankedDataset { + dataset: dataset.clone(), + relevance_score: result.relevance_score, + }); + } else { + error!("Invalid dataset index from Cohere: {}", result.index); + } + } + + // Sort by relevance score (highest first) + ranked_datasets.sort_by(|a, b| { + b.relevance_score + .partial_cmp(&a.relevance_score) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + // Only keep results with meaningful relevance scores + // This threshold is arbitrary and may need tuning + let relevant_datasets = ranked_datasets.into_iter().collect::>(); + + Ok(relevant_datasets) +} + +async fn filter_datasets_with_llm( + query: &str, + ranked_datasets: Vec, + user: &AuthenticatedUser, + user_request: Option, +) -> Result, anyhow::Error> { + debug!( + "Filtering {} datasets with LLM for query: {}", + ranked_datasets.len(), + query + ); + + // Format datasets for LLM prompt + let datasets_json = ranked_datasets + .iter() + .map(|ranked| { + serde_json::json!({ + "id": ranked.dataset.id.to_string(), + "name": ranked.dataset.name, + "yml_content": ranked.dataset.yml_content.clone().unwrap_or_default(), + "relevance_score": ranked.relevance_score + }) + }) + .collect::>(); + + // Format the prompt + let user_request_text = user_request.unwrap_or_else(|| query.to_string()); + let prompt = LLM_FILTER_PROMPT + .replace("{user_request}", &user_request_text) + .replace("{query}", query) + .replace( + "{datasets_json}", + &serde_json::to_string_pretty(&datasets_json)?, + ); + + // Initialize LiteLLM client + let llm_client = LiteLLMClient::new(None, None); + + // Create the request + let request = ChatCompletionRequest { + model: "gemini-2.0-flash-001".to_string(), // Using a small model for cost efficiency + messages: vec![AgentMessage::User { + id: None, + content: prompt, + name: None, + }], + stream: Some(false), + response_format: Some(ResponseFormat { + type_: "json_object".to_string(), + json_schema: None, + }), + metadata: Some(Metadata { + generation_name: "filter_data_catalog".to_string(), + user_id: user.id.to_string(), + session_id: Uuid::new_v4().to_string(), + trace_id: Uuid::new_v4().to_string(), + }), + // reasoning_effort: Some(String::from("low")), + max_completion_tokens: Some(8096), + ..Default::default() + }; + + // Get response from LLM + let response = llm_client.chat_completion(request).await?; + + // Parse LLM response + let content = match &response.choices[0].message { + AgentMessage::Assistant { + content: Some(content), + .. + } => content, + _ => { + error!("LLM response missing content"); + return Err(anyhow::anyhow!("LLM response missing content")); + } + }; + + // Parse into typed response + let filter_response: LLMFilterResponse = match serde_json::from_str(content) { + Ok(response) => response, + Err(e) => { + error!("Failed to parse LLM response: {}", e); + return Err(anyhow::anyhow!("Failed to parse LLM response: {}", e)); + } + }; + + // Create a map for quick lookups of dataset IDs + let dataset_map: HashMap = ranked_datasets + .iter() + .map(|ranked| (ranked.dataset.id, &ranked.dataset)) + .collect(); + + // Convert filtered relevant datasets to DatasetResult + let filtered_datasets: Vec = filter_response + .results + .into_iter() + .filter_map(|result| { + // Parse the UUID + match Uuid::parse_str(&result.id) { + Ok(id) => { + // Get the dataset + dataset_map.get(&id).map(|dataset| DatasetResult { + id: dataset.id, + name: Some(dataset.name.clone()), + yml_content: dataset.yml_content.clone(), + }) + } + Err(_) => None, + } + }) + .collect(); + + debug!( + "LLM filtering complete, keeping {} relevant datasets", + filtered_datasets.len() + ); + Ok(filtered_datasets) +} diff --git a/api/src/routes/rest/routes/mod.rs b/api/src/routes/rest/routes/mod.rs index e7417e1ca..ced93cb5c 100644 --- a/api/src/routes/rest/routes/mod.rs +++ b/api/src/routes/rest/routes/mod.rs @@ -5,6 +5,7 @@ mod dashboards; mod data_sources; mod dataset_groups; mod datasets; +mod helpers; mod logs; mod messages; mod metrics; @@ -37,6 +38,7 @@ pub fn router() -> Router { .nest("/collections", collections::router()) .nest("/logs", logs::router()) .nest("/search", search::router()) + .nest("/helpers", helpers::router()) .route_layer(axum_middleware::from_fn(auth)), ) }