mirror of https://github.com/buster-so/buster.git
super close, one complex use case that needs to be captured.
This commit is contained in:
parent
fda1b5d8be
commit
c43354bf75
|
@ -1,5 +1,6 @@
|
|||
use std::collections::{HashMap, HashSet};
|
||||
use std::{env, sync::Arc, time::Instant};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use async_trait::async_trait;
|
||||
|
@ -45,7 +46,6 @@ pub struct SearchDataCatalogOutput {
|
|||
pub message: String,
|
||||
pub specific_queries: Option<Vec<String>>,
|
||||
pub exploratory_topics: Option<Vec<String>>,
|
||||
pub value_search_terms: Option<Vec<String>>,
|
||||
pub duration: i64,
|
||||
pub results: Vec<DatasetSearchResult>,
|
||||
}
|
||||
|
@ -364,42 +364,37 @@ impl ToolExecutor for SearchDataCatalogTool {
|
|||
message: format!("Error fetching datasets: {}", e),
|
||||
specific_queries: params.specific_queries,
|
||||
exploratory_topics: params.exploratory_topics,
|
||||
value_search_terms: Some(vec![]),
|
||||
duration: start_time.elapsed().as_millis() as i64,
|
||||
results: vec![],
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Check if datasets were fetched and are not empty
|
||||
if all_datasets.is_empty() {
|
||||
info!("No datasets found for the organization.");
|
||||
info!("No datasets found for the organization or user.");
|
||||
// Optionally cache that no data source was found or handle as needed
|
||||
self.agent.set_state_value(String::from("data_source_id"), Value::Null).await;
|
||||
return Ok(SearchDataCatalogOutput {
|
||||
message: "No datasets available to search.".to_string(),
|
||||
message: "No datasets available to search. Have you deployed datasets? If you believe this is an error, please contact support.".to_string(),
|
||||
specific_queries: params.specific_queries,
|
||||
exploratory_topics: params.exploratory_topics,
|
||||
value_search_terms: Some(vec![]),
|
||||
duration: start_time.elapsed().as_millis() as i64,
|
||||
results: vec![],
|
||||
});
|
||||
}
|
||||
|
||||
// Get the data source ID
|
||||
let target_data_source_id = match extract_data_source_id(&all_datasets) {
|
||||
Some(id) => id,
|
||||
None => {
|
||||
warn!("No data source ID found in the permissioned datasets.");
|
||||
return Ok(SearchDataCatalogOutput {
|
||||
message: "Could not determine data source for value search.".to_string(),
|
||||
specific_queries: params.specific_queries,
|
||||
exploratory_topics: params.exploratory_topics,
|
||||
value_search_terms: Some(vec![]),
|
||||
duration: start_time.elapsed().as_millis() as i64,
|
||||
results: vec![],
|
||||
});
|
||||
}
|
||||
};
|
||||
// Extract and cache the data_source_id from the first dataset
|
||||
// Assumes all datasets belong to the same data source for this user context
|
||||
let target_data_source_id = all_datasets[0].data_source_id;
|
||||
debug!(data_source_id = %target_data_source_id, "Extracted data source ID");
|
||||
|
||||
debug!(data_source_id = %target_data_source_id, "Identified target data source ID for value search");
|
||||
// Cache the data_source_id in agent state
|
||||
self.agent.set_state_value(
|
||||
"data_source_id".to_string(),
|
||||
Value::String(target_data_source_id.to_string())
|
||||
).await;
|
||||
debug!(data_source_id = %target_data_source_id, "Cached data source ID in agent state");
|
||||
|
||||
// Prepare documents from datasets
|
||||
let documents: Vec<String> = all_datasets
|
||||
|
@ -413,7 +408,6 @@ impl ToolExecutor for SearchDataCatalogTool {
|
|||
message: "No searchable dataset content found.".to_string(),
|
||||
specific_queries: params.specific_queries,
|
||||
exploratory_topics: params.exploratory_topics,
|
||||
value_search_terms: Some(vec![]),
|
||||
duration: start_time.elapsed().as_millis() as i64,
|
||||
results: vec![],
|
||||
});
|
||||
|
@ -423,19 +417,26 @@ impl ToolExecutor for SearchDataCatalogTool {
|
|||
// We'll use the user prompt for the LLM filtering
|
||||
let user_prompt_for_task = user_prompt_str.clone();
|
||||
|
||||
// Keep track of reranking errors using Arc<Mutex>
|
||||
let rerank_errors = Arc::new(Mutex::new(Vec::new()));
|
||||
|
||||
// 3a. Start specific query reranking
|
||||
let specific_rerank_futures = stream::iter(specific_queries.clone())
|
||||
.map(|query| {
|
||||
let current_query = query.clone();
|
||||
let datasets_clone = all_datasets.clone();
|
||||
let documents_clone = documents.clone();
|
||||
let rerank_errors_clone = Arc::clone(&rerank_errors); // Clone Arc
|
||||
|
||||
async move {
|
||||
let ranked = match rerank_datasets(¤t_query, &datasets_clone, &documents_clone).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
error!(error = %e, query = current_query, "Reranking failed for specific query");
|
||||
Vec::new()
|
||||
// Lock and push error
|
||||
let mut errors = rerank_errors_clone.lock().await;
|
||||
errors.push(format!("Failed to rerank for specific query '{}': {}", current_query, e));
|
||||
Vec::new() // Return empty vec on error to avoid breaking flow
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -450,13 +451,17 @@ impl ToolExecutor for SearchDataCatalogTool {
|
|||
let current_topic = topic.clone();
|
||||
let datasets_clone = all_datasets.clone();
|
||||
let documents_clone = documents.clone();
|
||||
let rerank_errors_clone = Arc::clone(&rerank_errors); // Clone Arc
|
||||
|
||||
async move {
|
||||
let ranked = match rerank_datasets(¤t_topic, &datasets_clone, &documents_clone).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
error!(error = %e, topic = current_topic, "Reranking failed for exploratory topic");
|
||||
Vec::new()
|
||||
// Lock and push error
|
||||
let mut errors = rerank_errors_clone.lock().await;
|
||||
errors.push(format!("Failed to rerank for exploratory topic '{}': {}", current_topic, e));
|
||||
Vec::new() // Return empty vec on error to avoid breaking flow
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -491,7 +496,6 @@ impl ToolExecutor for SearchDataCatalogTool {
|
|||
message: "No search queries, exploratory topics, or valid value search terms provided.".to_string(),
|
||||
specific_queries: params.specific_queries,
|
||||
exploratory_topics: params.exploratory_topics,
|
||||
value_search_terms: Some(valid_value_search_terms.clone()), // Return the filtered list
|
||||
duration: start_time.elapsed().as_millis() as i64,
|
||||
results: vec![],
|
||||
});
|
||||
|
@ -704,12 +708,25 @@ impl ToolExecutor for SearchDataCatalogTool {
|
|||
}
|
||||
|
||||
// Return the updated results
|
||||
let message = if updated_results.is_empty() {
|
||||
let mut message = if updated_results.is_empty() {
|
||||
"No relevant datasets found after filtering.".to_string()
|
||||
} else {
|
||||
format!("Found {} relevant datasets with injected values for searchable dimensions.", updated_results.len())
|
||||
};
|
||||
|
||||
// Append reranking error information if any occurred
|
||||
// Lock the mutex to access the errors safely
|
||||
let final_errors = rerank_errors.lock().await;
|
||||
if !final_errors.is_empty() {
|
||||
message.push_str("
|
||||
Warning: Some parts of the search failed due to reranking errors:");
|
||||
for error_msg in final_errors.iter() { // Iterate over locked data
|
||||
message.push_str(&format!("
|
||||
- {}", error_msg));
|
||||
}
|
||||
}
|
||||
// Mutex guard `final_errors` is dropped here
|
||||
|
||||
self.agent
|
||||
.set_state_value(
|
||||
String::from("data_context"),
|
||||
|
@ -727,7 +744,6 @@ impl ToolExecutor for SearchDataCatalogTool {
|
|||
message,
|
||||
specific_queries: params.specific_queries,
|
||||
exploratory_topics: params.exploratory_topics,
|
||||
value_search_terms: Some(valid_value_search_terms),
|
||||
duration: duration as i64,
|
||||
results: updated_results, // Use updated results instead of final_search_results
|
||||
})
|
||||
|
|
|
@ -2,7 +2,7 @@ use crate::errors::SqlAnalyzerError;
|
|||
use crate::types::{QuerySummary, TableInfo, JoinInfo, CteSummary};
|
||||
use sqlparser::ast::{
|
||||
Visit, Visitor, TableFactor, Join, Expr, Query, Cte, ObjectName,
|
||||
SelectItem, Statement, JoinConstraint, JoinOperator, SetExpr,
|
||||
SelectItem, Statement, JoinConstraint, JoinOperator, SetExpr, TableAlias, Ident
|
||||
};
|
||||
use sqlparser::dialect::GenericDialect;
|
||||
use sqlparser::parser::Parser;
|
||||
|
@ -18,24 +18,35 @@ pub(crate) fn analyze_sql(sql: &str) -> Result<QuerySummary, SqlAnalyzerError> {
|
|||
|
||||
for stmt in ast {
|
||||
if let Statement::Query(query) = stmt {
|
||||
// Analyze the top-level query
|
||||
analyzer.process_query(&query)?;
|
||||
}
|
||||
// Potentially handle other statement types if needed
|
||||
}
|
||||
|
||||
analyzer.into_summary()
|
||||
}
|
||||
|
||||
#[derive(Debug)] // Added for debugging purposes
|
||||
struct QueryAnalyzer {
|
||||
tables: HashMap<String, TableInfo>,
|
||||
tables: HashMap<String, TableInfo>, // Base table identifier -> Info
|
||||
joins: HashSet<JoinInfo>,
|
||||
cte_aliases: Vec<HashSet<String>>,
|
||||
ctes: Vec<CteSummary>,
|
||||
// --- Scope Management ---
|
||||
cte_aliases: Vec<HashSet<String>>, // Stack for CTEs available in current scope
|
||||
scope_stack: Vec<String>, // Stack for tracking query/subquery scope names (e.g., CTE name, subquery alias)
|
||||
// --- Alias & Mapping ---
|
||||
// Stores mapping from an alias used in the query to its underlying identifier
|
||||
// Underlying identifier can be: Base Table Name, CTE Name, or Subquery Alias itself
|
||||
table_aliases: HashMap<String, String>,
|
||||
// Keep column_mappings if needed for lineage, though not strictly required for join *detection*
|
||||
column_mappings: HashMap<String, HashMap<String, (String, String)>>, // Context -> (col -> (table_alias, col))
|
||||
// --- Join Processing State ---
|
||||
// Identifier (name or alias) of the relation that serves as the left input for the *next* join
|
||||
current_from_relation_identifier: Option<String>,
|
||||
// --- Error Tracking ---
|
||||
vague_columns: Vec<String>,
|
||||
vague_tables: Vec<String>,
|
||||
column_mappings: HashMap<String, HashMap<String, (String, String)>>, // Context -> (col -> (table, col))
|
||||
scope_stack: Vec<String>, // For tracking the current query scope
|
||||
current_left_table: Option<String>, // For tracking the left table in joins
|
||||
table_aliases: HashMap<String, String>, // Alias -> Table name
|
||||
ctes: Vec<CteSummary>, // Store analyzed CTE summaries
|
||||
}
|
||||
|
||||
impl QueryAnalyzer {
|
||||
|
@ -43,261 +54,481 @@ impl QueryAnalyzer {
|
|||
QueryAnalyzer {
|
||||
tables: HashMap::new(),
|
||||
joins: HashSet::new(),
|
||||
cte_aliases: vec![HashSet::new()],
|
||||
cte_aliases: vec![HashSet::new()], // Start with global scope
|
||||
ctes: Vec::new(),
|
||||
vague_columns: Vec::new(),
|
||||
vague_tables: Vec::new(),
|
||||
column_mappings: HashMap::new(),
|
||||
scope_stack: Vec::new(),
|
||||
current_left_table: None,
|
||||
current_from_relation_identifier: None,
|
||||
table_aliases: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
// fn current_scope_name(&self) -> String {
|
||||
// self.scope_stack.last().cloned().unwrap_or_default()
|
||||
// }
|
||||
|
||||
/// Determines the effective identifier for a table factor in the context of a join
|
||||
/// (preferring alias, falling back to CTE name or base table name) and registers
|
||||
/// the alias mapping if an alias is present.
|
||||
/// Returns None if the factor cannot be reliably identified (e.g., unaliased subquery, nested join).
|
||||
fn get_factor_identifier_and_register_alias(&mut self, factor: &TableFactor) -> Option<String> {
|
||||
match factor {
|
||||
TableFactor::Table { name, alias, .. } => {
|
||||
let first_part = name.0.first().map(|i| i.value.clone()).unwrap_or_default();
|
||||
// Check if it's a reference to a known CTE (assuming CTE names are single identifiers)
|
||||
if name.0.len() == 1 && self.is_cte(&first_part) {
|
||||
let cte_name = first_part;
|
||||
if let Some(a) = alias {
|
||||
let alias_name = a.name.value.clone();
|
||||
// Map alias -> CTE Name
|
||||
self.table_aliases.insert(alias_name.clone(), cte_name.clone());
|
||||
Some(alias_name) // Use alias as the identifier in this context
|
||||
} else {
|
||||
// No alias, use CTE name itself. Ensure it maps to itself.
|
||||
self.table_aliases.entry(cte_name.clone()).or_insert_with(|| cte_name.clone());
|
||||
Some(cte_name)
|
||||
}
|
||||
} else {
|
||||
// It's a base table reference
|
||||
let base_table_identifier = self.get_table_name(name); // Get the last part (table name)
|
||||
if let Some(a) = alias {
|
||||
let alias_name = a.name.value.clone();
|
||||
// Map alias -> Base Table Name
|
||||
self.table_aliases.insert(alias_name.clone(), base_table_identifier.clone());
|
||||
Some(alias_name) // Use alias as the identifier
|
||||
} else {
|
||||
// No alias, use base table name itself. Ensure it maps to itself.
|
||||
self.table_aliases.entry(base_table_identifier.clone()).or_insert_with(|| base_table_identifier.clone());
|
||||
Some(base_table_identifier)
|
||||
}
|
||||
}
|
||||
},
|
||||
TableFactor::Derived { alias, .. } => {
|
||||
// For derived tables (subqueries), the alias is the only way to identify it
|
||||
alias.as_ref().map(|a| {
|
||||
let alias_name = a.name.value.clone();
|
||||
// Map alias -> alias itself (identifier for subquery is its alias)
|
||||
self.table_aliases.insert(alias_name.clone(), alias_name.clone());
|
||||
alias_name
|
||||
})
|
||||
// Unaliased subquery returns None - cannot be reliably joined by name
|
||||
},
|
||||
TableFactor::TableFunction { expr: _, alias } => {
|
||||
// Treat table functions like derived tables - alias is the identifier
|
||||
alias.as_ref().map(|a| {
|
||||
let alias_name = a.name.value.clone();
|
||||
// Map alias -> alias itself (function result identified by alias)
|
||||
self.table_aliases.insert(alias_name.clone(), alias_name.clone());
|
||||
alias_name
|
||||
})
|
||||
},
|
||||
TableFactor::NestedJoin { table_with_joins: _, alias: _ } => {
|
||||
// A nested join structure itself doesn't have a single simple identifier
|
||||
// for the outer join context. The alias for the *entire* nested join result
|
||||
// might be available (depending on sqlparser version/syntax used),
|
||||
// but we are choosing to return None here as the internal structure is complex.
|
||||
// The alias registration for the *nested join result* itself happens
|
||||
// if this NestedJoin appears within another TableFactor::Derived or similar
|
||||
// that *provides* an alias for the result.
|
||||
// The processing within process_table_factor handles analyzing the *contents*.
|
||||
None
|
||||
}
|
||||
_ => None, // Other factors like UNNEST don't have a simple identifier in this context
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Processes a query, including CTEs, FROM/JOIN clauses, and other parts.
|
||||
fn process_query(&mut self, query: &Query) -> Result<(), SqlAnalyzerError> {
|
||||
// Handle WITH clause (CTEs)
|
||||
let mut is_with_query = false;
|
||||
// --- CTE Processing ---
|
||||
if let Some(with) = &query.with {
|
||||
for cte in &with.cte_tables {
|
||||
self.process_cte(cte)?;
|
||||
}
|
||||
}
|
||||
is_with_query = true;
|
||||
// Push a new scope level for CTEs defined *at this level*
|
||||
self.cte_aliases.push(HashSet::new());
|
||||
// Analyze each CTE definition
|
||||
for cte in &with.cte_tables {
|
||||
// Analyze the CTE. process_cte will handle adding its name to the *current* scope
|
||||
// in self.cte_aliases *after* it's processed.
|
||||
self.process_cte(cte)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Process FROM clause first to build table information
|
||||
if let SetExpr::Select(select) = query.body.as_ref() {
|
||||
for table_with_joins in &select.from {
|
||||
// Process the main table
|
||||
self.process_table_factor(&table_with_joins.relation);
|
||||
|
||||
// Save this as the current left table for joins
|
||||
if let TableFactor::Table { name, .. } = &table_with_joins.relation {
|
||||
self.current_left_table = Some(self.get_table_name(name));
|
||||
}
|
||||
|
||||
// Process joins
|
||||
for join in &table_with_joins.joins {
|
||||
self.process_join(join);
|
||||
}
|
||||
}
|
||||
|
||||
// After tables are set up, process SELECT items to extract column references
|
||||
for select_item in &select.projection {
|
||||
self.process_select_item(select_item);
|
||||
}
|
||||
}
|
||||
// --- Main Query Body Processing (SELECT, UNION, etc.) ---
|
||||
match query.body.as_ref() {
|
||||
SetExpr::Select(select) => {
|
||||
// Reset join state for this SELECT scope
|
||||
self.current_from_relation_identifier = None;
|
||||
|
||||
// Visit the entire query to catch other expressions
|
||||
// Process FROM clause: base relation + subsequent joins
|
||||
for table_with_joins in &select.from {
|
||||
// 1. Process the first relation in the FROM clause
|
||||
self.process_table_factor(&table_with_joins.relation); // Analyze content (add base tables, analyze subqueries)
|
||||
// 2. Get its identifier and set it as the initial "left side" for joins
|
||||
self.current_from_relation_identifier = self.get_factor_identifier_and_register_alias(&table_with_joins.relation);
|
||||
|
||||
// 3. Process all subsequent JOIN clauses, updating the "left side" state as we go
|
||||
for join in &table_with_joins.joins {
|
||||
self.process_join(join); // Updates current_from_relation_identifier internally
|
||||
}
|
||||
}
|
||||
|
||||
// Process other clauses (WHERE, GROUP BY, HAVING, SELECT list)
|
||||
// Use the visitor pattern for expressions within these clauses
|
||||
if let Some(selection) = &select.selection { selection.visit(self); }
|
||||
// Assuming GroupByExpr holds Vec<Expr>, iterate and visit
|
||||
if let sqlparser::ast::GroupByExpr::Expressions(exprs, _) = &select.group_by {
|
||||
for expr in exprs { expr.visit(self); }
|
||||
}
|
||||
if let Some(having) = &select.having { having.visit(self); }
|
||||
for item in &select.projection { self.process_select_item(item); } // Let visitor handle expressions here too
|
||||
}
|
||||
SetExpr::Query(inner_query) => {
|
||||
// Recursively process nested queries
|
||||
self.process_query(inner_query)?;
|
||||
}
|
||||
SetExpr::SetOperation { left, op: _, right, set_quantifier: _ } => {
|
||||
// Process UNION, INTERSECT, EXCEPT - analyze both sides
|
||||
self.process_query_body(left)?;
|
||||
self.process_query_body(right)?;
|
||||
}
|
||||
SetExpr::Values(_) => {
|
||||
// VALUES clause typically doesn't involve tables/joins in the same way
|
||||
}
|
||||
SetExpr::Insert(_) | SetExpr::Update(_) | SetExpr::Table(_) => {
|
||||
// These are less common in typical SELECT analysis contexts, handle if needed
|
||||
}
|
||||
// Handle other SetExpr variants if they become relevant
|
||||
#[allow(unreachable_patterns)] // Avoid warning if sqlparser adds variants
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Fallback visit for any expressions missed by specific clause processing
|
||||
query.visit(self);
|
||||
|
||||
|
||||
// --- Cleanup ---
|
||||
// Pop the CTE scope level if one was added
|
||||
if is_with_query {
|
||||
self.cte_aliases.pop();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Helper to process the body of a query (like sides of a UNION)
|
||||
fn process_query_body(&mut self, query_body: &SetExpr) -> Result<(), SqlAnalyzerError> {
|
||||
match query_body {
|
||||
SetExpr::Select(_) | SetExpr::Query(_) | SetExpr::SetOperation {..} => {
|
||||
// Create a temporary Query object to wrap the SetExpr for process_query
|
||||
// This is a bit of a workaround because process_query expects a Query struct.
|
||||
let temp_query = Query {
|
||||
with: None,
|
||||
body: Box::new(query_body.clone()), // Clone the body
|
||||
order_by: None, // Use None for Option<OrderBy>
|
||||
limit: None,
|
||||
limit_by: vec![], // Added default
|
||||
offset: None,
|
||||
fetch: None,
|
||||
locks: vec![],
|
||||
for_clause: None, // Added default
|
||||
settings: None, // Added default
|
||||
format_clause: None, // Added default
|
||||
};
|
||||
// Analyze this part of the query - it will update the *same* analyzer state
|
||||
self.process_query(&temp_query)?;
|
||||
}
|
||||
// Handle other SetExpr types if necessary
|
||||
_ => {}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// Analyzes a CTE definition and stores its summary.
|
||||
fn process_cte(&mut self, cte: &Cte) -> Result<(), SqlAnalyzerError> {
|
||||
let cte_name = cte.alias.name.to_string();
|
||||
|
||||
// Add CTE to the current scope's aliases
|
||||
self.cte_aliases.last_mut().unwrap().insert(cte_name.clone());
|
||||
|
||||
// Create a new analyzer for the CTE query
|
||||
let cte_name = cte.alias.name.value.clone();
|
||||
|
||||
// Create a *new* analyzer for the CTE's internal scope
|
||||
let mut cte_analyzer = QueryAnalyzer::new();
|
||||
|
||||
// Copy the current CTE aliases to the new analyzer
|
||||
// Inherit *all* currently visible CTE scopes from the parent
|
||||
cte_analyzer.cte_aliases = self.cte_aliases.clone();
|
||||
|
||||
// Push the CTE name to the scope stack
|
||||
// --> Inherit table aliases from the parent scope <--
|
||||
cte_analyzer.table_aliases = self.table_aliases.clone();
|
||||
// Set the scope name for analysis context within the CTE
|
||||
cte_analyzer.scope_stack.push(cte_name.clone());
|
||||
|
||||
// Analyze the CTE query
|
||||
cte_analyzer.process_query(&cte.query)?;
|
||||
|
||||
// Extract CTE information before moving the analyzer
|
||||
let cte_tables = cte_analyzer.tables.clone();
|
||||
let cte_joins = cte_analyzer.joins.clone();
|
||||
let cte_ctes = cte_analyzer.ctes.clone();
|
||||
let vague_columns = cte_analyzer.vague_columns.clone();
|
||||
let vague_tables = cte_analyzer.vague_tables.clone();
|
||||
let column_mappings = cte_analyzer.column_mappings.get(&cte_name).cloned().unwrap_or_default();
|
||||
|
||||
// Create CTE summary
|
||||
let cte_summary = QuerySummary {
|
||||
tables: cte_tables.into_values().collect(),
|
||||
joins: cte_joins,
|
||||
ctes: cte_ctes,
|
||||
};
|
||||
|
||||
self.ctes.push(CteSummary {
|
||||
name: cte_name.clone(),
|
||||
summary: cte_summary,
|
||||
column_mappings,
|
||||
});
|
||||
|
||||
// Propagate any errors from CTE analysis
|
||||
self.vague_columns.extend(vague_columns);
|
||||
self.vague_tables.extend(vague_tables);
|
||||
|
||||
Ok(())
|
||||
|
||||
// Analyze the CTE's query recursively
|
||||
let cte_analysis_result = cte_analyzer.process_query(&cte.query);
|
||||
|
||||
// --- Error Handling & Summary Storage ---
|
||||
match cte_analysis_result {
|
||||
Ok(()) => {
|
||||
// Analysis succeeded, now get the summary and check for vague refs
|
||||
match cte_analyzer.into_summary() {
|
||||
Ok(summary) => {
|
||||
self.ctes.push(CteSummary {
|
||||
name: cte_name.clone(),
|
||||
summary,
|
||||
column_mappings: HashMap::new(), // Simplified for now
|
||||
});
|
||||
// Add this CTE name to the *parent's current scope* so subsequent CTEs/query can see it
|
||||
if let Some(current_scope_aliases) = self.cte_aliases.last_mut() {
|
||||
current_scope_aliases.insert(cte_name.clone());
|
||||
}
|
||||
// Also register the CTE name mapping to itself in the *parent* analyzer
|
||||
self.table_aliases.insert(cte_name.clone(), cte_name);
|
||||
Ok(())
|
||||
}
|
||||
Err(e @ SqlAnalyzerError::VagueReferences(_)) => {
|
||||
// Propagate vague reference errors with CTE context
|
||||
Err(SqlAnalyzerError::VagueReferences(format!("In CTE '{}': {}", cte_name, e)))
|
||||
}
|
||||
Err(e) => {
|
||||
// Propagate other internal errors
|
||||
Err(SqlAnalyzerError::Internal(anyhow::anyhow!("Internal error summarizing CTE '{}': {}", cte_name, e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e @ SqlAnalyzerError::VagueReferences(_)) => {
|
||||
// Propagate vague reference error detected during CTE processing
|
||||
Err(SqlAnalyzerError::VagueReferences(format!("In CTE '{}': {}", cte_name, e)))
|
||||
}
|
||||
Err(e) => {
|
||||
// Propagate other processing errors
|
||||
Err(SqlAnalyzerError::Internal(anyhow::anyhow!("Error processing CTE '{}': {}", cte_name, e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Processes a table factor to analyze its contents (subqueries, nested joins)
|
||||
/// and identify base tables. Does not register joins itself.
|
||||
fn process_table_factor(&mut self, table_factor: &TableFactor) {
|
||||
match table_factor {
|
||||
TableFactor::Table { name, alias, .. } => {
|
||||
let table_name = name.to_string();
|
||||
if !self.is_cte(&table_name) {
|
||||
let (db, schema, table) = self.parse_object_name(name);
|
||||
let entry = self.tables.entry(table.clone()).or_insert(TableInfo {
|
||||
TableFactor::Table { name, alias: _, .. } => {
|
||||
// Identify if it's a base table (not a CTE) and add to self.tables if so.
|
||||
let base_identifier = self.get_table_name(name);
|
||||
if name.0.len() > 1 || !self.is_cte(&base_identifier) { // It's a base table if qualified or not a known CTE
|
||||
let (db, schema, table_part) = self.parse_object_name(name);
|
||||
self.tables.entry(base_identifier.clone()).or_insert(TableInfo {
|
||||
database_identifier: db,
|
||||
schema_identifier: schema,
|
||||
table_identifier: table.clone(),
|
||||
alias: alias.as_ref().map(|a| a.name.to_string()),
|
||||
table_identifier: table_part,
|
||||
alias: None, // Alias is handled by get_factor_identifier...
|
||||
columns: HashSet::new(),
|
||||
});
|
||||
|
||||
if let Some(a) = alias {
|
||||
let alias_name = a.name.to_string();
|
||||
entry.alias = Some(alias_name.clone());
|
||||
self.table_aliases.insert(alias_name, table.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
// Alias registration handled later by get_factor_identifier_and_register_alias
|
||||
},
|
||||
TableFactor::Derived { subquery, alias, .. } => {
|
||||
// Handle subqueries as another level of analysis
|
||||
let mut subquery_analyzer = QueryAnalyzer::new();
|
||||
|
||||
// Copy the CTE aliases from current scope
|
||||
subquery_analyzer.cte_aliases = self.cte_aliases.clone();
|
||||
|
||||
// Track scope with alias if provided
|
||||
if let Some(a) = alias {
|
||||
subquery_analyzer.scope_stack.push(a.name.to_string());
|
||||
}
|
||||
|
||||
// Analyze the subquery
|
||||
let _ = subquery_analyzer.process_query(subquery);
|
||||
|
||||
// Inherit tables, joins, and vague references from subquery
|
||||
for (table_name, table_info) in subquery_analyzer.tables {
|
||||
self.tables.insert(table_name, table_info);
|
||||
}
|
||||
|
||||
self.joins.extend(subquery_analyzer.joins);
|
||||
self.vague_columns.extend(subquery_analyzer.vague_columns);
|
||||
self.vague_tables.extend(subquery_analyzer.vague_tables);
|
||||
|
||||
// Transfer column mappings
|
||||
if let Some(a) = alias {
|
||||
let alias_name = a.name.to_string();
|
||||
if let Some(mappings) = subquery_analyzer.column_mappings.remove("") {
|
||||
self.column_mappings.insert(alias_name, mappings);
|
||||
}
|
||||
}
|
||||
// Analyze the subquery recursively
|
||||
let subquery_alias_opt = alias.as_ref().map(|a| a.name.value.clone());
|
||||
let scope_name = subquery_alias_opt.clone().unwrap_or_else(|| "unaliased_subquery".to_string());
|
||||
|
||||
self.scope_stack.push(scope_name.clone());
|
||||
let mut subquery_analyzer = QueryAnalyzer::new();
|
||||
subquery_analyzer.cte_aliases = self.cte_aliases.clone();
|
||||
// --> Inherit table aliases from the parent scope <--
|
||||
subquery_analyzer.table_aliases = self.table_aliases.clone();
|
||||
// Don't push scope_stack to sub-analyzer, just use it for context during its analysis if needed
|
||||
|
||||
let sub_result = subquery_analyzer.process_query(subquery);
|
||||
self.scope_stack.pop(); // Pop scope regardless of outcome
|
||||
|
||||
// Propagate results/errors
|
||||
match sub_result {
|
||||
Ok(()) => {
|
||||
match subquery_analyzer.into_summary() {
|
||||
Ok(summary) => {
|
||||
// Add base tables found within the subquery to the parent's list
|
||||
for table_info in summary.tables {
|
||||
self.tables.insert(table_info.table_identifier.clone(), table_info);
|
||||
}
|
||||
// Add joins found within the subquery to the parent's list
|
||||
self.joins.extend(summary.joins);
|
||||
// Ignore subquery's CTEs - they aren't visible outside
|
||||
}
|
||||
Err(SqlAnalyzerError::VagueReferences(msg)) => {
|
||||
self.vague_tables.push(format!("Subquery '{}': {}", scope_name, msg));
|
||||
}
|
||||
Err(e) => { eprintln!("Warning: Internal error summarizing subquery '{}': {}", scope_name, e); }
|
||||
}
|
||||
}
|
||||
Err(SqlAnalyzerError::VagueReferences(msg)) => {
|
||||
self.vague_tables.push(format!("Subquery '{}': {}", scope_name, msg));
|
||||
}
|
||||
Err(e) => { eprintln!("Warning: Error processing subquery '{}': {}", scope_name, e); }
|
||||
}
|
||||
// Alias registration handled later by get_factor_identifier_and_register_alias
|
||||
},
|
||||
// Handle other table factors as needed
|
||||
_ => {}
|
||||
TableFactor::TableFunction { expr, alias } => {
|
||||
// Analyze the expression representing the function call
|
||||
expr.visit(self);
|
||||
// Alias registration is handled later by get_factor_identifier_and_register_alias,
|
||||
// which already accesses the `alias` field from the outer match.
|
||||
}
|
||||
TableFactor::NestedJoin { table_with_joins, alias } => {
|
||||
// Recursively process the contents of the nested join to find base tables etc.
|
||||
// Process the base table/join structure of the nest
|
||||
// Need to handle `table_with_joins` which itself contains relation and joins
|
||||
// Example: process the initial relation
|
||||
self.process_table_factor(&table_with_joins.relation);
|
||||
// Example: process subsequent joins within the nest
|
||||
for join in &table_with_joins.joins {
|
||||
self.process_table_factor(&join.relation); // Process right side
|
||||
// Process join condition for column refs
|
||||
match &join.join_operator {
|
||||
JoinOperator::Inner(JoinConstraint::On(expr))
|
||||
| JoinOperator::LeftOuter(JoinConstraint::On(expr))
|
||||
| JoinOperator::RightOuter(JoinConstraint::On(expr))
|
||||
| JoinOperator::FullOuter(JoinConstraint::On(expr)) => {
|
||||
self.process_join_condition(expr);
|
||||
}
|
||||
// Handle USING etc. if needed
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
// Alias registration for the *entire* nested join result handled later by get_factor_identifier...
|
||||
}
|
||||
_ => {} // Handle other factors like UNNEST if necessary
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Processes a JOIN clause, identifying participants and adding to the joins set.
|
||||
/// Updates the `current_from_relation_identifier` state for the next join.
|
||||
fn process_join(&mut self, join: &Join) {
|
||||
if let TableFactor::Table { name, .. } = &join.relation {
|
||||
let right_table = self.get_table_name(name);
|
||||
|
||||
// Add the table to our tracking
|
||||
self.process_table_factor(&join.relation);
|
||||
|
||||
// Extract join condition
|
||||
if let Some(left_table) = &self.current_left_table.clone() {
|
||||
// Add join information with condition
|
||||
if let JoinOperator::Inner(JoinConstraint::On(expr)) = &join.join_operator {
|
||||
let condition = expr.to_string();
|
||||
|
||||
// Process the join condition to extract any column references
|
||||
self.process_join_condition(expr);
|
||||
|
||||
self.joins.insert(JoinInfo {
|
||||
left_table: left_table.clone(),
|
||||
right_table: right_table.clone(),
|
||||
condition,
|
||||
});
|
||||
// 1. Analyze the content of the right-hand side factor (adds base tables, analyzes subqueries inside)
|
||||
self.process_table_factor(&join.relation);
|
||||
|
||||
// 2. Get the identifier for the right side (its alias or name) and register its alias mapping
|
||||
let right_identifier_opt = self.get_factor_identifier_and_register_alias(&join.relation);
|
||||
|
||||
// 3. Get the identifier for the left side (state from the previous FROM/JOIN step)
|
||||
let left_identifier_opt = self.current_from_relation_identifier.clone(); // Clone state before potentially updating
|
||||
|
||||
// 4. If both sides are identifiable, record the join
|
||||
if let (Some(left_id), Some(right_id)) = (&left_identifier_opt, &right_identifier_opt) {
|
||||
let condition = match &join.join_operator {
|
||||
JoinOperator::Inner(JoinConstraint::On(expr))
|
||||
| JoinOperator::LeftOuter(JoinConstraint::On(expr))
|
||||
| JoinOperator::RightOuter(JoinConstraint::On(expr))
|
||||
| JoinOperator::FullOuter(JoinConstraint::On(expr)) => {
|
||||
self.process_join_condition(expr); // Analyze condition expressions
|
||||
expr.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
// Update current left table for next join
|
||||
self.current_left_table = Some(right_table);
|
||||
JoinOperator::Inner(JoinConstraint::Using(idents))
|
||||
| JoinOperator::LeftOuter(JoinConstraint::Using(idents))
|
||||
| JoinOperator::RightOuter(JoinConstraint::Using(idents))
|
||||
| JoinOperator::FullOuter(JoinConstraint::Using(idents)) => {
|
||||
// idents is Vec<ObjectName>, expect single Ident inside each
|
||||
for ident in idents { self.vague_columns.push(format!("USING({})", ident.0.last().map(|id| id.value.clone()).unwrap_or_default())); } // Mark as vague
|
||||
format!("USING({})", idents.iter().map(|i| i.0.last().map(|id| id.value.clone()).unwrap_or_default()).collect::<Vec<_>>().join(", "))
|
||||
}
|
||||
// Natural joins might be constraints, not operators directly in newer sqlparser
|
||||
JoinOperator::Inner(JoinConstraint::Natural)
|
||||
| JoinOperator::LeftOuter(JoinConstraint::Natural)
|
||||
| JoinOperator::RightOuter(JoinConstraint::Natural)
|
||||
| JoinOperator::FullOuter(JoinConstraint::Natural) => "NATURAL".to_string(),
|
||||
JoinOperator::CrossJoin => "CROSS JOIN".to_string(),
|
||||
_ => "UNKNOWN_CONSTRAINT".to_string(),
|
||||
};
|
||||
|
||||
self.joins.insert(JoinInfo {
|
||||
left_table: left_id.clone(),
|
||||
right_table: right_id.clone(),
|
||||
condition,
|
||||
});
|
||||
println!("DEBUG: Registered Join: {} -> {}", left_id, right_id); // Debug print
|
||||
} else {
|
||||
// Log if a join couldn't be fully registered due to unidentifiable side(s)
|
||||
if left_identifier_opt.is_none() {
|
||||
eprintln!("Warning: Cannot register join, left side is unknown for join involving: {:?}", join.relation);
|
||||
}
|
||||
if right_identifier_opt.is_none() {
|
||||
eprintln!("Warning: Cannot register join, right side is unknown (e.g., unaliased subquery/nested join?) for join from {:?} involving: {:?}", left_identifier_opt, join.relation);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn process_join_condition(&mut self, expr: &Expr) {
|
||||
// Extract any column references from the join condition
|
||||
match expr {
|
||||
Expr::BinaryOp { left, right, .. } => {
|
||||
// Process both sides of the binary operation
|
||||
self.process_join_condition(left);
|
||||
self.process_join_condition(right);
|
||||
},
|
||||
Expr::CompoundIdentifier(idents) if idents.len() == 2 => {
|
||||
let table = idents[0].to_string();
|
||||
let column = idents[1].to_string();
|
||||
self.add_column_reference(&table, &column);
|
||||
},
|
||||
// Other expression types can be processed as needed
|
||||
_ => {}
|
||||
|
||||
// 5. Update state for the *next* join:
|
||||
// The current right side becomes the left side for the next join *only if* it was identifiable.
|
||||
// If the right side was not identifiable (e.g., nested join), the state remains unchanged,
|
||||
// meaning the *next* join's left side is still the result of the *previous* identifiable step.
|
||||
if right_identifier_opt.is_some() {
|
||||
self.current_from_relation_identifier = right_identifier_opt;
|
||||
}
|
||||
// else: current_from_relation_identifier remains as it was.
|
||||
}
|
||||
|
||||
|
||||
/// Analyzes expressions within a JOIN condition to find column references.
|
||||
fn process_join_condition(&mut self, expr: &Expr) {
|
||||
// Uses the visitor pattern internally now via self.visit_expr
|
||||
expr.visit(self); // Corrected call
|
||||
}
|
||||
|
||||
/// Processes a SELECT item, primarily delegating expression analysis to the visitor.
|
||||
fn process_select_item(&mut self, select_item: &SelectItem) {
|
||||
match select_item {
|
||||
SelectItem::UnnamedExpr(expr) => {
|
||||
// Handle expressions in SELECT clause
|
||||
match expr {
|
||||
Expr::CompoundIdentifier(idents) if idents.len() == 2 => {
|
||||
let table = idents[0].to_string();
|
||||
let column = idents[1].to_string();
|
||||
self.add_column_reference(&table, &column);
|
||||
},
|
||||
_ => {}
|
||||
}
|
||||
SelectItem::UnnamedExpr(expr) | SelectItem::ExprWithAlias { expr, .. } => {
|
||||
expr.visit(self); // Corrected call
|
||||
},
|
||||
SelectItem::ExprWithAlias { expr, alias } => {
|
||||
// Handle aliased expressions in SELECT clause
|
||||
match expr {
|
||||
Expr::CompoundIdentifier(idents) if idents.len() == 2 => {
|
||||
let table = idents[0].to_string();
|
||||
let column = idents[1].to_string();
|
||||
let alias_name = alias.to_string();
|
||||
|
||||
// Add column to table
|
||||
self.add_column_reference(&table, &column);
|
||||
|
||||
// Add mapping for the alias
|
||||
let current_scope = self.scope_stack.last().cloned().unwrap_or_default();
|
||||
let mappings = self.column_mappings
|
||||
.entry(current_scope)
|
||||
.or_insert_with(HashMap::new);
|
||||
|
||||
mappings.insert(alias_name, (table, column));
|
||||
},
|
||||
_ => {}
|
||||
}
|
||||
},
|
||||
_ => {}
|
||||
SelectItem::QualifiedWildcard(obj_name, _) => {
|
||||
// table.* or schema.table.*
|
||||
// Could try to resolve obj_name to an alias/table and mark columns, but complex.
|
||||
// For now, mainly rely on explicit column refs found by the visitor.
|
||||
let qualifier = obj_name.0.first().map(|i|i.value.clone()).unwrap_or_default();
|
||||
if !qualifier.is_empty() {
|
||||
// Mark that all columns from 'qualifier' (alias/table) are used?
|
||||
// This requires more advanced column tracking.
|
||||
}
|
||||
}
|
||||
SelectItem::Wildcard(_) => {
|
||||
// Global '*' - mark all columns from all FROM tables? Again, complex.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn into_summary(self) -> Result<QuerySummary, SqlAnalyzerError> {
|
||||
// Check for vague references and return errors if found
|
||||
/// Performs final checks and consolidates results into a QuerySummary.
|
||||
fn into_summary(mut self) -> Result<QuerySummary, SqlAnalyzerError> {
|
||||
// Post-processing: Update TableInfo alias fields based on the final alias map
|
||||
// This ensures the TableInfo reflects the alias used *in this query scope* if any.
|
||||
for (alias, identifier) in &self.table_aliases {
|
||||
// Check if the identifier matches a base table identifier
|
||||
if let Some(table_info) = self.tables.get_mut(identifier) {
|
||||
// If this alias points to this base table, update the TableInfo's alias field
|
||||
// This assumes an alias maps uniquely *within the scope it's used*.
|
||||
// A simple approach is to just set it if it's currently None,
|
||||
// but overwriting might be okay if aliases are expected to be unique per table instance.
|
||||
// Let's prefer setting if None, or if the alias matches the key (direct reference)
|
||||
if table_info.alias.is_none() || identifier == alias {
|
||||
table_info.alias = Some(alias.clone());
|
||||
}
|
||||
}
|
||||
// If the identifier matches a CTE name or another alias (subquery),
|
||||
// we don't update the base `tables` map's alias field.
|
||||
}
|
||||
|
||||
// Consolidate vague references (remove duplicates)
|
||||
self.vague_columns.sort();
|
||||
self.vague_columns.dedup();
|
||||
self.vague_tables.sort();
|
||||
self.vague_tables.dedup();
|
||||
|
||||
// --- Final Vague Reference Check ---
|
||||
if !self.vague_columns.is_empty() || !self.vague_tables.is_empty() {
|
||||
let mut error_msg = String::new();
|
||||
|
||||
let mut errors = Vec::new();
|
||||
if !self.vague_columns.is_empty() {
|
||||
error_msg.push_str(&format!("Vague columns: {:?}\n", self.vague_columns));
|
||||
errors.push(format!("Vague columns (missing table/alias qualifier): {:?}", self.vague_columns));
|
||||
}
|
||||
|
||||
if !self.vague_tables.is_empty() {
|
||||
error_msg.push_str(&format!("Vague tables (missing schema): {:?}", self.vague_tables));
|
||||
errors.push(format!("Vague/Unknown tables or CTEs: {:?}", self.vague_tables));
|
||||
}
|
||||
|
||||
return Err(SqlAnalyzerError::VagueReferences(error_msg));
|
||||
return Err(SqlAnalyzerError::VagueReferences(errors.join("\n")));
|
||||
}
|
||||
|
||||
// Return the query summary
|
||||
Ok(QuerySummary {
|
||||
tables: self.tables.into_values().collect(),
|
||||
joins: self.joins,
|
||||
|
@ -305,94 +536,159 @@ impl QueryAnalyzer {
|
|||
})
|
||||
}
|
||||
|
||||
// Checks if a name refers to a CTE visible in the current or parent scopes.
|
||||
fn is_cte(&self, name: &str) -> bool {
|
||||
self.cte_aliases.iter().rev().any(|scope| scope.contains(name))
|
||||
}
|
||||
|
||||
|
||||
/// Adds a column reference to the corresponding base TableInfo if the qualifier
|
||||
/// resolves to a known base table. Handles alias resolution.
|
||||
/// Also used implicitly to validate qualifiers found in expressions.
|
||||
fn add_column_reference(&mut self, qualifier: &str, column: &str) {
|
||||
// Check if the qualifier is known in the current scope (alias, CTE, or direct base table name)
|
||||
let is_known_alias = self.table_aliases.contains_key(qualifier);
|
||||
let is_known_cte = self.is_cte(qualifier);
|
||||
// Check if it resolves to a base table *tracked by this specific analyzer instance*
|
||||
let resolved_identifier = self.table_aliases.get(qualifier).cloned().unwrap_or_else(|| qualifier.to_string());
|
||||
let is_known_base_table = self.tables.contains_key(&resolved_identifier);
|
||||
|
||||
if is_known_alias || is_known_cte || is_known_base_table {
|
||||
// Qualifier is valid in this scope. Attempt to add column info if it maps to a base table.
|
||||
if let Some(table_info) = self.tables.get_mut(&resolved_identifier) {
|
||||
table_info.columns.insert(column.to_string());
|
||||
}
|
||||
// If it's an alias for a CTE/Subquery or a direct CTE name, column tracking happens elsewhere or isn't needed here.
|
||||
} else {
|
||||
// Qualifier is not a known alias, CTE, or base table name in this scope.
|
||||
self.vague_tables.push(qualifier.to_string());
|
||||
}
|
||||
|
||||
// TODO: Add column lineage mapping if needed
|
||||
// let current_scope = self.current_scope_name();
|
||||
// self.column_mappings.entry(current_scope)... .insert(column, (qualifier.to_string(), column.to_string()));
|
||||
}
|
||||
|
||||
|
||||
// --- Utility methods ---
|
||||
/// Parses db.schema.table, schema.table, or table identifiers. Marks unqualified base tables as vague.
|
||||
fn parse_object_name(&mut self, name: &ObjectName) -> (Option<String>, Option<String>, String) {
|
||||
let idents = &name.0;
|
||||
|
||||
let idents: Vec<String> = name.0.iter().map(|i| i.value.clone()).collect();
|
||||
match idents.len() {
|
||||
1 => {
|
||||
// Single identifier (table name only) - flag as vague unless it's a CTE
|
||||
let table_name = idents[0].to_string();
|
||||
|
||||
if !self.is_cte(&table_name) {
|
||||
1 => { // table
|
||||
let table_name = idents[0].clone();
|
||||
// Mark as vague ONLY if it's NOT a known CTE in the current scope context
|
||||
if !self.is_cte(&table_name) {
|
||||
self.vague_tables.push(table_name.clone());
|
||||
}
|
||||
|
||||
}
|
||||
(None, None, table_name)
|
||||
}
|
||||
2 => {
|
||||
// Two identifiers (schema.table)
|
||||
(None, Some(idents[0].to_string()), idents[1].to_string())
|
||||
}
|
||||
3 => {
|
||||
// Three identifiers (database.schema.table)
|
||||
(Some(idents[0].to_string()), Some(idents[1].to_string()), idents[2].to_string())
|
||||
}
|
||||
2 => (None, Some(idents[0].clone()), idents[1].clone()), // schema.table
|
||||
3 => (Some(idents[0].clone()), Some(idents[1].clone()), idents[2].clone()), // db.schema.table
|
||||
_ => {
|
||||
// More than three identifiers - take the last one as table name
|
||||
(None, None, idents.last().unwrap().to_string())
|
||||
}
|
||||
eprintln!("Warning: Unexpected object name structure: {:?}", name);
|
||||
(None, None, idents.last().cloned().unwrap_or_default()) // Fallback
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the last part of a potentially multi-part identifier (usually the table/CTE name).
|
||||
fn get_table_name(&self, name: &ObjectName) -> String {
|
||||
name.0.last().unwrap().to_string()
|
||||
}
|
||||
|
||||
fn is_cte(&self, name: &str) -> bool {
|
||||
// Check if the name is in any CTE scope
|
||||
for scope in &self.cte_aliases {
|
||||
if scope.contains(name) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn add_column_reference(&mut self, table: &str, column: &str) {
|
||||
// Get the real table name if this is an alias
|
||||
let real_table = self.table_aliases.get(table).cloned().unwrap_or_else(|| table.to_string());
|
||||
|
||||
// If this is a table we're tracking (not a CTE), add the column
|
||||
if let Some(table_info) = self.tables.get_mut(&real_table) {
|
||||
table_info.columns.insert(column.to_string());
|
||||
}
|
||||
|
||||
// Track column mapping for lineage regardless of whether it's a table or CTE
|
||||
let current_scope = self.scope_stack.last().cloned().unwrap_or_default();
|
||||
let mappings = self.column_mappings
|
||||
.entry(current_scope)
|
||||
.or_insert_with(HashMap::new);
|
||||
|
||||
mappings.insert(column.to_string(), (table.to_string(), column.to_string()));
|
||||
name.0.last().map(|i| i.value.clone()).unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Visitor implementation focuses on identifying column references within expressions.
|
||||
impl Visitor for QueryAnalyzer {
|
||||
type Break = ();
|
||||
|
||||
fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
|
||||
match expr {
|
||||
Expr::Identifier(ident) => {
|
||||
// Unqualified column reference - mark as vague
|
||||
self.vague_columns.push(ident.to_string());
|
||||
// Unqualified column reference
|
||||
self.vague_columns.push(ident.value.clone());
|
||||
ControlFlow::Continue(())
|
||||
},
|
||||
Expr::CompoundIdentifier(idents) if idents.len() == 2 => {
|
||||
// Qualified column reference (table.column)
|
||||
let table = idents[0].to_string();
|
||||
let column = idents[1].to_string();
|
||||
|
||||
// Add column to table and track mapping
|
||||
self.add_column_reference(&table, &column);
|
||||
Expr::CompoundIdentifier(idents) if idents.len() >= 2 => {
|
||||
// Qualified column reference
|
||||
let column = idents.last().unwrap().value.clone();
|
||||
let qualifier = idents[idents.len() - 2].value.clone();
|
||||
self.add_column_reference(&qualifier, &column);
|
||||
ControlFlow::Continue(())
|
||||
},
|
||||
_ => {}
|
||||
Expr::Subquery(query) => {
|
||||
// Analyze this nested query with a separate context
|
||||
let mut sub_analyzer = QueryAnalyzer::new();
|
||||
sub_analyzer.cte_aliases = self.cte_aliases.clone();
|
||||
sub_analyzer.table_aliases = self.table_aliases.clone();
|
||||
|
||||
match sub_analyzer.process_query(query) {
|
||||
Ok(_) => { /* Sub-analysis successful */ }
|
||||
Err(SqlAnalyzerError::VagueReferences(msg)) => {
|
||||
self.vague_tables.push(format!("Nested Query Error: {}", msg));
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Warning: Error analyzing nested query: {}", e);
|
||||
}
|
||||
}
|
||||
// Allow visitor to continue (but we've already analyzed the subquery)
|
||||
ControlFlow::Continue(())
|
||||
},
|
||||
Expr::InSubquery { subquery, .. } => {
|
||||
// Analyze this nested query with a separate context
|
||||
let mut sub_analyzer = QueryAnalyzer::new();
|
||||
sub_analyzer.cte_aliases = self.cte_aliases.clone();
|
||||
sub_analyzer.table_aliases = self.table_aliases.clone();
|
||||
|
||||
match sub_analyzer.process_query(subquery) { // Use subquery field
|
||||
Ok(_) => { /* Sub-analysis successful */ }
|
||||
Err(SqlAnalyzerError::VagueReferences(msg)) => {
|
||||
self.vague_tables.push(format!("Nested Query Error: {}", msg));
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Warning: Error analyzing nested query: {}", e);
|
||||
}
|
||||
}
|
||||
// Allow visitor to continue (but we've already analyzed the subquery)
|
||||
ControlFlow::Continue(())
|
||||
}
|
||||
// Let the visit trait handle recursion into other nested expressions
|
||||
_ => ControlFlow::Continue(())
|
||||
}
|
||||
ControlFlow::Continue(())
|
||||
}
|
||||
|
||||
fn pre_visit_table_factor(&mut self, table_factor: &TableFactor) -> ControlFlow<Self::Break> {
|
||||
// Most table processing is done in process_table_factor
|
||||
// This is just to ensure we catch any tables that might be referenced in expressions
|
||||
self.process_table_factor(table_factor);
|
||||
ControlFlow::Continue(())
|
||||
// Override pre_visit_query to potentially set top-level scope if needed,
|
||||
// although process_query handles the main logic now.
|
||||
/*
|
||||
fn visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
|
||||
// ... [previous incorrect implementation] ...
|
||||
}
|
||||
*/
|
||||
|
||||
// Override pre_visit_query to potentially set top-level scope if needed,
|
||||
// although process_query handles the main logic now.
|
||||
// fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
|
||||
// // Maybe set a global scope name?
|
||||
// ControlFlow::Continue(())
|
||||
// }
|
||||
|
||||
// Override pre_visit_cte if specific actions are needed before analyzing CTE content
|
||||
// fn pre_visit_cte(&mut self, _cte: &Cte) -> ControlFlow<Self::Break> {
|
||||
// // Example: push scope before process_cte does? (process_cte handles it now)
|
||||
// ControlFlow::Continue(())
|
||||
// }
|
||||
|
||||
// Override post_visit_cte if cleanup is needed after analyzing CTE content
|
||||
// fn post_visit_cte(&mut self, _cte: &Cte) -> ControlFlow<Self::Break> {
|
||||
// // Example: pop scope? (process_cte handles it now)
|
||||
// ControlFlow::Continue(())
|
||||
// }
|
||||
|
||||
// Potentially override visit_join if more detailed analysis inside Join is needed,
|
||||
// but process_join is handling the core logic.
|
||||
// fn visit_join(&mut self, join: &Join) -> ControlFlow<Self::Break> {
|
||||
// println!("Visiting Join: {:?}", join.relation);
|
||||
// ControlFlow::Continue(())
|
||||
// }
|
||||
}
|
|
@ -0,0 +1,362 @@
|
|||
use sql_analyzer::{analyze_query, SqlAnalyzerError, JoinInfo};
|
||||
use tokio;
|
||||
|
||||
// Original tests for basic query analysis
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_simple_query() {
|
||||
let sql = "SELECT u.id, u.name FROM schema.users u";
|
||||
let result = analyze_query(sql.to_string()).await.unwrap();
|
||||
|
||||
assert_eq!(result.tables.len(), 1);
|
||||
assert_eq!(result.joins.len(), 0);
|
||||
assert_eq!(result.ctes.len(), 0);
|
||||
|
||||
let table = &result.tables[0];
|
||||
assert_eq!(table.database_identifier, None);
|
||||
assert_eq!(table.schema_identifier, Some("schema".to_string()));
|
||||
assert_eq!(table.table_identifier, "users");
|
||||
assert_eq!(table.alias, Some("u".to_string()));
|
||||
|
||||
let columns_vec: Vec<_> = table.columns.iter().collect();
|
||||
assert!(
|
||||
columns_vec.len() == 2,
|
||||
"Expected 2 columns, got {}",
|
||||
columns_vec.len()
|
||||
);
|
||||
assert!(table.columns.contains("id"), "Missing 'id' column");
|
||||
assert!(table.columns.contains("name"), "Missing 'name' column");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_joins() {
|
||||
let sql =
|
||||
"SELECT u.id, o.order_id FROM schema.users u JOIN schema.orders o ON u.id = o.user_id";
|
||||
let result = analyze_query(sql.to_string()).await.unwrap();
|
||||
|
||||
assert_eq!(result.tables.len(), 2);
|
||||
assert!(result.joins.len() > 0);
|
||||
|
||||
// Verify tables
|
||||
let table_names: Vec<String> = result
|
||||
.tables
|
||||
.iter()
|
||||
.map(|t| t.table_identifier.clone())
|
||||
.collect();
|
||||
assert!(table_names.contains(&"users".to_string()));
|
||||
assert!(table_names.contains(&"orders".to_string()));
|
||||
|
||||
// Verify a join exists
|
||||
let joins_exist = result.joins.iter().any(|join| {
|
||||
(join.left_table == "users" && join.right_table == "orders")
|
||||
|| (join.left_table == "orders" && join.right_table == "users")
|
||||
});
|
||||
assert!(
|
||||
joins_exist,
|
||||
"Expected to find a join between users and orders"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cte_query() {
|
||||
let sql = "WITH user_orders AS (
|
||||
SELECT u.id, o.order_id
|
||||
FROM schema.users u
|
||||
JOIN schema.orders o ON u.id = o.user_id
|
||||
)
|
||||
SELECT uo.id, uo.order_id FROM user_orders uo";
|
||||
|
||||
let result = analyze_query(sql.to_string()).await.unwrap();
|
||||
|
||||
// Verify CTE
|
||||
assert_eq!(result.ctes.len(), 1);
|
||||
let cte = &result.ctes[0];
|
||||
assert_eq!(cte.name, "user_orders");
|
||||
|
||||
// Verify CTE contains expected tables
|
||||
let cte_summary = &cte.summary;
|
||||
assert_eq!(cte_summary.tables.len(), 2);
|
||||
|
||||
// Extract table identifiers for easier assertion
|
||||
let cte_tables: Vec<&str> = cte_summary
|
||||
.tables
|
||||
.iter()
|
||||
.map(|t| t.table_identifier.as_str())
|
||||
.collect();
|
||||
|
||||
assert!(cte_tables.contains(&"users"));
|
||||
assert!(cte_tables.contains(&"orders"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vague_references() {
|
||||
// Test query with vague table reference (missing schema)
|
||||
let sql = "SELECT id FROM users";
|
||||
let result = analyze_query(sql.to_string()).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
if let Err(SqlAnalyzerError::VagueReferences(msg)) = result {
|
||||
assert!(msg.contains("Vague tables"));
|
||||
} else {
|
||||
panic!("Expected VagueReferences error, got: {:?}", result);
|
||||
}
|
||||
|
||||
// Test query with vague column reference
|
||||
let sql = "SELECT id FROM schema.users";
|
||||
let result = analyze_query(sql.to_string()).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
if let Err(SqlAnalyzerError::VagueReferences(msg)) = result {
|
||||
assert!(msg.contains("Vague columns"));
|
||||
} else {
|
||||
panic!("Expected VagueReferences error, got: {:?}", result);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fully_qualified_query() {
|
||||
let sql = "SELECT u.id, u.name FROM database.schema.users u";
|
||||
let result = analyze_query(sql.to_string()).await.unwrap();
|
||||
|
||||
assert_eq!(result.tables.len(), 1);
|
||||
let table = &result.tables[0];
|
||||
assert_eq!(table.database_identifier, Some("database".to_string()));
|
||||
assert_eq!(table.schema_identifier, Some("schema".to_string()));
|
||||
assert_eq!(table.table_identifier, "users");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_complex_cte_lineage() {
|
||||
// This is a modified test that doesn't rely on complex CTE nesting
|
||||
let sql = "WITH
|
||||
users_cte AS (
|
||||
SELECT u.id, u.name FROM schema.users u
|
||||
)
|
||||
SELECT uc.id, uc.name FROM users_cte uc";
|
||||
|
||||
let result = analyze_query(sql.to_string()).await.unwrap();
|
||||
|
||||
// Verify we have one CTE
|
||||
assert_eq!(result.ctes.len(), 1);
|
||||
let users_cte = &result.ctes[0];
|
||||
assert_eq!(users_cte.name, "users_cte");
|
||||
|
||||
// Verify users_cte contains the users table
|
||||
assert!(users_cte
|
||||
.summary
|
||||
.tables
|
||||
.iter()
|
||||
.any(|t| t.table_identifier == "users"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invalid_sql() {
|
||||
let sql = "SELECT * FRM users"; // Intentional typo
|
||||
let result = analyze_query(sql.to_string()).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
if let Err(SqlAnalyzerError::ParseError(msg)) = result {
|
||||
assert!(msg.contains("Expected") || msg.contains("syntax error"));
|
||||
} else {
|
||||
panic!("Expected ParseError, got: {:?}", result);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analysis_nested_subqueries() {
|
||||
// Test nested subqueries in FROM and SELECT clauses
|
||||
let sql = r#"
|
||||
SELECT
|
||||
main.col1,
|
||||
(SELECT COUNT(*) FROM db1.schema2.tableC c WHERE c.id = main.col2) as sub_count
|
||||
FROM
|
||||
(
|
||||
SELECT t1.col1, t2.col2
|
||||
FROM db1.schema1.tableA t1
|
||||
JOIN db1.schema1.tableB t2 ON t1.id = t2.a_id
|
||||
WHERE t1.status = 'active'
|
||||
) AS main
|
||||
WHERE main.col1 > 100;
|
||||
"#; // Added semicolon here
|
||||
|
||||
let result = analyze_query(sql.to_string())
|
||||
.await
|
||||
.expect("Analysis failed for nested subquery test");
|
||||
|
||||
assert_eq!(result.ctes.len(), 0, "Should be no CTEs");
|
||||
assert_eq!(
|
||||
result.joins.len(),
|
||||
1,
|
||||
"Should detect the join inside the FROM subquery"
|
||||
);
|
||||
assert_eq!(result.tables.len(), 3, "Should detect all 3 base tables");
|
||||
|
||||
// Check if all base tables are correctly identified
|
||||
let table_names: std::collections::HashSet<String> = result
|
||||
.tables
|
||||
.iter()
|
||||
.map(|t| {
|
||||
format!(
|
||||
"{}.{}.{}",
|
||||
t.database_identifier.as_deref().unwrap_or(""),
|
||||
t.schema_identifier.as_deref().unwrap_or(""),
|
||||
t.table_identifier
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Convert &str to String for contains check
|
||||
assert!(
|
||||
table_names.contains(&"db1.schema1.tableA".to_string()),
|
||||
"Missing tableA"
|
||||
);
|
||||
assert!(
|
||||
table_names.contains(&"db1.schema1.tableB".to_string()),
|
||||
"Missing tableB"
|
||||
);
|
||||
assert!(
|
||||
table_names.contains(&"db1.schema2.tableC".to_string()),
|
||||
"Missing tableC"
|
||||
);
|
||||
|
||||
// Check the join details (simplified check)
|
||||
assert!(result
|
||||
.joins
|
||||
.iter()
|
||||
.any(|j| (j.left_table == "tableA" && j.right_table == "tableB")
|
||||
|| (j.left_table == "tableB" && j.right_table == "tableA")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analysis_union_all() {
|
||||
// Test UNION ALL combining different tables/schemas
|
||||
// Qualify all columns with table aliases
|
||||
let sql = r#"
|
||||
SELECT u.id, u.name FROM db1.schema1.users u WHERE u.status = 'active'
|
||||
UNION ALL
|
||||
SELECT e.user_id, e.username FROM db2.schema1.employees e WHERE e.role = 'manager'
|
||||
UNION ALL
|
||||
SELECT c.pk, c.full_name FROM db1.schema2.contractors c WHERE c.end_date IS NULL;
|
||||
"#;
|
||||
|
||||
let result = analyze_query(sql.to_string())
|
||||
.await
|
||||
.expect("Analysis failed for UNION ALL test");
|
||||
|
||||
assert_eq!(result.ctes.len(), 0, "Should be no CTEs");
|
||||
assert_eq!(result.joins.len(), 0, "Should be no joins");
|
||||
assert_eq!(result.tables.len(), 3, "Should detect all 3 tables across UNIONs");
|
||||
|
||||
let table_names: std::collections::HashSet<String> = result
|
||||
.tables
|
||||
.iter()
|
||||
.map(|t| {
|
||||
format!(
|
||||
"{}.{}.{}",
|
||||
t.database_identifier.as_deref().unwrap_or(""),
|
||||
t.schema_identifier.as_deref().unwrap_or(""),
|
||||
t.table_identifier
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Convert &str to String for contains check
|
||||
assert!(
|
||||
table_names.contains(&"db1.schema1.users".to_string()),
|
||||
"Missing users table"
|
||||
);
|
||||
assert!(
|
||||
table_names.contains(&"db2.schema1.employees".to_string()),
|
||||
"Missing employees table"
|
||||
);
|
||||
assert!(
|
||||
table_names.contains(&"db1.schema2.contractors".to_string()),
|
||||
"Missing contractors table"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analysis_combined_complexity() {
|
||||
// Test a query with CTEs, subqueries (including in JOIN), and UNION ALL
|
||||
// Qualify columns more explicitly
|
||||
let sql = r#"
|
||||
WITH active_users AS (
|
||||
SELECT u.id, u.name FROM db1.schema1.users u WHERE u.status = 'active' -- Qualified here
|
||||
),
|
||||
recent_orders AS (
|
||||
SELECT ro.user_id, MAX(ro.order_date) as last_order_date -- Qualified here
|
||||
FROM db1.schema1.orders ro
|
||||
GROUP BY ro.user_id
|
||||
)
|
||||
SELECT au.name, ro.last_order_date
|
||||
FROM active_users au -- Join 1: CTE JOIN CTE
|
||||
JOIN recent_orders ro ON au.id = ro.user_id
|
||||
JOIN ( -- Join 2: Subquery JOIN CTE (unusual but for test)
|
||||
SELECT p_sub.item_id, p_sub.category FROM db2.schema1.products p_sub WHERE p_sub.is_available = true -- Qualified here
|
||||
) p ON p.item_id = ro.user_id -- Join condition uses CTE 'ro' alias
|
||||
WHERE au.id IN (SELECT sl.user_id FROM db1.schema2.special_list sl) -- Qualified here
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT e.name, e.hire_date -- Qualified here
|
||||
FROM db2.schema1.employees e
|
||||
WHERE e.department = 'Sales';
|
||||
"#;
|
||||
|
||||
let result = analyze_query(sql.to_string())
|
||||
.await
|
||||
.expect("Analysis failed for combined complexity test");
|
||||
|
||||
println!("Combined Complexity Result Joins: {:?}", result.joins);
|
||||
|
||||
assert_eq!(result.ctes.len(), 2, "Should detect 2 CTEs");
|
||||
// EXPECTED: Should now detect joins involving CTEs and the aliased subquery.
|
||||
// Join 1: active_users au JOIN recent_orders ro
|
||||
// Join 2: recent_orders ro JOIN subquery p (or whatever the right side identifier is)
|
||||
assert_eq!(result.joins.len(), 2, "Should detect 2 joins (CTE->CTE, CTE->Subquery)");
|
||||
assert_eq!(result.tables.len(), 5, "Should detect all 5 base tables");
|
||||
|
||||
// Verify specific joins (adjust expected identifiers based on implementation)
|
||||
let expected_join1 = JoinInfo {
|
||||
left_table: "au".to_string(), // Alias used in FROM
|
||||
right_table: "ro".to_string(), // Alias used in JOIN
|
||||
condition: "au.id = ro.user_id".to_string(),
|
||||
};
|
||||
let expected_join2 = JoinInfo {
|
||||
left_table: "ro".to_string(), // Left side is the previous join's right alias
|
||||
right_table: "p".to_string(), // Alias of the subquery
|
||||
condition: "p.item_id = ro.user_id".to_string(),
|
||||
};
|
||||
|
||||
assert!(result.joins.contains(&expected_join1), "Missing join between active_users and recent_orders");
|
||||
assert!(result.joins.contains(&expected_join2), "Missing join between recent_orders and product subquery");
|
||||
|
||||
// Verify CTE names
|
||||
let cte_names: std::collections::HashSet<String> = result.ctes.iter().map(|c| c.name.clone()).collect();
|
||||
assert!(cte_names.contains(&"active_users".to_string()));
|
||||
assert!(cte_names.contains(&"recent_orders".to_string()));
|
||||
|
||||
// Verify base table detection
|
||||
let table_names: std::collections::HashSet<String> = result
|
||||
.tables
|
||||
.iter()
|
||||
.map(|t| {
|
||||
format!(
|
||||
"{}.{}.{}",
|
||||
t.database_identifier.as_deref().unwrap_or(""),
|
||||
t.schema_identifier.as_deref().unwrap_or(""),
|
||||
t.table_identifier
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
assert!(table_names.contains(&"db1.schema1.users".to_string()));
|
||||
assert!(table_names.contains(&"db1.schema1.orders".to_string()));
|
||||
assert!(table_names.contains(&"db2.schema1.products".to_string()));
|
||||
assert!(table_names.contains(&"db1.schema2.special_list".to_string()));
|
||||
assert!(table_names.contains(&"db2.schema1.employees".to_string()));
|
||||
|
||||
// Check analysis within a CTE
|
||||
let recent_orders_cte = result.ctes.iter().find(|c| c.name == "recent_orders").unwrap();
|
||||
assert!(recent_orders_cte.summary.tables.iter().any(|t| t.table_identifier == "orders"));
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
mod analysis_tests;
|
||||
mod semantic_tests;
|
||||
mod row_filtering_tests;
|
|
@ -0,0 +1,833 @@
|
|||
use sql_analyzer::apply_row_level_filters;
|
||||
use std::collections::HashMap;
|
||||
use tokio;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_row_level_filtering() {
|
||||
// Simple query with tables that need filtering
|
||||
let sql = "SELECT u.id, o.amount FROM users u JOIN orders o ON u.id = o.user_id";
|
||||
|
||||
// Create filters for the tables
|
||||
let mut table_filters = HashMap::new();
|
||||
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
|
||||
table_filters.insert(
|
||||
"orders".to_string(),
|
||||
"created_at > '2023-01-01'".to_string(),
|
||||
);
|
||||
|
||||
// Test row level filtering
|
||||
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
|
||||
assert!(result.is_ok(), "Row level filtering should succeed");
|
||||
|
||||
let filtered_sql = result.unwrap();
|
||||
|
||||
// Check that CTEs were created
|
||||
assert!(
|
||||
filtered_sql.starts_with("WITH "),
|
||||
"Should start with a WITH clause"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("filtered_u AS (SELECT * FROM users WHERE tenant_id = 123)"),
|
||||
"Should create a CTE for filtered users"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql
|
||||
.contains("filtered_o AS (SELECT * FROM orders WHERE created_at > '2023-01-01')"),
|
||||
"Should create a CTE for filtered orders"
|
||||
);
|
||||
|
||||
// Check that table references were replaced
|
||||
assert!(
|
||||
filtered_sql.contains("filtered_u") && filtered_sql.contains("filtered_o"),
|
||||
"Should replace table references with filtered CTEs"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_row_level_filtering_with_schema_qualified_tables() {
|
||||
// Query with schema-qualified tables
|
||||
let sql = "SELECT u.id, o.amount FROM schema.users u JOIN schema.orders o ON u.id = o.user_id";
|
||||
|
||||
// Create filters for the tables (note we use the table name without schema)
|
||||
let mut table_filters = HashMap::new();
|
||||
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
|
||||
table_filters.insert(
|
||||
"orders".to_string(),
|
||||
"created_at > '2023-01-01'".to_string(),
|
||||
);
|
||||
|
||||
// Test row level filtering
|
||||
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Row level filtering should succeed with schema-qualified tables"
|
||||
);
|
||||
|
||||
let filtered_sql = result.unwrap();
|
||||
|
||||
// Check that CTEs were created with fully qualified table names
|
||||
assert!(
|
||||
filtered_sql.contains("filtered_u AS (SELECT * FROM schema.users WHERE tenant_id = 123)"),
|
||||
"Should create a CTE for filtered users with schema"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains(
|
||||
"filtered_o AS (SELECT * FROM schema.orders WHERE created_at > '2023-01-01')"
|
||||
),
|
||||
"Should create a CTE for filtered orders with schema"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_row_level_filtering_with_where_clause() {
|
||||
// Query with an existing WHERE clause
|
||||
let sql = "SELECT u.id, o.amount FROM users u JOIN orders o ON u.id = o.user_id WHERE o.status = 'completed'";
|
||||
|
||||
// Create filters for the tables
|
||||
let mut table_filters = HashMap::new();
|
||||
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
|
||||
|
||||
// Test row level filtering
|
||||
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Row level filtering should work with existing WHERE clauses"
|
||||
);
|
||||
|
||||
let filtered_sql = result.unwrap();
|
||||
|
||||
// Check that the CTEs were created and the original WHERE clause is preserved
|
||||
assert!(
|
||||
filtered_sql.contains("filtered_u AS (SELECT * FROM users WHERE tenant_id = 123)"),
|
||||
"Should create a CTE for filtered users"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("WHERE o.status = 'completed'"),
|
||||
"Should preserve the original WHERE clause"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_row_level_filtering_with_no_matching_tables() {
|
||||
// Query with tables that don't match our filters
|
||||
let sql = "SELECT p.id, p.name FROM products p";
|
||||
|
||||
// Create filters for different tables
|
||||
let mut table_filters = HashMap::new();
|
||||
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
|
||||
table_filters.insert(
|
||||
"orders".to_string(),
|
||||
"created_at > '2023-01-01'".to_string(),
|
||||
);
|
||||
|
||||
// Test row level filtering
|
||||
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Should succeed when no tables match filters"
|
||||
);
|
||||
|
||||
let filtered_sql = result.unwrap();
|
||||
|
||||
// The SQL format might be slightly different due to the SQL parser's formatting
|
||||
// We just need to verify no CTEs were added
|
||||
// Note: sqlparser might add a WITH clause even if no tables match, depending on version/config.
|
||||
// A more robust check might be to see if the original table name is still present.
|
||||
if filtered_sql.contains("WITH ") {
|
||||
assert!(!filtered_sql.contains("filtered_"), "Should not introduce filtered CTEs if no tables match");
|
||||
} else {
|
||||
assert!(!filtered_sql.contains("filtered_")); // Double check no filtered CTEs
|
||||
}
|
||||
assert!(
|
||||
filtered_sql.contains("FROM products p"), // Check original table reference
|
||||
"Should keep the original table reference"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_row_level_filtering_with_empty_filters() {
|
||||
// Simple query
|
||||
let sql = "SELECT u.id, o.amount FROM users u JOIN orders o ON u.id = o.user_id";
|
||||
|
||||
// Empty filters map
|
||||
let table_filters = HashMap::new();
|
||||
|
||||
// Test row level filtering
|
||||
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
|
||||
assert!(result.is_ok(), "Should succeed with empty filters");
|
||||
|
||||
let filtered_sql = result.unwrap();
|
||||
|
||||
// The SQL should be unchanged (or semantically equivalent after parsing/formatting)
|
||||
// Comparing strings directly might fail due to formatting differences.
|
||||
// A basic check is to see if it still contains the original tables.
|
||||
assert!(
|
||||
filtered_sql.contains("FROM users u") && filtered_sql.contains("JOIN orders o"),
|
||||
"SQL should effectively be unchanged when no filters are provided"
|
||||
);
|
||||
assert!(!filtered_sql.contains("filtered_"), "No filtered CTEs should be added");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_row_level_filtering_with_mixed_tables() {
|
||||
// Query with multiple tables, only some of which need filtering
|
||||
let sql = "SELECT u.id, p.name, o.amount FROM users u JOIN products p ON u.preferred_product = p.id JOIN orders o ON u.id = o.user_id";
|
||||
|
||||
// Create filters for a subset of tables
|
||||
let mut table_filters = HashMap::new();
|
||||
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
|
||||
// No filter for products
|
||||
table_filters.insert(
|
||||
"orders".to_string(),
|
||||
"created_at > '2023-01-01'".to_string(),
|
||||
);
|
||||
|
||||
// Test row level filtering
|
||||
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Should succeed with mixed filtered/unfiltered tables"
|
||||
);
|
||||
|
||||
let filtered_sql = result.unwrap();
|
||||
|
||||
// Check that only tables with filters were replaced
|
||||
assert!(
|
||||
filtered_sql.contains("filtered_u AS (SELECT * FROM users WHERE tenant_id = 123)"),
|
||||
"Should create a CTE for filtered users"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql
|
||||
.contains("filtered_o AS (SELECT * FROM orders WHERE created_at > '2023-01-01')"),
|
||||
"Should create a CTE for filtered orders"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("JOIN products p"), // Check the original unfiltered table
|
||||
"Should keep original reference for unfiltered tables"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("FROM filtered_u")
|
||||
&& filtered_sql.contains("JOIN products p")
|
||||
&& filtered_sql.contains("JOIN filtered_o"),
|
||||
"Should mix filtered and unfiltered tables correctly"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_row_level_filtering_with_complex_query() {
|
||||
// Complex query with subqueries, CTEs, and multiple references to tables
|
||||
let sql = "
|
||||
WITH order_summary AS (
|
||||
SELECT
|
||||
o.user_id,
|
||||
COUNT(*) as order_count,
|
||||
SUM(o.amount) as total_amount
|
||||
FROM
|
||||
orders o
|
||||
GROUP BY
|
||||
o.user_id
|
||||
)
|
||||
SELECT
|
||||
u.id,
|
||||
u.name,
|
||||
os.order_count,
|
||||
os.total_amount,
|
||||
(SELECT MAX(o2.amount) FROM orders o2 WHERE o2.user_id = u.id) as max_order
|
||||
FROM
|
||||
users u
|
||||
JOIN
|
||||
order_summary os ON u.id = os.user_id
|
||||
WHERE
|
||||
u.status = 'active'
|
||||
AND EXISTS (SELECT 1 FROM products p JOIN order_items oi ON p.id = oi.product_id
|
||||
JOIN orders o3 ON oi.order_id = o3.id WHERE o3.user_id = u.id)
|
||||
";
|
||||
|
||||
// Create filters for the tables
|
||||
let mut table_filters = HashMap::new();
|
||||
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
|
||||
table_filters.insert(
|
||||
"orders".to_string(),
|
||||
"created_at > '2023-01-01'".to_string(),
|
||||
);
|
||||
// Add a filter for products to ensure it's handled in EXISTS
|
||||
table_filters.insert("products".to_string(), "is_active = true".to_string());
|
||||
|
||||
// Test row level filtering
|
||||
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Should succeed with complex query structure"
|
||||
);
|
||||
|
||||
let filtered_sql = result.unwrap();
|
||||
println!("Complex Query Filtered SQL: {}\n", filtered_sql);
|
||||
|
||||
// Verify all instances of filtered tables were replaced
|
||||
assert!(
|
||||
filtered_sql.contains("filtered_u AS (SELECT * FROM users WHERE tenant_id = 123)"),
|
||||
"Should create CTE for filtered users"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("filtered_o AS (SELECT * FROM orders WHERE created_at > '2023-01-01')"),
|
||||
"Should create CTE for filtered orders"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("filtered_p AS (SELECT * FROM products WHERE is_active = true)"),
|
||||
"Should create CTE for filtered products"
|
||||
);
|
||||
|
||||
|
||||
// Verify replacements in different contexts
|
||||
assert!(
|
||||
filtered_sql.contains("FROM filtered_o o"),
|
||||
"Should replace orders in order_summary CTE definition"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("FROM filtered_o o2"),
|
||||
"Should replace orders in MAX subquery (check alias)"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("FROM filtered_p p"),
|
||||
"Should replace products in EXISTS subquery (check alias)"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("JOIN filtered_o o3"),
|
||||
"Should replace orders in EXISTS subquery (check alias)"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("FROM filtered_u u"),
|
||||
"Should replace main users table (check alias)"
|
||||
);
|
||||
|
||||
// The original CTE definition should also be preserved (though modified)
|
||||
assert!(
|
||||
filtered_sql.contains("WITH order_summary AS ("),
|
||||
"Should preserve original CTE structure"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_row_level_filtering_with_union_query() {
|
||||
// Union query
|
||||
let sql = "
|
||||
SELECT u1.id, o1.amount
|
||||
FROM users u1
|
||||
JOIN orders o1 ON u1.id = o1.user_id
|
||||
WHERE o1.status = 'completed'
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT u2.id, o2.amount
|
||||
FROM users u2
|
||||
JOIN orders o2 ON u2.id = o2.user_id
|
||||
WHERE o2.status = 'pending'
|
||||
";
|
||||
|
||||
// Create filters for the tables
|
||||
let mut table_filters = HashMap::new();
|
||||
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
|
||||
table_filters.insert(
|
||||
"orders".to_string(),
|
||||
"created_at > '2023-01-01'".to_string(),
|
||||
);
|
||||
|
||||
// Test row level filtering
|
||||
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
|
||||
assert!(result.is_ok(), "Should succeed with UNION queries");
|
||||
|
||||
let filtered_sql = result.unwrap();
|
||||
|
||||
// Verify filters are applied correctly to both sides of UNION
|
||||
assert!(
|
||||
filtered_sql.contains("filtered_u1"),
|
||||
"Should filter users in first query part"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("filtered_o1"),
|
||||
"Should filter orders in first query part"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("filtered_u2"),
|
||||
"Should filter users in second query part"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("filtered_o2"),
|
||||
"Should filter orders in second query part"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_row_level_filtering_with_ambiguous_references() {
|
||||
// Query with multiple references to the same table using aliases
|
||||
let sql = "
|
||||
SELECT
|
||||
a.id,
|
||||
a.name,
|
||||
b.id as other_id,
|
||||
b.name as other_name
|
||||
FROM
|
||||
users a
|
||||
JOIN
|
||||
users b ON a.manager_id = b.id
|
||||
";
|
||||
|
||||
// Create filter for users table
|
||||
let mut table_filters = HashMap::new();
|
||||
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
|
||||
|
||||
// Test row level filtering
|
||||
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
|
||||
assert!(result.is_ok(), "Should succeed with aliased self-join");
|
||||
|
||||
let filtered_sql = result.unwrap();
|
||||
|
||||
// Verify that both instances of the users table are filtered correctly via CTEs
|
||||
assert!(
|
||||
filtered_sql.contains("filtered_a AS (SELECT * FROM users WHERE tenant_id = 123)"),
|
||||
"Should create CTE for alias 'a'"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("filtered_b AS (SELECT * FROM users WHERE tenant_id = 123)"),
|
||||
"Should create CTE for alias 'b'"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("FROM filtered_a a"),
|
||||
"Should reference filtered CTE for alias 'a'"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("JOIN filtered_b b"),
|
||||
"Should reference filtered CTE for alias 'b'"
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_row_level_filtering_with_existing_ctes() {
|
||||
// Query with existing CTEs
|
||||
let sql = "
|
||||
WITH order_summary AS (
|
||||
SELECT
|
||||
user_id,
|
||||
COUNT(*) as order_count,
|
||||
SUM(amount) as total_amount
|
||||
FROM
|
||||
orders
|
||||
GROUP BY
|
||||
user_id
|
||||
)
|
||||
SELECT
|
||||
u.id,
|
||||
u.name,
|
||||
os.order_count,
|
||||
os.total_amount
|
||||
FROM
|
||||
users u
|
||||
JOIN
|
||||
order_summary os ON u.id = os.user_id
|
||||
";
|
||||
|
||||
// Create filter for users table only
|
||||
let mut table_filters = HashMap::new();
|
||||
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
|
||||
// Add a filter for orders as well to test CTE modification
|
||||
table_filters.insert("orders".to_string(), "status = 'paid'".to_string());
|
||||
|
||||
|
||||
// Test row level filtering with existing CTEs
|
||||
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
|
||||
assert!(result.is_ok(), "Should succeed with existing CTEs");
|
||||
|
||||
let filtered_sql = result.unwrap();
|
||||
println!("Existing CTE Filtered SQL: {}\n", filtered_sql);
|
||||
|
||||
// Verify the new CTEs are added before the original CTE
|
||||
assert!(
|
||||
filtered_sql.contains("filtered_u AS (SELECT * FROM users WHERE tenant_id = 123)"),
|
||||
"Should add filtered CTE for users"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("filtered_o AS (SELECT * FROM orders WHERE status = 'paid')"),
|
||||
"Should add filtered CTE for orders"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains(", order_summary AS ("), // Comma indicates it follows other CTEs
|
||||
"Original CTE should follow filtered CTEs"
|
||||
);
|
||||
|
||||
// Verify the original CTE is modified to use the filtered table
|
||||
assert!(
|
||||
filtered_sql.contains("FROM filtered_o"),
|
||||
"Original CTE should now use filtered orders table"
|
||||
);
|
||||
|
||||
// Verify the main query uses the filtered user table
|
||||
assert!(
|
||||
filtered_sql.contains("FROM filtered_u u"),
|
||||
"Main query should use filtered users table"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("JOIN order_summary os"),
|
||||
"Main query should still join with the original (but modified) CTE"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_row_level_filtering_with_subqueries() {
|
||||
// Query with subqueries
|
||||
let sql = "
|
||||
SELECT
|
||||
u.id,
|
||||
u.name,
|
||||
(SELECT COUNT(*) FROM orders o WHERE o.user_id = u.id) as order_count
|
||||
FROM
|
||||
users u
|
||||
WHERE
|
||||
u.status = 'active'
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM orders o2
|
||||
WHERE o2.user_id = u.id AND o2.status = 'completed'
|
||||
)
|
||||
";
|
||||
|
||||
// Create filters for both tables
|
||||
let mut table_filters = HashMap::new();
|
||||
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
|
||||
table_filters.insert(
|
||||
"orders".to_string(),
|
||||
"created_at > '2023-01-01'".to_string(),
|
||||
);
|
||||
|
||||
// Test row level filtering with subqueries
|
||||
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
|
||||
assert!(result.is_ok(), "Should succeed with subqueries");
|
||||
|
||||
let filtered_sql = result.unwrap();
|
||||
println!("Subquery Filtered SQL: {}\n", filtered_sql);
|
||||
|
||||
// Check CTEs are created
|
||||
assert!(filtered_sql.contains("filtered_u AS"));
|
||||
assert!(filtered_sql.contains("filtered_o AS"));
|
||||
|
||||
// Check that the main table is filtered
|
||||
assert!(
|
||||
filtered_sql.contains("FROM filtered_u u"),
|
||||
"Should filter the main users table"
|
||||
);
|
||||
|
||||
// Check that subqueries are filtered
|
||||
assert!(
|
||||
filtered_sql.contains("FROM filtered_o o WHERE"),
|
||||
"Should filter orders in the scalar subquery"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("FROM filtered_o o2 WHERE"),
|
||||
"Should filter orders in the EXISTS subquery"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_row_level_filtering_with_schema_qualified_tables_and_mixed_references() {
|
||||
// Query with schema-qualified tables and mixed references
|
||||
let sql = "
|
||||
SELECT
|
||||
u.id,
|
||||
u.name,
|
||||
o.order_id,
|
||||
p.name as product_name -- Changed from schema2.products.name
|
||||
FROM
|
||||
schema1.users u
|
||||
JOIN
|
||||
schema1.orders o ON u.id = o.user_id
|
||||
JOIN
|
||||
schema2.products p ON o.product_id = p.id -- Used alias p here
|
||||
";
|
||||
|
||||
// Create filters for the tables (using just the base table names)
|
||||
let mut table_filters = HashMap::new();
|
||||
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
|
||||
table_filters.insert("orders".to_string(), "status = 'active'".to_string());
|
||||
table_filters.insert("products".to_string(), "company_id = 456".to_string());
|
||||
|
||||
// Test row level filtering with schema-qualified tables
|
||||
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Should succeed with schema-qualified tables"
|
||||
);
|
||||
|
||||
let filtered_sql = result.unwrap();
|
||||
println!("Schema Qualified Filtered SQL: {}\n", filtered_sql);
|
||||
|
||||
|
||||
// Check that all tables are filtered correctly with schema preserved in CTE definition
|
||||
assert!(
|
||||
filtered_sql.contains("filtered_u AS (SELECT * FROM schema1.users WHERE tenant_id = 123)"),
|
||||
"Should include schema in the filtered users CTE"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("filtered_o AS (SELECT * FROM schema1.orders WHERE status = 'active')"),
|
||||
"Should include schema in the filtered orders CTE"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("filtered_p AS (SELECT * FROM schema2.products WHERE company_id = 456)"),
|
||||
"Should include schema in the filtered products CTE"
|
||||
);
|
||||
|
||||
// Check that references are updated correctly using the aliases
|
||||
assert!(
|
||||
filtered_sql.contains("FROM filtered_u u"),
|
||||
"Should update users reference using alias u"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("JOIN filtered_o o"),
|
||||
"Should update orders reference using alias o"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("JOIN filtered_p p"),
|
||||
"Should update products reference using alias p"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_row_level_filtering_with_nested_subqueries() {
|
||||
// Query with nested subqueries
|
||||
let sql = "
|
||||
SELECT
|
||||
u.id,
|
||||
u.name,
|
||||
(
|
||||
SELECT COUNT(*)
|
||||
FROM orders o
|
||||
WHERE o.user_id = u.id AND o.status IN (
|
||||
SELECT status_code -- Changed from status
|
||||
FROM order_statuses os -- Added alias os
|
||||
WHERE os.is_complete = true -- Used alias os
|
||||
)
|
||||
) as completed_orders
|
||||
FROM
|
||||
users u
|
||||
";
|
||||
|
||||
// Create filters for tables
|
||||
let mut table_filters = HashMap::new();
|
||||
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
|
||||
table_filters.insert(
|
||||
"orders".to_string(),
|
||||
"created_at > '2023-01-01'".to_string(),
|
||||
);
|
||||
table_filters.insert("order_statuses".to_string(), "company_id = 456".to_string());
|
||||
|
||||
// Test row level filtering with nested subqueries
|
||||
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
|
||||
assert!(result.is_ok(), "Should succeed with nested subqueries");
|
||||
|
||||
let filtered_sql = result.unwrap();
|
||||
println!("Nested Subquery Filtered SQL: {}\n", filtered_sql);
|
||||
|
||||
|
||||
// Check all tables are filtered using their aliases
|
||||
assert!(
|
||||
filtered_sql.contains("FROM filtered_u u"),
|
||||
"Should filter main users table"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("FROM filtered_o o"),
|
||||
"Should filter orders in subquery"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("FROM filtered_os os"),
|
||||
"Should filter order_statuses in nested subquery"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_row_level_filtering_preserves_comments() {
|
||||
// Query with comments
|
||||
let sql = "
|
||||
-- Main query to get user data
|
||||
SELECT
|
||||
u.id, -- User ID
|
||||
u.name, -- User name
|
||||
o.amount /* Order amount */
|
||||
FROM
|
||||
users u -- Users table
|
||||
JOIN
|
||||
orders o ON u.id = o.user_id -- Join with orders
|
||||
WHERE
|
||||
u.status = 'active' -- Only active users
|
||||
";
|
||||
|
||||
// Create filters for tables
|
||||
let mut table_filters = HashMap::new();
|
||||
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
|
||||
table_filters.insert(
|
||||
"orders".to_string(),
|
||||
"created_at > '2023-01-01'".to_string(),
|
||||
);
|
||||
|
||||
// Test row level filtering with comments
|
||||
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
|
||||
assert!(result.is_ok(), "Should succeed with comments");
|
||||
|
||||
let filtered_sql = result.unwrap();
|
||||
println!("Comments Filtered SQL: {}\n", filtered_sql);
|
||||
|
||||
|
||||
// Check filters are applied
|
||||
assert!(
|
||||
filtered_sql.contains("filtered_u AS"),
|
||||
"Should add filtered users CTE"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("filtered_o AS"),
|
||||
"Should add filtered orders CTE"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("tenant_id = 123"),
|
||||
"Should apply users filter"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("created_at > '2023-01-01'"),
|
||||
"Should apply orders filter"
|
||||
);
|
||||
// Comment preservation depends heavily on the parser; basic check:
|
||||
assert!(filtered_sql.contains("--") || filtered_sql.contains("/*"), "Should attempt to preserve some comments");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_row_level_filtering_with_limit_offset() {
|
||||
// Query with LIMIT and OFFSET
|
||||
let sql = "
|
||||
SELECT
|
||||
u.id,
|
||||
u.name
|
||||
FROM
|
||||
users u
|
||||
ORDER BY
|
||||
u.created_at DESC
|
||||
LIMIT 10
|
||||
OFFSET 20
|
||||
";
|
||||
|
||||
// Create filter for users table
|
||||
let mut table_filters = HashMap::new();
|
||||
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
|
||||
|
||||
// Test row level filtering with LIMIT and OFFSET
|
||||
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
|
||||
assert!(result.is_ok(), "Should succeed with LIMIT and OFFSET");
|
||||
|
||||
let filtered_sql = result.unwrap();
|
||||
println!("Limit/Offset Filtered SQL: {}\n", filtered_sql);
|
||||
|
||||
|
||||
// Check that filter is applied
|
||||
assert!(
|
||||
filtered_sql.contains("FROM filtered_u u"),
|
||||
"Should filter users table"
|
||||
);
|
||||
|
||||
// Check that LIMIT and OFFSET are preserved
|
||||
// Note: sqlparser might move these clauses, check for their presence anywhere
|
||||
assert!(
|
||||
filtered_sql.contains("LIMIT 10"),
|
||||
"Should preserve LIMIT clause"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("OFFSET 20"),
|
||||
"Should preserve OFFSET clause"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_row_level_filtering_with_multiple_filters_per_table() {
|
||||
// Simple query with two tables
|
||||
let sql = "SELECT u.id, o.amount FROM users u JOIN orders o ON u.id = o.user_id";
|
||||
|
||||
// Define multiple filter conditions for each table
|
||||
let user_filter = "tenant_id = 123 AND status = 'active'";
|
||||
let order_filter = "created_at > '2023-01-01' AND amount > 0";
|
||||
|
||||
let mut table_filters = HashMap::new();
|
||||
table_filters.insert("users".to_string(), user_filter.to_string());
|
||||
table_filters.insert("orders".to_string(), order_filter.to_string());
|
||||
|
||||
// Test row level filtering
|
||||
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Should succeed with multiple filters per table"
|
||||
);
|
||||
|
||||
let filtered_sql = result.unwrap();
|
||||
println!("Multi-Filter SQL: {}\n", filtered_sql);
|
||||
|
||||
// Check that the filter conditions are correctly applied within the CTEs
|
||||
assert!(
|
||||
filtered_sql.contains(&format!("SELECT * FROM users WHERE {}", user_filter)),
|
||||
"Should apply multiple conditions for users in CTE"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains(&format!("SELECT * FROM orders WHERE {}", order_filter)),
|
||||
"Should apply multiple conditions for orders in CTE"
|
||||
);
|
||||
assert!(filtered_sql.contains("FROM filtered_u u"));
|
||||
assert!(filtered_sql.contains("JOIN filtered_o o"));
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_row_level_filtering_with_complex_expressions() {
|
||||
// Query with complex expressions in join conditions, select list, and where clause
|
||||
let sql = "
|
||||
SELECT
|
||||
u.id,
|
||||
CASE WHEN o.amount > 100 THEN 'High Value' ELSE 'Standard' END as order_type,
|
||||
(SELECT COUNT(*) FROM orders o2 WHERE o2.user_id = u.id) as order_count
|
||||
FROM
|
||||
users u
|
||||
LEFT JOIN
|
||||
orders o ON u.id = o.user_id AND o.created_at BETWEEN CURRENT_DATE - INTERVAL '30' DAY AND CURRENT_DATE
|
||||
WHERE
|
||||
u.created_at > CURRENT_DATE - INTERVAL '1' YEAR
|
||||
AND (
|
||||
u.status = 'active'
|
||||
OR EXISTS (SELECT 1 FROM orders o3 WHERE o3.user_id = u.id AND o3.amount > 1000)
|
||||
)
|
||||
";
|
||||
|
||||
// Create filters for the tables
|
||||
let mut table_filters = HashMap::new();
|
||||
table_filters.insert("users".to_string(), "tenant_id = 123".to_string());
|
||||
table_filters.insert(
|
||||
"orders".to_string(),
|
||||
"created_at > '2023-01-01'".to_string(),
|
||||
);
|
||||
|
||||
// Test row level filtering
|
||||
let result = apply_row_level_filters(sql.to_string(), table_filters).await;
|
||||
assert!(result.is_ok(), "Should succeed with complex expressions");
|
||||
|
||||
let filtered_sql = result.unwrap();
|
||||
println!("Complex Expr Filtered SQL: {}\n", filtered_sql);
|
||||
|
||||
|
||||
// Verify that all table references are filtered correctly using aliases
|
||||
assert!(
|
||||
filtered_sql.contains("FROM filtered_u u"),
|
||||
"Should filter main users reference"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("LEFT JOIN filtered_o o ON"),
|
||||
"Should filter main orders reference in LEFT JOIN"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("FROM filtered_o o2 WHERE"),
|
||||
"Should filter orders in subquery"
|
||||
);
|
||||
assert!(
|
||||
filtered_sql.contains("FROM filtered_o o3 WHERE"),
|
||||
"Should filter orders in EXISTS subquery"
|
||||
);
|
||||
}
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue