search data catalog endpoint

This commit is contained in:
dal 2025-03-31 08:25:51 -06:00
parent 87cfb8d45d
commit 400dae6e58
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
3 changed files with 402 additions and 0 deletions

View File

@ -0,0 +1,8 @@
mod search_data_catalog;
use axum::Router;
pub fn router() -> Router {
Router::new()
.nest("/search_data_catalog", search_data_catalog::router())
}

View File

@ -0,0 +1,392 @@
use axum::{routing::post, Extension, Json, Router};
use cohere_rust::{
api::rerank::{ReRankModel, ReRankRequest},
Cohere,
};
use database::{pool::get_pg_pool, schema::datasets};
use diesel::prelude::*;
use diesel_async::RunQueryDsl;
use futures::{
future::join_all,
stream::{self, StreamExt},
};
use litellm::{AgentMessage, ChatCompletionRequest, LiteLLMClient, Metadata, ResponseFormat};
use middleware::types::AuthenticatedUser;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::{HashMap, HashSet};
use tracing::{debug, error, info, warn};
use uuid::Uuid;
use crate::routes::rest::ApiResponse;
#[derive(Debug, Deserialize)]
pub struct SearchDataCatalogRequest {
queries: Vec<String>,
user_request: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct SearchDataCatalogResponse {
results: Vec<DatasetResult>,
}
#[derive(Debug, Serialize, Clone, PartialEq, Eq, Hash)]
pub struct DatasetResult {
id: Uuid,
name: Option<String>,
yml_content: Option<String>,
}
// Model representing a dataset from the database
#[derive(Debug, Queryable, Selectable, Clone)]
#[diesel(table_name = datasets)]
#[diesel(check_for_backend(diesel::pg::Pg))]
struct Dataset {
id: Uuid,
name: String,
#[diesel(column_name = "yml_file")]
yml_content: Option<String>,
// Other fields omitted for query
}
#[derive(Debug, Clone)]
struct RankedDataset {
dataset: Dataset,
relevance_score: f64,
}
#[derive(Debug, Deserialize)]
struct LLMFilterResponse {
results: Vec<FilteredDataset>,
}
#[derive(Debug, Deserialize)]
struct FilteredDataset {
id: String,
reason: String,
}
const LLM_FILTER_PROMPT: &str = r#"
You are a dataset relevance evaluator. Your task is to determine which datasets might contain information relevant to the user's query based on their structure and metadata.
USER REQUEST: {user_request}
SEARCH QUERY: {query}
Below is a list of datasets that were identified as potentially relevant by an initial semantic ranking system.
For each dataset, review its description in the YAML format and determine if its structure is suitable for the user's query.
ONLY include datasets that you determine are relevant in your response.
DATASETS:
{datasets_json}
Return a JSON response with the following structure:
```json
{
"results": [
{
"id": "dataset-uuid-here",
"reason": "Brief explanation of why this dataset's structure is relevant"
},
// ... more relevant datasets only
]
}
```
IMPORTANT GUIDELINES:
1. DO NOT make assumptions about what specific values exist in the datasets
2. Focus EXCLUSIVELY on identifying datasets with STRUCTURES that could reasonably contain the type of information requested
3. For example, if a user asks about "red bull sales", consider datasets about products, sales, inventory, etc. as potentially relevant - even if "red bull" is not explicitly mentioned
4. Evaluate based on whether the dataset's schema, fields, or description indicates it COULD contain the relevant information
5. Look for structural compatibility rather than exact matches in the content
6. ONLY include datasets you find relevant in your response - omit any that aren't relevant
7. Ensure the "id" field exactly matches the dataset's UUID
8. Use both the USER REQUEST and SEARCH QUERY to understand the user's information needs - the USER REQUEST provides broader context while the SEARCH QUERY represents specific search intent
9. Restrict your evaluation strictly to the defined elements in the dataset metadata:
- Column names and their data types
- Entity relationships
- Predefined metrics
- Table schemas
- Dimension hierarchies
10. Do NOT make assumptions about what data might exist beyond what is explicitly defined in the metadata
11. A dataset is relevant ONLY if its documented structure supports answering the query, not because you assume it might contain certain data
"#;
pub fn router() -> Router {
Router::new().route("/", post(handle_search_data_catalog))
}
async fn handle_search_data_catalog(
Extension(user): Extension<AuthenticatedUser>,
Json(request): Json<SearchDataCatalogRequest>,
) -> ApiResponse<SearchDataCatalogResponse> {
// Basic validation
if request.queries.is_empty() {
return ApiResponse::JsonData(SearchDataCatalogResponse { results: vec![] });
}
// Get the user's organization ID (using the first organization)
let org_id = match user.organizations.get(0) {
Some(org) => org.id,
None => {
error!("User has no organizations");
return ApiResponse::JsonData(SearchDataCatalogResponse { results: vec![] });
}
};
// Retrieve datasets for the organization
let datasets = match get_datasets_for_organization(org_id).await {
Ok(datasets) => datasets,
Err(e) => {
error!("Failed to retrieve datasets: {}", e);
return ApiResponse::JsonData(SearchDataCatalogResponse { results: vec![] });
}
};
if datasets.is_empty() {
return ApiResponse::JsonData(SearchDataCatalogResponse { results: vec![] });
}
// Extract YML content for reranking
let documents: Vec<String> = datasets
.iter()
.filter_map(|dataset| dataset.yml_content.clone())
.collect();
if documents.is_empty() {
warn!("No datasets with YML content found");
return ApiResponse::JsonData(SearchDataCatalogResponse { results: vec![] });
}
// Store user_request for passing to process_query
let user_request = request.user_request.clone();
// Process all queries concurrently using Cohere reranking
let ranked_datasets_futures = stream::iter(request.queries)
.map(|query| process_query(query, datasets.clone(), documents.clone(), &user, user_request.clone()))
.buffer_unordered(5) // Process up to 5 queries concurrently
.collect::<Vec<_>>()
.await;
// Combine and deduplicate results
let mut unique_datasets = HashSet::new();
let results = ranked_datasets_futures
.into_iter()
.flat_map(|result| match result {
Ok(datasets) => datasets,
Err(e) => {
error!("Failed to process query: {}", e);
vec![]
}
})
.filter(|result| unique_datasets.insert(result.clone()))
.collect();
ApiResponse::JsonData(SearchDataCatalogResponse { results })
}
async fn get_datasets_for_organization(org_id: Uuid) -> Result<Vec<Dataset>, anyhow::Error> {
use database::schema::datasets::dsl::*;
let mut conn = get_pg_pool().get().await?;
let results = datasets
.filter(organization_id.eq(org_id))
.filter(deleted_at.is_null())
.filter(yml_file.is_not_null())
.select((id, name, yml_file))
.load::<Dataset>(&mut conn)
.await?;
Ok(results)
}
async fn process_query(
query: String,
all_datasets: Vec<Dataset>,
documents: Vec<String>,
user: &AuthenticatedUser,
user_request: Option<String>,
) -> Result<Vec<DatasetResult>, anyhow::Error> {
// Step 1: Rerank datasets using Cohere
let ranked_datasets = rerank_datasets(&query, &all_datasets, &documents).await?;
if ranked_datasets.is_empty() {
info!(
"No datasets were relevant after reranking for query: '{}'",
query
);
return Ok(vec![]);
}
// Step 2: Filter with LLM for true relevance
let filtered_datasets = filter_datasets_with_llm(&query, ranked_datasets, user, user_request).await?;
Ok(filtered_datasets)
}
async fn rerank_datasets(
query: &str,
all_datasets: &[Dataset],
documents: &[String],
) -> Result<Vec<RankedDataset>, anyhow::Error> {
// Initialize Cohere client
let co = Cohere::default();
// Create rerank request
let request = ReRankRequest {
query,
documents,
model: ReRankModel::EnglishV3,
top_n: Some(20), // Get top 20 results per query
..Default::default()
};
// Get reranked results
let rerank_results = co.rerank(&request).await?;
// Map results back to datasets
let mut ranked_datasets = Vec::new();
for result in rerank_results {
if let Some(dataset) = all_datasets.get(result.index as usize) {
ranked_datasets.push(RankedDataset {
dataset: dataset.clone(),
relevance_score: result.relevance_score,
});
} else {
error!("Invalid dataset index from Cohere: {}", result.index);
}
}
// Sort by relevance score (highest first)
ranked_datasets.sort_by(|a, b| {
b.relevance_score
.partial_cmp(&a.relevance_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
// Only keep results with meaningful relevance scores
// This threshold is arbitrary and may need tuning
let relevant_datasets = ranked_datasets.into_iter().collect::<Vec<_>>();
Ok(relevant_datasets)
}
async fn filter_datasets_with_llm(
query: &str,
ranked_datasets: Vec<RankedDataset>,
user: &AuthenticatedUser,
user_request: Option<String>,
) -> Result<Vec<DatasetResult>, anyhow::Error> {
debug!(
"Filtering {} datasets with LLM for query: {}",
ranked_datasets.len(),
query
);
// Format datasets for LLM prompt
let datasets_json = ranked_datasets
.iter()
.map(|ranked| {
serde_json::json!({
"id": ranked.dataset.id.to_string(),
"name": ranked.dataset.name,
"yml_content": ranked.dataset.yml_content.clone().unwrap_or_default(),
"relevance_score": ranked.relevance_score
})
})
.collect::<Vec<_>>();
// Format the prompt
let user_request_text = user_request.unwrap_or_else(|| query.to_string());
let prompt = LLM_FILTER_PROMPT
.replace("{user_request}", &user_request_text)
.replace("{query}", query)
.replace(
"{datasets_json}",
&serde_json::to_string_pretty(&datasets_json)?,
);
// Initialize LiteLLM client
let llm_client = LiteLLMClient::new(None, None);
// Create the request
let request = ChatCompletionRequest {
model: "gemini-2.0-flash-001".to_string(), // Using a small model for cost efficiency
messages: vec![AgentMessage::User {
id: None,
content: prompt,
name: None,
}],
stream: Some(false),
response_format: Some(ResponseFormat {
type_: "json_object".to_string(),
json_schema: None,
}),
metadata: Some(Metadata {
generation_name: "filter_data_catalog".to_string(),
user_id: user.id.to_string(),
session_id: Uuid::new_v4().to_string(),
trace_id: Uuid::new_v4().to_string(),
}),
// reasoning_effort: Some(String::from("low")),
max_completion_tokens: Some(8096),
..Default::default()
};
// Get response from LLM
let response = llm_client.chat_completion(request).await?;
// Parse LLM response
let content = match &response.choices[0].message {
AgentMessage::Assistant {
content: Some(content),
..
} => content,
_ => {
error!("LLM response missing content");
return Err(anyhow::anyhow!("LLM response missing content"));
}
};
// Parse into typed response
let filter_response: LLMFilterResponse = match serde_json::from_str(content) {
Ok(response) => response,
Err(e) => {
error!("Failed to parse LLM response: {}", e);
return Err(anyhow::anyhow!("Failed to parse LLM response: {}", e));
}
};
// Create a map for quick lookups of dataset IDs
let dataset_map: HashMap<Uuid, &Dataset> = ranked_datasets
.iter()
.map(|ranked| (ranked.dataset.id, &ranked.dataset))
.collect();
// Convert filtered relevant datasets to DatasetResult
let filtered_datasets: Vec<DatasetResult> = filter_response
.results
.into_iter()
.filter_map(|result| {
// Parse the UUID
match Uuid::parse_str(&result.id) {
Ok(id) => {
// Get the dataset
dataset_map.get(&id).map(|dataset| DatasetResult {
id: dataset.id,
name: Some(dataset.name.clone()),
yml_content: dataset.yml_content.clone(),
})
}
Err(_) => None,
}
})
.collect();
debug!(
"LLM filtering complete, keeping {} relevant datasets",
filtered_datasets.len()
);
Ok(filtered_datasets)
}

View File

@ -5,6 +5,7 @@ mod dashboards;
mod data_sources;
mod dataset_groups;
mod datasets;
mod helpers;
mod logs;
mod messages;
mod metrics;
@ -37,6 +38,7 @@ pub fn router() -> Router {
.nest("/collections", collections::router())
.nest("/logs", logs::router())
.nest("/search", search::router())
.nest("/helpers", helpers::router())
.route_layer(axum_middleware::from_fn(auth)),
)
}