Enhance SQL query safety checks with dialect-aware filtering and improve data source type retrieval in query engine

This commit is contained in:
dal 2025-08-05 11:00:34 -06:00
parent cb98f5dd9a
commit 9be23fac0e
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
3 changed files with 158 additions and 7 deletions

View File

@ -19,11 +19,18 @@ use crate::{
use database::types::data_metadata::{ColumnMetaData, ColumnType, DataMetadata, SimpleType};
use database::vault::read_secret;
use database::{
enums::DataSourceType,
pool::get_pg_pool,
schema::data_sources,
};
use diesel::prelude::*;
use diesel_async::RunQueryDsl;
use super::{
bigquery_query::bigquery_query, databricks_query::databricks_query, mysql_query::mysql_query,
postgres_query::postgres_query, redshift_query::redshift_query,
security_utils::query_safety_filter, snowflake_query::{snowflake_query, ProcessingResult},
security_utils::query_safety_filter_with_dialect, snowflake_query::{snowflake_query, ProcessingResult},
sql_server_query::sql_server_query,
};
@ -41,9 +48,25 @@ pub async fn query_engine(
) -> Result<QueryResult> {
let corrected_sql = sql.to_owned();
// Fetch the data source type from the database
let mut conn = get_pg_pool().get().await
.map_err(|e| anyhow!("Failed to get database connection: {}", e))?;
let data_source_type = data_sources::table
.filter(data_sources::id.eq(data_source_id))
.select(data_sources::type_)
.first::<DataSourceType>(&mut conn)
.await
.map_err(|e| anyhow!("Failed to fetch data source type: {}", e))?;
let data_source_dialect = data_source_type.to_str();
let secure_sql = corrected_sql.clone();
if let Some(warning) = query_safety_filter(secure_sql.clone()).await { return Err(anyhow!(warning)) };
// Use the dialect-aware security filter
if let Some(warning) = query_safety_filter_with_dialect(secure_sql.clone(), data_source_dialect).await {
return Err(anyhow!(warning))
};
let results = match route_to_query(data_source_id, &secure_sql, limit).await {
Ok(results) => results,

View File

@ -1,15 +1,42 @@
use sqlparser::dialect::GenericDialect;
use sqlparser::dialect::{
GenericDialect, SnowflakeDialect, PostgreSqlDialect, MySqlDialect,
BigQueryDialect, MsSqlDialect, DatabricksDialect, SQLiteDialect,
AnsiDialect, Dialect
};
use sqlparser::parser::Parser;
use sqlparser::ast::{Statement, SetExpr, Query};
/// Helper function to get the appropriate SQL dialect based on data source type
fn get_dialect(data_source_type: &str) -> Box<dyn Dialect> {
match data_source_type.to_lowercase().as_str() {
"bigquery" => Box::new(BigQueryDialect {}),
"databricks" => Box::new(DatabricksDialect {}),
"mysql" | "mariadb" => Box::new(MySqlDialect {}),
"postgres" | "postgresql" | "redshift" | "supabase" => Box::new(PostgreSqlDialect {}),
"snowflake" => Box::new(SnowflakeDialect {}),
"sqlserver" | "mssql" => Box::new(MsSqlDialect {}),
"sqlite" => Box::new(SQLiteDialect {}),
"ansi" => Box::new(AnsiDialect {}),
_ => Box::new(GenericDialect {}),
}
}
/// Checks if a SQL query is safe to execute by parsing it and ensuring it only contains
/// SELECT statements.
///
/// Returns None if the query is safe, or Some(error_message) if it's not allowed.
pub async fn query_safety_filter(sql: String) -> Option<String> {
query_safety_filter_with_dialect(sql, "generic").await
}
/// Checks if a SQL query is safe to execute by parsing it with the appropriate dialect
/// and ensuring it only contains SELECT statements.
///
/// Returns None if the query is safe, or Some(error_message) if it's not allowed.
pub async fn query_safety_filter_with_dialect(sql: String, data_source_type: &str) -> Option<String> {
// Parse the SQL query
let dialect = GenericDialect {}; // Generic SQL dialect
let ast = match Parser::parse_sql(&dialect, &sql) {
let dialect = get_dialect(data_source_type);
let ast = match Parser::parse_sql(dialect.as_ref(), &sql) {
Ok(ast) => ast,
Err(e) => {
return Some(format!("Failed to parse SQL query: {}", e));
@ -211,4 +238,105 @@ mod tests {
let result = query_safety_filter(query.to_string()).await;
assert!(result.is_none(), "Safe UNION query was rejected: {:?}", result);
}
#[tokio::test]
async fn test_snowflake_complex_case_expression() {
// This is the exact query that fails in production
let query = r#"select
date_trunc('month', r.createdat) as month,
count(distinct rtd.tracking_number) as return_labels
from staging.mongodb.stg_returns r
join staging.mongodb.stg_return_tracking_details rtd on r._id = rtd.return_id
join dbt.general.teams t on r.team = t.team_id
where r.status = 'complete'
and case
when coalesce(
r.shipment:_shipment:is_return::boolean,
r.shipment:_shipment:tracker:is_return::boolean,
r.shipment:_shipment:from_address:name like any ('%(REFUND)%', '%(STORE CREDIT)%', '%(EXCHANGE)%'),
false
)
then r.shipment:_shipment:to_address:country::text
else r.shipment:_shipment:from_address:country::text
end in ('GB', 'BE', 'EL', 'LT', 'PT', 'BG', 'ES', 'LU', 'RO', 'CZ', 'FR', 'HU', 'SI', 'DK', 'HR', 'MT', 'SK', 'DE', 'IT', 'NL', 'FI', 'EE', 'CY', 'AT', 'SE', 'IE', 'LV', 'PL')
group by all
order by month desc"#;
let result = query_safety_filter(query.to_string()).await;
// This test currently fails with the error:
// "Failed to parse SQL query: sql parser error: Expected: end of statement, found: when at Line: 9, Column: 9"
assert!(result.is_some(), "Expected parsing error for Snowflake-specific syntax");
assert!(result.unwrap().contains("Failed to parse SQL query"), "Should fail with parsing error");
}
#[tokio::test]
async fn test_snowflake_query_with_dialect_parameter() {
// Test the same query using the new dialect-aware function
let query = r#"select
date_trunc('month', r.createdat) as month,
count(distinct rtd.tracking_number) as return_labels
from staging.mongodb.stg_returns r
join staging.mongodb.stg_return_tracking_details rtd on r._id = rtd.return_id
join dbt.general.teams t on r.team = t.team_id
where r.status = 'complete'
and case
when coalesce(
r.shipment:_shipment:is_return::boolean,
r.shipment:_shipment:tracker:is_return::boolean,
r.shipment:_shipment:from_address:name like any ('%(REFUND)%', '%(STORE CREDIT)%', '%(EXCHANGE)%'),
false
)
then r.shipment:_shipment:to_address:country::text
else r.shipment:_shipment:from_address:country::text
end in ('GB', 'BE', 'EL', 'LT', 'PT', 'BG', 'ES', 'LU', 'RO', 'CZ', 'FR', 'HU', 'SI', 'DK', 'HR', 'MT', 'SK', 'DE', 'IT', 'NL', 'FI', 'EE', 'CY', 'AT', 'SE', 'IE', 'LV', 'PL')
group by all
order by month desc"#;
// Try with the new dialect-aware function
let result = query_safety_filter_with_dialect(query.to_string(), "snowflake").await;
// Should pass with Snowflake dialect
assert!(result.is_none(), "Snowflake query should be accepted with Snowflake dialect: {:?}", result);
}
#[tokio::test]
async fn test_snowflake_query_with_snowflake_dialect() {
// Test the same query but with SnowflakeDialect directly
let query = r#"select
date_trunc('month', r.createdat) as month,
count(distinct rtd.tracking_number) as return_labels
from staging.mongodb.stg_returns r
join staging.mongodb.stg_return_tracking_details rtd on r._id = rtd.return_id
join dbt.general.teams t on r.team = t.team_id
where r.status = 'complete'
and case
when coalesce(
r.shipment:_shipment:is_return::boolean,
r.shipment:_shipment:tracker:is_return::boolean,
r.shipment:_shipment:from_address:name like any ('%(REFUND)%', '%(STORE CREDIT)%', '%(EXCHANGE)%'),
false
)
then r.shipment:_shipment:to_address:country::text
else r.shipment:_shipment:from_address:country::text
end in ('GB', 'BE', 'EL', 'LT', 'PT', 'BG', 'ES', 'LU', 'RO', 'CZ', 'FR', 'HU', 'SI', 'DK', 'HR', 'MT', 'SK', 'DE', 'IT', 'NL', 'FI', 'EE', 'CY', 'AT', 'SE', 'IE', 'LV', 'PL')
group by all
order by month desc"#;
// Try parsing with SnowflakeDialect
let dialect = SnowflakeDialect {};
let parse_result = Parser::parse_sql(&dialect, query);
// Check if SnowflakeDialect can parse this query
match parse_result {
Ok(_) => {
println!("SnowflakeDialect successfully parsed the query!");
// If it parses, it would still be rejected as a SELECT query by our filter
},
Err(e) => {
println!("SnowflakeDialect also failed to parse: {}", e);
// Even SnowflakeDialect might have issues with this syntax
}
}
}
}

View File

@ -8,7 +8,7 @@ use sqlparser::ast::{
};
use sqlparser::dialect::{
AnsiDialect, BigQueryDialect, ClickHouseDialect, DatabricksDialect, Dialect, DuckDbDialect,
GenericDialect, HiveDialect, MsSqlDialect, MySqlDialect, PostgreSqlDialect, SQLiteDialect,
GenericDialect, HiveDialect, MsSqlDialect, MySqlDialect, PostgreSqlDialect, SQLiteDialect, SnowflakeDialect,
};
use sqlparser::parser::Parser;
use std::collections::{HashMap, HashSet};
@ -48,7 +48,7 @@ pub fn get_dialect(data_source_dialect: &str) -> &'static dyn Dialect {
"mariadb" => &MySqlDialect {}, // MariaDB uses MySQL dialect
"postgres" => &PostgreSqlDialect {},
"redshift" => &PostgreSqlDialect {}, // Redshift uses PostgreSQL dialect
"snowflake" => &GenericDialect {}, // SnowflakeDialect has limitations with some syntax, use GenericDialect
"snowflake" => &SnowflakeDialect{}, // SnowflakeDialect has limitations with some syntax, use GenericDialect
"sqlserver" => &MsSqlDialect {}, // SQL Server uses MS SQL dialect
"supabase" => &PostgreSqlDialect {}, // Supabase uses PostgreSQL dialect
"generic" => &GenericDialect {},