mirror of https://github.com/buster-so/buster.git
3253 lines
121 KiB
Rust
3253 lines
121 KiB
Rust
use sql_analyzer::{analyze_query, SqlAnalyzerError, JoinInfo};
|
|
use sql_analyzer::types::TableKind;
|
|
use tokio;
|
|
use std::collections::HashSet;
|
|
|
|
#[tokio::test]
|
|
async fn test_simple_query() {
|
|
let sql = "SELECT u.id, u.name FROM schema.users u";
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
assert_eq!(result.tables.len(), 1);
|
|
assert_eq!(result.joins.len(), 0);
|
|
assert_eq!(result.ctes.len(), 0);
|
|
|
|
let table = &result.tables[0];
|
|
assert_eq!(table.database_identifier, None);
|
|
assert_eq!(table.schema_identifier, Some("schema".to_string()));
|
|
assert_eq!(table.table_identifier, "users");
|
|
assert_eq!(table.alias, Some("u".to_string()));
|
|
|
|
let columns_vec: Vec<_> = table.columns.iter().collect();
|
|
assert!(
|
|
columns_vec.len() == 2,
|
|
"Expected 2 columns, got {}",
|
|
columns_vec.len()
|
|
);
|
|
assert!(table.columns.contains("id"), "Missing 'id' column");
|
|
assert!(table.columns.contains("name"), "Missing 'name' column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_joins() {
|
|
let sql =
|
|
"SELECT u.id, o.order_id FROM schema.users u JOIN schema.orders o ON u.id = o.user_id";
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
assert_eq!(result.tables.len(), 2);
|
|
assert!(result.joins.len() > 0, "Should detect at least one join");
|
|
|
|
let table_names: Vec<String> = result
|
|
.tables
|
|
.iter()
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
assert!(table_names.contains(&"users".to_string()));
|
|
assert!(table_names.contains(&"orders".to_string()));
|
|
|
|
let join_exists = result.joins.iter().any(|join| {
|
|
(join.left_table == "users" && join.right_table == "orders")
|
|
|| (join.left_table == "orders" && join.right_table == "users")
|
|
});
|
|
assert!(
|
|
join_exists,
|
|
"Expected to find a join between tables users and orders"
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_cte_query() {
|
|
let sql = "WITH user_orders AS (
|
|
SELECT u.id, o.order_id
|
|
FROM schema.users u
|
|
JOIN schema.orders o ON u.id = o.user_id
|
|
)
|
|
SELECT uo.id, uo.order_id FROM user_orders uo";
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
println!("Result: {:?}", result);
|
|
|
|
assert_eq!(result.ctes.len(), 1);
|
|
let cte = &result.ctes[0];
|
|
assert_eq!(cte.name, "user_orders");
|
|
assert_eq!(cte.summary.tables.len(), 2);
|
|
assert_eq!(cte.summary.joins.len(), 1);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_vague_references() {
|
|
// First test: Using a table without schema/db
|
|
let sql = "SELECT u.id FROM users u";
|
|
let result = analyze_query(sql.to_string()).await;
|
|
|
|
// Validate that any attempt to use a table without schema results in error
|
|
assert!(
|
|
result.is_err(),
|
|
"Using 'users' without schema/db identifier should fail"
|
|
);
|
|
|
|
if let Err(SqlAnalyzerError::VagueReferences(msg)) = result {
|
|
println!("Error message for users test: {}", msg);
|
|
assert!(
|
|
msg.contains("users"),
|
|
"Error should mention 'users' table: {}",
|
|
msg
|
|
);
|
|
} else {
|
|
panic!("Expected VagueReferences error, got: {:?}", result);
|
|
}
|
|
|
|
// Second test: Using unqualified column
|
|
let sql = "SELECT id FROM schema.users";
|
|
let result = analyze_query(sql.to_string()).await;
|
|
|
|
// Validate that unqualified column references result in error
|
|
assert!(
|
|
result.is_err(),
|
|
"Using unqualified 'id' column should fail"
|
|
);
|
|
|
|
if let Err(SqlAnalyzerError::VagueReferences(msg)) = result {
|
|
println!("Error message for id test: {}", msg);
|
|
assert!(
|
|
msg.contains("id"),
|
|
"Error should mention 'id' column: {}",
|
|
msg
|
|
);
|
|
} else {
|
|
panic!("Expected VagueReferences error, got: {:?}", result);
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_fully_qualified_query() {
|
|
let sql = "SELECT u.id, u.name FROM database.schema.users u";
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
assert_eq!(result.tables.len(), 1);
|
|
let table = &result.tables[0];
|
|
assert_eq!(table.database_identifier, Some("database".to_string()));
|
|
assert_eq!(table.schema_identifier, Some("schema".to_string()));
|
|
assert_eq!(table.table_identifier, "users");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_complex_cte_lineage() {
|
|
let sql = "WITH
|
|
users_cte AS (
|
|
SELECT u.id, u.name FROM schema.users u
|
|
)
|
|
SELECT uc.id, uc.name FROM users_cte uc";
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
assert_eq!(result.ctes.len(), 1);
|
|
let cte = &result.ctes[0];
|
|
assert_eq!(cte.name, "users_cte");
|
|
assert_eq!(cte.summary.tables.len(), 1);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_invalid_sql() {
|
|
let sql = "SELECT * FRM users";
|
|
let result = analyze_query(sql.to_string()).await;
|
|
|
|
|
|
assert!(result.is_err());
|
|
if let Err(SqlAnalyzerError::ParseError(msg)) = result {
|
|
assert!(msg.contains("Expected") || msg.contains("syntax error"));
|
|
} else {
|
|
panic!("Expected ParseError, got: {:?}", result);
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_analysis_nested_subqueries_as_join() {
|
|
let sql = r#"
|
|
WITH main_data AS (
|
|
SELECT
|
|
t1.col1,
|
|
t2.col2,
|
|
t1.id as t1_id,
|
|
c.id as c_id
|
|
FROM db1.schema1.tableA t1
|
|
JOIN db1.schema1.tableB t2 ON t1.id = t2.a_id
|
|
LEFT JOIN db1.schema2.tableC c ON c.id = t1.id
|
|
WHERE t1.status = 'active'
|
|
)
|
|
SELECT
|
|
md.col1,
|
|
COUNT(md.c_id) as sub_count
|
|
FROM
|
|
main_data md
|
|
WHERE md.col1 > 100
|
|
GROUP BY md.col1;
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string())
|
|
.await
|
|
.expect("Analysis failed for nested query rewritten as JOIN in CTE");
|
|
|
|
println!("Result: {:?}", result);
|
|
|
|
assert_eq!(result.ctes.len(), 1, "Should detect 1 CTE");
|
|
let main_cte = &result.ctes[0];
|
|
assert_eq!(main_cte.name, "main_data");
|
|
|
|
assert_eq!(main_cte.summary.joins.len(), 2, "Should detect 2 joins inside the CTE summary");
|
|
|
|
let join1_exists = main_cte.summary.joins.iter().any(|j|
|
|
(j.left_table == "tableA" && j.right_table == "tableB") || (j.left_table == "tableB" && j.right_table == "tableA")
|
|
);
|
|
let join2_exists = main_cte.summary.joins.iter().any(|j|
|
|
(j.left_table == "tableB" && j.right_table == "tableC") || (j.left_table == "tableC" && j.right_table == "tableB")
|
|
);
|
|
assert!(join1_exists, "Join between tableA and tableB not found in CTE summary");
|
|
assert!(join2_exists, "Join between tableB and tableC not found in CTE summary");
|
|
|
|
assert_eq!(result.joins.len(), 0, "Overall query should have no direct joins");
|
|
|
|
assert_eq!(result.tables.len(), 4, "Should detect all 3 base tables (A, B, C) and the CTE");
|
|
|
|
let table_names: std::collections::HashSet<String> = result
|
|
.tables
|
|
.iter()
|
|
.map(|t| format!("{}.{}.{}", t.database_identifier.as_deref().unwrap_or(""), t.schema_identifier.as_deref().unwrap_or(""), t.table_identifier))
|
|
.collect();
|
|
|
|
assert!(table_names.contains(&"db1.schema1.tableA".to_string()), "Missing tableA");
|
|
assert!(table_names.contains(&"db1.schema1.tableB".to_string()), "Missing tableB");
|
|
assert!(table_names.contains(&"db1.schema2.tableC".to_string()), "Missing tableC");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_analysis_union_all() {
|
|
let sql = r#"
|
|
SELECT u.id, u.name FROM db1.schema1.users u WHERE u.status = 'active'
|
|
UNION ALL
|
|
SELECT e.user_id, e.username FROM db2.schema1.employees e WHERE e.role = 'manager'
|
|
UNION ALL
|
|
SELECT c.pk, c.full_name FROM db1.schema2.contractors c WHERE c.end_date IS NULL;
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string())
|
|
.await
|
|
.expect("Analysis failed for UNION ALL test");
|
|
|
|
assert_eq!(result.ctes.len(), 0, "Should be no CTEs");
|
|
assert_eq!(result.joins.len(), 0, "Should be no joins");
|
|
assert_eq!(result.tables.len(), 3, "Should detect all 3 tables across UNIONs");
|
|
|
|
let table_names: std::collections::HashSet<String> = result
|
|
.tables
|
|
.iter()
|
|
.map(|t| {
|
|
format!(
|
|
"{}.{}.{}",
|
|
t.database_identifier.as_deref().unwrap_or(""),
|
|
t.schema_identifier.as_deref().unwrap_or(""),
|
|
t.table_identifier
|
|
)
|
|
})
|
|
.collect();
|
|
|
|
assert!(
|
|
table_names.contains(&"db1.schema1.users".to_string()),
|
|
"Missing users table"
|
|
);
|
|
assert!(
|
|
table_names.contains(&"db2.schema1.employees".to_string()),
|
|
"Missing employees table"
|
|
);
|
|
assert!(
|
|
table_names.contains(&"db1.schema2.contractors".to_string()),
|
|
"Missing contractors table"
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_analysis_combined_complexity() {
|
|
let sql = r#"
|
|
WITH active_users AS (
|
|
SELECT u.id, u.name FROM db1.schema1.users u WHERE u.status = 'active'
|
|
),
|
|
recent_orders AS (
|
|
SELECT ro.user_id, MAX(ro.order_date) as last_order_date
|
|
FROM db1.schema1.orders ro
|
|
GROUP BY ro.user_id
|
|
)
|
|
SELECT au.name, ro.last_order_date
|
|
FROM active_users au
|
|
JOIN recent_orders ro ON au.id = ro.user_id
|
|
JOIN (
|
|
SELECT p_sub.item_id, p_sub.category FROM db2.schema1.products p_sub WHERE p_sub.is_available = true
|
|
) p ON p.item_id = ro.user_id
|
|
WHERE au.id IN (SELECT sl.user_id FROM db1.schema2.special_list sl)
|
|
UNION ALL
|
|
SELECT e.name, e.hire_date
|
|
FROM db2.schema1.employees e
|
|
WHERE e.department = 'Sales';
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
println!("Result: {:?}", result);
|
|
|
|
// We'll check that we have at least the 2 explicit CTEs
|
|
let cte_names: Vec<_> = result.ctes.iter()
|
|
.map(|cte| cte.name.clone())
|
|
.filter(|name| name == "active_users" || name == "recent_orders")
|
|
.collect();
|
|
|
|
assert_eq!(cte_names.len(), 2, "Should detect the 'active_users' and 'recent_orders' CTEs");
|
|
assert_eq!(result.joins.len(), 2, "Should detect 2 joins in the main query");
|
|
}
|
|
|
|
// --- New Tests Start Here ---
|
|
|
|
#[tokio::test]
|
|
async fn test_multiple_chained_ctes() {
|
|
let sql = r#"
|
|
WITH
|
|
cte1 AS (
|
|
SELECT p.id, p.category
|
|
FROM db1.schema1.products p
|
|
),
|
|
cte2 AS (
|
|
SELECT c1.id, c1.category, o.order_date
|
|
FROM cte1 c1
|
|
JOIN db1.schema1.orders o ON c1.id = o.product_id
|
|
WHERE o.status = 'completed'
|
|
)
|
|
SELECT c2.category, COUNT(c2.id) as product_count
|
|
FROM cte2 c2
|
|
GROUP BY c2.category;
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
println!("Result CTEs: {:?}", result.ctes);
|
|
println!("Result tables: {:?}", result.tables);
|
|
|
|
// Count the named CTEs only (excluding subquery CTEs)
|
|
let named_ctes: Vec<_> = result.ctes.iter()
|
|
.filter(|c| c.name == "cte1" || c.name == "cte2")
|
|
.collect();
|
|
|
|
assert_eq!(named_ctes.len(), 2, "Should detect both cte1 and cte2");
|
|
|
|
// The tables should include at least products, orders, and cte2
|
|
assert!(result.tables.len() >= 3, "Should detect at least products, orders, and cte2");
|
|
|
|
// Check that expected tables are present
|
|
let table_ids: HashSet<_> = result.tables.iter().map(|t| t.table_identifier.as_str()).collect();
|
|
assert!(table_ids.contains("products"), "Should find products table");
|
|
assert!(table_ids.contains("orders"), "Should find orders table");
|
|
assert!(table_ids.contains("cte2"), "Should find cte2 as a referenced table");
|
|
|
|
// Find the cte2 in the ctes list
|
|
let cte2_opt = result.ctes.iter().find(|c| c.name == "cte2");
|
|
assert!(cte2_opt.is_some(), "Should find cte2 in CTEs list");
|
|
|
|
// Main query has no direct joins
|
|
assert_eq!(result.joins.len(), 0, "Main query should have no direct joins");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_complex_where_clause() {
|
|
let sql = r#"
|
|
SELECT
|
|
u.name, o.order_total
|
|
FROM
|
|
db1.schema1.users u
|
|
JOIN
|
|
db1.schema1.orders o ON u.id = o.user_id
|
|
WHERE
|
|
(u.signup_date > '2023-01-01' AND u.status = 'active')
|
|
OR (o.order_total > 1000 AND lower(u.country) = 'ca');
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
assert_eq!(result.tables.len(), 2);
|
|
assert_eq!(result.joins.len(), 1);
|
|
|
|
// Check if columns used in WHERE are captured (basic check)
|
|
let users_table = result.tables.iter().find(|t| t.table_identifier == "users").unwrap();
|
|
assert!(users_table.columns.contains("id"));
|
|
assert!(users_table.columns.contains("signup_date"));
|
|
assert!(users_table.columns.contains("status"));
|
|
assert!(users_table.columns.contains("country")); // Used in lower(u.country)
|
|
|
|
let orders_table = result.tables.iter().find(|t| t.table_identifier == "orders").unwrap();
|
|
assert!(orders_table.columns.contains("user_id"));
|
|
assert!(orders_table.columns.contains("order_total"));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_window_function() {
|
|
// Note: The analyzer primarily tracks table/column usage, not the specifics of window function logic.
|
|
let sql = r#"
|
|
SELECT
|
|
product_id,
|
|
order_date,
|
|
ROW_NUMBER() OVER (PARTITION BY customer_id ORDER BY order_date DESC) as rn
|
|
FROM
|
|
db1.schema2.order_items oi
|
|
WHERE oi.quantity > 0;
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
assert_eq!(result.tables.len(), 1);
|
|
assert_eq!(result.joins.len(), 0);
|
|
assert_eq!(result.ctes.len(), 0);
|
|
|
|
let table = &result.tables[0];
|
|
assert_eq!(table.table_identifier, "order_items");
|
|
assert_eq!(table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(table.schema_identifier, Some("schema2".to_string()));
|
|
|
|
// Verify columns used in SELECT, WHERE, PARTITION BY, ORDER BY are captured
|
|
assert!(table.columns.contains("product_id"));
|
|
assert!(table.columns.contains("order_date"));
|
|
assert!(table.columns.contains("customer_id")); // From PARTITION BY
|
|
assert!(table.columns.contains("quantity")); // From WHERE
|
|
}
|
|
|
|
// ----- New Complex Test Cases -----
|
|
|
|
#[tokio::test]
|
|
async fn test_complex_nested_ctes_with_multilevel_references() {
|
|
let sql = r#"
|
|
WITH
|
|
level1 AS (
|
|
SELECT e.id, e.name, e.dept_id FROM db1.schema1.employees e
|
|
),
|
|
level2 AS (
|
|
SELECT l1.id, l1.name, d.dept_name
|
|
FROM level1 l1
|
|
JOIN db1.schema1.departments d ON l1.dept_id = d.id
|
|
),
|
|
level3 AS (
|
|
SELECT
|
|
l2.id,
|
|
l2.name,
|
|
l2.dept_name,
|
|
(SELECT COUNT(*) FROM db1.schema1.projects p WHERE p.dept_id = l1.dept_id) as project_count
|
|
FROM level2 l2
|
|
JOIN level1 l1 ON l2.id = l1.id
|
|
)
|
|
SELECT
|
|
l3.id,
|
|
l3.name,
|
|
l3.dept_name,
|
|
l3.project_count,
|
|
s.salary_amount
|
|
FROM level3 l3
|
|
LEFT JOIN db1.schema1.salaries s ON l3.id = s.employee_id
|
|
WHERE l3.project_count > 0
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
println!("Complex nested CTE result: {:?}", result);
|
|
|
|
// Check that all CTEs are detected
|
|
let cte_names: Vec<_> = result.ctes.iter()
|
|
.map(|cte| cte.name.clone())
|
|
.filter(|name| name == "level1" || name == "level2" || name == "level3")
|
|
.collect();
|
|
|
|
assert_eq!(cte_names.len(), 3, "Should detect all three CTEs");
|
|
|
|
// Check base tables (employees, departments, projects, salaries)
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"employees".to_string()), "Should detect employees table");
|
|
assert!(base_tables.contains(&"departments".to_string()), "Should detect departments table");
|
|
assert!(base_tables.contains(&"projects".to_string()), "Should detect projects table");
|
|
assert!(base_tables.contains(&"salaries".to_string()), "Should detect salaries table");
|
|
|
|
// Check joins
|
|
assert!(!result.joins.is_empty(), "Should detect at least one join");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_complex_subqueries_in_different_clauses() {
|
|
// Simplified version with fewer deeply nested subqueries
|
|
let sql = r#"
|
|
-- Use CTEs instead of deeply nested subqueries
|
|
WITH user_orders AS (
|
|
SELECT o.id, o.user_id, o.order_date FROM db1.schema1.orders o
|
|
),
|
|
user_items AS (
|
|
SELECT oi.order_id, oi.item_id FROM db1.schema1.order_items oi
|
|
),
|
|
verified_users AS (
|
|
SELECT um.user_id FROM db1.schema1.user_metadata um WHERE um.is_verified = true
|
|
)
|
|
SELECT
|
|
u.id,
|
|
u.name,
|
|
(SELECT MAX(uo.order_date) FROM user_orders uo WHERE uo.user_id = u.id) as last_order,
|
|
(SELECT SUM(i.amount) FROM db1.schema1.items i JOIN user_items ui ON i.item_id = ui.item_id
|
|
WHERE ui.order_id IN (SELECT uo2.id FROM user_orders uo2 WHERE uo2.user_id = u.id)
|
|
) as total_amount
|
|
FROM db1.schema1.users u
|
|
WHERE
|
|
u.status = 'active'
|
|
AND EXISTS (SELECT 1 FROM db1.schema1.payments p WHERE p.user_id = u.id)
|
|
AND u.id IN (SELECT vu.user_id FROM verified_users vu)
|
|
ORDER BY
|
|
(SELECT COUNT(*) FROM user_orders uo3 WHERE uo3.user_id = u.id) DESC
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
println!("Complex subqueries result: {:?}", result);
|
|
|
|
// We should detect several CTEs - both explicit ones and implicit subquery CTEs
|
|
assert!(result.ctes.len() >= 3, "Should detect both explicit CTEs and subquery CTEs");
|
|
|
|
// We should detect all base tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"users".to_string()), "Should detect users table");
|
|
assert!(base_tables.contains(&"orders".to_string()), "Should detect orders table");
|
|
assert!(base_tables.contains(&"items".to_string()), "Should detect items table");
|
|
assert!(base_tables.contains(&"order_items".to_string()), "Should detect order_items table");
|
|
assert!(base_tables.contains(&"payments".to_string()), "Should detect payments table");
|
|
assert!(base_tables.contains(&"user_metadata".to_string()), "Should detect user_metadata table");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_recursive_cte() {
|
|
// Testing with a recursive CTE for hierarchical data
|
|
// Note: Some SQL dialects use RECURSIVE keyword, others don't
|
|
let sql = r#"
|
|
WITH employee_hierarchy AS (
|
|
-- Base case: start with CEO (employee with no manager)
|
|
SELECT e.id, e.name, NULL as manager_id, 0 as level
|
|
FROM db1.schema1.employees e
|
|
WHERE e.manager_id IS NULL
|
|
|
|
UNION ALL
|
|
|
|
-- Recursive case: get all employees who report to someone in the hierarchy
|
|
SELECT e.id, e.name, e.manager_id, eh.level + 1
|
|
FROM db1.schema1.employees e
|
|
JOIN employee_hierarchy eh ON e.manager_id = eh.id
|
|
)
|
|
SELECT
|
|
eh.id,
|
|
eh.name,
|
|
eh.level,
|
|
d.dept_name
|
|
FROM employee_hierarchy eh
|
|
JOIN db1.schema1.departments d ON eh.id = d.manager_id
|
|
ORDER BY eh.level, eh.name
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
println!("Recursive CTE result: {:?}", result);
|
|
|
|
// Check that the recursive CTE is detected
|
|
let cte_names: Vec<_> = result.ctes.iter()
|
|
.map(|cte| cte.name.clone())
|
|
.collect();
|
|
|
|
assert!(cte_names.contains(&"employee_hierarchy".to_string()), "Should detect the recursive CTE");
|
|
|
|
// Check base tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"employees".to_string()), "Should detect employees table");
|
|
assert!(base_tables.contains(&"departments".to_string()), "Should detect departments table");
|
|
|
|
// Check joins in the main query
|
|
assert!(!result.joins.is_empty(), "Should detect at least one join");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_complex_window_functions() {
|
|
let sql = r#"
|
|
WITH monthly_sales AS (
|
|
SELECT
|
|
p.product_id,
|
|
p.category_id,
|
|
DATE_TRUNC('month', s.sale_date) as month,
|
|
SUM(s.quantity * s.price) as monthly_revenue
|
|
FROM db1.schema1.products p
|
|
JOIN db1.schema1.sales s ON p.product_id = s.product_id
|
|
GROUP BY p.product_id, p.category_id, DATE_TRUNC('month', s.sale_date)
|
|
)
|
|
SELECT
|
|
ms.product_id,
|
|
c.category_name,
|
|
ms.month,
|
|
ms.monthly_revenue,
|
|
SUM(ms.monthly_revenue) OVER (
|
|
PARTITION BY ms.product_id
|
|
ORDER BY ms.month
|
|
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
|
|
) as cumulative_revenue,
|
|
RANK() OVER (
|
|
PARTITION BY ms.category_id, ms.month
|
|
ORDER BY ms.monthly_revenue DESC
|
|
) as category_rank,
|
|
LAG(ms.monthly_revenue, 1) OVER (
|
|
PARTITION BY ms.product_id
|
|
ORDER BY ms.month
|
|
) as prev_month_revenue,
|
|
CASE
|
|
WHEN LAG(ms.monthly_revenue, 1) OVER (PARTITION BY ms.product_id ORDER BY ms.month) IS NULL THEN NULL
|
|
ELSE (ms.monthly_revenue - LAG(ms.monthly_revenue, 1) OVER (PARTITION BY ms.product_id ORDER BY ms.month))
|
|
/ LAG(ms.monthly_revenue, 1) OVER (PARTITION BY ms.product_id ORDER BY ms.month) * 100
|
|
END as pct_change
|
|
FROM monthly_sales ms
|
|
JOIN db1.schema1.categories c ON ms.category_id = c.category_id
|
|
ORDER BY ms.product_id, ms.month
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
println!("Complex window functions result: {:?}", result);
|
|
|
|
// Check that the CTE is detected
|
|
let cte_exists = result.ctes.iter()
|
|
.any(|cte| cte.name == "monthly_sales");
|
|
|
|
assert!(cte_exists, "Should detect the monthly_sales CTE");
|
|
|
|
// Check base tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"products".to_string()), "Should detect products table");
|
|
assert!(base_tables.contains(&"sales".to_string()), "Should detect sales table");
|
|
assert!(base_tables.contains(&"categories".to_string()), "Should detect categories table");
|
|
|
|
// Check columns for window functions
|
|
let monthly_sales_table = result.tables.iter()
|
|
.find(|t| t.table_identifier == "monthly_sales");
|
|
|
|
assert!(monthly_sales_table.is_some(), "Should find monthly_sales as a table");
|
|
if let Some(ms_table) = monthly_sales_table {
|
|
assert!(ms_table.columns.contains("product_id"), "Should detect product_id column");
|
|
assert!(ms_table.columns.contains("category_id"), "Should detect category_id column");
|
|
assert!(ms_table.columns.contains("month"), "Should detect month column");
|
|
assert!(ms_table.columns.contains("monthly_revenue"), "Should detect monthly_revenue column");
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_pivot_query() {
|
|
// This test simulates a pivot query structure
|
|
let sql = r#"
|
|
WITH sales_data AS (
|
|
SELECT
|
|
s.product_id,
|
|
DATE_TRUNC('month', s.sale_date) as month,
|
|
SUM(s.quantity) as total_sold
|
|
FROM db1.schema1.sales s
|
|
GROUP BY s.product_id, DATE_TRUNC('month', s.sale_date)
|
|
)
|
|
SELECT
|
|
p.product_name,
|
|
SUM(CASE WHEN sd.month = '2023-01-01' THEN sd.total_sold ELSE 0 END) as jan_sales,
|
|
SUM(CASE WHEN sd.month = '2023-02-01' THEN sd.total_sold ELSE 0 END) as feb_sales,
|
|
SUM(CASE WHEN sd.month = '2023-03-01' THEN sd.total_sold ELSE 0 END) as mar_sales,
|
|
SUM(CASE WHEN sd.month = '2023-04-01' THEN sd.total_sold ELSE 0 END) as apr_sales,
|
|
SUM(CASE WHEN sd.month = '2023-05-01' THEN sd.total_sold ELSE 0 END) as may_sales,
|
|
SUM(CASE WHEN sd.month = '2023-06-01' THEN sd.total_sold ELSE 0 END) as jun_sales,
|
|
SUM(sd.total_sold) as total_sales
|
|
FROM sales_data sd
|
|
JOIN db1.schema1.products p ON sd.product_id = p.product_id
|
|
GROUP BY p.product_name
|
|
HAVING SUM(sd.total_sold) > 100
|
|
ORDER BY total_sales DESC
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
println!("Pivot query result: {:?}", result);
|
|
|
|
// Check that the CTE is detected
|
|
let cte_exists = result.ctes.iter()
|
|
.any(|cte| cte.name == "sales_data");
|
|
|
|
assert!(cte_exists, "Should detect the sales_data CTE");
|
|
|
|
// Check base tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"sales".to_string()), "Should detect sales table");
|
|
assert!(base_tables.contains(&"products".to_string()), "Should detect products table");
|
|
|
|
// Check columns
|
|
let sales_data_table = result.tables.iter()
|
|
.find(|t| t.table_identifier == "sales_data");
|
|
|
|
assert!(sales_data_table.is_some(), "Should find sales_data as a table");
|
|
if let Some(sd_table) = sales_data_table {
|
|
assert!(sd_table.columns.contains("product_id"), "Should detect product_id column");
|
|
assert!(sd_table.columns.contains("month"), "Should detect month column");
|
|
assert!(sd_table.columns.contains("total_sold"), "Should detect total_sold column");
|
|
}
|
|
|
|
let products_table = result.tables.iter()
|
|
.find(|t| t.table_identifier == "products");
|
|
|
|
if let Some(p_table) = products_table {
|
|
assert!(p_table.columns.contains("product_name"), "Should detect product_name column");
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_set_operations() {
|
|
// Simplified test for set operations - focusing on UNION ALL, which is better supported
|
|
let sql = r#"
|
|
WITH active_users AS (
|
|
SELECT u.id, u.name, u.email FROM db1.schema1.users u WHERE u.status = 'active'
|
|
),
|
|
premium_users AS (
|
|
SELECT s.id, s.name, s.email FROM db1.schema1.subscriptions s
|
|
WHERE s.plan_type = 'premium' AND s.end_date > CURRENT_DATE
|
|
),
|
|
churned_users AS (
|
|
SELECT s.id, s.name, s.email FROM db1.schema1.subscriptions s
|
|
WHERE s.end_date < CURRENT_DATE
|
|
)
|
|
|
|
-- Simplified to use direct UNION ALLs instead of nested EXCEPT/INTERSECT
|
|
SELECT
|
|
u.id,
|
|
u.name,
|
|
u.email,
|
|
'active' as user_type
|
|
FROM active_users u
|
|
|
|
UNION ALL
|
|
|
|
SELECT
|
|
p.id,
|
|
p.name,
|
|
p.email,
|
|
'premium' as user_type
|
|
FROM premium_users p
|
|
|
|
UNION ALL
|
|
|
|
SELECT
|
|
c.id,
|
|
c.name,
|
|
c.email,
|
|
'churned' as user_type
|
|
FROM churned_users c
|
|
|
|
ORDER BY user_type, name
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
println!("Set operations result: {:?}", result);
|
|
|
|
// Check that all CTEs are detected
|
|
let cte_names: Vec<_> = result.ctes.iter()
|
|
.map(|cte| cte.name.clone())
|
|
.filter(|name| ["active_users", "premium_users", "churned_users"].contains(&name.as_str()))
|
|
.collect();
|
|
|
|
assert_eq!(cte_names.len(), 3, "Should detect all three CTEs");
|
|
|
|
// Check base tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"users".to_string()), "Should detect users table");
|
|
assert!(base_tables.contains(&"subscriptions".to_string()), "Should detect subscriptions table");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_self_joins_with_correlated_subqueries() {
|
|
let sql = r#"
|
|
WITH employee_managers AS (
|
|
SELECT
|
|
e.id as employee_id,
|
|
e.name as employee_name,
|
|
e.manager_id,
|
|
m.name as manager_name,
|
|
m.department_id as manager_dept_id,
|
|
(SELECT COUNT(*) FROM db1.schema1.employees e2 WHERE e2.manager_id = e.id) as direct_reports
|
|
FROM db1.schema1.employees e
|
|
LEFT JOIN db1.schema1.employees m ON e.manager_id = m.id
|
|
),
|
|
dept_stats AS (
|
|
SELECT
|
|
d.id as department_id,
|
|
d.name as department_name,
|
|
COUNT(e.id) as employee_count,
|
|
AVG(e.salary) as avg_salary,
|
|
(
|
|
SELECT STRING_AGG(em.employee_name, ', ')
|
|
FROM employee_managers em
|
|
WHERE em.manager_dept_id = d.id AND em.direct_reports > 0
|
|
) as managers_list
|
|
FROM db1.schema1.departments d
|
|
LEFT JOIN db1.schema1.employees e ON d.id = e.department_id
|
|
GROUP BY d.id, d.name
|
|
)
|
|
SELECT
|
|
em.employee_id,
|
|
em.employee_name,
|
|
em.manager_name,
|
|
ds.department_name,
|
|
em.direct_reports,
|
|
ds.employee_count,
|
|
ds.avg_salary,
|
|
CASE
|
|
WHEN em.direct_reports > 0 THEN true
|
|
ELSE false
|
|
END as is_manager,
|
|
(
|
|
SELECT MAX(p.budget)
|
|
FROM db1.schema1.projects p
|
|
WHERE p.department_id = em.manager_dept_id
|
|
) as max_project_budget
|
|
FROM employee_managers em
|
|
JOIN dept_stats ds ON em.manager_dept_id = ds.department_id
|
|
WHERE em.direct_reports > 0
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
println!("Self joins with correlated subqueries result: {:?}", result);
|
|
|
|
// Check that all CTEs are detected
|
|
let cte_names: Vec<_> = result.ctes.iter()
|
|
.map(|cte| cte.name.clone())
|
|
.filter(|name| ["employee_managers", "dept_stats"].contains(&name.as_str()))
|
|
.collect();
|
|
|
|
assert_eq!(cte_names.len(), 2, "Should detect both CTEs");
|
|
|
|
// Check self-join by verifying the employees table appears with multiple roles
|
|
let employee_roles = result.tables.iter()
|
|
.filter(|t| t.table_identifier == "employees")
|
|
.count();
|
|
|
|
assert!(employee_roles >= 1, "Should detect employees table at least once");
|
|
|
|
// Check other base tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"departments".to_string()), "Should detect departments table");
|
|
assert!(base_tables.contains(&"projects".to_string()), "Should detect projects table");
|
|
|
|
// Check that we detect joins
|
|
assert!(!result.joins.is_empty(), "Should detect joins");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_lateral_joins() {
|
|
// Test LATERAL joins functionality
|
|
let sql = r#"
|
|
WITH users_with_orders AS (
|
|
SELECT u.id, u.name, u.registered_date
|
|
FROM db1.schema1.users u
|
|
WHERE EXISTS (SELECT 1 FROM db1.schema1.orders o WHERE o.user_id = u.id)
|
|
)
|
|
SELECT
|
|
u.id as user_id,
|
|
u.name as user_name,
|
|
recent_orders.order_id,
|
|
recent_orders.order_date,
|
|
recent_orders.amount
|
|
FROM users_with_orders u
|
|
CROSS JOIN LATERAL (
|
|
SELECT o.id as order_id, o.order_date, o.total_amount as amount
|
|
FROM db1.schema1.orders o
|
|
WHERE o.user_id = u.id
|
|
ORDER BY o.order_date DESC
|
|
LIMIT 3
|
|
) recent_orders
|
|
ORDER BY u.id, recent_orders.order_date DESC
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
println!("Lateral joins result: {:?}", result);
|
|
|
|
// Check that the CTE is detected
|
|
let cte_exists = result.ctes.iter()
|
|
.any(|cte| cte.name == "users_with_orders");
|
|
|
|
assert!(cte_exists, "Should detect the users_with_orders CTE");
|
|
|
|
// Check base tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"users".to_string()), "Should detect users table");
|
|
assert!(base_tables.contains(&"orders".to_string()), "Should detect orders table");
|
|
|
|
// Check for derived table from LATERAL join
|
|
let derived_tables = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Derived)
|
|
.count();
|
|
|
|
assert!(derived_tables >= 1, "Should detect at least one derived table from LATERAL join");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_deeply_nested_derived_tables() {
|
|
// Simplified test with fewer levels of nesting and more explicit aliases
|
|
let sql = r#"
|
|
WITH
|
|
active_customers AS (
|
|
SELECT c.id, c.name, c.status, c.region
|
|
FROM db1.schema1.customers c
|
|
WHERE c.status = 'active'
|
|
),
|
|
customer_orders AS (
|
|
SELECT
|
|
o.customer_id,
|
|
o.id as order_id,
|
|
o.total_amount as order_amount,
|
|
o.status
|
|
FROM db1.schema1.orders o
|
|
WHERE o.order_date > (CURRENT_DATE - INTERVAL '1 year')
|
|
)
|
|
SELECT
|
|
summary.customer_id,
|
|
summary.region,
|
|
summary.total_spent,
|
|
summary.order_count
|
|
FROM (
|
|
-- Only one level of derived table now
|
|
SELECT
|
|
ac.id as customer_id,
|
|
ac.region,
|
|
SUM(co.order_amount) as total_spent,
|
|
COUNT(DISTINCT co.order_id) as order_count
|
|
FROM active_customers ac
|
|
JOIN customer_orders co ON co.customer_id = ac.id
|
|
WHERE co.status = 'completed'
|
|
GROUP BY ac.id, ac.region
|
|
HAVING COUNT(DISTINCT co.order_id) >= 3
|
|
) summary
|
|
WHERE summary.total_spent > 1000
|
|
ORDER BY summary.total_spent DESC
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
println!("Deeply nested derived tables result: {:?}", result);
|
|
|
|
// Check that the CTEs are detected
|
|
let cte_names: Vec<_> = result.ctes.iter()
|
|
.map(|cte| cte.name.clone())
|
|
.filter(|name| ["active_customers", "customer_orders"].contains(&name.as_str()))
|
|
.collect();
|
|
|
|
assert_eq!(cte_names.len(), 2, "Should detect both explicit CTEs");
|
|
|
|
// Check base tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"customers".to_string()), "Should detect customers table");
|
|
assert!(base_tables.contains(&"orders".to_string()), "Should detect orders table");
|
|
|
|
// Check for derived tables - we should have at least one
|
|
let derived_tables = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Derived)
|
|
.count();
|
|
|
|
assert!(derived_tables >= 1, "Should detect at least one derived table");
|
|
|
|
// Check that we can find at least one join somewhere (either in main query or in subquery summary)
|
|
let has_join = !result.joins.is_empty() ||
|
|
result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Derived)
|
|
.flat_map(|t| t.subquery_summary.as_ref())
|
|
.any(|summary| !summary.joins.is_empty());
|
|
|
|
assert!(has_join, "Should detect at least one join somewhere in the query");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_calculations_in_select() {
|
|
let sql = r#"
|
|
SELECT
|
|
p.name,
|
|
p.price * (1 - p.discount_percent) AS final_price,
|
|
p.stock_level - 5 AS adjusted_stock
|
|
FROM
|
|
db2.warehouse.products p
|
|
WHERE p.category = 'electronics';
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
assert_eq!(result.tables.len(), 1);
|
|
assert_eq!(result.joins.len(), 0);
|
|
|
|
let table = &result.tables[0];
|
|
assert_eq!(table.table_identifier, "products");
|
|
assert!(table.columns.contains("name"));
|
|
assert!(table.columns.contains("price"));
|
|
assert!(table.columns.contains("discount_percent"));
|
|
assert!(table.columns.contains("stock_level"));
|
|
assert!(table.columns.contains("category")); // From WHERE
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_date_function_usage() {
|
|
// Using DATE_TRUNC style common in PG/Snowflake
|
|
let sql = r#"
|
|
SELECT
|
|
event_id, user_id
|
|
FROM
|
|
db_logs.public.user_events ue
|
|
WHERE
|
|
DATE_TRUNC('day', ue.event_timestamp) = CURRENT_DATE;
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
assert_eq!(result.tables.len(), 1);
|
|
let table = &result.tables[0];
|
|
assert_eq!(table.table_identifier, "user_events");
|
|
|
|
// Ensure the column used within the date function is captured
|
|
assert!(table.columns.contains("event_timestamp"));
|
|
assert!(table.columns.contains("event_id"));
|
|
assert!(table.columns.contains("user_id"));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_table_valued_functions() {
|
|
// Test handling of table-valued functions
|
|
let sql = r#"
|
|
SELECT e.employee_id, f.product_name, f.sales_amount
|
|
FROM db1.schema1.employees e
|
|
CROSS JOIN db1.schema1.get_employee_sales(e.employee_id, '2023-01-01') f
|
|
WHERE e.department = 'Sales'
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// We should detect the base table
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"employees".to_string()), "Should detect employees table");
|
|
|
|
// Check if columns are detected
|
|
let employees_table = result.tables.iter().find(|t| t.table_identifier == "employees").unwrap();
|
|
assert!(employees_table.columns.contains("employee_id"), "Should detect employee_id column");
|
|
assert!(employees_table.columns.contains("department"), "Should detect department column");
|
|
|
|
// Check for at least one join (the CROSS JOIN)
|
|
assert!(!result.joins.is_empty(), "Should detect the CROSS JOIN");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_nulls_first_last_ordering() {
|
|
// Test SQL with NULLS FIRST/LAST ordering specs
|
|
let sql = r#"
|
|
SELECT c.customer_id, c.name, o.order_date
|
|
FROM db1.schema1.customers c
|
|
LEFT JOIN db1.schema1.orders o ON c.customer_id = o.customer_id
|
|
ORDER BY o.order_date DESC NULLS LAST, c.name ASC NULLS FIRST
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// We should detect both tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"customers".to_string()), "Should detect customers table");
|
|
assert!(base_tables.contains(&"orders".to_string()), "Should detect orders table");
|
|
|
|
// Check if columns are detected, including those used in ORDER BY
|
|
let customers_table = result.tables.iter().find(|t| t.table_identifier == "customers").unwrap();
|
|
assert!(customers_table.columns.contains("customer_id"), "Should detect customer_id column");
|
|
assert!(customers_table.columns.contains("name"), "Should detect name column");
|
|
|
|
let orders_table = result.tables.iter().find(|t| t.table_identifier == "orders").unwrap();
|
|
assert!(orders_table.columns.contains("order_date"), "Should detect order_date column");
|
|
|
|
// Check for the join
|
|
assert_eq!(result.joins.len(), 1, "Should detect the LEFT JOIN");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_window_function_with_complex_frame() {
|
|
// Test window function with frame specification
|
|
let sql = r#"
|
|
SELECT
|
|
p.product_id,
|
|
p.product_name,
|
|
s.date,
|
|
s.quantity,
|
|
SUM(s.quantity) OVER (
|
|
PARTITION BY p.product_id
|
|
ORDER BY s.date
|
|
RANGE BETWEEN INTERVAL '30' DAY PRECEDING AND CURRENT ROW
|
|
) AS rolling_30_day_sales
|
|
FROM db1.schema1.products p
|
|
JOIN db1.schema1.sales s ON p.product_id = s.product_id
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// We should detect both tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"products".to_string()), "Should detect products table");
|
|
assert!(base_tables.contains(&"sales".to_string()), "Should detect sales table");
|
|
|
|
// Check if columns are detected, including those used in window function
|
|
let products_table = result.tables.iter().find(|t| t.table_identifier == "products").unwrap();
|
|
assert!(products_table.columns.contains("product_id"), "Should detect product_id column");
|
|
assert!(products_table.columns.contains("product_name"), "Should detect product_name column");
|
|
|
|
let sales_table = result.tables.iter().find(|t| t.table_identifier == "sales").unwrap();
|
|
assert!(sales_table.columns.contains("date"), "Should detect date column");
|
|
assert!(sales_table.columns.contains("quantity"), "Should detect quantity column");
|
|
|
|
// Check for the join
|
|
assert_eq!(result.joins.len(), 1, "Should detect the JOIN");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_grouping_sets() {
|
|
// Test GROUPING SETS functionality
|
|
let sql = r#"
|
|
SELECT
|
|
COALESCE(p.category, 'All Categories') AS category,
|
|
COALESCE(c.region, 'All Regions') AS region,
|
|
COALESCE(TO_CHAR(s.sale_date, 'YYYY-MM'), 'All Periods') AS period,
|
|
SUM(s.amount) AS total_sales
|
|
FROM db1.schema1.sales s
|
|
JOIN db1.schema1.products p ON s.product_id = p.product_id
|
|
JOIN db1.schema1.customers c ON s.customer_id = c.customer_id
|
|
GROUP BY GROUPING SETS (
|
|
(p.category, c.region, TO_CHAR(s.sale_date, 'YYYY-MM')),
|
|
(p.category, c.region),
|
|
(p.category, TO_CHAR(s.sale_date, 'YYYY-MM')),
|
|
(c.region, TO_CHAR(s.sale_date, 'YYYY-MM')),
|
|
(p.category),
|
|
(c.region),
|
|
(TO_CHAR(s.sale_date, 'YYYY-MM')),
|
|
()
|
|
)
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// We should detect all three base tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"sales".to_string()), "Should detect sales table");
|
|
assert!(base_tables.contains(&"products".to_string()), "Should detect products table");
|
|
assert!(base_tables.contains(&"customers".to_string()), "Should detect customers table");
|
|
|
|
// Check if columns are detected, including those used in GROUPING SETS
|
|
let products_table = result.tables.iter().find(|t| t.table_identifier == "products").unwrap();
|
|
assert!(products_table.columns.contains("category"), "Should detect category column");
|
|
assert!(products_table.columns.contains("product_id"), "Should detect product_id column");
|
|
|
|
let customers_table = result.tables.iter().find(|t| t.table_identifier == "customers").unwrap();
|
|
assert!(customers_table.columns.contains("region"), "Should detect region column");
|
|
assert!(customers_table.columns.contains("customer_id"), "Should detect customer_id column");
|
|
|
|
let sales_table = result.tables.iter().find(|t| t.table_identifier == "sales").unwrap();
|
|
assert!(sales_table.columns.contains("sale_date"), "Should detect sale_date column");
|
|
assert!(sales_table.columns.contains("amount"), "Should detect amount column");
|
|
|
|
// Check for the joins
|
|
assert_eq!(result.joins.len(), 2, "Should detect two JOINs");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_json_path_extraction() {
|
|
// Test JSON path extraction queries
|
|
let sql = r#"
|
|
SELECT
|
|
u.user_id,
|
|
u.name,
|
|
JSON_EXTRACT_PATH_TEXT(u.preferences, 'notifications', 'email') AS email_pref,
|
|
JSON_EXTRACT_PATH_TEXT(u.preferences, 'notifications', 'sms') AS sms_pref,
|
|
(
|
|
SELECT COUNT(*)
|
|
FROM db1.schema1.orders o
|
|
WHERE o.user_id = u.user_id AND o.metadata->>'payment_method' = 'credit_card'
|
|
) AS cc_order_count
|
|
FROM db1.schema1.users u
|
|
WHERE u.preferences->>'theme' = 'dark'
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check that we detect both tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"users".to_string()), "Should detect users table");
|
|
assert!(base_tables.contains(&"orders".to_string()), "Should detect orders table in subquery");
|
|
|
|
// Check columns used (including JSON paths)
|
|
let users_table = result.tables.iter().find(|t| t.table_identifier == "users").unwrap();
|
|
assert!(users_table.columns.contains("user_id"), "Should detect user_id column");
|
|
assert!(users_table.columns.contains("name"), "Should detect name column");
|
|
assert!(users_table.columns.contains("preferences"), "Should detect preferences column");
|
|
|
|
let orders_table = result.tables.iter().find(|t| t.table_identifier == "orders").unwrap();
|
|
assert!(orders_table.columns.contains("user_id"), "Should detect user_id column in orders");
|
|
assert!(orders_table.columns.contains("metadata"), "Should detect metadata column in orders");
|
|
|
|
// Check that we detect at least one subquery as a CTE
|
|
assert!(!result.ctes.is_empty(), "Should detect at least one CTE for the subquery");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_self_referencing_hierarchical_query() {
|
|
// Test hierarchical query with recursive CTE
|
|
let sql = r#"
|
|
WITH RECURSIVE org_hierarchy AS (
|
|
-- Base case: Top level employees (no manager)
|
|
SELECT e.id, e.name, e.manager_id, e.department_id, 1 AS level
|
|
FROM db1.schema1.employees e
|
|
WHERE e.manager_id IS NULL
|
|
|
|
UNION ALL
|
|
|
|
-- Recursive case: Employees with managers
|
|
SELECT
|
|
e.id,
|
|
e.name,
|
|
e.manager_id,
|
|
e.department_id,
|
|
oh.level + 1
|
|
FROM db1.schema1.employees e
|
|
JOIN org_hierarchy oh ON e.manager_id = oh.id
|
|
)
|
|
SELECT
|
|
oh.id,
|
|
oh.name,
|
|
oh.level,
|
|
d.name AS department,
|
|
m.name AS manager_name
|
|
FROM org_hierarchy oh
|
|
JOIN db1.schema1.departments d ON oh.department_id = d.id
|
|
LEFT JOIN db1.schema1.employees m ON oh.manager_id = m.id
|
|
ORDER BY oh.level, d.name, oh.name
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check for CTE detection
|
|
let cte_exists = result.ctes.iter().any(|cte| cte.name == "org_hierarchy");
|
|
assert!(cte_exists, "Should detect the recursive org_hierarchy CTE");
|
|
|
|
// Check base tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"employees".to_string()), "Should detect employees table");
|
|
assert!(base_tables.contains(&"departments".to_string()), "Should detect departments table");
|
|
|
|
// Count employees table references (should appear multiple times in different roles)
|
|
let employees_tables = result.tables.iter()
|
|
.filter(|t| t.table_identifier == "employees")
|
|
.count();
|
|
|
|
assert!(employees_tables >= 1, "Should detect employees table at least once");
|
|
|
|
// Check that we have some joins
|
|
assert!(!result.joins.is_empty(), "Should detect joins");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_multiple_joins_with_using() {
|
|
// Test different join types with USING syntax - avoid USING and NATURAL JOIN which cause vague reference errors
|
|
let sql = r#"
|
|
SELECT
|
|
o.order_id,
|
|
o.order_date,
|
|
c.name AS customer_name,
|
|
p.product_name,
|
|
oi.quantity,
|
|
oi.price AS unit_price
|
|
FROM db1.schema1.orders o
|
|
JOIN db1.schema1.customers c ON c.customer_id = o.customer_id
|
|
JOIN db1.schema1.order_items oi ON oi.order_id = o.order_id
|
|
JOIN db1.schema1.products p ON p.product_id = oi.product_id
|
|
WHERE o.order_date > '2023-01-01'
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check for all base tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"orders".to_string()), "Should detect orders table");
|
|
assert!(base_tables.contains(&"customers".to_string()), "Should detect customers table");
|
|
assert!(base_tables.contains(&"order_items".to_string()), "Should detect order_items table");
|
|
assert!(base_tables.contains(&"products".to_string()), "Should detect products table");
|
|
|
|
// Check that join columns are registered
|
|
let orders_table = result.tables.iter().find(|t| t.table_identifier == "orders").unwrap();
|
|
assert!(orders_table.columns.contains("order_id"), "Should detect order_id column");
|
|
assert!(orders_table.columns.contains("order_date"), "Should detect order_date column");
|
|
|
|
let order_items_table = result.tables.iter().find(|t| t.table_identifier == "order_items").unwrap();
|
|
assert!(order_items_table.columns.contains("quantity"), "Should detect quantity column");
|
|
assert!(order_items_table.columns.contains("price"), "Should detect price column");
|
|
|
|
// Check joins
|
|
assert!(result.joins.len() >= 3, "Should detect at least 3 joins");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_common_subexpression_factor_out() {
|
|
// Test factoring out common subexpressions with CTEs - avoid unqualified columns
|
|
let sql = r#"
|
|
WITH
|
|
customer_stats AS (
|
|
SELECT
|
|
c.id,
|
|
c.name,
|
|
c.email,
|
|
COUNT(o.id) AS order_count,
|
|
SUM(o.total_amount) AS total_spent,
|
|
MAX(o.order_date) AS last_order_date
|
|
FROM db1.schema1.customers c
|
|
LEFT JOIN db1.schema1.orders o ON c.id = o.customer_id
|
|
GROUP BY c.id, c.name, c.email
|
|
),
|
|
customer_segments AS (
|
|
SELECT
|
|
cs.id,
|
|
cs.name,
|
|
cs.email,
|
|
CASE
|
|
WHEN cs.order_count = 0 THEN 'Never Purchased'
|
|
WHEN cs.last_order_date < CURRENT_DATE - INTERVAL '180 days' THEN 'Inactive'
|
|
WHEN cs.order_count = 1 THEN 'New Customer'
|
|
WHEN cs.total_spent > 1000 THEN 'VIP'
|
|
ELSE 'Regular'
|
|
END AS segment
|
|
FROM customer_stats cs
|
|
)
|
|
SELECT
|
|
cs.segment,
|
|
COUNT(*) AS customer_count,
|
|
SUM(CASE WHEN cs.segment = 'VIP' THEN 1 ELSE 0 END) OVER() AS total_vips,
|
|
AVG(CASE WHEN cs.segment = 'Regular' THEN 1.0 ELSE 0.0 END) OVER() AS regular_ratio
|
|
FROM customer_segments cs
|
|
GROUP BY cs.segment
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check for CTE detection
|
|
let cte_names: Vec<_> = result.ctes.iter()
|
|
.map(|cte| cte.name.clone())
|
|
.filter(|name| ["customer_stats", "customer_segments"].contains(&name.as_str()))
|
|
.collect();
|
|
|
|
assert_eq!(cte_names.len(), 2, "Should detect both CTEs");
|
|
|
|
// Check base tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"customers".to_string()), "Should detect customers table");
|
|
assert!(base_tables.contains(&"orders".to_string()), "Should detect orders table");
|
|
|
|
// Check chained CTE references (customer_segments depends on customer_stats)
|
|
let customer_segments_cte = result.ctes.iter().find(|cte| cte.name == "customer_segments");
|
|
assert!(customer_segments_cte.is_some(), "Should find customer_segments CTE");
|
|
|
|
// Check for window functions (OVER clauses)
|
|
let has_customer_segments_table = result.tables.iter().any(|t| t.table_identifier == "customer_segments");
|
|
assert!(has_customer_segments_table, "Should find customer_segments as a referenced table");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_lateral_joins_with_limit() {
|
|
// Test LATERAL join with LIMIT - use WITH to define fake data first
|
|
let sql = r#"
|
|
WITH
|
|
customers_data AS (
|
|
SELECT c.id AS customer_id, c.name, c.email, c.status
|
|
FROM db1.schema1.customers c
|
|
WHERE c.status = 'active'
|
|
),
|
|
orders_data AS (
|
|
SELECT o.id, o.customer_id, o.order_date, o.total_amount
|
|
FROM db1.schema1.orders o
|
|
)
|
|
SELECT
|
|
c.customer_id,
|
|
c.name,
|
|
c.email,
|
|
ro.order_id,
|
|
ro.order_date,
|
|
ro.total_amount
|
|
FROM customers_data c
|
|
CROSS JOIN LATERAL (
|
|
SELECT od.id AS order_id, od.order_date, od.total_amount
|
|
FROM orders_data od
|
|
WHERE od.customer_id = c.customer_id
|
|
ORDER BY od.order_date DESC
|
|
LIMIT 3
|
|
) ro
|
|
ORDER BY c.customer_id, ro.order_date DESC
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// First, print the result for debuggging
|
|
println!("Lateral test result: {:?}", result);
|
|
|
|
// Check CTEs
|
|
let cte_names: Vec<_> = result.ctes.iter()
|
|
.map(|cte| cte.name.clone())
|
|
.filter(|name| ["customers_data", "orders_data"].contains(&name.as_str()))
|
|
.collect();
|
|
|
|
assert_eq!(cte_names.len(), 2, "Should detect both CTEs");
|
|
|
|
// Check base tables inside CTE summaries
|
|
let has_customers = result.ctes.iter()
|
|
.filter(|cte| cte.name == "customers_data")
|
|
.flat_map(|cte| cte.summary.tables.iter())
|
|
.any(|t| t.table_identifier == "customers");
|
|
|
|
let has_orders = result.ctes.iter()
|
|
.filter(|cte| cte.name == "orders_data")
|
|
.flat_map(|cte| cte.summary.tables.iter())
|
|
.any(|t| t.table_identifier == "orders");
|
|
|
|
assert!(has_customers, "Should detect customers table in CTE");
|
|
assert!(has_orders, "Should detect orders table in CTE");
|
|
|
|
// Check for references to CTEs
|
|
let customers_data_ref = result.tables.iter().any(|t| t.table_identifier == "customers_data");
|
|
|
|
assert!(customers_data_ref, "Should reference customers_data CTE");
|
|
|
|
// The orders_data CTE might not appear directly in the derived table's summary
|
|
// because of how the analyzer processes subqueries.
|
|
// We can instead check that we have the orders_data CTE defined somewhere
|
|
let orders_data_defined = result.ctes.iter().any(|cte| cte.name == "orders_data");
|
|
assert!(orders_data_defined, "Should define the orders_data CTE");
|
|
|
|
// Check derived table from LATERAL join
|
|
let derived_tables = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Derived)
|
|
.count();
|
|
|
|
assert!(derived_tables >= 1, "Should detect at least one derived table from LATERAL join");
|
|
|
|
// Check join detection
|
|
assert!(!result.joins.is_empty(), "Should detect at least one join");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_parameterized_subqueries_with_different_types() {
|
|
// Test different types of subqueries
|
|
let sql = r#"
|
|
SELECT
|
|
p.id,
|
|
p.name,
|
|
p.price,
|
|
(
|
|
SELECT ARRAY_AGG(c.name ORDER BY c.name)
|
|
FROM db1.schema1.categories c
|
|
JOIN db1.schema1.product_categories pc ON c.id = pc.category_id
|
|
WHERE pc.product_id = p.id
|
|
) AS categories,
|
|
EXISTS (
|
|
SELECT 1
|
|
FROM db1.schema1.inventory i
|
|
WHERE i.product_id = p.id AND i.quantity > 0
|
|
) AS in_stock,
|
|
(
|
|
SELECT SUM(oi.quantity)
|
|
FROM db1.schema1.order_items oi
|
|
JOIN db1.schema1.orders o ON oi.order_id = o.id
|
|
WHERE oi.product_id = p.id AND o.order_date > CURRENT_DATE - INTERVAL '30 days'
|
|
) AS units_sold_last_30_days
|
|
FROM db1.schema1.products p
|
|
WHERE p.active = true
|
|
ORDER BY units_sold_last_30_days DESC NULLS LAST
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// We should detect many tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"products".to_string()), "Should detect products table");
|
|
assert!(base_tables.contains(&"categories".to_string()), "Should detect categories table");
|
|
assert!(base_tables.contains(&"product_categories".to_string()), "Should detect product_categories table");
|
|
assert!(base_tables.contains(&"inventory".to_string()), "Should detect inventory table");
|
|
assert!(base_tables.contains(&"order_items".to_string()), "Should detect order_items table");
|
|
assert!(base_tables.contains(&"orders".to_string()), "Should detect orders table");
|
|
|
|
// We should detect several CTEs for subqueries
|
|
assert!(result.ctes.len() >= 3, "Should detect multiple CTEs for subqueries");
|
|
|
|
// Check that columns are properly detected
|
|
let products_table = result.tables.iter().find(|t| t.table_identifier == "products").unwrap();
|
|
assert!(products_table.columns.contains("id"), "Should detect id column");
|
|
assert!(products_table.columns.contains("name"), "Should detect name column");
|
|
assert!(products_table.columns.contains("price"), "Should detect price column");
|
|
assert!(products_table.columns.contains("active"), "Should detect active column");
|
|
}
|
|
|
|
// Tests for non-read-only statements - they should all be rejected
|
|
|
|
#[tokio::test]
|
|
async fn test_reject_insert_statement() {
|
|
let sql = "INSERT INTO db1.schema1.users (name, email) VALUES ('John Doe', 'john@example.com')";
|
|
let result = analyze_query(sql.to_string()).await;
|
|
|
|
assert!(result.is_err(), "Should reject INSERT statement");
|
|
if let Err(SqlAnalyzerError::ParseError(msg)) = result {
|
|
assert!(msg.contains("Expected SELECT"), "Error message should indicate this is not a SELECT statement");
|
|
} else {
|
|
panic!("Expected ParseError for INSERT, got: {:?}", result);
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_reject_update_statement() {
|
|
let sql = "UPDATE db1.schema1.users SET status = 'inactive' WHERE last_login < CURRENT_DATE - INTERVAL '90 days'";
|
|
let result = analyze_query(sql.to_string()).await;
|
|
|
|
assert!(result.is_err(), "Should reject UPDATE statement");
|
|
if let Err(SqlAnalyzerError::ParseError(msg)) = result {
|
|
assert!(msg.contains("Expected SELECT"), "Error message should indicate this is not a SELECT statement");
|
|
} else {
|
|
panic!("Expected ParseError for UPDATE, got: {:?}", result);
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_reject_delete_statement() {
|
|
let sql = "DELETE FROM db1.schema1.users WHERE status = 'deleted'";
|
|
let result = analyze_query(sql.to_string()).await;
|
|
|
|
assert!(result.is_err(), "Should reject DELETE statement");
|
|
if let Err(SqlAnalyzerError::ParseError(msg)) = result {
|
|
assert!(msg.contains("Expected SELECT"), "Error message should indicate this is not a SELECT statement");
|
|
} else {
|
|
panic!("Expected ParseError for DELETE, got: {:?}", result);
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_reject_merge_statement() {
|
|
let sql = r#"
|
|
MERGE INTO db1.schema1.customers c
|
|
USING (SELECT * FROM db1.schema1.new_customers) nc
|
|
ON (c.customer_id = nc.customer_id)
|
|
WHEN MATCHED THEN
|
|
UPDATE SET c.name = nc.name, c.email = nc.email, c.updated_at = CURRENT_TIMESTAMP
|
|
WHEN NOT MATCHED THEN
|
|
INSERT (customer_id, name, email, created_at, updated_at)
|
|
VALUES (nc.customer_id, nc.name, nc.email, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await;
|
|
|
|
assert!(result.is_err(), "Should reject MERGE statement");
|
|
if let Err(SqlAnalyzerError::ParseError(msg)) = result {
|
|
assert!(msg.contains("Expected"), "Error message should indicate parsing failure");
|
|
} else {
|
|
panic!("Expected ParseError for MERGE, got: {:?}", result);
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_reject_create_table_statement() {
|
|
let sql = r#"
|
|
CREATE TABLE db1.schema1.new_users (
|
|
id SERIAL PRIMARY KEY,
|
|
name VARCHAR(255) NOT NULL,
|
|
email VARCHAR(255) UNIQUE NOT NULL,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await;
|
|
|
|
assert!(result.is_err(), "Should reject CREATE TABLE statement");
|
|
if let Err(SqlAnalyzerError::ParseError(msg)) = result {
|
|
assert!(msg.contains("Expected SELECT"), "Error message should indicate this is not a SELECT statement");
|
|
} else {
|
|
panic!("Expected ParseError for CREATE TABLE, got: {:?}", result);
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_reject_stored_procedure_call() {
|
|
let sql = "CALL db1.schema1.process_orders(123, 'PENDING', true)";
|
|
let result = analyze_query(sql.to_string()).await;
|
|
|
|
assert!(result.is_err(), "Should reject CALL statement");
|
|
if let Err(SqlAnalyzerError::ParseError(msg)) = result {
|
|
assert!(msg.contains("Expected"), "Error message should indicate parsing failure");
|
|
} else {
|
|
panic!("Expected ParseError for CALL, got: {:?}", result);
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_reject_dynamic_sql() {
|
|
let sql = "EXECUTE IMMEDIATE 'SELECT * FROM ' || table_name || ' WHERE id = ' || id";
|
|
let result = analyze_query(sql.to_string()).await;
|
|
|
|
assert!(result.is_err(), "Should reject EXECUTE IMMEDIATE statement");
|
|
if let Err(SqlAnalyzerError::ParseError(msg)) = result {
|
|
assert!(msg.contains("Expected"), "Error message should indicate parsing failure");
|
|
} else {
|
|
panic!("Expected ParseError for EXECUTE IMMEDIATE, got: {:?}", result);
|
|
}
|
|
}
|
|
|
|
// ======================================================
|
|
// SNOWFLAKE-SPECIFIC DIALECT TESTS
|
|
// ======================================================
|
|
|
|
#[tokio::test]
|
|
async fn test_snowflake_semi_structured_json() {
|
|
// Test Snowflake semi-structured data handling with JSON paths
|
|
let sql = r#"
|
|
SELECT
|
|
metadata:user.id::INTEGER as user_id,
|
|
metadata:user.profile.name::STRING as user_name,
|
|
metadata:product.id::INTEGER as product_id,
|
|
metadata:location.coordinates[0]::FLOAT as longitude,
|
|
metadata:location.coordinates[1]::FLOAT as latitude
|
|
FROM db1.schema1.events e
|
|
WHERE metadata:event.type::STRING = 'purchase'
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check that the base table is properly detected
|
|
let events_table = result.tables.iter().find(|t| t.table_identifier == "events").unwrap();
|
|
assert_eq!(events_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(events_table.schema_identifier, Some("schema1".to_string()));
|
|
|
|
// Check that the JSON path column is detected
|
|
assert!(events_table.columns.contains("metadata"), "Should detect metadata JSON column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_snowflake_lateral_flatten() {
|
|
// Test Snowflake's FLATTEN table function with LATERAL
|
|
let sql = r#"
|
|
SELECT
|
|
o.order_id,
|
|
o.customer_id,
|
|
i.value:product_id::INTEGER as product_id,
|
|
i.value:quantity::INTEGER as quantity,
|
|
i.value:price::FLOAT as price
|
|
FROM db1.schema1.orders o,
|
|
LATERAL FLATTEN(input => o.items) i
|
|
WHERE o.status = 'completed'
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"orders".to_string()), "Should detect orders table");
|
|
|
|
// Check for joins (LATERAL is represented as a join)
|
|
assert!(!result.joins.is_empty(), "Should detect LATERAL join");
|
|
|
|
// Check columns
|
|
let orders_table = result.tables.iter().find(|t| t.table_identifier == "orders").unwrap();
|
|
assert!(orders_table.columns.contains("order_id"), "Should detect order_id column");
|
|
assert!(orders_table.columns.contains("items"), "Should detect items column");
|
|
assert!(orders_table.columns.contains("status"), "Should detect status column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_snowflake_table_sample() {
|
|
// Test Snowflake's table sampling
|
|
let sql = r#"
|
|
SELECT
|
|
u.user_id,
|
|
u.name,
|
|
u.email
|
|
FROM db1.schema1.users u TABLESAMPLE (10)
|
|
WHERE u.status = 'active'
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let users_table = result.tables.iter().find(|t| t.table_identifier == "users").unwrap();
|
|
assert_eq!(users_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(users_table.schema_identifier, Some("schema1".to_string()));
|
|
|
|
// Check columns
|
|
assert!(users_table.columns.contains("user_id"), "Should detect user_id column");
|
|
assert!(users_table.columns.contains("name"), "Should detect name column");
|
|
assert!(users_table.columns.contains("email"), "Should detect email column");
|
|
assert!(users_table.columns.contains("status"), "Should detect status column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_snowflake_variant_array_access() {
|
|
// Test Snowflake array access on VARIANT columns
|
|
let sql = r#"
|
|
SELECT
|
|
c.customer_id,
|
|
c.name,
|
|
c.contact_info[0]:phone::STRING as primary_phone,
|
|
c.contact_info[1]:phone::STRING as secondary_phone,
|
|
c.addresses[0]:street::STRING as street_address,
|
|
c.addresses[0]:city::STRING as city,
|
|
c.addresses[0]:state::STRING as state
|
|
FROM db1.schema1.customers c
|
|
WHERE c.contact_info[0]:is_primary::BOOLEAN = true
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let customers_table = result.tables.iter().find(|t| t.table_identifier == "customers").unwrap();
|
|
assert_eq!(customers_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(customers_table.schema_identifier, Some("schema1".to_string()));
|
|
|
|
// Check columns
|
|
assert!(customers_table.columns.contains("customer_id"), "Should detect customer_id column");
|
|
assert!(customers_table.columns.contains("name"), "Should detect name column");
|
|
assert!(customers_table.columns.contains("contact_info"), "Should detect contact_info column");
|
|
assert!(customers_table.columns.contains("addresses"), "Should detect addresses column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_snowflake_time_travel() {
|
|
// Test Snowflake time travel feature
|
|
let sql = r#"
|
|
SELECT
|
|
o.order_id,
|
|
o.customer_id,
|
|
o.order_date,
|
|
o.status
|
|
FROM db1.schema1.orders o AT(TIMESTAMP => '2023-01-01 12:00:00'::TIMESTAMP)
|
|
WHERE o.status = 'shipped'
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let orders_table = result.tables.iter().find(|t| t.table_identifier == "orders").unwrap();
|
|
assert_eq!(orders_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(orders_table.schema_identifier, Some("schema1".to_string()));
|
|
|
|
// Check columns
|
|
assert!(orders_table.columns.contains("order_id"), "Should detect order_id column");
|
|
assert!(orders_table.columns.contains("customer_id"), "Should detect customer_id column");
|
|
assert!(orders_table.columns.contains("order_date"), "Should detect order_date column");
|
|
assert!(orders_table.columns.contains("status"), "Should detect status column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_snowflake_array_construct() {
|
|
// Test Snowflake array construction and unnesting
|
|
let sql = r#"
|
|
SELECT
|
|
p.product_id,
|
|
p.name,
|
|
t.value as tag
|
|
FROM db1.schema1.products p,
|
|
LATERAL FLATTEN(input => ARRAY_CONSTRUCT('electronics', 'gadget', 'tech')) t
|
|
WHERE p.category = 'Electronics'
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let products_table = result.tables.iter().find(|t| t.table_identifier == "products").unwrap();
|
|
assert_eq!(products_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(products_table.schema_identifier, Some("schema1".to_string()));
|
|
|
|
// Check columns
|
|
assert!(products_table.columns.contains("product_id"), "Should detect product_id column");
|
|
assert!(products_table.columns.contains("name"), "Should detect name column");
|
|
assert!(products_table.columns.contains("category"), "Should detect category column");
|
|
|
|
// Check for joins/lateral
|
|
assert!(!result.joins.is_empty(), "Should detect LATERAL join");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_snowflake_current_schemas() {
|
|
// Test Snowflake's current_schemas() function
|
|
let sql = r#"
|
|
SELECT
|
|
u.user_id,
|
|
u.name,
|
|
u.email
|
|
FROM db1.schema1.users u
|
|
WHERE u.schema_access IN (SELECT value FROM TABLE(RESULT_SCAN(LAST_QUERY_ID())))
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let users_table = result.tables.iter().find(|t| t.table_identifier == "users").unwrap();
|
|
assert_eq!(users_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(users_table.schema_identifier, Some("schema1".to_string()));
|
|
|
|
// Check columns
|
|
assert!(users_table.columns.contains("user_id"), "Should detect user_id column");
|
|
assert!(users_table.columns.contains("name"), "Should detect name column");
|
|
assert!(users_table.columns.contains("email"), "Should detect email column");
|
|
assert!(users_table.columns.contains("schema_access"), "Should detect schema_access column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_snowflake_object_construct() {
|
|
// Test Snowflake object construction
|
|
let sql = r#"
|
|
SELECT
|
|
c.customer_id,
|
|
c.name,
|
|
OBJECT_CONSTRUCT(
|
|
'contact', OBJECT_CONSTRUCT('phone', c.phone, 'email', c.email),
|
|
'address', OBJECT_CONSTRUCT('city', c.city, 'state', c.state)
|
|
) as customer_info
|
|
FROM db1.schema1.customers c
|
|
WHERE c.status = 'active'
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let customers_table = result.tables.iter().find(|t| t.table_identifier == "customers").unwrap();
|
|
assert_eq!(customers_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(customers_table.schema_identifier, Some("schema1".to_string()));
|
|
|
|
// Check columns
|
|
assert!(customers_table.columns.contains("customer_id"), "Should detect customer_id column");
|
|
assert!(customers_table.columns.contains("name"), "Should detect name column");
|
|
assert!(customers_table.columns.contains("phone"), "Should detect phone column");
|
|
assert!(customers_table.columns.contains("email"), "Should detect email column");
|
|
assert!(customers_table.columns.contains("city"), "Should detect city column");
|
|
assert!(customers_table.columns.contains("state"), "Should detect state column");
|
|
assert!(customers_table.columns.contains("status"), "Should detect status column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_snowflake_merge_with_cte() {
|
|
// Test snowflake with CTE for analytics
|
|
let sql = r#"
|
|
WITH monthly_purchases AS (
|
|
SELECT
|
|
customer_id,
|
|
DATE_TRUNC('MONTH', order_date) as month,
|
|
SUM(amount) as total_spent,
|
|
COUNT(*) as order_count
|
|
FROM db1.schema1.orders
|
|
GROUP BY customer_id, DATE_TRUNC('MONTH', order_date)
|
|
),
|
|
customer_averages AS (
|
|
SELECT
|
|
customer_id,
|
|
AVG(total_spent) as avg_monthly_spend,
|
|
AVG(order_count) as avg_monthly_orders
|
|
FROM monthly_purchases
|
|
GROUP BY customer_id
|
|
)
|
|
SELECT
|
|
c.customer_id,
|
|
c.name,
|
|
c.email,
|
|
COALESCE(ca.avg_monthly_spend, 0) as avg_spend,
|
|
COALESCE(ca.avg_monthly_orders, 0) as avg_orders,
|
|
IFF(ca.avg_monthly_spend > 500, 'High Value', 'Standard') as customer_segment
|
|
FROM db1.schema1.customers c
|
|
LEFT JOIN customer_averages ca ON c.customer_id = ca.customer_id
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check CTEs
|
|
let cte_names: Vec<_> = result.ctes.iter()
|
|
.map(|cte| cte.name.clone())
|
|
.filter(|name| ["monthly_purchases", "customer_averages"].contains(&name.as_str()))
|
|
.collect();
|
|
|
|
assert_eq!(cte_names.len(), 2, "Should detect both CTEs");
|
|
|
|
// Check base tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"customers".to_string()), "Should detect customers table");
|
|
assert!(base_tables.contains(&"orders".to_string()), "Should detect orders table");
|
|
|
|
// Check joins
|
|
assert!(!result.joins.is_empty(), "Should detect joins");
|
|
}
|
|
|
|
// ======================================================
|
|
// BIGQUERY-SPECIFIC DIALECT TESTS
|
|
// ======================================================
|
|
|
|
#[tokio::test]
|
|
async fn test_bigquery_nested_repeated_fields() {
|
|
// Test BigQuery nested and repeated fields
|
|
let sql = r#"
|
|
SELECT
|
|
event_id,
|
|
event_name,
|
|
user.user_id,
|
|
user.device.type AS device_type,
|
|
user.device.os_version AS os_version,
|
|
(SELECT COUNT(*) FROM UNNEST(event_params) WHERE key = 'page') AS page_param_count
|
|
FROM `project.dataset.events`
|
|
WHERE user.device.mobile = TRUE
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let events_table = result.tables.iter().find(|t| t.table_identifier == "events").unwrap();
|
|
assert_eq!(events_table.database_identifier, Some("project".to_string()));
|
|
assert_eq!(events_table.schema_identifier, Some("dataset".to_string()));
|
|
|
|
// Check columns
|
|
assert!(events_table.columns.contains("event_id"), "Should detect event_id column");
|
|
assert!(events_table.columns.contains("event_name"), "Should detect event_name column");
|
|
assert!(events_table.columns.contains("user"), "Should detect user column");
|
|
assert!(events_table.columns.contains("event_params"), "Should detect event_params column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_bigquery_array_functions() {
|
|
// Test BigQuery array functions
|
|
let sql = r#"
|
|
SELECT
|
|
product_id,
|
|
product_name,
|
|
ARRAY_LENGTH(categories) AS category_count,
|
|
ARRAY_AGG(DISTINCT o.order_id) AS order_ids
|
|
FROM `project.dataset.products`,
|
|
UNNEST(categories) AS category
|
|
LEFT JOIN `project.dataset.order_items` o
|
|
ON o.product_id = product_id
|
|
WHERE 'electronics' IN UNNEST(categories)
|
|
GROUP BY product_id, product_name, categories
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"products".to_string()), "Should detect products table");
|
|
assert!(base_tables.contains(&"order_items".to_string()), "Should detect order_items table");
|
|
|
|
// Check columns
|
|
let products_table = result.tables.iter().find(|t| t.table_identifier == "products").unwrap();
|
|
assert!(products_table.columns.contains("product_id"), "Should detect product_id column");
|
|
assert!(products_table.columns.contains("product_name"), "Should detect product_name column");
|
|
assert!(products_table.columns.contains("categories"), "Should detect categories column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_bigquery_struct_fields() {
|
|
// Test BigQuery struct fields
|
|
let sql = r#"
|
|
SELECT
|
|
user_id,
|
|
address.city,
|
|
address.state,
|
|
address.zip,
|
|
(SELECT COUNT(*) FROM UNNEST(orders) WHERE status = 'completed') AS completed_orders
|
|
FROM `project.dataset.users`
|
|
WHERE address.country = 'USA'
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let users_table = result.tables.iter().find(|t| t.table_identifier == "users").unwrap();
|
|
assert_eq!(users_table.database_identifier, Some("project".to_string()));
|
|
assert_eq!(users_table.schema_identifier, Some("dataset".to_string()));
|
|
|
|
// Check columns
|
|
assert!(users_table.columns.contains("user_id"), "Should detect user_id column");
|
|
assert!(users_table.columns.contains("address"), "Should detect address column");
|
|
assert!(users_table.columns.contains("orders"), "Should detect orders column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_bigquery_partition_by_date() {
|
|
// Test BigQuery partition pruning
|
|
let sql = r#"
|
|
SELECT
|
|
event_date,
|
|
COUNT(*) as event_count,
|
|
COUNT(DISTINCT user_id) as user_count
|
|
FROM `project.dataset.events`
|
|
WHERE event_date BETWEEN '2023-01-01' AND '2023-01-31'
|
|
GROUP BY event_date
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let events_table = result.tables.iter().find(|t| t.table_identifier == "events").unwrap();
|
|
assert_eq!(events_table.database_identifier, Some("project".to_string()));
|
|
assert_eq!(events_table.schema_identifier, Some("dataset".to_string()));
|
|
|
|
// Check columns
|
|
assert!(events_table.columns.contains("event_date"), "Should detect event_date column");
|
|
assert!(events_table.columns.contains("user_id"), "Should detect user_id column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_bigquery_window_functions() {
|
|
// Test BigQuery window functions
|
|
let sql = r#"
|
|
SELECT
|
|
date,
|
|
product_id,
|
|
revenue,
|
|
SUM(revenue) OVER(PARTITION BY product_id ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS cumulative_revenue,
|
|
LEAD(revenue, 1) OVER(PARTITION BY product_id ORDER BY date) AS next_day_revenue,
|
|
PERCENTILE_CONT(revenue, 0.5) OVER(PARTITION BY product_id) AS median_revenue
|
|
FROM `project.dataset.daily_sales`
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let sales_table = result.tables.iter().find(|t| t.table_identifier == "daily_sales").unwrap();
|
|
assert_eq!(sales_table.database_identifier, Some("project".to_string()));
|
|
assert_eq!(sales_table.schema_identifier, Some("dataset".to_string()));
|
|
|
|
// Check columns
|
|
assert!(sales_table.columns.contains("date"), "Should detect date column");
|
|
assert!(sales_table.columns.contains("product_id"), "Should detect product_id column");
|
|
assert!(sales_table.columns.contains("revenue"), "Should detect revenue column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_bigquery_geography_functions() {
|
|
// Test BigQuery geography functions
|
|
let sql = r#"
|
|
SELECT
|
|
store_id,
|
|
store_name,
|
|
ST_DISTANCE(ST_GEOGPOINT(longitude, latitude), ST_GEOGPOINT(-122.4194, 37.7749)) AS distance_to_sf
|
|
FROM `project.dataset.stores`
|
|
WHERE ST_DWITHIN(ST_GEOGPOINT(longitude, latitude), ST_GEOGPOINT(-122.4194, 37.7749), 50000)
|
|
ORDER BY distance_to_sf
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let stores_table = result.tables.iter().find(|t| t.table_identifier == "stores").unwrap();
|
|
assert_eq!(stores_table.database_identifier, Some("project".to_string()));
|
|
assert_eq!(stores_table.schema_identifier, Some("dataset".to_string()));
|
|
|
|
// Check columns
|
|
assert!(stores_table.columns.contains("store_id"), "Should detect store_id column");
|
|
assert!(stores_table.columns.contains("store_name"), "Should detect store_name column");
|
|
assert!(stores_table.columns.contains("longitude"), "Should detect longitude column");
|
|
assert!(stores_table.columns.contains("latitude"), "Should detect latitude column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_bigquery_wildcard_tables() {
|
|
// Test BigQuery wildcard table references
|
|
let sql = r#"
|
|
SELECT
|
|
_TABLE_SUFFIX AS date_suffix,
|
|
COUNT(*) AS row_count
|
|
FROM `project.dataset.events_*`
|
|
WHERE _TABLE_SUFFIX BETWEEN '20230101' AND '20230131'
|
|
GROUP BY _TABLE_SUFFIX
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table (may interpret as events_* or might have special handling)
|
|
let has_events_table = result.tables.iter().any(|t|
|
|
t.table_identifier.contains("events") &&
|
|
t.database_identifier == Some("project".to_string()) &&
|
|
t.schema_identifier == Some("dataset".to_string())
|
|
);
|
|
|
|
assert!(has_events_table, "Should detect events_* table pattern");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_bigquery_json_functions() {
|
|
// Test BigQuery JSON functions
|
|
let sql = r#"
|
|
SELECT
|
|
user_id,
|
|
JSON_EXTRACT(properties, '$.device.type') AS device_type,
|
|
JSON_EXTRACT_SCALAR(properties, '$.location.city') AS city,
|
|
JSON_VALUE(properties, '$.browser') AS browser
|
|
FROM `project.dataset.user_events`
|
|
WHERE JSON_VALUE(properties, '$.country') = 'USA'
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let events_table = result.tables.iter().find(|t| t.table_identifier == "user_events").unwrap();
|
|
assert_eq!(events_table.database_identifier, Some("project".to_string()));
|
|
assert_eq!(events_table.schema_identifier, Some("dataset".to_string()));
|
|
|
|
// Check columns
|
|
assert!(events_table.columns.contains("user_id"), "Should detect user_id column");
|
|
assert!(events_table.columns.contains("properties"), "Should detect properties column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_bigquery_ml_predict() {
|
|
// Test BigQuery ML.PREDICT function
|
|
let sql = r#"
|
|
SELECT
|
|
u.user_id,
|
|
u.age,
|
|
u.gender,
|
|
ML.PREDICT(MODEL `project.dataset.purchase_model`,
|
|
(
|
|
SELECT AS STRUCT
|
|
u.age,
|
|
u.gender,
|
|
u.country,
|
|
COUNT(p.product_id) AS product_view_count
|
|
FROM `project.dataset.product_views` p
|
|
WHERE p.user_id = u.user_id
|
|
GROUP BY u.user_id
|
|
)
|
|
) AS purchase_probability
|
|
FROM `project.dataset.users` u
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"users".to_string()), "Should detect users table");
|
|
assert!(base_tables.contains(&"product_views".to_string()), "Should detect product_views table");
|
|
|
|
// Check columns
|
|
let users_table = result.tables.iter().find(|t| t.table_identifier == "users").unwrap();
|
|
assert!(users_table.columns.contains("user_id"), "Should detect user_id column");
|
|
assert!(users_table.columns.contains("age"), "Should detect age column");
|
|
assert!(users_table.columns.contains("gender"), "Should detect gender column");
|
|
assert!(users_table.columns.contains("country"), "Should detect country column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_bigquery_array_agg_struct() {
|
|
// Test BigQuery ARRAY_AGG with structs
|
|
let sql = r#"
|
|
SELECT
|
|
user_id,
|
|
ARRAY_AGG(STRUCT(
|
|
transaction_id,
|
|
product_id,
|
|
amount,
|
|
timestamp
|
|
) ORDER BY timestamp DESC) AS recent_transactions
|
|
FROM `project.dataset.transactions`
|
|
GROUP BY user_id
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let transactions_table = result.tables.iter().find(|t| t.table_identifier == "transactions").unwrap();
|
|
assert_eq!(transactions_table.database_identifier, Some("project".to_string()));
|
|
assert_eq!(transactions_table.schema_identifier, Some("dataset".to_string()));
|
|
|
|
// Check columns
|
|
assert!(transactions_table.columns.contains("user_id"), "Should detect user_id column");
|
|
assert!(transactions_table.columns.contains("transaction_id"), "Should detect transaction_id column");
|
|
assert!(transactions_table.columns.contains("product_id"), "Should detect product_id column");
|
|
assert!(transactions_table.columns.contains("amount"), "Should detect amount column");
|
|
assert!(transactions_table.columns.contains("timestamp"), "Should detect timestamp column");
|
|
}
|
|
|
|
// ======================================================
|
|
// POSTGRESQL-SPECIFIC DIALECT TESTS
|
|
// ======================================================
|
|
|
|
#[tokio::test]
|
|
async fn test_postgres_array_operators() {
|
|
// Test PostgreSQL array operators
|
|
let sql = r#"
|
|
SELECT
|
|
p.product_id,
|
|
p.name,
|
|
p.tags,
|
|
c.name AS category_name
|
|
FROM db1.public.products p
|
|
JOIN db1.public.categories c ON c.id = ANY(p.category_ids)
|
|
WHERE 'electronics' = ANY(p.tags)
|
|
AND p.in_stock_quantities[1] > 0
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"products".to_string()), "Should detect products table");
|
|
assert!(base_tables.contains(&"categories".to_string()), "Should detect categories table");
|
|
|
|
// Check columns
|
|
let products_table = result.tables.iter().find(|t| t.table_identifier == "products").unwrap();
|
|
assert!(products_table.columns.contains("product_id"), "Should detect product_id column");
|
|
assert!(products_table.columns.contains("name"), "Should detect name column");
|
|
assert!(products_table.columns.contains("tags"), "Should detect tags column");
|
|
assert!(products_table.columns.contains("category_ids"), "Should detect category_ids column");
|
|
assert!(products_table.columns.contains("in_stock_quantities"), "Should detect in_stock_quantities column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_postgres_json_functions() {
|
|
// Test PostgreSQL JSON functions
|
|
let sql = r#"
|
|
SELECT
|
|
user_id,
|
|
data->>'name' AS name,
|
|
data->>'email' AS email,
|
|
jsonb_array_elements(data->'addresses') AS address,
|
|
(data->'settings'->>'notifications')::boolean AS notifications_enabled
|
|
FROM db1.public.users
|
|
WHERE (data->>'active')::boolean = true
|
|
AND data @> '{"premium": true}'
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let users_table = result.tables.iter().find(|t| t.table_identifier == "users").unwrap();
|
|
assert_eq!(users_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(users_table.schema_identifier, Some("public".to_string()));
|
|
|
|
// Check columns
|
|
assert!(users_table.columns.contains("user_id"), "Should detect user_id column");
|
|
assert!(users_table.columns.contains("data"), "Should detect data column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_postgres_window_functions() {
|
|
// Test PostgreSQL window functions
|
|
let sql = r#"
|
|
SELECT
|
|
customer_id,
|
|
order_id,
|
|
order_date,
|
|
amount,
|
|
SUM(amount) OVER (PARTITION BY customer_id ORDER BY order_date) AS cumulative_amount,
|
|
ROW_NUMBER() OVER (PARTITION BY customer_id ORDER BY order_date DESC) AS order_recency_rank,
|
|
FIRST_VALUE(amount) OVER (PARTITION BY customer_id ORDER BY amount DESC) AS largest_order
|
|
FROM db1.public.orders
|
|
WHERE order_date >= CURRENT_DATE - INTERVAL '1 year'
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let orders_table = result.tables.iter().find(|t| t.table_identifier == "orders").unwrap();
|
|
assert_eq!(orders_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(orders_table.schema_identifier, Some("public".to_string()));
|
|
|
|
// Check columns
|
|
assert!(orders_table.columns.contains("customer_id"), "Should detect customer_id column");
|
|
assert!(orders_table.columns.contains("order_id"), "Should detect order_id column");
|
|
assert!(orders_table.columns.contains("order_date"), "Should detect order_date column");
|
|
assert!(orders_table.columns.contains("amount"), "Should detect amount column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_postgres_range_types() {
|
|
// Test PostgreSQL range types
|
|
let sql = r#"
|
|
SELECT
|
|
reservation_id,
|
|
guest_name,
|
|
room_id,
|
|
daterange(check_in_date, check_out_date, '[]') AS stay_period
|
|
FROM db1.public.reservations
|
|
WHERE daterange(check_in_date, check_out_date, '[]') && daterange('2023-07-01', '2023-07-15', '[]')
|
|
AND room_id = 101
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let reservations_table = result.tables.iter().find(|t| t.table_identifier == "reservations").unwrap();
|
|
assert_eq!(reservations_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(reservations_table.schema_identifier, Some("public".to_string()));
|
|
|
|
// Check columns
|
|
assert!(reservations_table.columns.contains("reservation_id"), "Should detect reservation_id column");
|
|
assert!(reservations_table.columns.contains("guest_name"), "Should detect guest_name column");
|
|
assert!(reservations_table.columns.contains("room_id"), "Should detect room_id column");
|
|
assert!(reservations_table.columns.contains("check_in_date"), "Should detect check_in_date column");
|
|
assert!(reservations_table.columns.contains("check_out_date"), "Should detect check_out_date column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_postgres_full_text_search() {
|
|
// Test PostgreSQL full-text search
|
|
let sql = r#"
|
|
SELECT
|
|
product_id,
|
|
name,
|
|
description,
|
|
price,
|
|
ts_rank(search_vector, to_tsquery('english', 'wireless & headphones')) AS rank
|
|
FROM db1.public.products
|
|
WHERE search_vector @@ to_tsquery('english', 'wireless & headphones')
|
|
ORDER BY rank DESC
|
|
LIMIT 10
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let products_table = result.tables.iter().find(|t| t.table_identifier == "products").unwrap();
|
|
assert_eq!(products_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(products_table.schema_identifier, Some("public".to_string()));
|
|
|
|
// Check columns
|
|
assert!(products_table.columns.contains("product_id"), "Should detect product_id column");
|
|
assert!(products_table.columns.contains("name"), "Should detect name column");
|
|
assert!(products_table.columns.contains("description"), "Should detect description column");
|
|
assert!(products_table.columns.contains("price"), "Should detect price column");
|
|
assert!(products_table.columns.contains("search_vector"), "Should detect search_vector column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_postgres_common_table_expressions() {
|
|
// Test PostgreSQL recursive CTEs
|
|
let sql = r#"
|
|
WITH RECURSIVE comment_tree AS (
|
|
-- Base case: top-level comments
|
|
SELECT id, content, parent_id, author_id, 0 AS depth
|
|
FROM db1.public.comments
|
|
WHERE parent_id IS NULL AND post_id = 42
|
|
|
|
UNION ALL
|
|
|
|
-- Recursive case: replies to comments
|
|
SELECT c.id, c.content, c.parent_id, c.author_id, ct.depth + 1
|
|
FROM db1.public.comments c
|
|
JOIN comment_tree ct ON c.parent_id = ct.id
|
|
)
|
|
SELECT
|
|
ct.id,
|
|
ct.content,
|
|
ct.depth,
|
|
u.username AS author
|
|
FROM comment_tree ct
|
|
JOIN db1.public.users u ON ct.author_id = u.id
|
|
ORDER BY ct.depth, ct.id
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check for CTE
|
|
let has_cte = result.ctes.iter().any(|cte| cte.name == "comment_tree");
|
|
assert!(has_cte, "Should detect comment_tree CTE");
|
|
|
|
// Check base tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"comments".to_string()), "Should detect comments table");
|
|
assert!(base_tables.contains(&"users".to_string()), "Should detect users table");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_postgres_lateral_join() {
|
|
// Test PostgreSQL LATERAL joins
|
|
let sql = r#"
|
|
SELECT
|
|
c.customer_id,
|
|
c.name,
|
|
o.order_id,
|
|
o.order_date,
|
|
o.amount
|
|
FROM db1.public.customers c
|
|
LEFT JOIN LATERAL (
|
|
SELECT order_id, order_date, amount
|
|
FROM db1.public.orders
|
|
WHERE customer_id = c.customer_id
|
|
ORDER BY order_date DESC
|
|
LIMIT 3
|
|
) o ON true
|
|
WHERE c.region = 'West'
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"customers".to_string()), "Should detect customers table");
|
|
assert!(base_tables.contains(&"orders".to_string()), "Should detect orders table");
|
|
|
|
// Check columns
|
|
let customers_table = result.tables.iter().find(|t| t.table_identifier == "customers").unwrap();
|
|
assert!(customers_table.columns.contains("customer_id"), "Should detect customer_id column");
|
|
assert!(customers_table.columns.contains("name"), "Should detect name column");
|
|
assert!(customers_table.columns.contains("region"), "Should detect region column");
|
|
|
|
// Check for lateral join
|
|
assert!(!result.joins.is_empty(), "Should detect LATERAL join");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_postgres_date_functions() {
|
|
// Test PostgreSQL date/time functions
|
|
let sql = r#"
|
|
SELECT
|
|
date_trunc('month', order_date) AS month,
|
|
COUNT(*) AS order_count,
|
|
SUM(amount) AS total_amount,
|
|
AVG(amount) AS avg_order_value
|
|
FROM db1.public.orders
|
|
WHERE order_date BETWEEN CURRENT_DATE - INTERVAL '1 year' AND CURRENT_DATE
|
|
AND extract(hour from order_time) BETWEEN 9 AND 17
|
|
GROUP BY date_trunc('month', order_date)
|
|
ORDER BY month
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let orders_table = result.tables.iter().find(|t| t.table_identifier == "orders").unwrap();
|
|
assert_eq!(orders_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(orders_table.schema_identifier, Some("public".to_string()));
|
|
|
|
// Check columns
|
|
assert!(orders_table.columns.contains("order_date"), "Should detect order_date column");
|
|
assert!(orders_table.columns.contains("order_time"), "Should detect order_time column");
|
|
assert!(orders_table.columns.contains("amount"), "Should detect amount column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_postgres_generate_series() {
|
|
// Test PostgreSQL generate_series function
|
|
let sql = r#"
|
|
SELECT
|
|
d.date,
|
|
COALESCE(COUNT(o.order_id), 0) AS order_count,
|
|
COALESCE(SUM(o.amount), 0) AS total_sales
|
|
FROM generate_series(
|
|
CURRENT_DATE - INTERVAL '30 days',
|
|
CURRENT_DATE,
|
|
'1 day'::interval
|
|
) AS d(date)
|
|
LEFT JOIN db1.public.orders o ON date_trunc('day', o.order_date) = d.date
|
|
GROUP BY d.date
|
|
ORDER BY d.date
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"orders".to_string()), "Should detect orders table");
|
|
|
|
// Check column
|
|
let orders_table = result.tables.iter().find(|t| t.table_identifier == "orders").unwrap();
|
|
assert!(orders_table.columns.contains("order_id"), "Should detect order_id column");
|
|
assert!(orders_table.columns.contains("order_date"), "Should detect order_date column");
|
|
assert!(orders_table.columns.contains("amount"), "Should detect amount column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_postgres_geometric_types() {
|
|
// Test PostgreSQL geometric types and operators
|
|
let sql = r#"
|
|
SELECT
|
|
store_id,
|
|
name,
|
|
location::text AS coordinates,
|
|
city,
|
|
state
|
|
FROM db1.public.stores
|
|
WHERE location <-> point(37.7749, -122.4194) < 50
|
|
ORDER BY location <-> point(37.7749, -122.4194)
|
|
LIMIT 10
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let stores_table = result.tables.iter().find(|t| t.table_identifier == "stores").unwrap();
|
|
assert_eq!(stores_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(stores_table.schema_identifier, Some("public".to_string()));
|
|
|
|
// Check columns
|
|
assert!(stores_table.columns.contains("store_id"), "Should detect store_id column");
|
|
assert!(stores_table.columns.contains("name"), "Should detect name column");
|
|
assert!(stores_table.columns.contains("location"), "Should detect location column");
|
|
assert!(stores_table.columns.contains("city"), "Should detect city column");
|
|
assert!(stores_table.columns.contains("state"), "Should detect state column");
|
|
}
|
|
|
|
// ======================================================
|
|
// REDSHIFT-SPECIFIC DIALECT TESTS
|
|
// ======================================================
|
|
|
|
#[tokio::test]
|
|
async fn test_redshift_distribution_key() {
|
|
// Test Redshift's DISTKEY usage
|
|
let sql = r#"
|
|
SELECT
|
|
c.customer_id,
|
|
c.name,
|
|
c.email,
|
|
SUM(o.amount) AS total_spent
|
|
FROM db1.public.customers c
|
|
JOIN db1.public.orders o ON c.customer_id = o.customer_id
|
|
WHERE c.region = 'West'
|
|
GROUP BY 1, 2, 3
|
|
ORDER BY total_spent DESC
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"customers".to_string()), "Should detect customers table");
|
|
assert!(base_tables.contains(&"orders".to_string()), "Should detect orders table");
|
|
|
|
// Check columns
|
|
let customers_table = result.tables.iter().find(|t| t.table_identifier == "customers").unwrap();
|
|
assert!(customers_table.columns.contains("customer_id"), "Should detect customer_id column");
|
|
assert!(customers_table.columns.contains("name"), "Should detect name column");
|
|
assert!(customers_table.columns.contains("email"), "Should detect email column");
|
|
assert!(customers_table.columns.contains("region"), "Should detect region column");
|
|
|
|
// Check joins
|
|
assert!(!result.joins.is_empty(), "Should detect JOIN");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_redshift_time_functions() {
|
|
// Test Redshift time functions
|
|
let sql = r#"
|
|
SELECT
|
|
GETDATE() AS current_time,
|
|
DATEADD(day, -30, GETDATE()) AS thirty_days_ago,
|
|
DATE_PART(hour, o.created_at) AS hour_of_day,
|
|
DATEDIFF(day, o.created_at, o.shipped_at) AS days_to_ship
|
|
FROM db1.public.orders o
|
|
WHERE DATE_PART(year, o.created_at) = 2023
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let orders_table = result.tables.iter().find(|t| t.table_identifier == "orders").unwrap();
|
|
assert_eq!(orders_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(orders_table.schema_identifier, Some("public".to_string()));
|
|
|
|
// Check columns
|
|
assert!(orders_table.columns.contains("created_at"), "Should detect created_at column");
|
|
assert!(orders_table.columns.contains("shipped_at"), "Should detect shipped_at column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_redshift_sortkey() {
|
|
// Test Redshift sorting operations
|
|
let sql = r#"
|
|
SELECT
|
|
DATE_TRUNC('month', o.order_date) AS month,
|
|
c.region,
|
|
COUNT(*) AS order_count,
|
|
SUM(o.amount) AS total_amount
|
|
FROM db1.public.orders o
|
|
JOIN db1.public.customers c ON o.customer_id = c.customer_id
|
|
WHERE o.order_date BETWEEN '2023-01-01' AND '2023-12-31'
|
|
GROUP BY 1, 2
|
|
ORDER BY 1, 2
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"orders".to_string()), "Should detect orders table");
|
|
assert!(base_tables.contains(&"customers".to_string()), "Should detect customers table");
|
|
|
|
// Check columns
|
|
let orders_table = result.tables.iter().find(|t| t.table_identifier == "orders").unwrap();
|
|
assert!(orders_table.columns.contains("order_date"), "Should detect order_date column");
|
|
assert!(orders_table.columns.contains("amount"), "Should detect amount column");
|
|
assert!(orders_table.columns.contains("customer_id"), "Should detect customer_id column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_redshift_window_functions() {
|
|
// Test Redshift window functions
|
|
let sql = r#"
|
|
SELECT
|
|
customer_id,
|
|
order_date,
|
|
amount,
|
|
SUM(amount) OVER (PARTITION BY customer_id ORDER BY order_date ROWS UNBOUNDED PRECEDING) AS running_total,
|
|
RANK() OVER (PARTITION BY customer_id ORDER BY amount DESC) AS amount_rank,
|
|
LAG(amount, 1) OVER (PARTITION BY customer_id ORDER BY order_date) AS prev_amount
|
|
FROM db1.public.orders
|
|
WHERE order_date >= '2023-01-01'
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let orders_table = result.tables.iter().find(|t| t.table_identifier == "orders").unwrap();
|
|
assert_eq!(orders_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(orders_table.schema_identifier, Some("public".to_string()));
|
|
|
|
// Check columns
|
|
assert!(orders_table.columns.contains("customer_id"), "Should detect customer_id column");
|
|
assert!(orders_table.columns.contains("order_date"), "Should detect order_date column");
|
|
assert!(orders_table.columns.contains("amount"), "Should detect amount column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_redshift_unload() {
|
|
// Test Redshift UNLOAD (readonly analysis still)
|
|
let sql = r#"
|
|
SELECT
|
|
c.customer_id,
|
|
c.name,
|
|
c.email,
|
|
o.order_date,
|
|
o.amount
|
|
FROM db1.public.customers c
|
|
JOIN db1.public.orders o ON c.customer_id = o.customer_id
|
|
WHERE c.region = 'West' AND o.order_date >= '2023-01-01'
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"customers".to_string()), "Should detect customers table");
|
|
assert!(base_tables.contains(&"orders".to_string()), "Should detect orders table");
|
|
|
|
// Check joins
|
|
assert!(!result.joins.is_empty(), "Should detect JOIN");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_redshift_spectrum() {
|
|
// Test Redshift Spectrum (external tables)
|
|
let sql = r#"
|
|
SELECT
|
|
year,
|
|
month,
|
|
day,
|
|
COUNT(*) AS event_count,
|
|
COUNT(DISTINCT user_id) AS unique_users
|
|
FROM db1.external.clickstream_events
|
|
WHERE year = 2023 AND month = 7
|
|
GROUP BY 1, 2, 3
|
|
ORDER BY 1, 2, 3
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let events_table = result.tables.iter().find(|t| t.table_identifier == "clickstream_events").unwrap();
|
|
assert_eq!(events_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(events_table.schema_identifier, Some("external".to_string()));
|
|
|
|
// Check columns
|
|
assert!(events_table.columns.contains("year"), "Should detect year column");
|
|
assert!(events_table.columns.contains("month"), "Should detect month column");
|
|
assert!(events_table.columns.contains("day"), "Should detect day column");
|
|
assert!(events_table.columns.contains("user_id"), "Should detect user_id column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_redshift_system_tables() {
|
|
// Test Redshift system table query
|
|
let sql = r#"
|
|
SELECT
|
|
t.database,
|
|
t.schema,
|
|
t.table,
|
|
t.encoded,
|
|
t.rows,
|
|
t.size
|
|
FROM db1.public.tables t
|
|
JOIN db1.public.schemas s ON t.schema = s.schema
|
|
WHERE t.schema = 'public' AND t.size > 1000000
|
|
ORDER BY t.size DESC
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"tables".to_string()), "Should detect tables table");
|
|
assert!(base_tables.contains(&"schemas".to_string()), "Should detect schemas table");
|
|
|
|
// Check columns
|
|
let tables_table = result.tables.iter().find(|t| t.table_identifier == "tables").unwrap();
|
|
assert!(tables_table.columns.contains("database"), "Should detect database column");
|
|
assert!(tables_table.columns.contains("schema"), "Should detect schema column");
|
|
assert!(tables_table.columns.contains("table"), "Should detect table column");
|
|
assert!(tables_table.columns.contains("encoded"), "Should detect encoded column");
|
|
assert!(tables_table.columns.contains("rows"), "Should detect rows column");
|
|
assert!(tables_table.columns.contains("size"), "Should detect size column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_redshift_json_extract() {
|
|
// Test Redshift JSON functions
|
|
let sql = r#"
|
|
SELECT
|
|
event_id,
|
|
timestamp,
|
|
JSON_EXTRACT_PATH_TEXT(data, 'user', 'id') AS user_id,
|
|
JSON_EXTRACT_PATH_TEXT(data, 'device', 'type') AS device_type,
|
|
JSON_EXTRACT_PATH_TEXT(data, 'event', 'name') AS event_name
|
|
FROM db1.public.events
|
|
WHERE JSON_EXTRACT_PATH_TEXT(data, 'event', 'category') = 'purchase'
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let events_table = result.tables.iter().find(|t| t.table_identifier == "events").unwrap();
|
|
assert_eq!(events_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(events_table.schema_identifier, Some("public".to_string()));
|
|
|
|
// Check columns
|
|
assert!(events_table.columns.contains("event_id"), "Should detect event_id column");
|
|
assert!(events_table.columns.contains("timestamp"), "Should detect timestamp column");
|
|
assert!(events_table.columns.contains("data"), "Should detect data column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_redshift_semi_structured() {
|
|
// Test Redshift semi-structured data (SUPER type)
|
|
let sql = r#"
|
|
SELECT
|
|
order_id,
|
|
items[0].product_id AS first_product_id,
|
|
items[0].quantity AS first_quantity,
|
|
items[0].price AS first_price,
|
|
(SELECT COUNT(*) FROM items_arr i AT i.idx) AS item_count
|
|
FROM db1.public.orders
|
|
WHERE items[0].product_id = 'ABC123'
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let orders_table = result.tables.iter().find(|t| t.table_identifier == "orders").unwrap();
|
|
assert_eq!(orders_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(orders_table.schema_identifier, Some("public".to_string()));
|
|
|
|
// Check columns
|
|
assert!(orders_table.columns.contains("order_id"), "Should detect order_id column");
|
|
assert!(orders_table.columns.contains("items"), "Should detect items column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_redshift_materialized_view() {
|
|
// Test Redshift materialized view
|
|
let sql = r#"
|
|
SELECT
|
|
product_id,
|
|
month,
|
|
sales_count,
|
|
sales_amount,
|
|
avg_price
|
|
FROM db1.public.monthly_product_sales
|
|
WHERE month BETWEEN '2023-01-01' AND '2023-12-31'
|
|
AND sales_amount > 10000
|
|
ORDER BY sales_amount DESC
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table (materialized view is treated as a regular table)
|
|
let sales_table = result.tables.iter().find(|t| t.table_identifier == "monthly_product_sales").unwrap();
|
|
assert_eq!(sales_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(sales_table.schema_identifier, Some("public".to_string()));
|
|
|
|
// Check columns
|
|
assert!(sales_table.columns.contains("product_id"), "Should detect product_id column");
|
|
assert!(sales_table.columns.contains("month"), "Should detect month column");
|
|
assert!(sales_table.columns.contains("sales_count"), "Should detect sales_count column");
|
|
assert!(sales_table.columns.contains("sales_amount"), "Should detect sales_amount column");
|
|
assert!(sales_table.columns.contains("avg_price"), "Should detect avg_price column");
|
|
}
|
|
|
|
// ======================================================
|
|
// DATABRICKS-SPECIFIC DIALECT TESTS
|
|
// ======================================================
|
|
|
|
#[tokio::test]
|
|
async fn test_databricks_delta_time_travel() {
|
|
// Test Databricks Delta time travel
|
|
let sql = r#"
|
|
SELECT
|
|
customer_id,
|
|
name,
|
|
email,
|
|
address
|
|
FROM db1.default.customers VERSION AS OF 25
|
|
WHERE region = 'West'
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let customers_table = result.tables.iter().find(|t| t.table_identifier == "customers").unwrap();
|
|
assert_eq!(customers_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(customers_table.schema_identifier, Some("default".to_string()));
|
|
|
|
// Check columns
|
|
assert!(customers_table.columns.contains("customer_id"), "Should detect customer_id column");
|
|
assert!(customers_table.columns.contains("name"), "Should detect name column");
|
|
assert!(customers_table.columns.contains("email"), "Should detect email column");
|
|
assert!(customers_table.columns.contains("address"), "Should detect address column");
|
|
assert!(customers_table.columns.contains("region"), "Should detect region column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_databricks_complex_types() {
|
|
// Test Databricks complex types (arrays, maps, structs)
|
|
let sql = r#"
|
|
SELECT
|
|
user_id,
|
|
profile.name,
|
|
profile.location.city,
|
|
profile.location.state,
|
|
EXPLODE(profile.interests) AS interest,
|
|
activity_history['login'] AS last_login,
|
|
activity_history['purchase'] AS last_purchase
|
|
FROM db1.default.users
|
|
WHERE SIZE(profile.interests) > 2
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let users_table = result.tables.iter().find(|t| t.table_identifier == "users").unwrap();
|
|
assert_eq!(users_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(users_table.schema_identifier, Some("default".to_string()));
|
|
|
|
// Check columns
|
|
assert!(users_table.columns.contains("user_id"), "Should detect user_id column");
|
|
assert!(users_table.columns.contains("profile"), "Should detect profile column");
|
|
assert!(users_table.columns.contains("activity_history"), "Should detect activity_history column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_databricks_higher_order_functions() {
|
|
// Test Databricks higher-order functions
|
|
let sql = r#"
|
|
SELECT
|
|
product_id,
|
|
name,
|
|
categories,
|
|
TRANSFORM(tags, t -> UPPER(t)) AS uppercase_tags,
|
|
FILTER(categories, c -> c LIKE '%electronics%') AS electronics_categories,
|
|
AGGREGATE(price_history, 0, (acc, price) -> acc + price, acc -> acc / SIZE(price_history)) AS avg_price
|
|
FROM db1.default.products
|
|
WHERE EXISTS(FILTER(tags, t -> t = 'premium'))
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let products_table = result.tables.iter().find(|t| t.table_identifier == "products").unwrap();
|
|
assert_eq!(products_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(products_table.schema_identifier, Some("default".to_string()));
|
|
|
|
// Check columns
|
|
assert!(products_table.columns.contains("product_id"), "Should detect product_id column");
|
|
assert!(products_table.columns.contains("name"), "Should detect name column");
|
|
assert!(products_table.columns.contains("categories"), "Should detect categories column");
|
|
assert!(products_table.columns.contains("tags"), "Should detect tags column");
|
|
assert!(products_table.columns.contains("price_history"), "Should detect price_history column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_databricks_date_functions() {
|
|
// Test Databricks date functions
|
|
let sql = r#"
|
|
SELECT
|
|
DATE_FORMAT(order_date, 'yyyy-MM') AS month,
|
|
COUNT(*) AS order_count,
|
|
SUM(amount) AS total_sales,
|
|
DATE_ADD(MAX(order_date), 30) AS next_30_days,
|
|
MONTH(order_date) AS month_num,
|
|
YEAR(order_date) AS year_num
|
|
FROM db1.default.orders
|
|
WHERE order_date BETWEEN DATE_SUB(CURRENT_DATE(), 365) AND CURRENT_DATE()
|
|
GROUP BY DATE_FORMAT(order_date, 'yyyy-MM')
|
|
ORDER BY month
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let orders_table = result.tables.iter().find(|t| t.table_identifier == "orders").unwrap();
|
|
assert_eq!(orders_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(orders_table.schema_identifier, Some("default".to_string()));
|
|
|
|
// Check columns
|
|
assert!(orders_table.columns.contains("order_date"), "Should detect order_date column");
|
|
assert!(orders_table.columns.contains("amount"), "Should detect amount column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_databricks_window_functions() {
|
|
// Test Databricks window functions
|
|
let sql = r#"
|
|
SELECT
|
|
customer_id,
|
|
order_date,
|
|
amount,
|
|
SUM(amount) OVER (PARTITION BY customer_id ORDER BY order_date) AS running_total,
|
|
DENSE_RANK() OVER (PARTITION BY customer_id ORDER BY amount DESC) AS amount_rank,
|
|
PERCENT_RANK() OVER (PARTITION BY customer_id ORDER BY amount) AS amount_percentile
|
|
FROM db1.default.orders
|
|
WHERE YEAR(order_date) = 2023
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let orders_table = result.tables.iter().find(|t| t.table_identifier == "orders").unwrap();
|
|
assert_eq!(orders_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(orders_table.schema_identifier, Some("default".to_string()));
|
|
|
|
// Check columns
|
|
assert!(orders_table.columns.contains("customer_id"), "Should detect customer_id column");
|
|
assert!(orders_table.columns.contains("order_date"), "Should detect order_date column");
|
|
assert!(orders_table.columns.contains("amount"), "Should detect amount column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_databricks_json_functions() {
|
|
// Test Databricks JSON functions
|
|
let sql = r#"
|
|
SELECT
|
|
event_id,
|
|
FROM_JSON(payload, 'STRUCT<
|
|
user_id: STRING,
|
|
event_type: STRING,
|
|
properties: MAP<STRING, STRING>
|
|
>') AS parsed_event
|
|
FROM db1.default.events
|
|
WHERE GET_JSON_OBJECT(payload, '$.event_type') = 'purchase'
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let events_table = result.tables.iter().find(|t| t.table_identifier == "events").unwrap();
|
|
assert_eq!(events_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(events_table.schema_identifier, Some("default".to_string()));
|
|
|
|
// Check columns
|
|
assert!(events_table.columns.contains("event_id"), "Should detect event_id column");
|
|
assert!(events_table.columns.contains("payload"), "Should detect payload column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_databricks_pivot() {
|
|
// Test Databricks PIVOT
|
|
let sql = r#"
|
|
SELECT * FROM (
|
|
SELECT
|
|
DATE_FORMAT(order_date, 'yyyy-MM') AS month,
|
|
product_category,
|
|
amount
|
|
FROM db1.default.orders
|
|
WHERE YEAR(order_date) = 2023
|
|
) PIVOT (
|
|
SUM(amount) AS sales
|
|
FOR product_category IN ('Electronics', 'Clothing', 'Home', 'Books')
|
|
)
|
|
ORDER BY month
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let orders_table = result.tables.iter().find(|t| t.table_identifier == "orders").unwrap();
|
|
assert_eq!(orders_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(orders_table.schema_identifier, Some("default".to_string()));
|
|
|
|
// Check columns
|
|
assert!(orders_table.columns.contains("order_date"), "Should detect order_date column");
|
|
assert!(orders_table.columns.contains("product_category"), "Should detect product_category column");
|
|
assert!(orders_table.columns.contains("amount"), "Should detect amount column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_databricks_array_operations() {
|
|
// Test Databricks array operations
|
|
let sql = r#"
|
|
SELECT
|
|
user_id,
|
|
name,
|
|
ARRAY_CONTAINS(interests, 'travel') AS likes_travel,
|
|
ARRAY_DISTINCT(tags) AS unique_tags,
|
|
ARRAY_INTERSECT(interests, searched_terms) AS matched_interests,
|
|
ARRAYS_ZIP(interests, tags) AS interests_with_tags
|
|
FROM db1.default.users
|
|
WHERE ARRAY_CONTAINS(interests, 'sports') AND ARRAY_SIZE(tags) > 2
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table
|
|
let users_table = result.tables.iter().find(|t| t.table_identifier == "users").unwrap();
|
|
assert_eq!(users_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(users_table.schema_identifier, Some("default".to_string()));
|
|
|
|
// Check columns
|
|
assert!(users_table.columns.contains("user_id"), "Should detect user_id column");
|
|
assert!(users_table.columns.contains("name"), "Should detect name column");
|
|
assert!(users_table.columns.contains("interests"), "Should detect interests column");
|
|
assert!(users_table.columns.contains("tags"), "Should detect tags column");
|
|
assert!(users_table.columns.contains("searched_terms"), "Should detect searched_terms column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_databricks_qualified_wildcard() {
|
|
// Test Databricks qualified wildcards
|
|
let sql = r#"
|
|
SELECT
|
|
u.user_id,
|
|
u.name,
|
|
u.*,
|
|
p.*
|
|
FROM db1.default.users u
|
|
JOIN db1.default.purchases p
|
|
ON u.user_id = p.user_id
|
|
WHERE u.status = 'active' AND p.amount > 100
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base tables
|
|
let base_tables: Vec<_> = result.tables.iter()
|
|
.filter(|t| t.kind == TableKind::Base)
|
|
.map(|t| t.table_identifier.clone())
|
|
.collect();
|
|
|
|
assert!(base_tables.contains(&"users".to_string()), "Should detect users table");
|
|
assert!(base_tables.contains(&"purchases".to_string()), "Should detect purchases table");
|
|
|
|
// Check columns
|
|
let users_table = result.tables.iter().find(|t| t.table_identifier == "users").unwrap();
|
|
assert!(users_table.columns.contains("user_id"), "Should detect user_id column");
|
|
assert!(users_table.columns.contains("name"), "Should detect name column");
|
|
assert!(users_table.columns.contains("status"), "Should detect status column");
|
|
|
|
// Check joins
|
|
assert!(!result.joins.is_empty(), "Should detect JOIN");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_databricks_dynamic_views() {
|
|
// Test Databricks dynamic views
|
|
let sql = r#"
|
|
SELECT
|
|
order_id,
|
|
user_id,
|
|
order_date,
|
|
total_amount,
|
|
status
|
|
FROM db1.default.orders_by_region
|
|
WHERE region = 'West' AND YEAR(order_date) = 2023
|
|
ORDER BY order_date DESC
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
|
|
// Check base table (view is treated as a regular table)
|
|
let orders_table = result.tables.iter().find(|t| t.table_identifier == "orders_by_region").unwrap();
|
|
assert_eq!(orders_table.database_identifier, Some("db1".to_string()));
|
|
assert_eq!(orders_table.schema_identifier, Some("default".to_string()));
|
|
|
|
// Check columns
|
|
assert!(orders_table.columns.contains("order_id"), "Should detect order_id column");
|
|
assert!(orders_table.columns.contains("user_id"), "Should detect user_id column");
|
|
assert!(orders_table.columns.contains("order_date"), "Should detect order_date column");
|
|
assert!(orders_table.columns.contains("total_amount"), "Should detect total_amount column");
|
|
assert!(orders_table.columns.contains("status"), "Should detect status column");
|
|
assert!(orders_table.columns.contains("region"), "Should detect region column");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_scalar_subquery_in_select() {
|
|
let sql = r#"
|
|
SELECT
|
|
c.customer_name,
|
|
(SELECT MAX(o.order_date) FROM db1.schema1.orders o WHERE o.customer_id = c.id) as last_order_date
|
|
FROM
|
|
db1.schema1.customers c
|
|
WHERE
|
|
c.is_active = true;
|
|
"#;
|
|
|
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
|
println!("Scalar Subquery Result: {:?}", result);
|
|
|
|
// The analyzer should detect both tables (customers from main query, orders from subquery)
|
|
// We now represent subqueries as CTEs for better analysis
|
|
assert_eq!(result.tables.len(), 2, "Should detect customers and orders tables");
|
|
assert_eq!(result.joins.len(), 0, "Should be no explicit joins");
|
|
assert!(result.ctes.len() >= 1, "Should detect at least one CTE for the subquery");
|
|
|
|
let table_names: HashSet<_> = result.tables.iter().map(|t| t.table_identifier.as_str()).collect();
|
|
assert!(table_names.contains("customers"));
|
|
assert!(table_names.contains("orders"));
|
|
|
|
// Check columns used
|
|
let customers_table = result.tables.iter().find(|t| t.table_identifier == "customers").unwrap();
|
|
assert!(customers_table.columns.contains("customer_name"));
|
|
|
|
// 'id' is now part of the CTE state rather than the main query
|
|
let id_in_customers = customers_table.columns.contains("id");
|
|
let id_in_cte = result.ctes.iter()
|
|
.filter_map(|cte| cte.summary.tables.iter()
|
|
.find(|t| t.table_identifier == "customers")
|
|
.map(|t| t.columns.contains("id")))
|
|
.any(|contains| contains);
|
|
|
|
assert!(id_in_customers || id_in_cte, "id should be tracked somewhere in customers (either main or within CTE)");
|
|
assert!(customers_table.columns.contains("is_active")); // Used in WHERE
|
|
|
|
let orders_table = result.tables.iter().find(|t| t.table_identifier == "orders").unwrap();
|
|
assert!(orders_table.columns.contains("order_date")); // Used in MAX()
|
|
assert!(orders_table.columns.contains("customer_id")); // Used in subquery WHERE
|
|
} |