diff --git a/packages/ai/src/tools/database-tools/execute-sql.ts b/packages/ai/src/tools/database-tools/execute-sql.ts index c4f851c5d..ac05c6d4c 100644 --- a/packages/ai/src/tools/database-tools/execute-sql.ts +++ b/packages/ai/src/tools/database-tools/execute-sql.ts @@ -14,7 +14,7 @@ const executeSqlStatementInputSchema = z.object({ SELECT queries without a LIMIT clause will automatically have LIMIT 50 added for performance. Existing LIMIT clauses will be preserved. YOU MUST USE THE . syntax/qualifier for all table names. - NEVER use SELECT * - you must explicitly list the columns you want to query from the documentation provided. + NEVER use SELECT * on physical tables - for security purposes you must explicitly select the columns you intend to use. NOT ADHERING TO THESE INSTRUCTIONS WILL RETURN AN ERROR NEVER query system tables or use 'SHOW' statements as these will fail to execute. Queries without these requirements will fail to execute.` ), diff --git a/packages/ai/src/utils/sql-permissions/permission-validator.test.ts b/packages/ai/src/utils/sql-permissions/permission-validator.test.ts index fc92cb355..1dce6de07 100644 --- a/packages/ai/src/utils/sql-permissions/permission-validator.test.ts +++ b/packages/ai/src/utils/sql-permissions/permission-validator.test.ts @@ -32,7 +32,7 @@ describe('Permission Validator', () => { }, ] as any); - const result = await validateSqlPermissions('SELECT * FROM public.users', 'user123'); + const result = await validateSqlPermissions('SELECT id, name FROM public.users', 'user123'); expect(result).toEqual({ isAuthorized: true, @@ -51,7 +51,7 @@ describe('Permission Validator', () => { }, ] as any); - const result = await validateSqlPermissions('SELECT * FROM public.orders', 'user123'); + const result = await validateSqlPermissions('SELECT id, user_id FROM public.orders', 'user123'); expect(result).toEqual({ isAuthorized: false, @@ -73,7 +73,7 @@ describe('Permission Validator', () => { ] as any); const result = await validateSqlPermissions( - 'SELECT * FROM public.users u JOIN public.orders o ON u.id = o.user_id', + 'SELECT u.id, u.name, o.id, o.total FROM public.users u JOIN public.orders o ON u.id = o.user_id', 'user123' ); @@ -95,7 +95,7 @@ describe('Permission Validator', () => { ] as any); const result = await validateSqlPermissions( - 'SELECT * FROM public.users u JOIN sales.orders o ON u.id = o.user_id', + 'SELECT u.id, u.name, o.id, o.total FROM public.users u JOIN sales.orders o ON u.id = o.user_id', 'user123' ); @@ -124,7 +124,7 @@ describe('Permission Validator', () => { FROM ont_ont.product_total_revenue AS ptr GROUP BY ptr.product_name ) - SELECT pqs.*, t.total_revenue + SELECT pqs.product_name, pqs.quarter, t.total_revenue FROM ont_ont.product_quarterly_sales AS pqs JOIN top5 t ON pqs.product_name = t.product_name `; @@ -151,7 +151,7 @@ describe('Permission Validator', () => { ] as any); const sql = ` - SELECT * FROM public.users u + SELECT u.id, u.name FROM public.users u WHERE u.id IN ( SELECT user_id FROM public.orders WHERE total > 100 ) @@ -178,7 +178,7 @@ describe('Permission Validator', () => { // Query has full qualification, permission has partial // Note: Parser may not support database.schema.table in FROM clause - const result = await validateSqlPermissions('SELECT * FROM public.users', 'user123'); + const result = await validateSqlPermissions('SELECT id, name FROM public.users', 'user123'); expect(result).toEqual({ isAuthorized: true, @@ -198,7 +198,7 @@ describe('Permission Validator', () => { ] as any); // Query missing schema that permission requires - const result = await validateSqlPermissions('SELECT * FROM users', 'user123'); + const result = await validateSqlPermissions('SELECT id, name FROM users', 'user123'); expect(result.isAuthorized).toBe(false); expect(result.unauthorizedTables).toContain('users'); @@ -209,7 +209,7 @@ describe('Permission Validator', () => { new Error('Database connection failed') ); - const result = await validateSqlPermissions('SELECT * FROM users', 'user123'); + const result = await validateSqlPermissions('SELECT id, name FROM users', 'user123'); expect(result).toEqual({ isAuthorized: false, @@ -323,7 +323,7 @@ describe('Permission Validator', () => { ] as any); const result = await validateSqlPermissions( - 'SELECT * FROM public.users u JOIN public.orders o ON u.id = o.user_id', + 'SELECT u.id, u.name, o.id, o.total FROM public.users u JOIN public.orders o ON u.id = o.user_id', 'user123' ); diff --git a/packages/ai/src/utils/sql-permissions/permission-validator.ts b/packages/ai/src/utils/sql-permissions/permission-validator.ts index cf559916b..b2e5429ea 100644 --- a/packages/ai/src/utils/sql-permissions/permission-validator.ts +++ b/packages/ai/src/utils/sql-permissions/permission-validator.ts @@ -5,6 +5,7 @@ import { extractPhysicalTables, extractTablesFromYml, tablesMatch, + validateWildcardUsage, } from './sql-parser-helpers'; export interface PermissionValidationResult { @@ -33,6 +34,15 @@ export async function validateSqlPermissions( }; } + const wildcardCheck = validateWildcardUsage(sql, dataSourceSyntax); + if (!wildcardCheck.isValid) { + return { + isAuthorized: false, + unauthorizedTables: wildcardCheck.blockedTables || [], + error: wildcardCheck.error || 'Wildcard usage on physical tables is not allowed', + }; + } + // Extract physical tables from SQL const tablesInQuery = extractPhysicalTables(sql, dataSourceSyntax); diff --git a/packages/ai/src/utils/sql-permissions/sql-parser-helpers.test.ts b/packages/ai/src/utils/sql-permissions/sql-parser-helpers.test.ts index 05f743022..22371bdc8 100644 --- a/packages/ai/src/utils/sql-permissions/sql-parser-helpers.test.ts +++ b/packages/ai/src/utils/sql-permissions/sql-parser-helpers.test.ts @@ -7,6 +7,7 @@ import { normalizeTableIdentifier, parseTableReference, tablesMatch, + validateWildcardUsage, } from './sql-parser-helpers'; describe('SQL Parser Helpers', () => { @@ -420,6 +421,110 @@ models: }); }); + describe('validateWildcardUsage', () => { + it('should block unqualified wildcard on physical table', () => { + const sql = 'SELECT * FROM users'; + const result = validateWildcardUsage(sql); + expect(result.isValid).toBe(false); + expect(result.error).toContain('Wildcard usage on physical tables is not allowed'); + expect(result.blockedTables).toContain('users'); + }); + + it('should block qualified wildcard on physical table', () => { + const sql = 'SELECT u.* FROM users u'; + const result = validateWildcardUsage(sql); + expect(result.isValid).toBe(false); + expect(result.error).toContain('Wildcard usage on physical tables is not allowed'); + expect(result.blockedTables).toContain('u'); + }); + + it('should allow wildcard on CTE', () => { + const sql = ` + WITH user_stats AS ( + SELECT user_id, COUNT(*) as count FROM orders GROUP BY user_id + ) + SELECT * FROM user_stats + `; + const result = validateWildcardUsage(sql); + expect(result.isValid).toBe(true); + expect(result.error).toBeUndefined(); + }); + + it('should allow qualified wildcard on CTE', () => { + const sql = ` + WITH user_stats AS ( + SELECT user_id, COUNT(*) as count FROM orders GROUP BY user_id + ) + SELECT us.* FROM user_stats us + `; + const result = validateWildcardUsage(sql); + expect(result.isValid).toBe(true); + }); + + it('should block wildcard when CTE uses wildcard on physical table', () => { + const sql = ` + WITH user_cte AS ( + SELECT * FROM users + ) + SELECT * FROM user_cte + `; + const result = validateWildcardUsage(sql); + expect(result.isValid).toBe(false); + expect(result.error).toContain('Wildcard usage on physical tables is not allowed'); + expect(result.blockedTables).toContain('users'); + }); + + it('should allow wildcard when CTE uses explicit columns', () => { + const sql = ` + WITH user_cte AS ( + SELECT id, name FROM users + ) + SELECT * FROM user_cte + `; + const result = validateWildcardUsage(sql); + expect(result.isValid).toBe(true); + }); + + it('should block wildcard on physical tables in JOIN', () => { + const sql = ` + WITH orders_cte AS ( + SELECT order_id FROM orders + ) + SELECT oc.*, u.* FROM orders_cte oc JOIN users u ON oc.order_id = u.id + `; + const result = validateWildcardUsage(sql); + expect(result.isValid).toBe(false); + expect(result.blockedTables).toContain('u'); + }); + + it('should allow explicit column selection', () => { + const sql = 'SELECT id, name, email FROM users'; + const result = validateWildcardUsage(sql); + expect(result.isValid).toBe(true); + }); + + it('should handle multiple physical tables with wildcards', () => { + const sql = 'SELECT u.*, o.* FROM users u JOIN orders o ON u.id = o.user_id'; + const result = validateWildcardUsage(sql); + expect(result.isValid).toBe(false); + expect(result.blockedTables).toEqual(expect.arrayContaining(['u', 'o'])); + }); + + it('should handle schema-qualified tables', () => { + const sql = 'SELECT * FROM public.users'; + const result = validateWildcardUsage(sql); + expect(result.isValid).toBe(false); + expect(result.error).toContain('Wildcard usage on physical tables is not allowed'); + }); + + it('should handle invalid SQL gracefully', () => { + const sql = 'NOT VALID SQL'; + const result = validateWildcardUsage(sql); + expect(result.isValid).toBe(false); + expect(result.error).toContain('Failed to validate wildcard usage'); + }); + }); + describe('checkQueryIsReadOnly', () => { it('should allow SELECT statements', () => { const result = checkQueryIsReadOnly('SELECT * FROM users'); diff --git a/packages/ai/src/utils/sql-permissions/sql-parser-helpers.ts b/packages/ai/src/utils/sql-permissions/sql-parser-helpers.ts index 22c04062e..ef6d4e2dc 100644 --- a/packages/ai/src/utils/sql-permissions/sql-parser-helpers.ts +++ b/packages/ai/src/utils/sql-permissions/sql-parser-helpers.ts @@ -1,4 +1,4 @@ -import { Parser } from 'node-sql-parser'; +import { BaseFrom, ColumnRefItem, Join, Parser, type Select } from 'node-sql-parser'; import * as yaml from 'yaml'; export interface ParsedTable { @@ -15,6 +15,12 @@ export interface QueryTypeCheckResult { error?: string; } +export interface WildcardValidationResult { + isValid: boolean; + error?: string; + blockedTables?: string[]; +} + // Map data source syntax to node-sql-parser dialect const DIALECT_MAPPING: Record = { // Direct mappings @@ -338,6 +344,269 @@ export function extractTablesFromYml(ymlContent: string): ParsedTable[] { return tables; } +/** + * Validates that wildcards (SELECT *) are not used on physical tables + * Allows wildcards on CTEs but blocks them on physical database tables + */ +export function validateWildcardUsage( + sql: string, + dataSourceSyntax?: string +): WildcardValidationResult { + const dialect = getParserDialect(dataSourceSyntax); + const parser = new Parser(); + + try { + // Parse SQL into AST with the appropriate dialect + const ast = parser.astify(sql, { database: dialect }); + + // Handle single statement or array of statements + const statements = Array.isArray(ast) ? ast : [ast]; + + // Extract CTE names to allow wildcards on them + const cteNames = new Set(); + for (const statement of statements) { + if ('with' in statement && statement.with && Array.isArray(statement.with)) { + for (const cte of statement.with) { + if (cte.name?.value) { + cteNames.add(cte.name.value.toLowerCase()); + } + } + } + } + + const tableList = parser.tableList(sql, { database: dialect }); + const tableAliasMap = new Map(); // alias -> table name + + if (Array.isArray(tableList)) { + for (const tableRef of tableList) { + if (typeof tableRef === 'string') { + // Simple table name + tableAliasMap.set(tableRef.toLowerCase(), tableRef); + } else if (tableRef && typeof tableRef === 'object') { + const tableRefObj = tableRef as Record; + const tableName = tableRefObj.table || tableRefObj.name; + const alias = tableRefObj.as || tableRefObj.alias; + if (tableName && typeof tableName === 'string') { + if (alias && typeof alias === 'string') { + tableAliasMap.set(alias.toLowerCase(), tableName); + } + tableAliasMap.set(tableName.toLowerCase(), tableName); + } + } + } + } + + // Check each statement for wildcard usage + const blockedTables: string[] = []; + + for (const statement of statements) { + if ('type' in statement && statement.type === 'select') { + const wildcardTables = findWildcardUsageOnPhysicalTables( + statement as unknown as Record, + cteNames + ); + blockedTables.push(...wildcardTables); + } + } + + if (blockedTables.length > 0) { + const tableList = blockedTables.join(', '); + return { + isValid: false, + error: `Wildcard usage on physical tables is not allowed: ${tableList}. Please specify explicit column names.`, + blockedTables, + }; + } + + return { isValid: true }; + } catch (error) { + return { + isValid: false, + error: `Failed to validate wildcard usage: ${error instanceof Error ? error.message : 'Unknown error'}`, + }; + } +} + +/** + * Recursively finds wildcard usage on physical tables in a SELECT statement + */ +function findWildcardUsageOnPhysicalTables( + selectStatement: Record, + cteNames: Set +): string[] { + const blockedTables: string[] = []; + + // Build alias mapping for this statement + const aliasToTableMap = new Map(); + if (selectStatement.from && Array.isArray(selectStatement.from)) { + for (const fromItem of selectStatement.from) { + const fromItemAny = fromItem as unknown as Record; + if (fromItemAny.table && fromItemAny.as) { + let tableName: string; + if (typeof fromItemAny.table === 'string') { + tableName = fromItemAny.table; + } else if (fromItemAny.table && typeof fromItemAny.table === 'object') { + const tableObj = fromItemAny.table as Record; + tableName = String( + tableObj.table || tableObj.name || tableObj.value || fromItemAny.table + ); + } else { + continue; + } + aliasToTableMap.set(String(fromItemAny.as).toLowerCase(), tableName.toLowerCase()); + } + + // Handle JOINs + if (fromItemAny.join && Array.isArray(fromItemAny.join)) { + for (const joinItem of fromItemAny.join) { + if (joinItem.table && joinItem.as) { + let tableName: string; + if (typeof joinItem.table === 'string') { + tableName = joinItem.table; + } else if (joinItem.table && typeof joinItem.table === 'object') { + const tableObj = joinItem.table as Record; + tableName = String( + tableObj.table || tableObj.name || tableObj.value || joinItem.table + ); + } else { + continue; + } + aliasToTableMap.set(String(joinItem.as).toLowerCase(), tableName.toLowerCase()); + } + } + } + } + } + + if (selectStatement.columns && Array.isArray(selectStatement.columns)) { + for (const column of selectStatement.columns) { + if (column.expr && column.expr.type === 'column_ref') { + // Check for unqualified wildcard (SELECT *) + if (column.expr.column === '*' && !column.expr.table) { + // Get all tables in FROM clause that are not CTEs + const physicalTables = getPhysicalTablesFromFrom( + selectStatement.from as unknown as Record[], + cteNames + ); + blockedTables.push(...physicalTables); + } + // Check for qualified wildcard (SELECT table.*) + else if (column.expr.column === '*' && column.expr.table) { + // Handle table reference - could be string or object + let tableName: string; + if (typeof column.expr.table === 'string') { + tableName = column.expr.table; + } else if (column.expr.table && typeof column.expr.table === 'object') { + // Handle object format - could have table property or be the table name itself + const tableRefObj = column.expr.table as Record; + tableName = String( + tableRefObj.table || tableRefObj.name || tableRefObj.value || column.expr.table + ); + } else { + continue; // Skip if we can't determine table name + } + + // Check if this is an alias that maps to a CTE + const actualTableName = aliasToTableMap.get(tableName.toLowerCase()); + const isAliasToCte = actualTableName && cteNames.has(actualTableName); + const isDirectCte = cteNames.has(tableName.toLowerCase()); + + if (!isAliasToCte && !isDirectCte) { + blockedTables.push(tableName); + } + } + } + } + } + + // Check CTEs for nested wildcard usage + if (selectStatement.with && Array.isArray(selectStatement.with)) { + for (const cte of selectStatement.with) { + const cteAny = cte as unknown as Record; + if (cteAny.stmt && typeof cteAny.stmt === 'object' && cteAny.stmt !== null) { + const stmt = cteAny.stmt as Record; + if (stmt.type === 'select') { + const subBlocked = findWildcardUsageOnPhysicalTables(stmt, cteNames); + blockedTables.push(...subBlocked); + } + } + } + } + + if (selectStatement.from && Array.isArray(selectStatement.from)) { + for (const fromItem of selectStatement.from) { + const fromItemAny = fromItem as unknown as Record; + if (fromItemAny.expr && typeof fromItemAny.expr === 'object' && fromItemAny.expr !== null) { + const expr = fromItemAny.expr as Record; + if (expr.type === 'select') { + const subBlocked = findWildcardUsageOnPhysicalTables(expr, cteNames); + blockedTables.push(...subBlocked); + } + } + } + } + + return blockedTables; +} + +/** + * Extracts physical table names from FROM clause, excluding CTEs + */ +function getPhysicalTablesFromFrom( + fromClause: Record[], + cteNames: Set +): string[] { + const tables: string[] = []; + + if (!fromClause || !Array.isArray(fromClause)) { + return tables; + } + + for (const fromItem of fromClause) { + // Extract table name from fromItem + if (fromItem.table) { + let tableName: string; + if (typeof fromItem.table === 'string') { + tableName = fromItem.table; + } else if (fromItem.table && typeof fromItem.table === 'object') { + const tableObj = fromItem.table as Record; + tableName = String(tableObj.table || tableObj.name || tableObj.value || fromItem.table); + } else { + continue; + } + + if (tableName && !cteNames.has(tableName.toLowerCase())) { + const aliasName = fromItem.as || tableName; + tables.push(String(aliasName)); + } + } + + // Handle JOINs + if (fromItem.join && Array.isArray(fromItem.join)) { + for (const joinItem of fromItem.join) { + if (joinItem.table) { + let tableName: string; + if (typeof joinItem.table === 'string') { + tableName = joinItem.table; + } else if (joinItem.table && typeof joinItem.table === 'object') { + const tableObj = joinItem.table as Record; + tableName = String(tableObj.table || tableObj.name || tableObj.value || joinItem.table); + } else { + continue; + } + + if (tableName && !cteNames.has(tableName.toLowerCase())) { + const aliasName = joinItem.as || tableName; + tables.push(String(aliasName)); + } + } + } + } + } + + return tables; +} + /** * Checks if a SQL query is read-only (SELECT statements only) * Returns error if query contains write operations