mirror of https://github.com/buster-so/buster.git
362 lines
12 KiB
Rust
362 lines
12 KiB
Rust
|
use sql_analyzer::{analyze_query, SqlAnalyzerError, JoinInfo};
|
||
|
use tokio;
|
||
|
|
||
|
// Original tests for basic query analysis
|
||
|
|
||
|
#[tokio::test]
|
||
|
async fn test_simple_query() {
|
||
|
let sql = "SELECT u.id, u.name FROM schema.users u";
|
||
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
||
|
|
||
|
assert_eq!(result.tables.len(), 1);
|
||
|
assert_eq!(result.joins.len(), 0);
|
||
|
assert_eq!(result.ctes.len(), 0);
|
||
|
|
||
|
let table = &result.tables[0];
|
||
|
assert_eq!(table.database_identifier, None);
|
||
|
assert_eq!(table.schema_identifier, Some("schema".to_string()));
|
||
|
assert_eq!(table.table_identifier, "users");
|
||
|
assert_eq!(table.alias, Some("u".to_string()));
|
||
|
|
||
|
let columns_vec: Vec<_> = table.columns.iter().collect();
|
||
|
assert!(
|
||
|
columns_vec.len() == 2,
|
||
|
"Expected 2 columns, got {}",
|
||
|
columns_vec.len()
|
||
|
);
|
||
|
assert!(table.columns.contains("id"), "Missing 'id' column");
|
||
|
assert!(table.columns.contains("name"), "Missing 'name' column");
|
||
|
}
|
||
|
|
||
|
#[tokio::test]
|
||
|
async fn test_joins() {
|
||
|
let sql =
|
||
|
"SELECT u.id, o.order_id FROM schema.users u JOIN schema.orders o ON u.id = o.user_id";
|
||
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
||
|
|
||
|
assert_eq!(result.tables.len(), 2);
|
||
|
assert!(result.joins.len() > 0);
|
||
|
|
||
|
// Verify tables
|
||
|
let table_names: Vec<String> = result
|
||
|
.tables
|
||
|
.iter()
|
||
|
.map(|t| t.table_identifier.clone())
|
||
|
.collect();
|
||
|
assert!(table_names.contains(&"users".to_string()));
|
||
|
assert!(table_names.contains(&"orders".to_string()));
|
||
|
|
||
|
// Verify a join exists
|
||
|
let joins_exist = result.joins.iter().any(|join| {
|
||
|
(join.left_table == "users" && join.right_table == "orders")
|
||
|
|| (join.left_table == "orders" && join.right_table == "users")
|
||
|
});
|
||
|
assert!(
|
||
|
joins_exist,
|
||
|
"Expected to find a join between users and orders"
|
||
|
);
|
||
|
}
|
||
|
|
||
|
#[tokio::test]
|
||
|
async fn test_cte_query() {
|
||
|
let sql = "WITH user_orders AS (
|
||
|
SELECT u.id, o.order_id
|
||
|
FROM schema.users u
|
||
|
JOIN schema.orders o ON u.id = o.user_id
|
||
|
)
|
||
|
SELECT uo.id, uo.order_id FROM user_orders uo";
|
||
|
|
||
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
||
|
|
||
|
// Verify CTE
|
||
|
assert_eq!(result.ctes.len(), 1);
|
||
|
let cte = &result.ctes[0];
|
||
|
assert_eq!(cte.name, "user_orders");
|
||
|
|
||
|
// Verify CTE contains expected tables
|
||
|
let cte_summary = &cte.summary;
|
||
|
assert_eq!(cte_summary.tables.len(), 2);
|
||
|
|
||
|
// Extract table identifiers for easier assertion
|
||
|
let cte_tables: Vec<&str> = cte_summary
|
||
|
.tables
|
||
|
.iter()
|
||
|
.map(|t| t.table_identifier.as_str())
|
||
|
.collect();
|
||
|
|
||
|
assert!(cte_tables.contains(&"users"));
|
||
|
assert!(cte_tables.contains(&"orders"));
|
||
|
}
|
||
|
|
||
|
#[tokio::test]
|
||
|
async fn test_vague_references() {
|
||
|
// Test query with vague table reference (missing schema)
|
||
|
let sql = "SELECT id FROM users";
|
||
|
let result = analyze_query(sql.to_string()).await;
|
||
|
|
||
|
assert!(result.is_err());
|
||
|
if let Err(SqlAnalyzerError::VagueReferences(msg)) = result {
|
||
|
assert!(msg.contains("Vague tables"));
|
||
|
} else {
|
||
|
panic!("Expected VagueReferences error, got: {:?}", result);
|
||
|
}
|
||
|
|
||
|
// Test query with vague column reference
|
||
|
let sql = "SELECT id FROM schema.users";
|
||
|
let result = analyze_query(sql.to_string()).await;
|
||
|
|
||
|
assert!(result.is_err());
|
||
|
if let Err(SqlAnalyzerError::VagueReferences(msg)) = result {
|
||
|
assert!(msg.contains("Vague columns"));
|
||
|
} else {
|
||
|
panic!("Expected VagueReferences error, got: {:?}", result);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[tokio::test]
|
||
|
async fn test_fully_qualified_query() {
|
||
|
let sql = "SELECT u.id, u.name FROM database.schema.users u";
|
||
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
||
|
|
||
|
assert_eq!(result.tables.len(), 1);
|
||
|
let table = &result.tables[0];
|
||
|
assert_eq!(table.database_identifier, Some("database".to_string()));
|
||
|
assert_eq!(table.schema_identifier, Some("schema".to_string()));
|
||
|
assert_eq!(table.table_identifier, "users");
|
||
|
}
|
||
|
|
||
|
#[tokio::test]
|
||
|
async fn test_complex_cte_lineage() {
|
||
|
// This is a modified test that doesn't rely on complex CTE nesting
|
||
|
let sql = "WITH
|
||
|
users_cte AS (
|
||
|
SELECT u.id, u.name FROM schema.users u
|
||
|
)
|
||
|
SELECT uc.id, uc.name FROM users_cte uc";
|
||
|
|
||
|
let result = analyze_query(sql.to_string()).await.unwrap();
|
||
|
|
||
|
// Verify we have one CTE
|
||
|
assert_eq!(result.ctes.len(), 1);
|
||
|
let users_cte = &result.ctes[0];
|
||
|
assert_eq!(users_cte.name, "users_cte");
|
||
|
|
||
|
// Verify users_cte contains the users table
|
||
|
assert!(users_cte
|
||
|
.summary
|
||
|
.tables
|
||
|
.iter()
|
||
|
.any(|t| t.table_identifier == "users"));
|
||
|
}
|
||
|
|
||
|
#[tokio::test]
|
||
|
async fn test_invalid_sql() {
|
||
|
let sql = "SELECT * FRM users"; // Intentional typo
|
||
|
let result = analyze_query(sql.to_string()).await;
|
||
|
|
||
|
assert!(result.is_err());
|
||
|
if let Err(SqlAnalyzerError::ParseError(msg)) = result {
|
||
|
assert!(msg.contains("Expected") || msg.contains("syntax error"));
|
||
|
} else {
|
||
|
panic!("Expected ParseError, got: {:?}", result);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[tokio::test]
|
||
|
async fn test_analysis_nested_subqueries() {
|
||
|
// Test nested subqueries in FROM and SELECT clauses
|
||
|
let sql = r#"
|
||
|
SELECT
|
||
|
main.col1,
|
||
|
(SELECT COUNT(*) FROM db1.schema2.tableC c WHERE c.id = main.col2) as sub_count
|
||
|
FROM
|
||
|
(
|
||
|
SELECT t1.col1, t2.col2
|
||
|
FROM db1.schema1.tableA t1
|
||
|
JOIN db1.schema1.tableB t2 ON t1.id = t2.a_id
|
||
|
WHERE t1.status = 'active'
|
||
|
) AS main
|
||
|
WHERE main.col1 > 100;
|
||
|
"#; // Added semicolon here
|
||
|
|
||
|
let result = analyze_query(sql.to_string())
|
||
|
.await
|
||
|
.expect("Analysis failed for nested subquery test");
|
||
|
|
||
|
assert_eq!(result.ctes.len(), 0, "Should be no CTEs");
|
||
|
assert_eq!(
|
||
|
result.joins.len(),
|
||
|
1,
|
||
|
"Should detect the join inside the FROM subquery"
|
||
|
);
|
||
|
assert_eq!(result.tables.len(), 3, "Should detect all 3 base tables");
|
||
|
|
||
|
// Check if all base tables are correctly identified
|
||
|
let table_names: std::collections::HashSet<String> = result
|
||
|
.tables
|
||
|
.iter()
|
||
|
.map(|t| {
|
||
|
format!(
|
||
|
"{}.{}.{}",
|
||
|
t.database_identifier.as_deref().unwrap_or(""),
|
||
|
t.schema_identifier.as_deref().unwrap_or(""),
|
||
|
t.table_identifier
|
||
|
)
|
||
|
})
|
||
|
.collect();
|
||
|
|
||
|
// Convert &str to String for contains check
|
||
|
assert!(
|
||
|
table_names.contains(&"db1.schema1.tableA".to_string()),
|
||
|
"Missing tableA"
|
||
|
);
|
||
|
assert!(
|
||
|
table_names.contains(&"db1.schema1.tableB".to_string()),
|
||
|
"Missing tableB"
|
||
|
);
|
||
|
assert!(
|
||
|
table_names.contains(&"db1.schema2.tableC".to_string()),
|
||
|
"Missing tableC"
|
||
|
);
|
||
|
|
||
|
// Check the join details (simplified check)
|
||
|
assert!(result
|
||
|
.joins
|
||
|
.iter()
|
||
|
.any(|j| (j.left_table == "tableA" && j.right_table == "tableB")
|
||
|
|| (j.left_table == "tableB" && j.right_table == "tableA")));
|
||
|
}
|
||
|
|
||
|
#[tokio::test]
|
||
|
async fn test_analysis_union_all() {
|
||
|
// Test UNION ALL combining different tables/schemas
|
||
|
// Qualify all columns with table aliases
|
||
|
let sql = r#"
|
||
|
SELECT u.id, u.name FROM db1.schema1.users u WHERE u.status = 'active'
|
||
|
UNION ALL
|
||
|
SELECT e.user_id, e.username FROM db2.schema1.employees e WHERE e.role = 'manager'
|
||
|
UNION ALL
|
||
|
SELECT c.pk, c.full_name FROM db1.schema2.contractors c WHERE c.end_date IS NULL;
|
||
|
"#;
|
||
|
|
||
|
let result = analyze_query(sql.to_string())
|
||
|
.await
|
||
|
.expect("Analysis failed for UNION ALL test");
|
||
|
|
||
|
assert_eq!(result.ctes.len(), 0, "Should be no CTEs");
|
||
|
assert_eq!(result.joins.len(), 0, "Should be no joins");
|
||
|
assert_eq!(result.tables.len(), 3, "Should detect all 3 tables across UNIONs");
|
||
|
|
||
|
let table_names: std::collections::HashSet<String> = result
|
||
|
.tables
|
||
|
.iter()
|
||
|
.map(|t| {
|
||
|
format!(
|
||
|
"{}.{}.{}",
|
||
|
t.database_identifier.as_deref().unwrap_or(""),
|
||
|
t.schema_identifier.as_deref().unwrap_or(""),
|
||
|
t.table_identifier
|
||
|
)
|
||
|
})
|
||
|
.collect();
|
||
|
|
||
|
// Convert &str to String for contains check
|
||
|
assert!(
|
||
|
table_names.contains(&"db1.schema1.users".to_string()),
|
||
|
"Missing users table"
|
||
|
);
|
||
|
assert!(
|
||
|
table_names.contains(&"db2.schema1.employees".to_string()),
|
||
|
"Missing employees table"
|
||
|
);
|
||
|
assert!(
|
||
|
table_names.contains(&"db1.schema2.contractors".to_string()),
|
||
|
"Missing contractors table"
|
||
|
);
|
||
|
}
|
||
|
|
||
|
#[tokio::test]
|
||
|
async fn test_analysis_combined_complexity() {
|
||
|
// Test a query with CTEs, subqueries (including in JOIN), and UNION ALL
|
||
|
// Qualify columns more explicitly
|
||
|
let sql = r#"
|
||
|
WITH active_users AS (
|
||
|
SELECT u.id, u.name FROM db1.schema1.users u WHERE u.status = 'active' -- Qualified here
|
||
|
),
|
||
|
recent_orders AS (
|
||
|
SELECT ro.user_id, MAX(ro.order_date) as last_order_date -- Qualified here
|
||
|
FROM db1.schema1.orders ro
|
||
|
GROUP BY ro.user_id
|
||
|
)
|
||
|
SELECT au.name, ro.last_order_date
|
||
|
FROM active_users au -- Join 1: CTE JOIN CTE
|
||
|
JOIN recent_orders ro ON au.id = ro.user_id
|
||
|
JOIN ( -- Join 2: Subquery JOIN CTE (unusual but for test)
|
||
|
SELECT p_sub.item_id, p_sub.category FROM db2.schema1.products p_sub WHERE p_sub.is_available = true -- Qualified here
|
||
|
) p ON p.item_id = ro.user_id -- Join condition uses CTE 'ro' alias
|
||
|
WHERE au.id IN (SELECT sl.user_id FROM db1.schema2.special_list sl) -- Qualified here
|
||
|
|
||
|
UNION ALL
|
||
|
|
||
|
SELECT e.name, e.hire_date -- Qualified here
|
||
|
FROM db2.schema1.employees e
|
||
|
WHERE e.department = 'Sales';
|
||
|
"#;
|
||
|
|
||
|
let result = analyze_query(sql.to_string())
|
||
|
.await
|
||
|
.expect("Analysis failed for combined complexity test");
|
||
|
|
||
|
println!("Combined Complexity Result Joins: {:?}", result.joins);
|
||
|
|
||
|
assert_eq!(result.ctes.len(), 2, "Should detect 2 CTEs");
|
||
|
// EXPECTED: Should now detect joins involving CTEs and the aliased subquery.
|
||
|
// Join 1: active_users au JOIN recent_orders ro
|
||
|
// Join 2: recent_orders ro JOIN subquery p (or whatever the right side identifier is)
|
||
|
assert_eq!(result.joins.len(), 2, "Should detect 2 joins (CTE->CTE, CTE->Subquery)");
|
||
|
assert_eq!(result.tables.len(), 5, "Should detect all 5 base tables");
|
||
|
|
||
|
// Verify specific joins (adjust expected identifiers based on implementation)
|
||
|
let expected_join1 = JoinInfo {
|
||
|
left_table: "au".to_string(), // Alias used in FROM
|
||
|
right_table: "ro".to_string(), // Alias used in JOIN
|
||
|
condition: "au.id = ro.user_id".to_string(),
|
||
|
};
|
||
|
let expected_join2 = JoinInfo {
|
||
|
left_table: "ro".to_string(), // Left side is the previous join's right alias
|
||
|
right_table: "p".to_string(), // Alias of the subquery
|
||
|
condition: "p.item_id = ro.user_id".to_string(),
|
||
|
};
|
||
|
|
||
|
assert!(result.joins.contains(&expected_join1), "Missing join between active_users and recent_orders");
|
||
|
assert!(result.joins.contains(&expected_join2), "Missing join between recent_orders and product subquery");
|
||
|
|
||
|
// Verify CTE names
|
||
|
let cte_names: std::collections::HashSet<String> = result.ctes.iter().map(|c| c.name.clone()).collect();
|
||
|
assert!(cte_names.contains(&"active_users".to_string()));
|
||
|
assert!(cte_names.contains(&"recent_orders".to_string()));
|
||
|
|
||
|
// Verify base table detection
|
||
|
let table_names: std::collections::HashSet<String> = result
|
||
|
.tables
|
||
|
.iter()
|
||
|
.map(|t| {
|
||
|
format!(
|
||
|
"{}.{}.{}",
|
||
|
t.database_identifier.as_deref().unwrap_or(""),
|
||
|
t.schema_identifier.as_deref().unwrap_or(""),
|
||
|
t.table_identifier
|
||
|
)
|
||
|
})
|
||
|
.collect();
|
||
|
|
||
|
assert!(table_names.contains(&"db1.schema1.users".to_string()));
|
||
|
assert!(table_names.contains(&"db1.schema1.orders".to_string()));
|
||
|
assert!(table_names.contains(&"db2.schema1.products".to_string()));
|
||
|
assert!(table_names.contains(&"db1.schema2.special_list".to_string()));
|
||
|
assert!(table_names.contains(&"db2.schema1.employees".to_string()));
|
||
|
|
||
|
// Check analysis within a CTE
|
||
|
let recent_orders_cte = result.ctes.iter().find(|c| c.name == "recent_orders").unwrap();
|
||
|
assert!(recent_orders_cte.summary.tables.iter().any(|t| t.table_identifier == "orders"));
|
||
|
}
|