last few tweaks

This commit is contained in:
dal 2025-05-08 16:27:47 -06:00
parent 9cc01639c5
commit 854ec0b0b5
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
4 changed files with 127 additions and 15 deletions

View File

@ -349,6 +349,12 @@ To conclude your worklow, you use the `finish_and_respond` tool to send a final
- **SQL Requirements**:
- Use database-qualified schema-qualified table names (`<DATABASE_NAME>.<SCHEMA_NAME>.<TABLE_NAME>`).
- Use fully qualified column names with table aliases (e.g., `<table_alias>.<column>`).
- **MANDATORY SQL NAMING CONVENTIONS**:
- **All Table References**: MUST be fully qualified: `DATABASE_NAME.SCHEMA_NAME.TABLE_NAME`.
- **All Column References**: MUST be qualified with their table alias (e.g., `alias.column_name`) or CTE name (e.g., `cte_alias.column_name_from_cte`).
- **Inside CTE Definitions**: When defining a CTE (e.g., `WITH my_cte AS (SELECT t.column1 FROM DATABASE.SCHEMA.TABLE1 t ...)`), all columns selected from underlying database tables MUST use their table alias (e.g., `t.column1`, not just `column1`). This applies even if the CTE is simple and selects from only one table.
- **Selecting From CTEs**: When selecting from a defined CTE, use the CTE's alias for its columns (e.g., `SELECT mc.column1 FROM my_cte mc ...`).
- **Universal Application**: These naming conventions are strict requirements and apply universally to all parts of the SQL query, including every CTE definition and every subsequent SELECT statement. Non-compliance will lead to errors.
- **Context Adherence**: Strictly use only columns that are present in the data context provided by search results. Never invent or assume columns.
- Select specific columns (avoid `SELECT *` or `COUNT(*)`).
- Use CTEs instead of subqueries, and use snake_case for naming them.

View File

@ -163,6 +163,7 @@ pub const METRIC_YML_SCHEMA: &str = r##"
# Note: Respond only with the time period, without explanation or additional copy.
# `sql`: The SQL query for the metric.
# - RULE: MUST use the pipe `|` block scalar style to preserve formatting and newlines.
# - NOTE: Remember to use fully qualified names: DATABASE_NAME.SCHEMA_NAME.TABLE_NAME for tables and table_alias.column for columns. This applies to all table and column references, including those within Common Table Expressions (CTEs) and when selecting from CTEs.
# - Example:
# sql: |
# SELECT ...
@ -231,6 +232,7 @@ properties:
SQL query using YAML pipe syntax (|).
The SQL query should be formatted with proper indentation using the YAML pipe (|) syntax.
This ensures the multi-line SQL is properly parsed while preserving whitespace and newlines.
IMPORTANT: Remember to use fully qualified names: DATABASE_NAME.SCHEMA_NAME.TABLE_NAME for tables and table_alias.column for columns. This rule is critical for all table and column references, including those within Common Table Expressions (CTEs) and when selecting from CTEs.
Example:
sql: |
SELECT column1, column2

View File

@ -4,7 +4,7 @@ use anyhow::Result;
use rand;
use sqlparser::ast::{
Cte, Expr, Join, JoinConstraint, JoinOperator, ObjectName, Query, SelectItem, SetExpr,
Statement, TableFactor, Visit, Visitor, WindowSpec,
Statement, TableFactor, Visit, Visitor, WindowSpec, TableAlias,
};
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
@ -49,6 +49,7 @@ struct QueryAnalyzer {
vague_columns: Vec<String>,
vague_tables: Vec<String>,
ctes: Vec<CteSummary>,
current_select_list_aliases: HashSet<String>,
}
impl QueryAnalyzer {
@ -65,6 +66,7 @@ impl QueryAnalyzer {
current_from_relation_identifier: None,
current_scope_aliases: HashMap::new(),
parent_scope_aliases: HashMap::new(),
current_select_list_aliases: HashSet::new(),
}
}
@ -207,6 +209,7 @@ impl QueryAnalyzer {
// Process a SELECT query
fn process_select_query(&mut self, select: &sqlparser::ast::Select) {
self.current_scope_aliases.clear();
self.current_select_list_aliases.clear();
self.current_from_relation_identifier = None;
let mut join_conditions_to_visit: Vec<&Expr> = Vec::new();
@ -223,6 +226,14 @@ impl QueryAnalyzer {
}
}
// Populate select list aliases *before* processing expressions in WHERE, GROUP BY, HAVING, or SELECT list itself
// This makes them available for resolution in those clauses.
for item in &select.projection {
if let SelectItem::ExprWithAlias { alias, .. } = item {
self.current_select_list_aliases.insert(alias.value.clone());
}
}
// Process all parts of the query with collected context
let combined_aliases_for_visit = self
.current_scope_aliases
@ -1031,6 +1042,13 @@ impl QueryAnalyzer {
}
}
None => {
// Check if it's a known select list alias first
if self.current_select_list_aliases.contains(column) {
// It's a select list alias, consider it resolved for this scope.
// No need to add to vague_columns or assign to a table.
return;
}
// Special handling for nested fields without qualifier
// For example: "SELECT user.device.type" in BigQuery becomes "SELECT user__device__type"
if dialect_nested {
@ -1142,6 +1160,7 @@ impl QueryAnalyzer {
scope_stack: self.scope_stack.clone(),
parent_scope_aliases: HashMap::new(),
current_scope_aliases: HashMap::new(),
current_select_list_aliases: HashSet::new(),
..QueryAnalyzer::new()
}
}
@ -1172,12 +1191,11 @@ impl Visitor for QueryAnalyzer {
type Break = ();
fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
let available_aliases = self
.current_scope_aliases
.iter()
.chain(self.parent_scope_aliases.iter())
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
let mut available_aliases = self.parent_scope_aliases.clone();
available_aliases.extend(self.current_scope_aliases.clone());
for alias in &self.current_select_list_aliases {
available_aliases.insert(alias.clone(), alias.clone()); // Select list aliases map to themselves
}
match expr {
Expr::Identifier(ident) => {
@ -1222,14 +1240,29 @@ impl QueryAnalyzer {
sqlparser::ast::FunctionArg::Unnamed(arg_expr) => {
if let sqlparser::ast::FunctionArgExpr::Expr(expr) = arg_expr {
self.visit_expr_with_parent_scope(expr, available_aliases);
} else if let sqlparser::ast::FunctionArgExpr::QualifiedWildcard(name) = arg_expr {
// Handle cases like COUNT(table.*)
let qualifier = name.0.first().map(|i| i.value.clone()).unwrap_or_default();
if !qualifier.is_empty() {
if !available_aliases.contains_key(&qualifier) && // Check against combined available_aliases
!self.tables.contains_key(&qualifier) &&
!self.is_known_cte_definition(&qualifier) {
self.vague_tables.push(qualifier);
}
}
} else if let sqlparser::ast::FunctionArgExpr::Wildcard = arg_expr {
// Handle COUNT(*) - no specific column to track here
}
}
sqlparser::ast::FunctionArg::Named { arg: named_arg, .. } => {
sqlparser::ast::FunctionArg::Named { name, arg: named_arg, operator: _ } => {
// Argument name itself might be an identifier (though less common in SQL for this context)
// self.add_column_reference(None, &name.value, &available_aliases);
if let sqlparser::ast::FunctionArgExpr::Expr(expr) = named_arg {
self.visit_expr_with_parent_scope(expr, available_aliases);
}
}
sqlparser::ast::FunctionArg::ExprNamed { arg: expr_named_arg, .. } => {
sqlparser::ast::FunctionArg::ExprNamed { name, arg: expr_named_arg, operator: _ } => {
// self.add_column_reference(None, &name.value, &available_aliases);
if let sqlparser::ast::FunctionArgExpr::Expr(expr) = expr_named_arg {
self.visit_expr_with_parent_scope(expr, available_aliases);
}
@ -1246,16 +1279,41 @@ impl QueryAnalyzer {
..
})) = &function.over
{
for expr in partition_by {
self.visit_expr_with_parent_scope(expr, available_aliases);
for expr_item in partition_by { // expr_item is &Expr
self.visit_expr_with_parent_scope(expr_item, available_aliases);
}
for order_expr in order_by {
self.visit_expr_with_parent_scope(&order_expr.expr, available_aliases);
for order_expr_item in order_by { // order_expr_item is &OrderByExpr
self.visit_expr_with_parent_scope(&order_expr_item.expr, available_aliases);
}
if let Some(frame) = window_frame {
frame.start_bound.visit(self);
// frame.start_bound and frame.end_bound are WindowFrameBound
// which can contain Expr that needs visiting.
// The default Visitor implementation should handle these if they are Expr.
// However, sqlparser::ast::WindowFrameBound is not directly visitable.
// We need to manually extract expressions from it.
// Example for start_bound:
match &frame.start_bound {
sqlparser::ast::WindowFrameBound::CurrentRow => {}
sqlparser::ast::WindowFrameBound::Preceding(Some(expr)) |
sqlparser::ast::WindowFrameBound::Following(Some(expr)) => {
self.visit_expr_with_parent_scope(expr, available_aliases);
}
sqlparser::ast::WindowFrameBound::Preceding(None) |
sqlparser::ast::WindowFrameBound::Following(None) => {}
}
// Example for end_bound:
if let Some(end_bound) = &frame.end_bound {
end_bound.visit(self);
match end_bound {
sqlparser::ast::WindowFrameBound::CurrentRow => {}
sqlparser::ast::WindowFrameBound::Preceding(Some(expr)) |
sqlparser::ast::WindowFrameBound::Following(Some(expr)) => {
self.visit_expr_with_parent_scope(expr, available_aliases);
}
sqlparser::ast::WindowFrameBound::Preceding(None) |
sqlparser::ast::WindowFrameBound::Following(None) => {}
}
}
}
}

View File

@ -28,6 +28,52 @@ async fn test_simple_query() {
assert!(table.columns.contains("name"), "Missing 'name' column");
}
#[tokio::test]
async fn test_complex_cte_with_date_function() {
let sql = "WITH top5 AS (
SELECT ptr.product_name, SUM(ptr.metric_producttotalrevenue) AS total_revenue
FROM ont_ont.product_total_revenue AS ptr
GROUP BY ptr.product_name
ORDER BY total_revenue DESC
LIMIT 5
)
SELECT
MAKE_DATE(pqs.year::int, ((pqs.quarter - 1) * 3 + 1)::int, 1) AS quarter_start,
pqs.product_name,
SUM(pqs.metric_productquarterlysales) AS quarterly_revenue
FROM ont_ont.product_quarterly_sales AS pqs
JOIN top5 ON pqs.product_name = top5.product_name
GROUP BY quarter_start, pqs.product_name
ORDER BY quarter_start ASC, pqs.product_name;";
let result = analyze_query(sql.to_string()).await.unwrap();
// Check CTE
assert_eq!(result.ctes.len(), 1);
let cte = &result.ctes[0];
assert_eq!(cte.name, "top5");
assert_eq!(cte.summary.tables.len(), 1);
assert_eq!(cte.summary.joins.len(), 0);
// Check main query tables
assert_eq!(result.tables.len(), 2);
let table_names: Vec<String> = result.tables.iter().map(|t| t.table_identifier.clone()).collect();
assert!(table_names.contains(&"product_quarterly_sales".to_string()));
assert!(table_names.contains(&"product_total_revenue".to_string()));
// Check joins
assert_eq!(result.joins.len(), 1);
let join = result.joins.iter().next().unwrap();
assert_eq!(join.left_table, "product_quarterly_sales");
assert_eq!(join.right_table, "product_total_revenue");
// Check schema identifiers
for table in result.tables {
assert_eq!(table.schema_identifier, Some("ont_ont".to_string()));
}
}
#[tokio::test]
async fn test_joins() {
let sql =