diff --git a/apps/api/libs/sql_analyzer/src/analysis.rs b/apps/api/libs/sql_analyzer/src/analysis.rs index 8fbd439b2..6ed4f4dd2 100644 --- a/apps/api/libs/sql_analyzer/src/analysis.rs +++ b/apps/api/libs/sql_analyzer/src/analysis.rs @@ -178,11 +178,11 @@ impl QueryAnalyzer { self.parent_scope_aliases = parent_aliases.clone(); // Process WITH clause (CTEs) if present - let is_with_query = self.process_with_clause(query); + let is_with_query = self.process_with_clause(query)?; // Process the main query body match query.body.as_ref() { - SetExpr::Select(select) => self.process_select_query(select), + SetExpr::Select(select) => self.process_select_query(select)?, SetExpr::Query(inner_query) => { self.process_nested_query(inner_query)?; } @@ -201,7 +201,7 @@ impl QueryAnalyzer { } // Process WITH clause and return whether it was processed - fn process_with_clause(&mut self, query: &Query) -> bool { + fn process_with_clause(&mut self, query: &Query) -> Result { if let Some(with) = &query.with { if !with.cte_tables.is_empty() { // Create a new scope for CTE definitions @@ -223,18 +223,16 @@ impl QueryAnalyzer { .chain(self.parent_scope_aliases.iter()) .map(|(k, v)| (k.clone(), v.clone())) .collect(); - if let Err(e) = self.process_cte(cte, &combined_aliases_for_cte) { - eprintln!("Error processing CTE: {}", e); - } + self.process_cte(cte, &combined_aliases_for_cte)?; } - return true; + return Ok(true); } } - false + Ok(false) } // Process a SELECT query - fn process_select_query(&mut self, select: &sqlparser::ast::Select) { + fn process_select_query(&mut self, select: &sqlparser::ast::Select) -> Result<(), SqlAnalyzerError> { self.current_scope_aliases.clear(); self.current_select_list_aliases.clear(); self.current_from_relation_identifier = None; @@ -293,8 +291,9 @@ impl QueryAnalyzer { // Process SELECT list for item in &select.projection { - self.process_select_item(item, &combined_aliases_for_visit); + self.process_select_item(item, &combined_aliases_for_visit)?; } + Ok(()) } // Process join data and collect conditions for later processing @@ -518,6 +517,7 @@ impl QueryAnalyzer { Err(e @ SqlAnalyzerError::VagueReferences(_)) => Err( SqlAnalyzerError::VagueReferences(format!("In CTE '{}': {}", cte_name, e)), ), + Err(e @ SqlAnalyzerError::BlockedWildcardUsage(_)) => Err(e), Err(e) => Err(SqlAnalyzerError::Internal(anyhow::anyhow!( "Internal error summarizing CTE '{}': {}", cte_name, @@ -528,6 +528,7 @@ impl QueryAnalyzer { Err(e @ SqlAnalyzerError::VagueReferences(_)) => Err( SqlAnalyzerError::VagueReferences(format!("In CTE '{}': {}", cte_name, e)), ), + Err(e @ SqlAnalyzerError::BlockedWildcardUsage(_)) => Err(e), Err(e) => Err(SqlAnalyzerError::Internal(anyhow::anyhow!( "Error processing CTE '{}': {}", cte_name, @@ -914,7 +915,7 @@ impl QueryAnalyzer { &mut self, select_item: &SelectItem, parent_aliases: &HashMap, - ) { + ) -> Result<(), SqlAnalyzerError> { match select_item { SelectItem::UnnamedExpr(expr) | SelectItem::ExprWithAlias { expr, .. } => { self.visit_expr_with_parent_scope(expr, parent_aliases); @@ -926,6 +927,8 @@ impl QueryAnalyzer { .map(|i| i.value.clone()) .unwrap_or_default(); if !qualifier.is_empty() { + self.validate_qualified_wildcard(&qualifier)?; + if !self.current_scope_aliases.contains_key(&qualifier) && !parent_aliases.contains_key(&qualifier) && !self.tables.contains_key(&qualifier) @@ -936,9 +939,10 @@ impl QueryAnalyzer { } } SelectItem::Wildcard(_) => { - // Unqualified wildcard - we don't explicitly add columns for unqualified wildcard + self.validate_wildcard_on_tables()?; } } + Ok(()) } fn into_summary(mut self) -> Result { @@ -1145,6 +1149,53 @@ impl QueryAnalyzer { in_definitions || in_ctes } + fn validate_wildcard_on_tables(&self) -> Result<(), SqlAnalyzerError> { + // Only validate tables that are actually in the FROM clause + if let Some(from_table) = &self.current_from_relation_identifier { + if let Some(table_info) = self.tables.get(from_table) { + if table_info.kind == TableKind::Base { + return Err(SqlAnalyzerError::BlockedWildcardUsage(format!( + "table '{}'", table_info.table_identifier + ))); + } + } + } + + // Also check any tables that might be in current scope aliases that are physical tables + for alias in self.current_scope_aliases.keys() { + if let Some(from_table) = &self.current_from_relation_identifier { + if alias == from_table { + continue; + } + } + + if let Some(table_info) = self.tables.get(alias) { + if table_info.kind == TableKind::Base { + return Err(SqlAnalyzerError::BlockedWildcardUsage(format!( + "table '{}'", table_info.table_identifier + ))); + } + } + } + Ok(()) + } + + fn validate_qualified_wildcard(&self, qualifier: &str) -> Result<(), SqlAnalyzerError> { + let resolved_table = self.current_scope_aliases.get(qualifier) + .or_else(|| self.parent_scope_aliases.get(qualifier)) + .map(|s| s.as_str()) + .unwrap_or(qualifier); + + if let Some(table_info) = self.tables.get(resolved_table) { + if table_info.kind == TableKind::Base { + return Err(SqlAnalyzerError::BlockedWildcardUsage(format!( + "table '{}'", table_info.table_identifier + ))); + } + } + Ok(()) + } + fn add_column_reference( &mut self, qualifier_opt: Option<&str>, diff --git a/apps/api/libs/sql_analyzer/src/errors.rs b/apps/api/libs/sql_analyzer/src/errors.rs index 0db563506..ffda1d0b4 100644 --- a/apps/api/libs/sql_analyzer/src/errors.rs +++ b/apps/api/libs/sql_analyzer/src/errors.rs @@ -35,6 +35,9 @@ pub enum SqlAnalyzerError { #[error("Unsupported statement type found: {0}")] UnsupportedStatement(String), + #[error("Wildcard usage on physical tables is not allowed: {0}")] + BlockedWildcardUsage(String), + #[error("Internal error: {0}")] Internal(#[from] anyhow::Error), } @@ -43,4 +46,4 @@ impl From for SqlAnalyzerError { fn from(err: sqlparser::parser::ParserError) -> Self { SqlAnalyzerError::ParseError(err.to_string()) } -} \ No newline at end of file +} diff --git a/apps/api/libs/sql_analyzer/tests/analysis_tests.rs b/apps/api/libs/sql_analyzer/tests/analysis_tests.rs index 8f3e1a4bc..2c0bcde10 100644 --- a/apps/api/libs/sql_analyzer/tests/analysis_tests.rs +++ b/apps/api/libs/sql_analyzer/tests/analysis_tests.rs @@ -413,6 +413,90 @@ async fn test_multiple_chained_ctes() { assert_eq!(result.joins.len(), 0, "Main query should have no direct joins"); } +#[tokio::test] +async fn test_wildcard_blocked_on_physical_table() { + let sql = "SELECT * FROM schema.users"; + let result = analyze_query(sql.to_string(), "postgres").await; + + assert!(result.is_err()); + if let Err(SqlAnalyzerError::BlockedWildcardUsage(msg)) = result { + assert!(msg.contains("users")); + } else { + panic!("Expected BlockedWildcardUsage error, got: {:?}", result); + } +} + +#[tokio::test] +async fn test_qualified_wildcard_blocked_on_physical_table() { + let sql = "SELECT u.* FROM schema.users u"; + let result = analyze_query(sql.to_string(), "postgres").await; + + assert!(result.is_err()); + if let Err(SqlAnalyzerError::BlockedWildcardUsage(msg)) = result { + assert!(msg.contains("users")); + } else { + panic!("Expected BlockedWildcardUsage error, got: {:?}", result); + } +} + +#[tokio::test] +async fn test_wildcard_allowed_on_cte() { + let sql = "WITH user_cte AS (SELECT u.id, u.name FROM schema.users u) SELECT * FROM user_cte"; + let result = analyze_query(sql.to_string(), "postgres").await; + + match result { + Ok(_) => { + } + Err(e) => { + eprintln!("DEBUG: Unexpected error in test_wildcard_allowed_on_cte: {:?}", e); + panic!("Wildcard on CTE should be allowed, but got error: {:?}", e); + } + } +} + +#[tokio::test] +async fn test_qualified_wildcard_allowed_on_cte() { + let sql = "WITH user_cte AS (SELECT u.id, u.name FROM schema.users u) SELECT uc.* FROM user_cte uc"; + let result = analyze_query(sql.to_string(), "postgres").await; + + assert!(result.is_ok(), "Qualified wildcard on CTE should be allowed"); +} + +#[tokio::test] +async fn test_wildcard_blocked_when_cte_uses_wildcard_on_physical_table() { + let sql = "WITH user_cte AS (SELECT * FROM schema.users) SELECT * FROM user_cte"; + let result = analyze_query(sql.to_string(), "postgres").await; + + assert!(result.is_err()); + if let Err(SqlAnalyzerError::BlockedWildcardUsage(msg)) = result { + assert!(msg.contains("users")); + } else { + panic!("Expected BlockedWildcardUsage error for CTE using wildcard on physical table, got: {:?}", result); + } +} + +#[tokio::test] +async fn test_wildcard_allowed_when_cte_uses_explicit_columns() { + let sql = "WITH user_cte AS (SELECT u.id, u.name FROM schema.users u) SELECT * FROM user_cte"; + let result = analyze_query(sql.to_string(), "postgres").await; + + assert!(result.is_ok(), "Wildcard should be allowed when CTE uses explicit columns"); +} + +#[tokio::test] +async fn test_mixed_wildcard_scenarios() { + let sql = "WITH orders_cte AS (SELECT o.order_id FROM schema.orders o) + SELECT oc.*, u.* FROM orders_cte oc JOIN schema.users u ON oc.order_id = u.id"; + let result = analyze_query(sql.to_string(), "postgres").await; + + assert!(result.is_err()); + if let Err(SqlAnalyzerError::BlockedWildcardUsage(msg)) = result { + assert!(msg.contains("users")); + } else { + panic!("Expected BlockedWildcardUsage error for wildcard on physical table, got: {:?}", result); + } +} + #[tokio::test] async fn test_complex_where_clause() { let sql = r#" @@ -2098,7 +2182,7 @@ async fn test_databricks_pivot() { #[tokio::test] async fn test_databricks_qualified_wildcard() { - // Test Databricks qualified wildcards + // Test Databricks qualified wildcards - should now be blocked due to security enhancement let sql = r#" SELECT u.user_id, @@ -2111,25 +2195,14 @@ async fn test_databricks_qualified_wildcard() { WHERE u.status = 'active' AND p.amount > 100 "#; - let result = analyze_query(sql.to_string(), "databricks").await.unwrap(); + let result = analyze_query(sql.to_string(), "databricks").await; - // Check base tables - let base_tables: Vec<_> = result.tables.iter() - .filter(|t| t.kind == TableKind::Base) - .map(|t| t.table_identifier.clone()) - .collect(); - - assert!(base_tables.contains(&"users".to_string()), "Should detect users table"); - assert!(base_tables.contains(&"purchases".to_string()), "Should detect purchases table"); - - // Check columns - let users_table = result.tables.iter().find(|t| t.table_identifier == "users").unwrap(); - assert!(users_table.columns.contains("user_id"), "Should detect user_id column"); - assert!(users_table.columns.contains("name"), "Should detect name column"); - assert!(users_table.columns.contains("status"), "Should detect status column"); - - // Check joins - assert!(!result.joins.is_empty(), "Should detect JOIN"); + assert!(result.is_err()); + if let Err(SqlAnalyzerError::BlockedWildcardUsage(msg)) = result { + assert!(msg.contains("users") || msg.contains("purchases")); + } else { + panic!("Expected BlockedWildcardUsage error for wildcards on physical tables, got: {:?}", result); + } } #[tokio::test] diff --git a/packages/ai/src/utils/sql-permissions/sql-parser-helpers.ts b/packages/ai/src/utils/sql-permissions/sql-parser-helpers.ts index 9ae605b5c..b36c9c02f 100644 --- a/packages/ai/src/utils/sql-permissions/sql-parser-helpers.ts +++ b/packages/ai/src/utils/sql-permissions/sql-parser-helpers.ts @@ -371,6 +371,28 @@ export function validateWildcardUsage(sql: string, dataSourceSyntax?: string): W } } + const tableList = parser.tableList(sql, { database: dialect }); + const tableAliasMap = new Map(); // alias -> table name + + if (Array.isArray(tableList)) { + for (const tableRef of tableList) { + if (typeof tableRef === 'string') { + // Simple table name + tableAliasMap.set(tableRef.toLowerCase(), tableRef); + } else if (tableRef && typeof tableRef === 'object') { + const tableRefObj = tableRef as any; // Type assertion to handle dynamic properties + const tableName = tableRefObj.table || tableRefObj.name; + const alias = tableRefObj.as || tableRefObj.alias; + if (tableName) { + if (alias) { + tableAliasMap.set(alias.toLowerCase(), tableName); + } + tableAliasMap.set(tableName.toLowerCase(), tableName); + } + } + } + } + // Check each statement for wildcard usage const blockedTables: string[] = []; @@ -405,6 +427,43 @@ export function validateWildcardUsage(sql: string, dataSourceSyntax?: string): W function findWildcardUsageOnPhysicalTables(selectStatement: any, cteNames: Set): string[] { const blockedTables: string[] = []; + // Build alias mapping for this statement + const aliasToTableMap = new Map(); + if (selectStatement.from && Array.isArray(selectStatement.from)) { + for (const fromItem of selectStatement.from) { + if (fromItem.table && fromItem.as) { + let tableName: string; + if (typeof fromItem.table === 'string') { + tableName = fromItem.table; + } else if (fromItem.table && typeof fromItem.table === 'object') { + const tableObj = fromItem.table as any; + tableName = tableObj.table || tableObj.name || tableObj.value || String(fromItem.table); + } else { + continue; + } + aliasToTableMap.set(fromItem.as.toLowerCase(), tableName.toLowerCase()); + } + + // Handle JOINs + if (fromItem.join && Array.isArray(fromItem.join)) { + for (const joinItem of fromItem.join) { + if (joinItem.table && joinItem.as) { + let tableName: string; + if (typeof joinItem.table === 'string') { + tableName = joinItem.table; + } else if (joinItem.table && typeof joinItem.table === 'object') { + const tableObj = joinItem.table as any; + tableName = tableObj.table || tableObj.name || tableObj.value || String(joinItem.table); + } else { + continue; + } + aliasToTableMap.set(joinItem.as.toLowerCase(), tableName.toLowerCase()); + } + } + } + } + } + if (selectStatement.columns && Array.isArray(selectStatement.columns)) { for (const column of selectStatement.columns) { if (column.expr && column.expr.type === 'column_ref') { @@ -416,8 +475,24 @@ function findWildcardUsageOnPhysicalTables(selectStatement: any, cteNames: Set): st } for (const fromItem of fromClause) { + // Extract table name from fromItem if (fromItem.table) { - const tableName = typeof fromItem.table === 'string' - ? fromItem.table - : fromItem.table.table || fromItem.table; + let tableName: string; + if (typeof fromItem.table === 'string') { + tableName = fromItem.table; + } else if (fromItem.table && typeof fromItem.table === 'object') { + const tableObj = fromItem.table as any; + tableName = tableObj.table || tableObj.name || tableObj.value || String(fromItem.table); + } else { + continue; + } if (tableName && !cteNames.has(tableName.toLowerCase())) { - tables.push(tableName); + const aliasName = fromItem.as || tableName; + tables.push(aliasName); } } @@ -471,12 +555,19 @@ function getPhysicalTablesFromFrom(fromClause: any[], cteNames: Set): st if (fromItem.join && Array.isArray(fromItem.join)) { for (const joinItem of fromItem.join) { if (joinItem.table) { - const tableName = typeof joinItem.table === 'string' - ? joinItem.table - : joinItem.table.table || joinItem.table; + let tableName: string; + if (typeof joinItem.table === 'string') { + tableName = joinItem.table; + } else if (joinItem.table && typeof joinItem.table === 'object') { + const tableObj = joinItem.table as any; + tableName = tableObj.table || tableObj.name || tableObj.value || String(joinItem.table); + } else { + continue; + } if (tableName && !cteNames.has(tableName.toLowerCase())) { - tables.push(tableName); + const aliasName = joinItem.as || tableName; + tables.push(aliasName); } } }