Merge branch 'staging' into big-nate-bus-1617-create-report-page-and-file-page

This commit is contained in:
Nate Kelley 2025-08-07 12:51:57 -06:00
commit 12d060cf58
No known key found for this signature in database
GPG Key ID: FD90372AB8D98B4F
10 changed files with 251 additions and 20 deletions

View File

@ -19,11 +19,18 @@ use crate::{
use database::types::data_metadata::{ColumnMetaData, ColumnType, DataMetadata, SimpleType};
use database::vault::read_secret;
use database::{
enums::DataSourceType,
pool::get_pg_pool,
schema::data_sources,
};
use diesel::prelude::*;
use diesel_async::RunQueryDsl;
use super::{
bigquery_query::bigquery_query, databricks_query::databricks_query, mysql_query::mysql_query,
postgres_query::postgres_query, redshift_query::redshift_query,
security_utils::query_safety_filter, snowflake_query::{snowflake_query, ProcessingResult},
security_utils::query_safety_filter_with_dialect, snowflake_query::{snowflake_query, ProcessingResult},
sql_server_query::sql_server_query,
};
@ -41,9 +48,25 @@ pub async fn query_engine(
) -> Result<QueryResult> {
let corrected_sql = sql.to_owned();
// Fetch the data source type from the database
let mut conn = get_pg_pool().get().await
.map_err(|e| anyhow!("Failed to get database connection: {}", e))?;
let data_source_type = data_sources::table
.filter(data_sources::id.eq(data_source_id))
.select(data_sources::type_)
.first::<DataSourceType>(&mut conn)
.await
.map_err(|e| anyhow!("Failed to fetch data source type: {}", e))?;
let data_source_dialect = data_source_type.to_str();
let secure_sql = corrected_sql.clone();
if let Some(warning) = query_safety_filter(secure_sql.clone()).await { return Err(anyhow!(warning)) };
// Use the dialect-aware security filter
if let Some(warning) = query_safety_filter_with_dialect(secure_sql.clone(), data_source_dialect).await {
return Err(anyhow!(warning))
};
let results = match route_to_query(data_source_id, &secure_sql, limit).await {
Ok(results) => results,

View File

@ -1,15 +1,42 @@
use sqlparser::dialect::GenericDialect;
use sqlparser::dialect::{
GenericDialect, SnowflakeDialect, PostgreSqlDialect, MySqlDialect,
BigQueryDialect, MsSqlDialect, DatabricksDialect, SQLiteDialect,
AnsiDialect, Dialect
};
use sqlparser::parser::Parser;
use sqlparser::ast::{Statement, SetExpr, Query};
/// Helper function to get the appropriate SQL dialect based on data source type
fn get_dialect(data_source_type: &str) -> Box<dyn Dialect> {
match data_source_type.to_lowercase().as_str() {
"bigquery" => Box::new(BigQueryDialect {}),
"databricks" => Box::new(DatabricksDialect {}),
"mysql" | "mariadb" => Box::new(MySqlDialect {}),
"postgres" | "postgresql" | "redshift" | "supabase" => Box::new(PostgreSqlDialect {}),
"snowflake" => Box::new(SnowflakeDialect {}),
"sqlserver" | "mssql" => Box::new(MsSqlDialect {}),
"sqlite" => Box::new(SQLiteDialect {}),
"ansi" => Box::new(AnsiDialect {}),
_ => Box::new(GenericDialect {}),
}
}
/// Checks if a SQL query is safe to execute by parsing it and ensuring it only contains
/// SELECT statements.
///
/// Returns None if the query is safe, or Some(error_message) if it's not allowed.
pub async fn query_safety_filter(sql: String) -> Option<String> {
query_safety_filter_with_dialect(sql, "generic").await
}
/// Checks if a SQL query is safe to execute by parsing it with the appropriate dialect
/// and ensuring it only contains SELECT statements.
///
/// Returns None if the query is safe, or Some(error_message) if it's not allowed.
pub async fn query_safety_filter_with_dialect(sql: String, data_source_type: &str) -> Option<String> {
// Parse the SQL query
let dialect = GenericDialect {}; // Generic SQL dialect
let ast = match Parser::parse_sql(&dialect, &sql) {
let dialect = get_dialect(data_source_type);
let ast = match Parser::parse_sql(dialect.as_ref(), &sql) {
Ok(ast) => ast,
Err(e) => {
return Some(format!("Failed to parse SQL query: {}", e));
@ -211,4 +238,105 @@ mod tests {
let result = query_safety_filter(query.to_string()).await;
assert!(result.is_none(), "Safe UNION query was rejected: {:?}", result);
}
#[tokio::test]
async fn test_snowflake_complex_case_expression() {
// This is the exact query that fails in production
let query = r#"select
date_trunc('month', r.createdat) as month,
count(distinct rtd.tracking_number) as return_labels
from staging.mongodb.stg_returns r
join staging.mongodb.stg_return_tracking_details rtd on r._id = rtd.return_id
join dbt.general.teams t on r.team = t.team_id
where r.status = 'complete'
and case
when coalesce(
r.shipment:_shipment:is_return::boolean,
r.shipment:_shipment:tracker:is_return::boolean,
r.shipment:_shipment:from_address:name like any ('%(REFUND)%', '%(STORE CREDIT)%', '%(EXCHANGE)%'),
false
)
then r.shipment:_shipment:to_address:country::text
else r.shipment:_shipment:from_address:country::text
end in ('GB', 'BE', 'EL', 'LT', 'PT', 'BG', 'ES', 'LU', 'RO', 'CZ', 'FR', 'HU', 'SI', 'DK', 'HR', 'MT', 'SK', 'DE', 'IT', 'NL', 'FI', 'EE', 'CY', 'AT', 'SE', 'IE', 'LV', 'PL')
group by all
order by month desc"#;
let result = query_safety_filter(query.to_string()).await;
// This test currently fails with the error:
// "Failed to parse SQL query: sql parser error: Expected: end of statement, found: when at Line: 9, Column: 9"
assert!(result.is_some(), "Expected parsing error for Snowflake-specific syntax");
assert!(result.unwrap().contains("Failed to parse SQL query"), "Should fail with parsing error");
}
#[tokio::test]
async fn test_snowflake_query_with_dialect_parameter() {
// Test the same query using the new dialect-aware function
let query = r#"select
date_trunc('month', r.createdat) as month,
count(distinct rtd.tracking_number) as return_labels
from staging.mongodb.stg_returns r
join staging.mongodb.stg_return_tracking_details rtd on r._id = rtd.return_id
join dbt.general.teams t on r.team = t.team_id
where r.status = 'complete'
and case
when coalesce(
r.shipment:_shipment:is_return::boolean,
r.shipment:_shipment:tracker:is_return::boolean,
r.shipment:_shipment:from_address:name like any ('%(REFUND)%', '%(STORE CREDIT)%', '%(EXCHANGE)%'),
false
)
then r.shipment:_shipment:to_address:country::text
else r.shipment:_shipment:from_address:country::text
end in ('GB', 'BE', 'EL', 'LT', 'PT', 'BG', 'ES', 'LU', 'RO', 'CZ', 'FR', 'HU', 'SI', 'DK', 'HR', 'MT', 'SK', 'DE', 'IT', 'NL', 'FI', 'EE', 'CY', 'AT', 'SE', 'IE', 'LV', 'PL')
group by all
order by month desc"#;
// Try with the new dialect-aware function
let result = query_safety_filter_with_dialect(query.to_string(), "snowflake").await;
// Should pass with Snowflake dialect
assert!(result.is_none(), "Snowflake query should be accepted with Snowflake dialect: {:?}", result);
}
#[tokio::test]
async fn test_snowflake_query_with_snowflake_dialect() {
// Test the same query but with SnowflakeDialect directly
let query = r#"select
date_trunc('month', r.createdat) as month,
count(distinct rtd.tracking_number) as return_labels
from staging.mongodb.stg_returns r
join staging.mongodb.stg_return_tracking_details rtd on r._id = rtd.return_id
join dbt.general.teams t on r.team = t.team_id
where r.status = 'complete'
and case
when coalesce(
r.shipment:_shipment:is_return::boolean,
r.shipment:_shipment:tracker:is_return::boolean,
r.shipment:_shipment:from_address:name like any ('%(REFUND)%', '%(STORE CREDIT)%', '%(EXCHANGE)%'),
false
)
then r.shipment:_shipment:to_address:country::text
else r.shipment:_shipment:from_address:country::text
end in ('GB', 'BE', 'EL', 'LT', 'PT', 'BG', 'ES', 'LU', 'RO', 'CZ', 'FR', 'HU', 'SI', 'DK', 'HR', 'MT', 'SK', 'DE', 'IT', 'NL', 'FI', 'EE', 'CY', 'AT', 'SE', 'IE', 'LV', 'PL')
group by all
order by month desc"#;
// Try parsing with SnowflakeDialect
let dialect = SnowflakeDialect {};
let parse_result = Parser::parse_sql(&dialect, query);
// Check if SnowflakeDialect can parse this query
match parse_result {
Ok(_) => {
println!("SnowflakeDialect successfully parsed the query!");
// If it parses, it would still be rejected as a SELECT query by our filter
},
Err(e) => {
println!("SnowflakeDialect also failed to parse: {}", e);
// Even SnowflakeDialect might have issues with this syntax
}
}
}
}

View File

@ -8,7 +8,7 @@ use sqlparser::ast::{
};
use sqlparser::dialect::{
AnsiDialect, BigQueryDialect, ClickHouseDialect, DatabricksDialect, Dialect, DuckDbDialect,
GenericDialect, HiveDialect, MsSqlDialect, MySqlDialect, PostgreSqlDialect, SQLiteDialect,
GenericDialect, HiveDialect, MsSqlDialect, MySqlDialect, PostgreSqlDialect, SQLiteDialect, SnowflakeDialect,
};
use sqlparser::parser::Parser;
use std::collections::{HashMap, HashSet};
@ -48,7 +48,7 @@ pub fn get_dialect(data_source_dialect: &str) -> &'static dyn Dialect {
"mariadb" => &MySqlDialect {}, // MariaDB uses MySQL dialect
"postgres" => &PostgreSqlDialect {},
"redshift" => &PostgreSqlDialect {}, // Redshift uses PostgreSQL dialect
"snowflake" => &GenericDialect {}, // SnowflakeDialect has limitations with some syntax, use GenericDialect
"snowflake" => &SnowflakeDialect{}, // SnowflakeDialect has limitations with some syntax, use GenericDialect
"sqlserver" => &MsSqlDialect {}, // SQL Server uses MS SQL dialect
"supabase" => &PostgreSqlDialect {}, // Supabase uses PostgreSQL dialect
"generic" => &GenericDialect {},

View File

@ -89,14 +89,14 @@ const yourStuff = (
route: createBusterRoute({ route: BusterRoutes.APP_COLLECTIONS }),
id: BusterRoutes.APP_COLLECTIONS,
active: isActiveCheck('collection', BusterRoutes.APP_COLLECTIONS)
},
{
label: 'Reports',
icon: <ASSET_ICONS.reports />,
route: createBusterRoute({ route: BusterRoutes.APP_REPORTS }),
id: BusterRoutes.APP_REPORTS,
active: isActiveCheck('report', BusterRoutes.APP_REPORTS)
}
// {
// label: 'Reports',
// icon: <ASSET_ICONS.reports />,
// route: createBusterRoute({ route: BusterRoutes.APP_REPORTS }),
// id: BusterRoutes.APP_REPORTS,
// active: isActiveCheck('report', BusterRoutes.APP_REPORTS)
// }
]
};
};

View File

@ -39,6 +39,7 @@
"dependencies": {
"@ai-sdk/anthropic": "^1.2.12",
"@ai-sdk/google-vertex": "^2.2.27",
"@ai-sdk/openai": "^1.3.23",
"@ai-sdk/provider": "^1.1.3",
"@buster/access-controls": "workspace:*",
"@buster/data-source": "workspace:*",

View File

@ -7,23 +7,27 @@ import {
modifyDashboards,
modifyMetrics,
} from '../../tools';
import { GPT5 } from '../../utils';
import { Sonnet4 } from '../../utils/models/sonnet-4';
const DEFAULT_OPTIONS = {
maxSteps: 18,
temperature: 0,
maxTokens: 10000,
temperature: 1,
providerOptions: {
anthropic: {
disableParallelToolCalls: true,
},
openai: {
parallelToolCalls: false,
reasoningEffort: 'minimal',
},
},
};
export const analystAgent = new Agent({
name: 'Analyst Agent',
instructions: '', // We control the system messages in the step at stream instantiation
model: Sonnet4,
model: GPT5,
tools: {
createMetrics,
modifyMetrics,

View File

@ -6,12 +6,12 @@ import {
sequentialThinking,
submitThoughts,
} from '../../tools';
import { GPT5 } from '../../utils';
import { Sonnet4 } from '../../utils/models/sonnet-4';
const DEFAULT_OPTIONS = {
maxSteps: 18,
temperature: 0,
maxTokens: 10000,
temperature: 1,
providerOptions: {
anthropic: {
disableParallelToolCalls: true,
@ -22,7 +22,7 @@ const DEFAULT_OPTIONS = {
export const thinkAndPrepAgent = new Agent({
name: 'Think and Prep Agent',
instructions: '', // We control the system messages in the step at stream instantiation
model: Sonnet4,
model: GPT5,
tools: {
sequentialThinking,
executeSql,

View File

@ -14,8 +14,10 @@ export * from './models/ai-fallback';
export * from './models/providers/anthropic';
export * from './models/anthropic-cached';
export * from './models/providers/vertex';
export * from './models/providers/openai';
export * from './models/sonnet-4';
export * from './models/haiku-3-5';
export * from './models/gpt-5';
// Streaming utilities
export * from './streaming';

View File

@ -0,0 +1,62 @@
import type { LanguageModelV1 } from '@ai-sdk/provider';
import { createFallback } from './ai-fallback';
import { openaiModel } from './providers/openai';
// Lazy initialization to allow mocking in tests
let _gpt5Instance: ReturnType<typeof createFallback> | null = null;
function initializeGPT5() {
if (_gpt5Instance) {
return _gpt5Instance;
}
// Build models array based on available credentials
const models: LanguageModelV1[] = [];
// Only include OpenAI if API key is available
if (process.env.OPENAI_API_KEY) {
try {
models.push(openaiModel('gpt-5-2025-08-07'));
console.info('GPT5: OpenAI model added to fallback chain');
} catch (error) {
console.warn('GPT5: Failed to initialize OpenAI model:', error);
}
}
// Ensure we have at least one model
if (models.length === 0) {
throw new Error('No AI models available. Please set OPENAI_API_KEY environment variable.');
}
console.info(`GPT5: Initialized with ${models.length} model(s) in fallback chain`);
_gpt5Instance = createFallback({
models,
modelResetInterval: 60000,
retryAfterOutput: true,
onError: (err) => console.error(`FALLBACK. Here is the error: ${err}`),
});
return _gpt5Instance;
}
// Export a proxy that initializes on first use
export const GPT5 = new Proxy({} as ReturnType<typeof createFallback>, {
get(_target, prop) {
const instance = initializeGPT5();
// Direct property access without receiver to avoid proxy conflicts
return instance[prop as keyof typeof instance];
},
has(_target, prop) {
const instance = initializeGPT5();
return prop in instance;
},
ownKeys(_target) {
const instance = initializeGPT5();
return Reflect.ownKeys(instance);
},
getOwnPropertyDescriptor(_target, prop) {
const instance = initializeGPT5();
return Reflect.getOwnPropertyDescriptor(instance, prop);
},
});

View File

@ -0,0 +1,11 @@
import { createOpenAI } from '@ai-sdk/openai';
import { wrapAISDKModel } from 'braintrust';
export const openaiModel = (modelId: string) => {
const openai = createOpenAI({
apiKey: process.env.OPENAI_API_KEY,
});
// Wrap the model with Braintrust tracing and return it
return wrapAISDKModel(openai(modelId));
};