mirror of https://github.com/buster-so/buster.git
Enhance SQL query safety checks with dialect-aware filtering and improve data source type retrieval in query engine
This commit is contained in:
parent
cb98f5dd9a
commit
9be23fac0e
|
@ -19,11 +19,18 @@ use crate::{
|
|||
|
||||
use database::types::data_metadata::{ColumnMetaData, ColumnType, DataMetadata, SimpleType};
|
||||
use database::vault::read_secret;
|
||||
use database::{
|
||||
enums::DataSourceType,
|
||||
pool::get_pg_pool,
|
||||
schema::data_sources,
|
||||
};
|
||||
use diesel::prelude::*;
|
||||
use diesel_async::RunQueryDsl;
|
||||
|
||||
use super::{
|
||||
bigquery_query::bigquery_query, databricks_query::databricks_query, mysql_query::mysql_query,
|
||||
postgres_query::postgres_query, redshift_query::redshift_query,
|
||||
security_utils::query_safety_filter, snowflake_query::{snowflake_query, ProcessingResult},
|
||||
security_utils::query_safety_filter_with_dialect, snowflake_query::{snowflake_query, ProcessingResult},
|
||||
sql_server_query::sql_server_query,
|
||||
};
|
||||
|
||||
|
@ -41,9 +48,25 @@ pub async fn query_engine(
|
|||
) -> Result<QueryResult> {
|
||||
let corrected_sql = sql.to_owned();
|
||||
|
||||
// Fetch the data source type from the database
|
||||
let mut conn = get_pg_pool().get().await
|
||||
.map_err(|e| anyhow!("Failed to get database connection: {}", e))?;
|
||||
|
||||
let data_source_type = data_sources::table
|
||||
.filter(data_sources::id.eq(data_source_id))
|
||||
.select(data_sources::type_)
|
||||
.first::<DataSourceType>(&mut conn)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to fetch data source type: {}", e))?;
|
||||
|
||||
let data_source_dialect = data_source_type.to_str();
|
||||
|
||||
let secure_sql = corrected_sql.clone();
|
||||
|
||||
if let Some(warning) = query_safety_filter(secure_sql.clone()).await { return Err(anyhow!(warning)) };
|
||||
// Use the dialect-aware security filter
|
||||
if let Some(warning) = query_safety_filter_with_dialect(secure_sql.clone(), data_source_dialect).await {
|
||||
return Err(anyhow!(warning))
|
||||
};
|
||||
|
||||
let results = match route_to_query(data_source_id, &secure_sql, limit).await {
|
||||
Ok(results) => results,
|
||||
|
|
|
@ -1,15 +1,42 @@
|
|||
use sqlparser::dialect::GenericDialect;
|
||||
use sqlparser::dialect::{
|
||||
GenericDialect, SnowflakeDialect, PostgreSqlDialect, MySqlDialect,
|
||||
BigQueryDialect, MsSqlDialect, DatabricksDialect, SQLiteDialect,
|
||||
AnsiDialect, Dialect
|
||||
};
|
||||
use sqlparser::parser::Parser;
|
||||
use sqlparser::ast::{Statement, SetExpr, Query};
|
||||
|
||||
/// Helper function to get the appropriate SQL dialect based on data source type
|
||||
fn get_dialect(data_source_type: &str) -> Box<dyn Dialect> {
|
||||
match data_source_type.to_lowercase().as_str() {
|
||||
"bigquery" => Box::new(BigQueryDialect {}),
|
||||
"databricks" => Box::new(DatabricksDialect {}),
|
||||
"mysql" | "mariadb" => Box::new(MySqlDialect {}),
|
||||
"postgres" | "postgresql" | "redshift" | "supabase" => Box::new(PostgreSqlDialect {}),
|
||||
"snowflake" => Box::new(SnowflakeDialect {}),
|
||||
"sqlserver" | "mssql" => Box::new(MsSqlDialect {}),
|
||||
"sqlite" => Box::new(SQLiteDialect {}),
|
||||
"ansi" => Box::new(AnsiDialect {}),
|
||||
_ => Box::new(GenericDialect {}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if a SQL query is safe to execute by parsing it and ensuring it only contains
|
||||
/// SELECT statements.
|
||||
///
|
||||
/// Returns None if the query is safe, or Some(error_message) if it's not allowed.
|
||||
pub async fn query_safety_filter(sql: String) -> Option<String> {
|
||||
query_safety_filter_with_dialect(sql, "generic").await
|
||||
}
|
||||
|
||||
/// Checks if a SQL query is safe to execute by parsing it with the appropriate dialect
|
||||
/// and ensuring it only contains SELECT statements.
|
||||
///
|
||||
/// Returns None if the query is safe, or Some(error_message) if it's not allowed.
|
||||
pub async fn query_safety_filter_with_dialect(sql: String, data_source_type: &str) -> Option<String> {
|
||||
// Parse the SQL query
|
||||
let dialect = GenericDialect {}; // Generic SQL dialect
|
||||
let ast = match Parser::parse_sql(&dialect, &sql) {
|
||||
let dialect = get_dialect(data_source_type);
|
||||
let ast = match Parser::parse_sql(dialect.as_ref(), &sql) {
|
||||
Ok(ast) => ast,
|
||||
Err(e) => {
|
||||
return Some(format!("Failed to parse SQL query: {}", e));
|
||||
|
@ -211,4 +238,105 @@ mod tests {
|
|||
let result = query_safety_filter(query.to_string()).await;
|
||||
assert!(result.is_none(), "Safe UNION query was rejected: {:?}", result);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_snowflake_complex_case_expression() {
|
||||
// This is the exact query that fails in production
|
||||
let query = r#"select
|
||||
date_trunc('month', r.createdat) as month,
|
||||
count(distinct rtd.tracking_number) as return_labels
|
||||
from staging.mongodb.stg_returns r
|
||||
join staging.mongodb.stg_return_tracking_details rtd on r._id = rtd.return_id
|
||||
join dbt.general.teams t on r.team = t.team_id
|
||||
where r.status = 'complete'
|
||||
and case
|
||||
when coalesce(
|
||||
r.shipment:_shipment:is_return::boolean,
|
||||
r.shipment:_shipment:tracker:is_return::boolean,
|
||||
r.shipment:_shipment:from_address:name like any ('%(REFUND)%', '%(STORE CREDIT)%', '%(EXCHANGE)%'),
|
||||
false
|
||||
)
|
||||
then r.shipment:_shipment:to_address:country::text
|
||||
else r.shipment:_shipment:from_address:country::text
|
||||
end in ('GB', 'BE', 'EL', 'LT', 'PT', 'BG', 'ES', 'LU', 'RO', 'CZ', 'FR', 'HU', 'SI', 'DK', 'HR', 'MT', 'SK', 'DE', 'IT', 'NL', 'FI', 'EE', 'CY', 'AT', 'SE', 'IE', 'LV', 'PL')
|
||||
group by all
|
||||
order by month desc"#;
|
||||
|
||||
let result = query_safety_filter(query.to_string()).await;
|
||||
|
||||
// This test currently fails with the error:
|
||||
// "Failed to parse SQL query: sql parser error: Expected: end of statement, found: when at Line: 9, Column: 9"
|
||||
assert!(result.is_some(), "Expected parsing error for Snowflake-specific syntax");
|
||||
assert!(result.unwrap().contains("Failed to parse SQL query"), "Should fail with parsing error");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_snowflake_query_with_dialect_parameter() {
|
||||
// Test the same query using the new dialect-aware function
|
||||
let query = r#"select
|
||||
date_trunc('month', r.createdat) as month,
|
||||
count(distinct rtd.tracking_number) as return_labels
|
||||
from staging.mongodb.stg_returns r
|
||||
join staging.mongodb.stg_return_tracking_details rtd on r._id = rtd.return_id
|
||||
join dbt.general.teams t on r.team = t.team_id
|
||||
where r.status = 'complete'
|
||||
and case
|
||||
when coalesce(
|
||||
r.shipment:_shipment:is_return::boolean,
|
||||
r.shipment:_shipment:tracker:is_return::boolean,
|
||||
r.shipment:_shipment:from_address:name like any ('%(REFUND)%', '%(STORE CREDIT)%', '%(EXCHANGE)%'),
|
||||
false
|
||||
)
|
||||
then r.shipment:_shipment:to_address:country::text
|
||||
else r.shipment:_shipment:from_address:country::text
|
||||
end in ('GB', 'BE', 'EL', 'LT', 'PT', 'BG', 'ES', 'LU', 'RO', 'CZ', 'FR', 'HU', 'SI', 'DK', 'HR', 'MT', 'SK', 'DE', 'IT', 'NL', 'FI', 'EE', 'CY', 'AT', 'SE', 'IE', 'LV', 'PL')
|
||||
group by all
|
||||
order by month desc"#;
|
||||
|
||||
// Try with the new dialect-aware function
|
||||
let result = query_safety_filter_with_dialect(query.to_string(), "snowflake").await;
|
||||
|
||||
// Should pass with Snowflake dialect
|
||||
assert!(result.is_none(), "Snowflake query should be accepted with Snowflake dialect: {:?}", result);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_snowflake_query_with_snowflake_dialect() {
|
||||
// Test the same query but with SnowflakeDialect directly
|
||||
let query = r#"select
|
||||
date_trunc('month', r.createdat) as month,
|
||||
count(distinct rtd.tracking_number) as return_labels
|
||||
from staging.mongodb.stg_returns r
|
||||
join staging.mongodb.stg_return_tracking_details rtd on r._id = rtd.return_id
|
||||
join dbt.general.teams t on r.team = t.team_id
|
||||
where r.status = 'complete'
|
||||
and case
|
||||
when coalesce(
|
||||
r.shipment:_shipment:is_return::boolean,
|
||||
r.shipment:_shipment:tracker:is_return::boolean,
|
||||
r.shipment:_shipment:from_address:name like any ('%(REFUND)%', '%(STORE CREDIT)%', '%(EXCHANGE)%'),
|
||||
false
|
||||
)
|
||||
then r.shipment:_shipment:to_address:country::text
|
||||
else r.shipment:_shipment:from_address:country::text
|
||||
end in ('GB', 'BE', 'EL', 'LT', 'PT', 'BG', 'ES', 'LU', 'RO', 'CZ', 'FR', 'HU', 'SI', 'DK', 'HR', 'MT', 'SK', 'DE', 'IT', 'NL', 'FI', 'EE', 'CY', 'AT', 'SE', 'IE', 'LV', 'PL')
|
||||
group by all
|
||||
order by month desc"#;
|
||||
|
||||
// Try parsing with SnowflakeDialect
|
||||
let dialect = SnowflakeDialect {};
|
||||
let parse_result = Parser::parse_sql(&dialect, query);
|
||||
|
||||
// Check if SnowflakeDialect can parse this query
|
||||
match parse_result {
|
||||
Ok(_) => {
|
||||
println!("SnowflakeDialect successfully parsed the query!");
|
||||
// If it parses, it would still be rejected as a SELECT query by our filter
|
||||
},
|
||||
Err(e) => {
|
||||
println!("SnowflakeDialect also failed to parse: {}", e);
|
||||
// Even SnowflakeDialect might have issues with this syntax
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ use sqlparser::ast::{
|
|||
};
|
||||
use sqlparser::dialect::{
|
||||
AnsiDialect, BigQueryDialect, ClickHouseDialect, DatabricksDialect, Dialect, DuckDbDialect,
|
||||
GenericDialect, HiveDialect, MsSqlDialect, MySqlDialect, PostgreSqlDialect, SQLiteDialect,
|
||||
GenericDialect, HiveDialect, MsSqlDialect, MySqlDialect, PostgreSqlDialect, SQLiteDialect, SnowflakeDialect,
|
||||
};
|
||||
use sqlparser::parser::Parser;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
@ -48,7 +48,7 @@ pub fn get_dialect(data_source_dialect: &str) -> &'static dyn Dialect {
|
|||
"mariadb" => &MySqlDialect {}, // MariaDB uses MySQL dialect
|
||||
"postgres" => &PostgreSqlDialect {},
|
||||
"redshift" => &PostgreSqlDialect {}, // Redshift uses PostgreSQL dialect
|
||||
"snowflake" => &GenericDialect {}, // SnowflakeDialect has limitations with some syntax, use GenericDialect
|
||||
"snowflake" => &SnowflakeDialect{}, // SnowflakeDialect has limitations with some syntax, use GenericDialect
|
||||
"sqlserver" => &MsSqlDialect {}, // SQL Server uses MS SQL dialect
|
||||
"supabase" => &PostgreSqlDialect {}, // Supabase uses PostgreSQL dialect
|
||||
"generic" => &GenericDialect {},
|
||||
|
|
Loading…
Reference in New Issue