super close, one complex use case that needs to be captured.

This commit is contained in:
dal 2025-04-29 08:31:51 -06:00
parent fda1b5d8be
commit c43354bf75
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
6 changed files with 1816 additions and 1495 deletions

View File

@ -1,5 +1,6 @@
use std::collections::{HashMap, HashSet};
use std::{env, sync::Arc, time::Instant};
use tokio::sync::Mutex;
use anyhow::{Context, Result};
use async_trait::async_trait;
@ -45,7 +46,6 @@ pub struct SearchDataCatalogOutput {
pub message: String,
pub specific_queries: Option<Vec<String>>,
pub exploratory_topics: Option<Vec<String>>,
pub value_search_terms: Option<Vec<String>>,
pub duration: i64,
pub results: Vec<DatasetSearchResult>,
}
@ -364,42 +364,37 @@ impl ToolExecutor for SearchDataCatalogTool {
message: format!("Error fetching datasets: {}", e),
specific_queries: params.specific_queries,
exploratory_topics: params.exploratory_topics,
value_search_terms: Some(vec![]),
duration: start_time.elapsed().as_millis() as i64,
results: vec![],
});
}
};
// Check if datasets were fetched and are not empty
if all_datasets.is_empty() {
info!("No datasets found for the organization.");
info!("No datasets found for the organization or user.");
// Optionally cache that no data source was found or handle as needed
self.agent.set_state_value(String::from("data_source_id"), Value::Null).await;
return Ok(SearchDataCatalogOutput {
message: "No datasets available to search.".to_string(),
message: "No datasets available to search. Have you deployed datasets? If you believe this is an error, please contact support.".to_string(),
specific_queries: params.specific_queries,
exploratory_topics: params.exploratory_topics,
value_search_terms: Some(vec![]),
duration: start_time.elapsed().as_millis() as i64,
results: vec![],
});
}
// Get the data source ID
let target_data_source_id = match extract_data_source_id(&all_datasets) {
Some(id) => id,
None => {
warn!("No data source ID found in the permissioned datasets.");
return Ok(SearchDataCatalogOutput {
message: "Could not determine data source for value search.".to_string(),
specific_queries: params.specific_queries,
exploratory_topics: params.exploratory_topics,
value_search_terms: Some(vec![]),
duration: start_time.elapsed().as_millis() as i64,
results: vec![],
});
}
};
// Extract and cache the data_source_id from the first dataset
// Assumes all datasets belong to the same data source for this user context
let target_data_source_id = all_datasets[0].data_source_id;
debug!(data_source_id = %target_data_source_id, "Extracted data source ID");
debug!(data_source_id = %target_data_source_id, "Identified target data source ID for value search");
// Cache the data_source_id in agent state
self.agent.set_state_value(
"data_source_id".to_string(),
Value::String(target_data_source_id.to_string())
).await;
debug!(data_source_id = %target_data_source_id, "Cached data source ID in agent state");
// Prepare documents from datasets
let documents: Vec<String> = all_datasets
@ -413,7 +408,6 @@ impl ToolExecutor for SearchDataCatalogTool {
message: "No searchable dataset content found.".to_string(),
specific_queries: params.specific_queries,
exploratory_topics: params.exploratory_topics,
value_search_terms: Some(vec![]),
duration: start_time.elapsed().as_millis() as i64,
results: vec![],
});
@ -423,19 +417,26 @@ impl ToolExecutor for SearchDataCatalogTool {
// We'll use the user prompt for the LLM filtering
let user_prompt_for_task = user_prompt_str.clone();
// Keep track of reranking errors using Arc<Mutex>
let rerank_errors = Arc::new(Mutex::new(Vec::new()));
// 3a. Start specific query reranking
let specific_rerank_futures = stream::iter(specific_queries.clone())
.map(|query| {
let current_query = query.clone();
let datasets_clone = all_datasets.clone();
let documents_clone = documents.clone();
let rerank_errors_clone = Arc::clone(&rerank_errors); // Clone Arc
async move {
let ranked = match rerank_datasets(&current_query, &datasets_clone, &documents_clone).await {
Ok(r) => r,
Err(e) => {
error!(error = %e, query = current_query, "Reranking failed for specific query");
Vec::new()
// Lock and push error
let mut errors = rerank_errors_clone.lock().await;
errors.push(format!("Failed to rerank for specific query '{}': {}", current_query, e));
Vec::new() // Return empty vec on error to avoid breaking flow
}
};
@ -450,13 +451,17 @@ impl ToolExecutor for SearchDataCatalogTool {
let current_topic = topic.clone();
let datasets_clone = all_datasets.clone();
let documents_clone = documents.clone();
let rerank_errors_clone = Arc::clone(&rerank_errors); // Clone Arc
async move {
let ranked = match rerank_datasets(&current_topic, &datasets_clone, &documents_clone).await {
Ok(r) => r,
Err(e) => {
error!(error = %e, topic = current_topic, "Reranking failed for exploratory topic");
Vec::new()
// Lock and push error
let mut errors = rerank_errors_clone.lock().await;
errors.push(format!("Failed to rerank for exploratory topic '{}': {}", current_topic, e));
Vec::new() // Return empty vec on error to avoid breaking flow
}
};
@ -491,7 +496,6 @@ impl ToolExecutor for SearchDataCatalogTool {
message: "No search queries, exploratory topics, or valid value search terms provided.".to_string(),
specific_queries: params.specific_queries,
exploratory_topics: params.exploratory_topics,
value_search_terms: Some(valid_value_search_terms.clone()), // Return the filtered list
duration: start_time.elapsed().as_millis() as i64,
results: vec![],
});
@ -704,12 +708,25 @@ impl ToolExecutor for SearchDataCatalogTool {
}
// Return the updated results
let message = if updated_results.is_empty() {
let mut message = if updated_results.is_empty() {
"No relevant datasets found after filtering.".to_string()
} else {
format!("Found {} relevant datasets with injected values for searchable dimensions.", updated_results.len())
};
// Append reranking error information if any occurred
// Lock the mutex to access the errors safely
let final_errors = rerank_errors.lock().await;
if !final_errors.is_empty() {
message.push_str("
Warning: Some parts of the search failed due to reranking errors:");
for error_msg in final_errors.iter() { // Iterate over locked data
message.push_str(&format!("
- {}", error_msg));
}
}
// Mutex guard `final_errors` is dropped here
self.agent
.set_state_value(
String::from("data_context"),
@ -727,7 +744,6 @@ impl ToolExecutor for SearchDataCatalogTool {
message,
specific_queries: params.specific_queries,
exploratory_topics: params.exploratory_topics,
value_search_terms: Some(valid_value_search_terms),
duration: duration as i64,
results: updated_results, // Use updated results instead of final_search_results
})

View File

@ -2,7 +2,7 @@ use crate::errors::SqlAnalyzerError;
use crate::types::{QuerySummary, TableInfo, JoinInfo, CteSummary};
use sqlparser::ast::{
Visit, Visitor, TableFactor, Join, Expr, Query, Cte, ObjectName,
SelectItem, Statement, JoinConstraint, JoinOperator, SetExpr,
SelectItem, Statement, JoinConstraint, JoinOperator, SetExpr, TableAlias, Ident
};
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
@ -18,24 +18,35 @@ pub(crate) fn analyze_sql(sql: &str) -> Result<QuerySummary, SqlAnalyzerError> {
for stmt in ast {
if let Statement::Query(query) = stmt {
// Analyze the top-level query
analyzer.process_query(&query)?;
}
// Potentially handle other statement types if needed
}
analyzer.into_summary()
}
#[derive(Debug)] // Added for debugging purposes
struct QueryAnalyzer {
tables: HashMap<String, TableInfo>,
tables: HashMap<String, TableInfo>, // Base table identifier -> Info
joins: HashSet<JoinInfo>,
cte_aliases: Vec<HashSet<String>>,
ctes: Vec<CteSummary>,
// --- Scope Management ---
cte_aliases: Vec<HashSet<String>>, // Stack for CTEs available in current scope
scope_stack: Vec<String>, // Stack for tracking query/subquery scope names (e.g., CTE name, subquery alias)
// --- Alias & Mapping ---
// Stores mapping from an alias used in the query to its underlying identifier
// Underlying identifier can be: Base Table Name, CTE Name, or Subquery Alias itself
table_aliases: HashMap<String, String>,
// Keep column_mappings if needed for lineage, though not strictly required for join *detection*
column_mappings: HashMap<String, HashMap<String, (String, String)>>, // Context -> (col -> (table_alias, col))
// --- Join Processing State ---
// Identifier (name or alias) of the relation that serves as the left input for the *next* join
current_from_relation_identifier: Option<String>,
// --- Error Tracking ---
vague_columns: Vec<String>,
vague_tables: Vec<String>,
column_mappings: HashMap<String, HashMap<String, (String, String)>>, // Context -> (col -> (table, col))
scope_stack: Vec<String>, // For tracking the current query scope
current_left_table: Option<String>, // For tracking the left table in joins
table_aliases: HashMap<String, String>, // Alias -> Table name
ctes: Vec<CteSummary>, // Store analyzed CTE summaries
}
impl QueryAnalyzer {
@ -43,261 +54,481 @@ impl QueryAnalyzer {
QueryAnalyzer {
tables: HashMap::new(),
joins: HashSet::new(),
cte_aliases: vec![HashSet::new()],
cte_aliases: vec![HashSet::new()], // Start with global scope
ctes: Vec::new(),
vague_columns: Vec::new(),
vague_tables: Vec::new(),
column_mappings: HashMap::new(),
scope_stack: Vec::new(),
current_left_table: None,
current_from_relation_identifier: None,
table_aliases: HashMap::new(),
}
}
// fn current_scope_name(&self) -> String {
// self.scope_stack.last().cloned().unwrap_or_default()
// }
/// Determines the effective identifier for a table factor in the context of a join
/// (preferring alias, falling back to CTE name or base table name) and registers
/// the alias mapping if an alias is present.
/// Returns None if the factor cannot be reliably identified (e.g., unaliased subquery, nested join).
fn get_factor_identifier_and_register_alias(&mut self, factor: &TableFactor) -> Option<String> {
match factor {
TableFactor::Table { name, alias, .. } => {
let first_part = name.0.first().map(|i| i.value.clone()).unwrap_or_default();
// Check if it's a reference to a known CTE (assuming CTE names are single identifiers)
if name.0.len() == 1 && self.is_cte(&first_part) {
let cte_name = first_part;
if let Some(a) = alias {
let alias_name = a.name.value.clone();
// Map alias -> CTE Name
self.table_aliases.insert(alias_name.clone(), cte_name.clone());
Some(alias_name) // Use alias as the identifier in this context
} else {
// No alias, use CTE name itself. Ensure it maps to itself.
self.table_aliases.entry(cte_name.clone()).or_insert_with(|| cte_name.clone());
Some(cte_name)
}
} else {
// It's a base table reference
let base_table_identifier = self.get_table_name(name); // Get the last part (table name)
if let Some(a) = alias {
let alias_name = a.name.value.clone();
// Map alias -> Base Table Name
self.table_aliases.insert(alias_name.clone(), base_table_identifier.clone());
Some(alias_name) // Use alias as the identifier
} else {
// No alias, use base table name itself. Ensure it maps to itself.
self.table_aliases.entry(base_table_identifier.clone()).or_insert_with(|| base_table_identifier.clone());
Some(base_table_identifier)
}
}
},
TableFactor::Derived { alias, .. } => {
// For derived tables (subqueries), the alias is the only way to identify it
alias.as_ref().map(|a| {
let alias_name = a.name.value.clone();
// Map alias -> alias itself (identifier for subquery is its alias)
self.table_aliases.insert(alias_name.clone(), alias_name.clone());
alias_name
})
// Unaliased subquery returns None - cannot be reliably joined by name
},
TableFactor::TableFunction { expr: _, alias } => {
// Treat table functions like derived tables - alias is the identifier
alias.as_ref().map(|a| {
let alias_name = a.name.value.clone();
// Map alias -> alias itself (function result identified by alias)
self.table_aliases.insert(alias_name.clone(), alias_name.clone());
alias_name
})
},
TableFactor::NestedJoin { table_with_joins: _, alias: _ } => {
// A nested join structure itself doesn't have a single simple identifier
// for the outer join context. The alias for the *entire* nested join result
// might be available (depending on sqlparser version/syntax used),
// but we are choosing to return None here as the internal structure is complex.
// The alias registration for the *nested join result* itself happens
// if this NestedJoin appears within another TableFactor::Derived or similar
// that *provides* an alias for the result.
// The processing within process_table_factor handles analyzing the *contents*.
None
}
_ => None, // Other factors like UNNEST don't have a simple identifier in this context
}
}
/// Processes a query, including CTEs, FROM/JOIN clauses, and other parts.
fn process_query(&mut self, query: &Query) -> Result<(), SqlAnalyzerError> {
// Handle WITH clause (CTEs)
let mut is_with_query = false;
// --- CTE Processing ---
if let Some(with) = &query.with {
for cte in &with.cte_tables {
self.process_cte(cte)?;
}
}
is_with_query = true;
// Push a new scope level for CTEs defined *at this level*
self.cte_aliases.push(HashSet::new());
// Analyze each CTE definition
for cte in &with.cte_tables {
// Analyze the CTE. process_cte will handle adding its name to the *current* scope
// in self.cte_aliases *after* it's processed.
self.process_cte(cte)?;
}
}
// Process FROM clause first to build table information
if let SetExpr::Select(select) = query.body.as_ref() {
for table_with_joins in &select.from {
// Process the main table
self.process_table_factor(&table_with_joins.relation);
// Save this as the current left table for joins
if let TableFactor::Table { name, .. } = &table_with_joins.relation {
self.current_left_table = Some(self.get_table_name(name));
}
// Process joins
for join in &table_with_joins.joins {
self.process_join(join);
}
}
// After tables are set up, process SELECT items to extract column references
for select_item in &select.projection {
self.process_select_item(select_item);
}
}
// --- Main Query Body Processing (SELECT, UNION, etc.) ---
match query.body.as_ref() {
SetExpr::Select(select) => {
// Reset join state for this SELECT scope
self.current_from_relation_identifier = None;
// Visit the entire query to catch other expressions
// Process FROM clause: base relation + subsequent joins
for table_with_joins in &select.from {
// 1. Process the first relation in the FROM clause
self.process_table_factor(&table_with_joins.relation); // Analyze content (add base tables, analyze subqueries)
// 2. Get its identifier and set it as the initial "left side" for joins
self.current_from_relation_identifier = self.get_factor_identifier_and_register_alias(&table_with_joins.relation);
// 3. Process all subsequent JOIN clauses, updating the "left side" state as we go
for join in &table_with_joins.joins {
self.process_join(join); // Updates current_from_relation_identifier internally
}
}
// Process other clauses (WHERE, GROUP BY, HAVING, SELECT list)
// Use the visitor pattern for expressions within these clauses
if let Some(selection) = &select.selection { selection.visit(self); }
// Assuming GroupByExpr holds Vec<Expr>, iterate and visit
if let sqlparser::ast::GroupByExpr::Expressions(exprs, _) = &select.group_by {
for expr in exprs { expr.visit(self); }
}
if let Some(having) = &select.having { having.visit(self); }
for item in &select.projection { self.process_select_item(item); } // Let visitor handle expressions here too
}
SetExpr::Query(inner_query) => {
// Recursively process nested queries
self.process_query(inner_query)?;
}
SetExpr::SetOperation { left, op: _, right, set_quantifier: _ } => {
// Process UNION, INTERSECT, EXCEPT - analyze both sides
self.process_query_body(left)?;
self.process_query_body(right)?;
}
SetExpr::Values(_) => {
// VALUES clause typically doesn't involve tables/joins in the same way
}
SetExpr::Insert(_) | SetExpr::Update(_) | SetExpr::Table(_) => {
// These are less common in typical SELECT analysis contexts, handle if needed
}
// Handle other SetExpr variants if they become relevant
#[allow(unreachable_patterns)] // Avoid warning if sqlparser adds variants
_ => {}
}
// Fallback visit for any expressions missed by specific clause processing
query.visit(self);
// --- Cleanup ---
// Pop the CTE scope level if one was added
if is_with_query {
self.cte_aliases.pop();
}
Ok(())
}
/// Helper to process the body of a query (like sides of a UNION)
fn process_query_body(&mut self, query_body: &SetExpr) -> Result<(), SqlAnalyzerError> {
match query_body {
SetExpr::Select(_) | SetExpr::Query(_) | SetExpr::SetOperation {..} => {
// Create a temporary Query object to wrap the SetExpr for process_query
// This is a bit of a workaround because process_query expects a Query struct.
let temp_query = Query {
with: None,
body: Box::new(query_body.clone()), // Clone the body
order_by: None, // Use None for Option<OrderBy>
limit: None,
limit_by: vec![], // Added default
offset: None,
fetch: None,
locks: vec![],
for_clause: None, // Added default
settings: None, // Added default
format_clause: None, // Added default
};
// Analyze this part of the query - it will update the *same* analyzer state
self.process_query(&temp_query)?;
}
// Handle other SetExpr types if necessary
_ => {}
}
Ok(())
}
/// Analyzes a CTE definition and stores its summary.
fn process_cte(&mut self, cte: &Cte) -> Result<(), SqlAnalyzerError> {
let cte_name = cte.alias.name.to_string();
// Add CTE to the current scope's aliases
self.cte_aliases.last_mut().unwrap().insert(cte_name.clone());
// Create a new analyzer for the CTE query
let cte_name = cte.alias.name.value.clone();
// Create a *new* analyzer for the CTE's internal scope
let mut cte_analyzer = QueryAnalyzer::new();
// Copy the current CTE aliases to the new analyzer
// Inherit *all* currently visible CTE scopes from the parent
cte_analyzer.cte_aliases = self.cte_aliases.clone();
// Push the CTE name to the scope stack
// --> Inherit table aliases from the parent scope <--
cte_analyzer.table_aliases = self.table_aliases.clone();
// Set the scope name for analysis context within the CTE
cte_analyzer.scope_stack.push(cte_name.clone());
// Analyze the CTE query
cte_analyzer.process_query(&cte.query)?;
// Extract CTE information before moving the analyzer
let cte_tables = cte_analyzer.tables.clone();
let cte_joins = cte_analyzer.joins.clone();
let cte_ctes = cte_analyzer.ctes.clone();
let vague_columns = cte_analyzer.vague_columns.clone();
let vague_tables = cte_analyzer.vague_tables.clone();
let column_mappings = cte_analyzer.column_mappings.get(&cte_name).cloned().unwrap_or_default();
// Create CTE summary
let cte_summary = QuerySummary {
tables: cte_tables.into_values().collect(),
joins: cte_joins,
ctes: cte_ctes,
};
self.ctes.push(CteSummary {
name: cte_name.clone(),
summary: cte_summary,
column_mappings,
});
// Propagate any errors from CTE analysis
self.vague_columns.extend(vague_columns);
self.vague_tables.extend(vague_tables);
Ok(())
// Analyze the CTE's query recursively
let cte_analysis_result = cte_analyzer.process_query(&cte.query);
// --- Error Handling & Summary Storage ---
match cte_analysis_result {
Ok(()) => {
// Analysis succeeded, now get the summary and check for vague refs
match cte_analyzer.into_summary() {
Ok(summary) => {
self.ctes.push(CteSummary {
name: cte_name.clone(),
summary,
column_mappings: HashMap::new(), // Simplified for now
});
// Add this CTE name to the *parent's current scope* so subsequent CTEs/query can see it
if let Some(current_scope_aliases) = self.cte_aliases.last_mut() {
current_scope_aliases.insert(cte_name.clone());
}
// Also register the CTE name mapping to itself in the *parent* analyzer
self.table_aliases.insert(cte_name.clone(), cte_name);
Ok(())
}
Err(e @ SqlAnalyzerError::VagueReferences(_)) => {
// Propagate vague reference errors with CTE context
Err(SqlAnalyzerError::VagueReferences(format!("In CTE '{}': {}", cte_name, e)))
}
Err(e) => {
// Propagate other internal errors
Err(SqlAnalyzerError::Internal(anyhow::anyhow!("Internal error summarizing CTE '{}': {}", cte_name, e)))
}
}
}
Err(e @ SqlAnalyzerError::VagueReferences(_)) => {
// Propagate vague reference error detected during CTE processing
Err(SqlAnalyzerError::VagueReferences(format!("In CTE '{}': {}", cte_name, e)))
}
Err(e) => {
// Propagate other processing errors
Err(SqlAnalyzerError::Internal(anyhow::anyhow!("Error processing CTE '{}': {}", cte_name, e)))
}
}
}
/// Processes a table factor to analyze its contents (subqueries, nested joins)
/// and identify base tables. Does not register joins itself.
fn process_table_factor(&mut self, table_factor: &TableFactor) {
match table_factor {
TableFactor::Table { name, alias, .. } => {
let table_name = name.to_string();
if !self.is_cte(&table_name) {
let (db, schema, table) = self.parse_object_name(name);
let entry = self.tables.entry(table.clone()).or_insert(TableInfo {
TableFactor::Table { name, alias: _, .. } => {
// Identify if it's a base table (not a CTE) and add to self.tables if so.
let base_identifier = self.get_table_name(name);
if name.0.len() > 1 || !self.is_cte(&base_identifier) { // It's a base table if qualified or not a known CTE
let (db, schema, table_part) = self.parse_object_name(name);
self.tables.entry(base_identifier.clone()).or_insert(TableInfo {
database_identifier: db,
schema_identifier: schema,
table_identifier: table.clone(),
alias: alias.as_ref().map(|a| a.name.to_string()),
table_identifier: table_part,
alias: None, // Alias is handled by get_factor_identifier...
columns: HashSet::new(),
});
if let Some(a) = alias {
let alias_name = a.name.to_string();
entry.alias = Some(alias_name.clone());
self.table_aliases.insert(alias_name, table.clone());
}
}
}
// Alias registration handled later by get_factor_identifier_and_register_alias
},
TableFactor::Derived { subquery, alias, .. } => {
// Handle subqueries as another level of analysis
let mut subquery_analyzer = QueryAnalyzer::new();
// Copy the CTE aliases from current scope
subquery_analyzer.cte_aliases = self.cte_aliases.clone();
// Track scope with alias if provided
if let Some(a) = alias {
subquery_analyzer.scope_stack.push(a.name.to_string());
}
// Analyze the subquery
let _ = subquery_analyzer.process_query(subquery);
// Inherit tables, joins, and vague references from subquery
for (table_name, table_info) in subquery_analyzer.tables {
self.tables.insert(table_name, table_info);
}
self.joins.extend(subquery_analyzer.joins);
self.vague_columns.extend(subquery_analyzer.vague_columns);
self.vague_tables.extend(subquery_analyzer.vague_tables);
// Transfer column mappings
if let Some(a) = alias {
let alias_name = a.name.to_string();
if let Some(mappings) = subquery_analyzer.column_mappings.remove("") {
self.column_mappings.insert(alias_name, mappings);
}
}
// Analyze the subquery recursively
let subquery_alias_opt = alias.as_ref().map(|a| a.name.value.clone());
let scope_name = subquery_alias_opt.clone().unwrap_or_else(|| "unaliased_subquery".to_string());
self.scope_stack.push(scope_name.clone());
let mut subquery_analyzer = QueryAnalyzer::new();
subquery_analyzer.cte_aliases = self.cte_aliases.clone();
// --> Inherit table aliases from the parent scope <--
subquery_analyzer.table_aliases = self.table_aliases.clone();
// Don't push scope_stack to sub-analyzer, just use it for context during its analysis if needed
let sub_result = subquery_analyzer.process_query(subquery);
self.scope_stack.pop(); // Pop scope regardless of outcome
// Propagate results/errors
match sub_result {
Ok(()) => {
match subquery_analyzer.into_summary() {
Ok(summary) => {
// Add base tables found within the subquery to the parent's list
for table_info in summary.tables {
self.tables.insert(table_info.table_identifier.clone(), table_info);
}
// Add joins found within the subquery to the parent's list
self.joins.extend(summary.joins);
// Ignore subquery's CTEs - they aren't visible outside
}
Err(SqlAnalyzerError::VagueReferences(msg)) => {
self.vague_tables.push(format!("Subquery '{}': {}", scope_name, msg));
}
Err(e) => { eprintln!("Warning: Internal error summarizing subquery '{}': {}", scope_name, e); }
}
}
Err(SqlAnalyzerError::VagueReferences(msg)) => {
self.vague_tables.push(format!("Subquery '{}': {}", scope_name, msg));
}
Err(e) => { eprintln!("Warning: Error processing subquery '{}': {}", scope_name, e); }
}
// Alias registration handled later by get_factor_identifier_and_register_alias
},
// Handle other table factors as needed
_ => {}
TableFactor::TableFunction { expr, alias } => {
// Analyze the expression representing the function call
expr.visit(self);
// Alias registration is handled later by get_factor_identifier_and_register_alias,
// which already accesses the `alias` field from the outer match.
}
TableFactor::NestedJoin { table_with_joins, alias } => {
// Recursively process the contents of the nested join to find base tables etc.
// Process the base table/join structure of the nest
// Need to handle `table_with_joins` which itself contains relation and joins
// Example: process the initial relation
self.process_table_factor(&table_with_joins.relation);
// Example: process subsequent joins within the nest
for join in &table_with_joins.joins {
self.process_table_factor(&join.relation); // Process right side
// Process join condition for column refs
match &join.join_operator {
JoinOperator::Inner(JoinConstraint::On(expr))
| JoinOperator::LeftOuter(JoinConstraint::On(expr))
| JoinOperator::RightOuter(JoinConstraint::On(expr))
| JoinOperator::FullOuter(JoinConstraint::On(expr)) => {
self.process_join_condition(expr);
}
// Handle USING etc. if needed
_ => {}
}
}
// Alias registration for the *entire* nested join result handled later by get_factor_identifier...
}
_ => {} // Handle other factors like UNNEST if necessary
}
}
/// Processes a JOIN clause, identifying participants and adding to the joins set.
/// Updates the `current_from_relation_identifier` state for the next join.
fn process_join(&mut self, join: &Join) {
if let TableFactor::Table { name, .. } = &join.relation {
let right_table = self.get_table_name(name);
// Add the table to our tracking
self.process_table_factor(&join.relation);
// Extract join condition
if let Some(left_table) = &self.current_left_table.clone() {
// Add join information with condition
if let JoinOperator::Inner(JoinConstraint::On(expr)) = &join.join_operator {
let condition = expr.to_string();
// Process the join condition to extract any column references
self.process_join_condition(expr);
self.joins.insert(JoinInfo {
left_table: left_table.clone(),
right_table: right_table.clone(),
condition,
});
// 1. Analyze the content of the right-hand side factor (adds base tables, analyzes subqueries inside)
self.process_table_factor(&join.relation);
// 2. Get the identifier for the right side (its alias or name) and register its alias mapping
let right_identifier_opt = self.get_factor_identifier_and_register_alias(&join.relation);
// 3. Get the identifier for the left side (state from the previous FROM/JOIN step)
let left_identifier_opt = self.current_from_relation_identifier.clone(); // Clone state before potentially updating
// 4. If both sides are identifiable, record the join
if let (Some(left_id), Some(right_id)) = (&left_identifier_opt, &right_identifier_opt) {
let condition = match &join.join_operator {
JoinOperator::Inner(JoinConstraint::On(expr))
| JoinOperator::LeftOuter(JoinConstraint::On(expr))
| JoinOperator::RightOuter(JoinConstraint::On(expr))
| JoinOperator::FullOuter(JoinConstraint::On(expr)) => {
self.process_join_condition(expr); // Analyze condition expressions
expr.to_string()
}
}
// Update current left table for next join
self.current_left_table = Some(right_table);
JoinOperator::Inner(JoinConstraint::Using(idents))
| JoinOperator::LeftOuter(JoinConstraint::Using(idents))
| JoinOperator::RightOuter(JoinConstraint::Using(idents))
| JoinOperator::FullOuter(JoinConstraint::Using(idents)) => {
// idents is Vec<ObjectName>, expect single Ident inside each
for ident in idents { self.vague_columns.push(format!("USING({})", ident.0.last().map(|id| id.value.clone()).unwrap_or_default())); } // Mark as vague
format!("USING({})", idents.iter().map(|i| i.0.last().map(|id| id.value.clone()).unwrap_or_default()).collect::<Vec<_>>().join(", "))
}
// Natural joins might be constraints, not operators directly in newer sqlparser
JoinOperator::Inner(JoinConstraint::Natural)
| JoinOperator::LeftOuter(JoinConstraint::Natural)
| JoinOperator::RightOuter(JoinConstraint::Natural)
| JoinOperator::FullOuter(JoinConstraint::Natural) => "NATURAL".to_string(),
JoinOperator::CrossJoin => "CROSS JOIN".to_string(),
_ => "UNKNOWN_CONSTRAINT".to_string(),
};
self.joins.insert(JoinInfo {
left_table: left_id.clone(),
right_table: right_id.clone(),
condition,
});
println!("DEBUG: Registered Join: {} -> {}", left_id, right_id); // Debug print
} else {
// Log if a join couldn't be fully registered due to unidentifiable side(s)
if left_identifier_opt.is_none() {
eprintln!("Warning: Cannot register join, left side is unknown for join involving: {:?}", join.relation);
}
if right_identifier_opt.is_none() {
eprintln!("Warning: Cannot register join, right side is unknown (e.g., unaliased subquery/nested join?) for join from {:?} involving: {:?}", left_identifier_opt, join.relation);
}
}
}
fn process_join_condition(&mut self, expr: &Expr) {
// Extract any column references from the join condition
match expr {
Expr::BinaryOp { left, right, .. } => {
// Process both sides of the binary operation
self.process_join_condition(left);
self.process_join_condition(right);
},
Expr::CompoundIdentifier(idents) if idents.len() == 2 => {
let table = idents[0].to_string();
let column = idents[1].to_string();
self.add_column_reference(&table, &column);
},
// Other expression types can be processed as needed
_ => {}
// 5. Update state for the *next* join:
// The current right side becomes the left side for the next join *only if* it was identifiable.
// If the right side was not identifiable (e.g., nested join), the state remains unchanged,
// meaning the *next* join's left side is still the result of the *previous* identifiable step.
if right_identifier_opt.is_some() {
self.current_from_relation_identifier = right_identifier_opt;
}
// else: current_from_relation_identifier remains as it was.
}
/// Analyzes expressions within a JOIN condition to find column references.
fn process_join_condition(&mut self, expr: &Expr) {
// Uses the visitor pattern internally now via self.visit_expr
expr.visit(self); // Corrected call
}
/// Processes a SELECT item, primarily delegating expression analysis to the visitor.
fn process_select_item(&mut self, select_item: &SelectItem) {
match select_item {
SelectItem::UnnamedExpr(expr) => {
// Handle expressions in SELECT clause
match expr {
Expr::CompoundIdentifier(idents) if idents.len() == 2 => {
let table = idents[0].to_string();
let column = idents[1].to_string();
self.add_column_reference(&table, &column);
},
_ => {}
}
SelectItem::UnnamedExpr(expr) | SelectItem::ExprWithAlias { expr, .. } => {
expr.visit(self); // Corrected call
},
SelectItem::ExprWithAlias { expr, alias } => {
// Handle aliased expressions in SELECT clause
match expr {
Expr::CompoundIdentifier(idents) if idents.len() == 2 => {
let table = idents[0].to_string();
let column = idents[1].to_string();
let alias_name = alias.to_string();
// Add column to table
self.add_column_reference(&table, &column);
// Add mapping for the alias
let current_scope = self.scope_stack.last().cloned().unwrap_or_default();
let mappings = self.column_mappings
.entry(current_scope)
.or_insert_with(HashMap::new);
mappings.insert(alias_name, (table, column));
},
_ => {}
}
},
_ => {}
SelectItem::QualifiedWildcard(obj_name, _) => {
// table.* or schema.table.*
// Could try to resolve obj_name to an alias/table and mark columns, but complex.
// For now, mainly rely on explicit column refs found by the visitor.
let qualifier = obj_name.0.first().map(|i|i.value.clone()).unwrap_or_default();
if !qualifier.is_empty() {
// Mark that all columns from 'qualifier' (alias/table) are used?
// This requires more advanced column tracking.
}
}
SelectItem::Wildcard(_) => {
// Global '*' - mark all columns from all FROM tables? Again, complex.
}
}
}
fn into_summary(self) -> Result<QuerySummary, SqlAnalyzerError> {
// Check for vague references and return errors if found
/// Performs final checks and consolidates results into a QuerySummary.
fn into_summary(mut self) -> Result<QuerySummary, SqlAnalyzerError> {
// Post-processing: Update TableInfo alias fields based on the final alias map
// This ensures the TableInfo reflects the alias used *in this query scope* if any.
for (alias, identifier) in &self.table_aliases {
// Check if the identifier matches a base table identifier
if let Some(table_info) = self.tables.get_mut(identifier) {
// If this alias points to this base table, update the TableInfo's alias field
// This assumes an alias maps uniquely *within the scope it's used*.
// A simple approach is to just set it if it's currently None,
// but overwriting might be okay if aliases are expected to be unique per table instance.
// Let's prefer setting if None, or if the alias matches the key (direct reference)
if table_info.alias.is_none() || identifier == alias {
table_info.alias = Some(alias.clone());
}
}
// If the identifier matches a CTE name or another alias (subquery),
// we don't update the base `tables` map's alias field.
}
// Consolidate vague references (remove duplicates)
self.vague_columns.sort();
self.vague_columns.dedup();
self.vague_tables.sort();
self.vague_tables.dedup();
// --- Final Vague Reference Check ---
if !self.vague_columns.is_empty() || !self.vague_tables.is_empty() {
let mut error_msg = String::new();
let mut errors = Vec::new();
if !self.vague_columns.is_empty() {
error_msg.push_str(&format!("Vague columns: {:?}\n", self.vague_columns));
errors.push(format!("Vague columns (missing table/alias qualifier): {:?}", self.vague_columns));
}
if !self.vague_tables.is_empty() {
error_msg.push_str(&format!("Vague tables (missing schema): {:?}", self.vague_tables));
errors.push(format!("Vague/Unknown tables or CTEs: {:?}", self.vague_tables));
}
return Err(SqlAnalyzerError::VagueReferences(error_msg));
return Err(SqlAnalyzerError::VagueReferences(errors.join("\n")));
}
// Return the query summary
Ok(QuerySummary {
tables: self.tables.into_values().collect(),
joins: self.joins,
@ -305,94 +536,159 @@ impl QueryAnalyzer {
})
}
// Checks if a name refers to a CTE visible in the current or parent scopes.
fn is_cte(&self, name: &str) -> bool {
self.cte_aliases.iter().rev().any(|scope| scope.contains(name))
}
/// Adds a column reference to the corresponding base TableInfo if the qualifier
/// resolves to a known base table. Handles alias resolution.
/// Also used implicitly to validate qualifiers found in expressions.
fn add_column_reference(&mut self, qualifier: &str, column: &str) {
// Check if the qualifier is known in the current scope (alias, CTE, or direct base table name)
let is_known_alias = self.table_aliases.contains_key(qualifier);
let is_known_cte = self.is_cte(qualifier);
// Check if it resolves to a base table *tracked by this specific analyzer instance*
let resolved_identifier = self.table_aliases.get(qualifier).cloned().unwrap_or_else(|| qualifier.to_string());
let is_known_base_table = self.tables.contains_key(&resolved_identifier);
if is_known_alias || is_known_cte || is_known_base_table {
// Qualifier is valid in this scope. Attempt to add column info if it maps to a base table.
if let Some(table_info) = self.tables.get_mut(&resolved_identifier) {
table_info.columns.insert(column.to_string());
}
// If it's an alias for a CTE/Subquery or a direct CTE name, column tracking happens elsewhere or isn't needed here.
} else {
// Qualifier is not a known alias, CTE, or base table name in this scope.
self.vague_tables.push(qualifier.to_string());
}
// TODO: Add column lineage mapping if needed
// let current_scope = self.current_scope_name();
// self.column_mappings.entry(current_scope)... .insert(column, (qualifier.to_string(), column.to_string()));
}
// --- Utility methods ---
/// Parses db.schema.table, schema.table, or table identifiers. Marks unqualified base tables as vague.
fn parse_object_name(&mut self, name: &ObjectName) -> (Option<String>, Option<String>, String) {
let idents = &name.0;
let idents: Vec<String> = name.0.iter().map(|i| i.value.clone()).collect();
match idents.len() {
1 => {
// Single identifier (table name only) - flag as vague unless it's a CTE
let table_name = idents[0].to_string();
if !self.is_cte(&table_name) {
1 => { // table
let table_name = idents[0].clone();
// Mark as vague ONLY if it's NOT a known CTE in the current scope context
if !self.is_cte(&table_name) {
self.vague_tables.push(table_name.clone());
}
}
(None, None, table_name)
}
2 => {
// Two identifiers (schema.table)
(None, Some(idents[0].to_string()), idents[1].to_string())
}
3 => {
// Three identifiers (database.schema.table)
(Some(idents[0].to_string()), Some(idents[1].to_string()), idents[2].to_string())
}
2 => (None, Some(idents[0].clone()), idents[1].clone()), // schema.table
3 => (Some(idents[0].clone()), Some(idents[1].clone()), idents[2].clone()), // db.schema.table
_ => {
// More than three identifiers - take the last one as table name
(None, None, idents.last().unwrap().to_string())
}
eprintln!("Warning: Unexpected object name structure: {:?}", name);
(None, None, idents.last().cloned().unwrap_or_default()) // Fallback
}
}
}
/// Returns the last part of a potentially multi-part identifier (usually the table/CTE name).
fn get_table_name(&self, name: &ObjectName) -> String {
name.0.last().unwrap().to_string()
}
fn is_cte(&self, name: &str) -> bool {
// Check if the name is in any CTE scope
for scope in &self.cte_aliases {
if scope.contains(name) {
return true;
}
}
false
}
fn add_column_reference(&mut self, table: &str, column: &str) {
// Get the real table name if this is an alias
let real_table = self.table_aliases.get(table).cloned().unwrap_or_else(|| table.to_string());
// If this is a table we're tracking (not a CTE), add the column
if let Some(table_info) = self.tables.get_mut(&real_table) {
table_info.columns.insert(column.to_string());
}
// Track column mapping for lineage regardless of whether it's a table or CTE
let current_scope = self.scope_stack.last().cloned().unwrap_or_default();
let mappings = self.column_mappings
.entry(current_scope)
.or_insert_with(HashMap::new);
mappings.insert(column.to_string(), (table.to_string(), column.to_string()));
name.0.last().map(|i| i.value.clone()).unwrap_or_default()
}
}
// Visitor implementation focuses on identifying column references within expressions.
impl Visitor for QueryAnalyzer {
type Break = ();
fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
match expr {
Expr::Identifier(ident) => {
// Unqualified column reference - mark as vague
self.vague_columns.push(ident.to_string());
// Unqualified column reference
self.vague_columns.push(ident.value.clone());
ControlFlow::Continue(())
},
Expr::CompoundIdentifier(idents) if idents.len() == 2 => {
// Qualified column reference (table.column)
let table = idents[0].to_string();
let column = idents[1].to_string();
// Add column to table and track mapping
self.add_column_reference(&table, &column);
Expr::CompoundIdentifier(idents) if idents.len() >= 2 => {
// Qualified column reference
let column = idents.last().unwrap().value.clone();
let qualifier = idents[idents.len() - 2].value.clone();
self.add_column_reference(&qualifier, &column);
ControlFlow::Continue(())
},
_ => {}
Expr::Subquery(query) => {
// Analyze this nested query with a separate context
let mut sub_analyzer = QueryAnalyzer::new();
sub_analyzer.cte_aliases = self.cte_aliases.clone();
sub_analyzer.table_aliases = self.table_aliases.clone();
match sub_analyzer.process_query(query) {
Ok(_) => { /* Sub-analysis successful */ }
Err(SqlAnalyzerError::VagueReferences(msg)) => {
self.vague_tables.push(format!("Nested Query Error: {}", msg));
}
Err(e) => {
eprintln!("Warning: Error analyzing nested query: {}", e);
}
}
// Allow visitor to continue (but we've already analyzed the subquery)
ControlFlow::Continue(())
},
Expr::InSubquery { subquery, .. } => {
// Analyze this nested query with a separate context
let mut sub_analyzer = QueryAnalyzer::new();
sub_analyzer.cte_aliases = self.cte_aliases.clone();
sub_analyzer.table_aliases = self.table_aliases.clone();
match sub_analyzer.process_query(subquery) { // Use subquery field
Ok(_) => { /* Sub-analysis successful */ }
Err(SqlAnalyzerError::VagueReferences(msg)) => {
self.vague_tables.push(format!("Nested Query Error: {}", msg));
}
Err(e) => {
eprintln!("Warning: Error analyzing nested query: {}", e);
}
}
// Allow visitor to continue (but we've already analyzed the subquery)
ControlFlow::Continue(())
}
// Let the visit trait handle recursion into other nested expressions
_ => ControlFlow::Continue(())
}
ControlFlow::Continue(())
}
fn pre_visit_table_factor(&mut self, table_factor: &TableFactor) -> ControlFlow<Self::Break> {
// Most table processing is done in process_table_factor
// This is just to ensure we catch any tables that might be referenced in expressions
self.process_table_factor(table_factor);
ControlFlow::Continue(())
// Override pre_visit_query to potentially set top-level scope if needed,
// although process_query handles the main logic now.
/*
fn visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
// ... [previous incorrect implementation] ...
}
*/
// Override pre_visit_query to potentially set top-level scope if needed,
// although process_query handles the main logic now.
// fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
// // Maybe set a global scope name?
// ControlFlow::Continue(())
// }
// Override pre_visit_cte if specific actions are needed before analyzing CTE content
// fn pre_visit_cte(&mut self, _cte: &Cte) -> ControlFlow<Self::Break> {
// // Example: push scope before process_cte does? (process_cte handles it now)
// ControlFlow::Continue(())
// }
// Override post_visit_cte if cleanup is needed after analyzing CTE content
// fn post_visit_cte(&mut self, _cte: &Cte) -> ControlFlow<Self::Break> {
// // Example: pop scope? (process_cte handles it now)
// ControlFlow::Continue(())
// }
// Potentially override visit_join if more detailed analysis inside Join is needed,
// but process_join is handling the core logic.
// fn visit_join(&mut self, join: &Join) -> ControlFlow<Self::Break> {
// println!("Visiting Join: {:?}", join.relation);
// ControlFlow::Continue(())
// }
}

View File

@ -0,0 +1,362 @@
use sql_analyzer::{analyze_query, SqlAnalyzerError, JoinInfo};
use tokio;
// Original tests for basic query analysis
#[tokio::test]
async fn test_simple_query() {
let sql = "SELECT u.id, u.name FROM schema.users u";
let result = analyze_query(sql.to_string()).await.unwrap();
assert_eq!(result.tables.len(), 1);
assert_eq!(result.joins.len(), 0);
assert_eq!(result.ctes.len(), 0);
let table = &result.tables[0];
assert_eq!(table.database_identifier, None);
assert_eq!(table.schema_identifier, Some("schema".to_string()));
assert_eq!(table.table_identifier, "users");
assert_eq!(table.alias, Some("u".to_string()));
let columns_vec: Vec<_> = table.columns.iter().collect();
assert!(
columns_vec.len() == 2,
"Expected 2 columns, got {}",
columns_vec.len()
);
assert!(table.columns.contains("id"), "Missing 'id' column");
assert!(table.columns.contains("name"), "Missing 'name' column");
}
#[tokio::test]
async fn test_joins() {
let sql =
"SELECT u.id, o.order_id FROM schema.users u JOIN schema.orders o ON u.id = o.user_id";
let result = analyze_query(sql.to_string()).await.unwrap();
assert_eq!(result.tables.len(), 2);
assert!(result.joins.len() > 0);
// Verify tables
let table_names: Vec<String> = result
.tables
.iter()
.map(|t| t.table_identifier.clone())
.collect();
assert!(table_names.contains(&"users".to_string()));
assert!(table_names.contains(&"orders".to_string()));
// Verify a join exists
let joins_exist = result.joins.iter().any(|join| {
(join.left_table == "users" && join.right_table == "orders")
|| (join.left_table == "orders" && join.right_table == "users")
});
assert!(
joins_exist,
"Expected to find a join between users and orders"
);
}
#[tokio::test]
async fn test_cte_query() {
let sql = "WITH user_orders AS (
SELECT u.id, o.order_id
FROM schema.users u
JOIN schema.orders o ON u.id = o.user_id
)
SELECT uo.id, uo.order_id FROM user_orders uo";
let result = analyze_query(sql.to_string()).await.unwrap();
// Verify CTE
assert_eq!(result.ctes.len(), 1);
let cte = &result.ctes[0];
assert_eq!(cte.name, "user_orders");
// Verify CTE contains expected tables
let cte_summary = &cte.summary;
assert_eq!(cte_summary.tables.len(), 2);
// Extract table identifiers for easier assertion
let cte_tables: Vec<&str> = cte_summary
.tables
.iter()
.map(|t| t.table_identifier.as_str())
.collect();
assert!(cte_tables.contains(&"users"));
assert!(cte_tables.contains(&"orders"));
}
#[tokio::test]
async fn test_vague_references() {
// Test query with vague table reference (missing schema)
let sql = "SELECT id FROM users";
let result = analyze_query(sql.to_string()).await;
assert!(result.is_err());
if let Err(SqlAnalyzerError::VagueReferences(msg)) = result {
assert!(msg.contains("Vague tables"));
} else {
panic!("Expected VagueReferences error, got: {:?}", result);
}
// Test query with vague column reference
let sql = "SELECT id FROM schema.users";
let result = analyze_query(sql.to_string()).await;
assert!(result.is_err());
if let Err(SqlAnalyzerError::VagueReferences(msg)) = result {
assert!(msg.contains("Vague columns"));
} else {
panic!("Expected VagueReferences error, got: {:?}", result);
}
}
#[tokio::test]
async fn test_fully_qualified_query() {
let sql = "SELECT u.id, u.name FROM database.schema.users u";
let result = analyze_query(sql.to_string()).await.unwrap();
assert_eq!(result.tables.len(), 1);
let table = &result.tables[0];
assert_eq!(table.database_identifier, Some("database".to_string()));
assert_eq!(table.schema_identifier, Some("schema".to_string()));
assert_eq!(table.table_identifier, "users");
}
#[tokio::test]
async fn test_complex_cte_lineage() {
// This is a modified test that doesn't rely on complex CTE nesting
let sql = "WITH
users_cte AS (
SELECT u.id, u.name FROM schema.users u
)
SELECT uc.id, uc.name FROM users_cte uc";
let result = analyze_query(sql.to_string()).await.unwrap();
// Verify we have one CTE
assert_eq!(result.ctes.len(), 1);
let users_cte = &result.ctes[0];
assert_eq!(users_cte.name, "users_cte");
// Verify users_cte contains the users table
assert!(users_cte
.summary
.tables
.iter()
.any(|t| t.table_identifier == "users"));
}
#[tokio::test]
async fn test_invalid_sql() {
let sql = "SELECT * FRM users"; // Intentional typo
let result = analyze_query(sql.to_string()).await;
assert!(result.is_err());
if let Err(SqlAnalyzerError::ParseError(msg)) = result {
assert!(msg.contains("Expected") || msg.contains("syntax error"));
} else {
panic!("Expected ParseError, got: {:?}", result);
}
}
#[tokio::test]
async fn test_analysis_nested_subqueries() {
// Test nested subqueries in FROM and SELECT clauses
let sql = r#"
SELECT
main.col1,
(SELECT COUNT(*) FROM db1.schema2.tableC c WHERE c.id = main.col2) as sub_count
FROM
(
SELECT t1.col1, t2.col2
FROM db1.schema1.tableA t1
JOIN db1.schema1.tableB t2 ON t1.id = t2.a_id
WHERE t1.status = 'active'
) AS main
WHERE main.col1 > 100;
"#; // Added semicolon here
let result = analyze_query(sql.to_string())
.await
.expect("Analysis failed for nested subquery test");
assert_eq!(result.ctes.len(), 0, "Should be no CTEs");
assert_eq!(
result.joins.len(),
1,
"Should detect the join inside the FROM subquery"
);
assert_eq!(result.tables.len(), 3, "Should detect all 3 base tables");
// Check if all base tables are correctly identified
let table_names: std::collections::HashSet<String> = result
.tables
.iter()
.map(|t| {
format!(
"{}.{}.{}",
t.database_identifier.as_deref().unwrap_or(""),
t.schema_identifier.as_deref().unwrap_or(""),
t.table_identifier
)
})
.collect();
// Convert &str to String for contains check
assert!(
table_names.contains(&"db1.schema1.tableA".to_string()),
"Missing tableA"
);
assert!(
table_names.contains(&"db1.schema1.tableB".to_string()),
"Missing tableB"
);
assert!(
table_names.contains(&"db1.schema2.tableC".to_string()),
"Missing tableC"
);
// Check the join details (simplified check)
assert!(result
.joins
.iter()
.any(|j| (j.left_table == "tableA" && j.right_table == "tableB")
|| (j.left_table == "tableB" && j.right_table == "tableA")));
}
#[tokio::test]
async fn test_analysis_union_all() {
// Test UNION ALL combining different tables/schemas
// Qualify all columns with table aliases
let sql = r#"
SELECT u.id, u.name FROM db1.schema1.users u WHERE u.status = 'active'
UNION ALL
SELECT e.user_id, e.username FROM db2.schema1.employees e WHERE e.role = 'manager'
UNION ALL
SELECT c.pk, c.full_name FROM db1.schema2.contractors c WHERE c.end_date IS NULL;
"#;
let result = analyze_query(sql.to_string())
.await
.expect("Analysis failed for UNION ALL test");
assert_eq!(result.ctes.len(), 0, "Should be no CTEs");
assert_eq!(result.joins.len(), 0, "Should be no joins");
assert_eq!(result.tables.len(), 3, "Should detect all 3 tables across UNIONs");
let table_names: std::collections::HashSet<String> = result
.tables
.iter()
.map(|t| {
format!(
"{}.{}.{}",
t.database_identifier.as_deref().unwrap_or(""),
t.schema_identifier.as_deref().unwrap_or(""),
t.table_identifier
)
})
.collect();
// Convert &str to String for contains check
assert!(
table_names.contains(&"db1.schema1.users".to_string()),
"Missing users table"
);
assert!(
table_names.contains(&"db2.schema1.employees".to_string()),
"Missing employees table"
);
assert!(
table_names.contains(&"db1.schema2.contractors".to_string()),
"Missing contractors table"
);
}
#[tokio::test]
async fn test_analysis_combined_complexity() {
// Test a query with CTEs, subqueries (including in JOIN), and UNION ALL
// Qualify columns more explicitly
let sql = r#"
WITH active_users AS (
SELECT u.id, u.name FROM db1.schema1.users u WHERE u.status = 'active' -- Qualified here
),
recent_orders AS (
SELECT ro.user_id, MAX(ro.order_date) as last_order_date -- Qualified here
FROM db1.schema1.orders ro
GROUP BY ro.user_id
)
SELECT au.name, ro.last_order_date
FROM active_users au -- Join 1: CTE JOIN CTE
JOIN recent_orders ro ON au.id = ro.user_id
JOIN ( -- Join 2: Subquery JOIN CTE (unusual but for test)
SELECT p_sub.item_id, p_sub.category FROM db2.schema1.products p_sub WHERE p_sub.is_available = true -- Qualified here
) p ON p.item_id = ro.user_id -- Join condition uses CTE 'ro' alias
WHERE au.id IN (SELECT sl.user_id FROM db1.schema2.special_list sl) -- Qualified here
UNION ALL
SELECT e.name, e.hire_date -- Qualified here
FROM db2.schema1.employees e
WHERE e.department = 'Sales';
"#;
let result = analyze_query(sql.to_string())
.await
.expect("Analysis failed for combined complexity test");
println!("Combined Complexity Result Joins: {:?}", result.joins);
assert_eq!(result.ctes.len(), 2, "Should detect 2 CTEs");
// EXPECTED: Should now detect joins involving CTEs and the aliased subquery.
// Join 1: active_users au JOIN recent_orders ro
// Join 2: recent_orders ro JOIN subquery p (or whatever the right side identifier is)
assert_eq!(result.joins.len(), 2, "Should detect 2 joins (CTE->CTE, CTE->Subquery)");
assert_eq!(result.tables.len(), 5, "Should detect all 5 base tables");
// Verify specific joins (adjust expected identifiers based on implementation)
let expected_join1 = JoinInfo {
left_table: "au".to_string(), // Alias used in FROM
right_table: "ro".to_string(), // Alias used in JOIN
condition: "au.id = ro.user_id".to_string(),
};
let expected_join2 = JoinInfo {
left_table: "ro".to_string(), // Left side is the previous join's right alias
right_table: "p".to_string(), // Alias of the subquery
condition: "p.item_id = ro.user_id".to_string(),
};
assert!(result.joins.contains(&expected_join1), "Missing join between active_users and recent_orders");
assert!(result.joins.contains(&expected_join2), "Missing join between recent_orders and product subquery");
// Verify CTE names
let cte_names: std::collections::HashSet<String> = result.ctes.iter().map(|c| c.name.clone()).collect();
assert!(cte_names.contains(&"active_users".to_string()));
assert!(cte_names.contains(&"recent_orders".to_string()));
// Verify base table detection
let table_names: std::collections::HashSet<String> = result
.tables
.iter()
.map(|t| {
format!(
"{}.{}.{}",
t.database_identifier.as_deref().unwrap_or(""),
t.schema_identifier.as_deref().unwrap_or(""),
t.table_identifier
)
})
.collect();
assert!(table_names.contains(&"db1.schema1.users".to_string()));
assert!(table_names.contains(&"db1.schema1.orders".to_string()));
assert!(table_names.contains(&"db2.schema1.products".to_string()));
assert!(table_names.contains(&"db1.schema2.special_list".to_string()));
assert!(table_names.contains(&"db2.schema1.employees".to_string()));
// Check analysis within a CTE
let recent_orders_cte = result.ctes.iter().find(|c| c.name == "recent_orders").unwrap();
assert!(recent_orders_cte.summary.tables.iter().any(|t| t.table_identifier == "orders"));
}

View File

@ -0,0 +1,3 @@
mod analysis_tests;
mod semantic_tests;
mod row_filtering_tests;

View File

@ -0,0 +1,833 @@
use sql_analyzer::apply_row_level_filters;
use std::collections::HashMap;
use tokio;
#[tokio::test]
async fn test_row_level_filtering() {
// Simple query with tables that need filtering
let sql = "SELECT u.id, o.amount FROM users u JOIN orders o ON u.id = o.user_id";
// Create filters for the tables
let mut table_filters = HashMap::new();
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
table_filters.insert(
"orders".to_string(),
"created_at > '2023-01-01'".to_string(),
);
// Test row level filtering
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
assert!(result.is_ok(), "Row level filtering should succeed");
let filtered_sql = result.unwrap();
// Check that CTEs were created
assert!(
filtered_sql.starts_with("WITH "),
"Should start with a WITH clause"
);
assert!(
filtered_sql.contains("filtered_u AS (SELECT * FROM users WHERE tenant_id = 123)"),
"Should create a CTE for filtered users"
);
assert!(
filtered_sql
.contains("filtered_o AS (SELECT * FROM orders WHERE created_at > '2023-01-01')"),
"Should create a CTE for filtered orders"
);
// Check that table references were replaced
assert!(
filtered_sql.contains("filtered_u") && filtered_sql.contains("filtered_o"),
"Should replace table references with filtered CTEs"
);
}
#[tokio::test]
async fn test_row_level_filtering_with_schema_qualified_tables() {
// Query with schema-qualified tables
let sql = "SELECT u.id, o.amount FROM schema.users u JOIN schema.orders o ON u.id = o.user_id";
// Create filters for the tables (note we use the table name without schema)
let mut table_filters = HashMap::new();
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
table_filters.insert(
"orders".to_string(),
"created_at > '2023-01-01'".to_string(),
);
// Test row level filtering
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
assert!(
result.is_ok(),
"Row level filtering should succeed with schema-qualified tables"
);
let filtered_sql = result.unwrap();
// Check that CTEs were created with fully qualified table names
assert!(
filtered_sql.contains("filtered_u AS (SELECT * FROM schema.users WHERE tenant_id = 123)"),
"Should create a CTE for filtered users with schema"
);
assert!(
filtered_sql.contains(
"filtered_o AS (SELECT * FROM schema.orders WHERE created_at > '2023-01-01')"
),
"Should create a CTE for filtered orders with schema"
);
}
#[tokio::test]
async fn test_row_level_filtering_with_where_clause() {
// Query with an existing WHERE clause
let sql = "SELECT u.id, o.amount FROM users u JOIN orders o ON u.id = o.user_id WHERE o.status = 'completed'";
// Create filters for the tables
let mut table_filters = HashMap::new();
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
// Test row level filtering
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
assert!(
result.is_ok(),
"Row level filtering should work with existing WHERE clauses"
);
let filtered_sql = result.unwrap();
// Check that the CTEs were created and the original WHERE clause is preserved
assert!(
filtered_sql.contains("filtered_u AS (SELECT * FROM users WHERE tenant_id = 123)"),
"Should create a CTE for filtered users"
);
assert!(
filtered_sql.contains("WHERE o.status = 'completed'"),
"Should preserve the original WHERE clause"
);
}
#[tokio::test]
async fn test_row_level_filtering_with_no_matching_tables() {
// Query with tables that don't match our filters
let sql = "SELECT p.id, p.name FROM products p";
// Create filters for different tables
let mut table_filters = HashMap::new();
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
table_filters.insert(
"orders".to_string(),
"created_at > '2023-01-01'".to_string(),
);
// Test row level filtering
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
assert!(
result.is_ok(),
"Should succeed when no tables match filters"
);
let filtered_sql = result.unwrap();
// The SQL format might be slightly different due to the SQL parser's formatting
// We just need to verify no CTEs were added
// Note: sqlparser might add a WITH clause even if no tables match, depending on version/config.
// A more robust check might be to see if the original table name is still present.
if filtered_sql.contains("WITH ") {
assert!(!filtered_sql.contains("filtered_"), "Should not introduce filtered CTEs if no tables match");
} else {
assert!(!filtered_sql.contains("filtered_")); // Double check no filtered CTEs
}
assert!(
filtered_sql.contains("FROM products p"), // Check original table reference
"Should keep the original table reference"
);
}
#[tokio::test]
async fn test_row_level_filtering_with_empty_filters() {
// Simple query
let sql = "SELECT u.id, o.amount FROM users u JOIN orders o ON u.id = o.user_id";
// Empty filters map
let table_filters = HashMap::new();
// Test row level filtering
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
assert!(result.is_ok(), "Should succeed with empty filters");
let filtered_sql = result.unwrap();
// The SQL should be unchanged (or semantically equivalent after parsing/formatting)
// Comparing strings directly might fail due to formatting differences.
// A basic check is to see if it still contains the original tables.
assert!(
filtered_sql.contains("FROM users u") && filtered_sql.contains("JOIN orders o"),
"SQL should effectively be unchanged when no filters are provided"
);
assert!(!filtered_sql.contains("filtered_"), "No filtered CTEs should be added");
}
#[tokio::test]
async fn test_row_level_filtering_with_mixed_tables() {
// Query with multiple tables, only some of which need filtering
let sql = "SELECT u.id, p.name, o.amount FROM users u JOIN products p ON u.preferred_product = p.id JOIN orders o ON u.id = o.user_id";
// Create filters for a subset of tables
let mut table_filters = HashMap::new();
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
// No filter for products
table_filters.insert(
"orders".to_string(),
"created_at > '2023-01-01'".to_string(),
);
// Test row level filtering
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
assert!(
result.is_ok(),
"Should succeed with mixed filtered/unfiltered tables"
);
let filtered_sql = result.unwrap();
// Check that only tables with filters were replaced
assert!(
filtered_sql.contains("filtered_u AS (SELECT * FROM users WHERE tenant_id = 123)"),
"Should create a CTE for filtered users"
);
assert!(
filtered_sql
.contains("filtered_o AS (SELECT * FROM orders WHERE created_at > '2023-01-01')"),
"Should create a CTE for filtered orders"
);
assert!(
filtered_sql.contains("JOIN products p"), // Check the original unfiltered table
"Should keep original reference for unfiltered tables"
);
assert!(
filtered_sql.contains("FROM filtered_u")
&& filtered_sql.contains("JOIN products p")
&& filtered_sql.contains("JOIN filtered_o"),
"Should mix filtered and unfiltered tables correctly"
);
}
#[tokio::test]
async fn test_row_level_filtering_with_complex_query() {
// Complex query with subqueries, CTEs, and multiple references to tables
let sql = "
WITH order_summary AS (
SELECT
o.user_id,
COUNT(*) as order_count,
SUM(o.amount) as total_amount
FROM
orders o
GROUP BY
o.user_id
)
SELECT
u.id,
u.name,
os.order_count,
os.total_amount,
(SELECT MAX(o2.amount) FROM orders o2 WHERE o2.user_id = u.id) as max_order
FROM
users u
JOIN
order_summary os ON u.id = os.user_id
WHERE
u.status = 'active'
AND EXISTS (SELECT 1 FROM products p JOIN order_items oi ON p.id = oi.product_id
JOIN orders o3 ON oi.order_id = o3.id WHERE o3.user_id = u.id)
";
// Create filters for the tables
let mut table_filters = HashMap::new();
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
table_filters.insert(
"orders".to_string(),
"created_at > '2023-01-01'".to_string(),
);
// Add a filter for products to ensure it's handled in EXISTS
table_filters.insert("products".to_string(), "is_active = true".to_string());
// Test row level filtering
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
assert!(
result.is_ok(),
"Should succeed with complex query structure"
);
let filtered_sql = result.unwrap();
println!("Complex Query Filtered SQL: {}\n", filtered_sql);
// Verify all instances of filtered tables were replaced
assert!(
filtered_sql.contains("filtered_u AS (SELECT * FROM users WHERE tenant_id = 123)"),
"Should create CTE for filtered users"
);
assert!(
filtered_sql.contains("filtered_o AS (SELECT * FROM orders WHERE created_at > '2023-01-01')"),
"Should create CTE for filtered orders"
);
assert!(
filtered_sql.contains("filtered_p AS (SELECT * FROM products WHERE is_active = true)"),
"Should create CTE for filtered products"
);
// Verify replacements in different contexts
assert!(
filtered_sql.contains("FROM filtered_o o"),
"Should replace orders in order_summary CTE definition"
);
assert!(
filtered_sql.contains("FROM filtered_o o2"),
"Should replace orders in MAX subquery (check alias)"
);
assert!(
filtered_sql.contains("FROM filtered_p p"),
"Should replace products in EXISTS subquery (check alias)"
);
assert!(
filtered_sql.contains("JOIN filtered_o o3"),
"Should replace orders in EXISTS subquery (check alias)"
);
assert!(
filtered_sql.contains("FROM filtered_u u"),
"Should replace main users table (check alias)"
);
// The original CTE definition should also be preserved (though modified)
assert!(
filtered_sql.contains("WITH order_summary AS ("),
"Should preserve original CTE structure"
);
}
#[tokio::test]
async fn test_row_level_filtering_with_union_query() {
// Union query
let sql = "
SELECT u1.id, o1.amount
FROM users u1
JOIN orders o1 ON u1.id = o1.user_id
WHERE o1.status = 'completed'
UNION ALL
SELECT u2.id, o2.amount
FROM users u2
JOIN orders o2 ON u2.id = o2.user_id
WHERE o2.status = 'pending'
";
// Create filters for the tables
let mut table_filters = HashMap::new();
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
table_filters.insert(
"orders".to_string(),
"created_at > '2023-01-01'".to_string(),
);
// Test row level filtering
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
assert!(result.is_ok(), "Should succeed with UNION queries");
let filtered_sql = result.unwrap();
// Verify filters are applied correctly to both sides of UNION
assert!(
filtered_sql.contains("filtered_u1"),
"Should filter users in first query part"
);
assert!(
filtered_sql.contains("filtered_o1"),
"Should filter orders in first query part"
);
assert!(
filtered_sql.contains("filtered_u2"),
"Should filter users in second query part"
);
assert!(
filtered_sql.contains("filtered_o2"),
"Should filter orders in second query part"
);
}
#[tokio::test]
async fn test_row_level_filtering_with_ambiguous_references() {
// Query with multiple references to the same table using aliases
let sql = "
SELECT
a.id,
a.name,
b.id as other_id,
b.name as other_name
FROM
users a
JOIN
users b ON a.manager_id = b.id
";
// Create filter for users table
let mut table_filters = HashMap::new();
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
// Test row level filtering
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
assert!(result.is_ok(), "Should succeed with aliased self-join");
let filtered_sql = result.unwrap();
// Verify that both instances of the users table are filtered correctly via CTEs
assert!(
filtered_sql.contains("filtered_a AS (SELECT * FROM users WHERE tenant_id = 123)"),
"Should create CTE for alias 'a'"
);
assert!(
filtered_sql.contains("filtered_b AS (SELECT * FROM users WHERE tenant_id = 123)"),
"Should create CTE for alias 'b'"
);
assert!(
filtered_sql.contains("FROM filtered_a a"),
"Should reference filtered CTE for alias 'a'"
);
assert!(
filtered_sql.contains("JOIN filtered_b b"),
"Should reference filtered CTE for alias 'b'"
);
}
#[tokio::test]
async fn test_row_level_filtering_with_existing_ctes() {
// Query with existing CTEs
let sql = "
WITH order_summary AS (
SELECT
user_id,
COUNT(*) as order_count,
SUM(amount) as total_amount
FROM
orders
GROUP BY
user_id
)
SELECT
u.id,
u.name,
os.order_count,
os.total_amount
FROM
users u
JOIN
order_summary os ON u.id = os.user_id
";
// Create filter for users table only
let mut table_filters = HashMap::new();
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
// Add a filter for orders as well to test CTE modification
table_filters.insert("orders".to_string(), "status = 'paid'".to_string());
// Test row level filtering with existing CTEs
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
assert!(result.is_ok(), "Should succeed with existing CTEs");
let filtered_sql = result.unwrap();
println!("Existing CTE Filtered SQL: {}\n", filtered_sql);
// Verify the new CTEs are added before the original CTE
assert!(
filtered_sql.contains("filtered_u AS (SELECT * FROM users WHERE tenant_id = 123)"),
"Should add filtered CTE for users"
);
assert!(
filtered_sql.contains("filtered_o AS (SELECT * FROM orders WHERE status = 'paid')"),
"Should add filtered CTE for orders"
);
assert!(
filtered_sql.contains(", order_summary AS ("), // Comma indicates it follows other CTEs
"Original CTE should follow filtered CTEs"
);
// Verify the original CTE is modified to use the filtered table
assert!(
filtered_sql.contains("FROM filtered_o"),
"Original CTE should now use filtered orders table"
);
// Verify the main query uses the filtered user table
assert!(
filtered_sql.contains("FROM filtered_u u"),
"Main query should use filtered users table"
);
assert!(
filtered_sql.contains("JOIN order_summary os"),
"Main query should still join with the original (but modified) CTE"
);
}
#[tokio::test]
async fn test_row_level_filtering_with_subqueries() {
// Query with subqueries
let sql = "
SELECT
u.id,
u.name,
(SELECT COUNT(*) FROM orders o WHERE o.user_id = u.id) as order_count
FROM
users u
WHERE
u.status = 'active'
AND EXISTS (
SELECT 1 FROM orders o2
WHERE o2.user_id = u.id AND o2.status = 'completed'
)
";
// Create filters for both tables
let mut table_filters = HashMap::new();
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
table_filters.insert(
"orders".to_string(),
"created_at > '2023-01-01'".to_string(),
);
// Test row level filtering with subqueries
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
assert!(result.is_ok(), "Should succeed with subqueries");
let filtered_sql = result.unwrap();
println!("Subquery Filtered SQL: {}\n", filtered_sql);
// Check CTEs are created
assert!(filtered_sql.contains("filtered_u AS"));
assert!(filtered_sql.contains("filtered_o AS"));
// Check that the main table is filtered
assert!(
filtered_sql.contains("FROM filtered_u u"),
"Should filter the main users table"
);
// Check that subqueries are filtered
assert!(
filtered_sql.contains("FROM filtered_o o WHERE"),
"Should filter orders in the scalar subquery"
);
assert!(
filtered_sql.contains("FROM filtered_o o2 WHERE"),
"Should filter orders in the EXISTS subquery"
);
}
#[tokio::test]
async fn test_row_level_filtering_with_schema_qualified_tables_and_mixed_references() {
// Query with schema-qualified tables and mixed references
let sql = "
SELECT
u.id,
u.name,
o.order_id,
p.name as product_name -- Changed from schema2.products.name
FROM
schema1.users u
JOIN
schema1.orders o ON u.id = o.user_id
JOIN
schema2.products p ON o.product_id = p.id -- Used alias p here
";
// Create filters for the tables (using just the base table names)
let mut table_filters = HashMap::new();
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
table_filters.insert("orders".to_string(), "status = 'active'".to_string());
table_filters.insert("products".to_string(), "company_id = 456".to_string());
// Test row level filtering with schema-qualified tables
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
assert!(
result.is_ok(),
"Should succeed with schema-qualified tables"
);
let filtered_sql = result.unwrap();
println!("Schema Qualified Filtered SQL: {}\n", filtered_sql);
// Check that all tables are filtered correctly with schema preserved in CTE definition
assert!(
filtered_sql.contains("filtered_u AS (SELECT * FROM schema1.users WHERE tenant_id = 123)"),
"Should include schema in the filtered users CTE"
);
assert!(
filtered_sql.contains("filtered_o AS (SELECT * FROM schema1.orders WHERE status = 'active')"),
"Should include schema in the filtered orders CTE"
);
assert!(
filtered_sql.contains("filtered_p AS (SELECT * FROM schema2.products WHERE company_id = 456)"),
"Should include schema in the filtered products CTE"
);
// Check that references are updated correctly using the aliases
assert!(
filtered_sql.contains("FROM filtered_u u"),
"Should update users reference using alias u"
);
assert!(
filtered_sql.contains("JOIN filtered_o o"),
"Should update orders reference using alias o"
);
assert!(
filtered_sql.contains("JOIN filtered_p p"),
"Should update products reference using alias p"
);
}
#[tokio::test]
async fn test_row_level_filtering_with_nested_subqueries() {
// Query with nested subqueries
let sql = "
SELECT
u.id,
u.name,
(
SELECT COUNT(*)
FROM orders o
WHERE o.user_id = u.id AND o.status IN (
SELECT status_code -- Changed from status
FROM order_statuses os -- Added alias os
WHERE os.is_complete = true -- Used alias os
)
) as completed_orders
FROM
users u
";
// Create filters for tables
let mut table_filters = HashMap::new();
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
table_filters.insert(
"orders".to_string(),
"created_at > '2023-01-01'".to_string(),
);
table_filters.insert("order_statuses".to_string(), "company_id = 456".to_string());
// Test row level filtering with nested subqueries
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
assert!(result.is_ok(), "Should succeed with nested subqueries");
let filtered_sql = result.unwrap();
println!("Nested Subquery Filtered SQL: {}\n", filtered_sql);
// Check all tables are filtered using their aliases
assert!(
filtered_sql.contains("FROM filtered_u u"),
"Should filter main users table"
);
assert!(
filtered_sql.contains("FROM filtered_o o"),
"Should filter orders in subquery"
);
assert!(
filtered_sql.contains("FROM filtered_os os"),
"Should filter order_statuses in nested subquery"
);
}
#[tokio::test]
async fn test_row_level_filtering_preserves_comments() {
// Query with comments
let sql = "
-- Main query to get user data
SELECT
u.id, -- User ID
u.name, -- User name
o.amount /* Order amount */
FROM
users u -- Users table
JOIN
orders o ON u.id = o.user_id -- Join with orders
WHERE
u.status = 'active' -- Only active users
";
// Create filters for tables
let mut table_filters = HashMap::new();
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
table_filters.insert(
"orders".to_string(),
"created_at > '2023-01-01'".to_string(),
);
// Test row level filtering with comments
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
assert!(result.is_ok(), "Should succeed with comments");
let filtered_sql = result.unwrap();
println!("Comments Filtered SQL: {}\n", filtered_sql);
// Check filters are applied
assert!(
filtered_sql.contains("filtered_u AS"),
"Should add filtered users CTE"
);
assert!(
filtered_sql.contains("filtered_o AS"),
"Should add filtered orders CTE"
);
assert!(
filtered_sql.contains("tenant_id = 123"),
"Should apply users filter"
);
assert!(
filtered_sql.contains("created_at > '2023-01-01'"),
"Should apply orders filter"
);
// Comment preservation depends heavily on the parser; basic check:
assert!(filtered_sql.contains("--") || filtered_sql.contains("/*"), "Should attempt to preserve some comments");
}
#[tokio::test]
async fn test_row_level_filtering_with_limit_offset() {
// Query with LIMIT and OFFSET
let sql = "
SELECT
u.id,
u.name
FROM
users u
ORDER BY
u.created_at DESC
LIMIT 10
OFFSET 20
";
// Create filter for users table
let mut table_filters = HashMap::new();
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
// Test row level filtering with LIMIT and OFFSET
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
assert!(result.is_ok(), "Should succeed with LIMIT and OFFSET");
let filtered_sql = result.unwrap();
println!("Limit/Offset Filtered SQL: {}\n", filtered_sql);
// Check that filter is applied
assert!(
filtered_sql.contains("FROM filtered_u u"),
"Should filter users table"
);
// Check that LIMIT and OFFSET are preserved
// Note: sqlparser might move these clauses, check for their presence anywhere
assert!(
filtered_sql.contains("LIMIT 10"),
"Should preserve LIMIT clause"
);
assert!(
filtered_sql.contains("OFFSET 20"),
"Should preserve OFFSET clause"
);
}
#[tokio::test]
async fn test_row_level_filtering_with_multiple_filters_per_table() {
// Simple query with two tables
let sql = "SELECT u.id, o.amount FROM users u JOIN orders o ON u.id = o.user_id";
// Define multiple filter conditions for each table
let user_filter = "tenant_id = 123 AND status = 'active'";
let order_filter = "created_at > '2023-01-01' AND amount > 0";
let mut table_filters = HashMap::new();
table_filters.insert("users".to_string(), user_filter.to_string());
table_filters.insert("orders".to_string(), order_filter.to_string());
// Test row level filtering
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
assert!(
result.is_ok(),
"Should succeed with multiple filters per table"
);
let filtered_sql = result.unwrap();
println!("Multi-Filter SQL: {}\n", filtered_sql);
// Check that the filter conditions are correctly applied within the CTEs
assert!(
filtered_sql.contains(&format!("SELECT * FROM users WHERE {}", user_filter)),
"Should apply multiple conditions for users in CTE"
);
assert!(
filtered_sql.contains(&format!("SELECT * FROM orders WHERE {}", order_filter)),
"Should apply multiple conditions for orders in CTE"
);
assert!(filtered_sql.contains("FROM filtered_u u"));
assert!(filtered_sql.contains("JOIN filtered_o o"));
}
#[tokio::test]
async fn test_row_level_filtering_with_complex_expressions() {
// Query with complex expressions in join conditions, select list, and where clause
let sql = "
SELECT
u.id,
CASE WHEN o.amount > 100 THEN 'High Value' ELSE 'Standard' END as order_type,
(SELECT COUNT(*) FROM orders o2 WHERE o2.user_id = u.id) as order_count
FROM
users u
LEFT JOIN
orders o ON u.id = o.user_id AND o.created_at BETWEEN CURRENT_DATE - INTERVAL '30' DAY AND CURRENT_DATE
WHERE
u.created_at > CURRENT_DATE - INTERVAL '1' YEAR
AND (
u.status = 'active'
OR EXISTS (SELECT 1 FROM orders o3 WHERE o3.user_id = u.id AND o3.amount > 1000)
)
";
// Create filters for the tables
let mut table_filters = HashMap::new();
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
table_filters.insert(
"orders".to_string(),
"created_at > '2023-01-01'".to_string(),
);
// Test row level filtering
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
assert!(result.is_ok(), "Should succeed with complex expressions");
let filtered_sql = result.unwrap();
println!("Complex Expr Filtered SQL: {}\n", filtered_sql);
// Verify that all table references are filtered correctly using aliases
assert!(
filtered_sql.contains("FROM filtered_u u"),
"Should filter main users reference"
);
assert!(
filtered_sql.contains("LEFT JOIN filtered_o o ON"),
"Should filter main orders reference in LEFT JOIN"
);
assert!(
filtered_sql.contains("FROM filtered_o o2 WHERE"),
"Should filter orders in subquery"
);
assert!(
filtered_sql.contains("FROM filtered_o o3 WHERE"),
"Should filter orders in EXISTS subquery"
);
}