diff --git a/apps/server/src/api/v2/index.ts b/apps/server/src/api/v2/index.ts index 448120f50..e9d7932b1 100644 --- a/apps/server/src/api/v2/index.ts +++ b/apps/server/src/api/v2/index.ts @@ -17,6 +17,7 @@ import s3IntegrationsRoutes from './s3-integrations'; import securityRoutes from './security'; import shortcutsRoutes from './shortcuts'; import slackRoutes from './slack'; +import sqlRoutes from './sql'; import supportRoutes from './support'; import titleRoutes from './title'; import userRoutes from './users'; @@ -33,6 +34,7 @@ const app = new Hono() .route('/metric_files', metricFilesRoutes) .route('/github', githubRoutes) .route('/slack', slackRoutes) + .route('/sql', sqlRoutes) .route('/support', supportRoutes) .route('/security', securityRoutes) .route('/shortcuts', shortcutsRoutes) diff --git a/apps/server/src/api/v2/sql/index.ts b/apps/server/src/api/v2/sql/index.ts new file mode 100644 index 000000000..33965d61b --- /dev/null +++ b/apps/server/src/api/v2/sql/index.ts @@ -0,0 +1,8 @@ +import { Hono } from 'hono'; +import run from './run'; + +const app = new Hono() + // Mount the /run subrouter + .route('/run', run); + +export default app; diff --git a/apps/server/src/api/v2/sql/run/POST.ts b/apps/server/src/api/v2/sql/run/POST.ts new file mode 100644 index 000000000..3f7585a44 --- /dev/null +++ b/apps/server/src/api/v2/sql/run/POST.ts @@ -0,0 +1,123 @@ +import { createPermissionErrorMessage, validateSqlPermissions } from '@buster/access-controls'; +import { executeMetricQuery } from '@buster/data-source'; +import type { Credentials } from '@buster/data-source'; +import type { User } from '@buster/database'; +import { + getDataSourceById, + getDataSourceCredentials, + getUserOrganizationId, +} from '@buster/database'; +import type { RunSqlRequest, RunSqlResponse } from '@buster/server-shared'; +import { HTTPException } from 'hono/http-exception'; + +/** + * Handler for running SQL queries against data sources + * + * This handler: + * 1. Validates user has access to the organization + * 2. Verifies data source belongs to user's organization + * 3. Validates SQL permissions against user's permissioned datasets + * 4. Executes the query with retry logic and timeout handling + * 5. Returns the data with metadata and pagination info + * + * @param request - The SQL query request containing data_source_id and sql + * @param user - The authenticated user + * @returns The query results with metadata + */ +export async function runSqlHandler(request: RunSqlRequest, user: User): Promise { + // Get user's organization + const userOrg = await getUserOrganizationId(user.id); + + if (!userOrg) { + throw new HTTPException(403, { + message: 'You must be part of an organization to run SQL queries', + }); + } + + const { organizationId } = userOrg; + + // Get data source details + const dataSource = await getDataSourceById(request.data_source_id); + + if (!dataSource) { + throw new HTTPException(404, { + message: 'Data source not found', + }); + } + + // Verify data source belongs to user's organization + if (dataSource.organizationId !== organizationId) { + throw new HTTPException(403, { + message: 'You do not have permission to access this data source', + }); + } + + // Validate SQL against user's permissioned datasets + const permissionResult = await validateSqlPermissions(request.sql, user.id, dataSource.type); + + if (!permissionResult.isAuthorized) { + const errorMessage = + permissionResult.error || + createPermissionErrorMessage( + permissionResult.unauthorizedTables, + permissionResult.unauthorizedColumns + ); + + throw new HTTPException(403, { + message: errorMessage, + }); + } + + // Get data source credentials from vault + let credentials: Credentials; + try { + const rawCredentials = await getDataSourceCredentials({ + dataSourceId: request.data_source_id, + }); + + // Ensure credentials have the correct type + credentials = { + ...rawCredentials, + type: rawCredentials.type || dataSource.type, + } as Credentials; + } catch (error) { + console.error('Failed to retrieve data source credentials:', error); + throw new HTTPException(500, { + message: 'Failed to access data source', + }); + } + + // Execute query using the shared utility with 5000 row limit + try { + // Request one extra row to detect if there are more records + const result = await executeMetricQuery(request.data_source_id, request.sql, credentials, { + maxRows: 5001, + timeout: 60000, // 60 seconds + retryDelays: [1000, 3000, 6000], // 1s, 3s, 6s + }); + + // Trim to 5000 rows and check if there are more records + const hasMore = result.data.length > 5000; + const trimmedData = result.data.slice(0, 5000); + + const response: RunSqlResponse = { + data: trimmedData, + data_metadata: result.dataMetadata, + has_more_records: hasMore || result.hasMoreRecords, + }; + + return response; + } catch (error) { + console.error('Query execution failed:', error); + + if (error instanceof Error) { + throw new HTTPException(500, { + message: `Query execution failed: ${error.message}`, + }); + } + + throw new HTTPException(500, { + message: 'Query execution failed', + }); + } +} diff --git a/apps/server/src/api/v2/sql/run/index.ts b/apps/server/src/api/v2/sql/run/index.ts new file mode 100644 index 000000000..ecc847099 --- /dev/null +++ b/apps/server/src/api/v2/sql/run/index.ts @@ -0,0 +1,36 @@ +import { RunSqlRequestSchema } from '@buster/server-shared'; +import { zValidator } from '@hono/zod-validator'; +import { Hono } from 'hono'; +import { HTTPException } from 'hono/http-exception'; +import { requireAuth } from '../../../../middleware/auth'; +import '../../../../types/hono.types'; +import { runSqlHandler } from './POST'; + +const app = new Hono() + // Apply authentication middleware to all routes + .use('*', requireAuth) + + // POST /sql/run - Execute SQL query against data source + .post('/', zValidator('json', RunSqlRequestSchema), async (c) => { + const request = c.req.valid('json'); + const user = c.get('busterUser'); + + const response = await runSqlHandler(request, user); + + return c.json(response); + }) + + // Error handler for SQL run routes + .onError((err, c) => { + console.error('SQL run API error:', err); + + // Let HTTPException responses pass through + if (err instanceof HTTPException) { + return err.getResponse(); + } + + // Default error response + return c.json({ error: 'Internal server error' }, 500); + }); + +export default app; diff --git a/apps/web/src/api/buster_rest/sql/requests.ts b/apps/web/src/api/buster_rest/sql/requests.ts index 0fc48d742..441d3dac7 100644 --- a/apps/web/src/api/buster_rest/sql/requests.ts +++ b/apps/web/src/api/buster_rest/sql/requests.ts @@ -1,6 +1,6 @@ import type { RunSQLResponse } from '../../asset_interfaces/sql'; -import { mainApi } from '../instances'; +import { mainApiV2 } from '../instances'; export const runSQL = async (params: { data_source_id: string; sql: string }) => { - return mainApi.post('/sql/run', params).then((res) => res.data); + return mainApiV2.post('/sql/run', params).then((res) => res.data); }; diff --git a/packages/access-controls/package.json b/packages/access-controls/package.json index fe501a97b..2cf1bc9ad 100644 --- a/packages/access-controls/package.json +++ b/packages/access-controls/package.json @@ -30,10 +30,13 @@ }, "dependencies": { "@buster/database": "workspace:*", + "@buster/data-source": "workspace:*", "@buster/env-utils": "workspace:*", "@buster/typescript-config": "workspace:*", "@buster/vitest-config": "workspace:*", "lru-cache": "^11.1.0", + "node-sql-parser": "^5.3.12", + "yaml": "^2.8.1", "zod": "catalog:", "uuid": "catalog:", "drizzle-orm": "catalog:" diff --git a/packages/access-controls/src/index.ts b/packages/access-controls/src/index.ts index 221476e1c..34dcac339 100644 --- a/packages/access-controls/src/index.ts +++ b/packages/access-controls/src/index.ts @@ -28,6 +28,9 @@ export * from './datasets'; // Export user utilities export * from './users'; +// Export SQL permissions +export * from './sql-permissions'; + // Export cache functions separately export { clearAllCaches, diff --git a/packages/access-controls/src/sql-permissions/execute-with-permission-check.ts b/packages/access-controls/src/sql-permissions/execute-with-permission-check.ts new file mode 100644 index 000000000..0864ad15f --- /dev/null +++ b/packages/access-controls/src/sql-permissions/execute-with-permission-check.ts @@ -0,0 +1,52 @@ +import { createPermissionErrorMessage, validateSqlPermissions } from './validator'; + +export interface ExecuteWithPermissionResult { + success: boolean; + data?: T; + error?: string; +} + +/** + * Wraps SQL execution with permission validation + * Ensures user has access to all tables referenced in the query + */ +export async function executeWithPermissionCheck( + sql: string, + userId: string, + executeFn: () => Promise, + dataSourceSyntax?: string +): Promise> { + if (!userId) { + return { + success: false, + error: 'User authentication required for SQL execution', + }; + } + + // Validate permissions + const permissionResult = await validateSqlPermissions(sql, userId, dataSourceSyntax); + + if (!permissionResult.isAuthorized) { + return { + success: false, + error: createPermissionErrorMessage( + permissionResult.unauthorizedTables, + permissionResult.unauthorizedColumns + ), + }; + } + + // Execute if authorized + try { + const result = await executeFn(); + return { + success: true, + data: result, + }; + } catch (error) { + return { + success: false, + error: error instanceof Error ? error.message : 'SQL execution failed', + }; + } +} diff --git a/packages/access-controls/src/sql-permissions/index.ts b/packages/access-controls/src/sql-permissions/index.ts new file mode 100644 index 000000000..ff92b2e46 --- /dev/null +++ b/packages/access-controls/src/sql-permissions/index.ts @@ -0,0 +1,3 @@ +export * from './parser-helpers'; +export * from './validator'; +export * from './execute-with-permission-check'; diff --git a/packages/access-controls/src/sql-permissions/parser-helpers.ts b/packages/access-controls/src/sql-permissions/parser-helpers.ts new file mode 100644 index 000000000..6168fd01a --- /dev/null +++ b/packages/access-controls/src/sql-permissions/parser-helpers.ts @@ -0,0 +1,1637 @@ +import pkg from 'node-sql-parser'; +const { Parser } = pkg; +import type { BaseFrom, ColumnRefItem, Join, Select } from 'node-sql-parser'; +import * as yaml from 'yaml'; +// Import checkQueryIsReadOnly from data-source package +export { checkQueryIsReadOnly } from '@buster/data-source'; +export type { QueryTypeCheckResult } from '@buster/data-source'; + +export interface ParsedTable { + database?: string; + schema?: string; + table: string; + fullName: string; + alias?: string; +} + +export interface ParsedDataset { + database?: string; + schema?: string; + table: string; + fullName: string; + allowedColumns: Set; // lowercase column names from dimensions and measures +} + +// Type for statements that may have UNION (_next property) +interface StatementWithNext extends Record { + _next?: StatementWithNext; + type?: string; +} + +export interface WildcardValidationResult { + isValid: boolean; + error?: string; + blockedTables?: string[]; +} + +// Map data source syntax to node-sql-parser dialect +const DIALECT_MAPPING: Record = { + // Direct mappings + mysql: 'mysql', + postgresql: 'postgresql', + sqlite: 'sqlite', + mariadb: 'mariadb', + bigquery: 'bigquery', + snowflake: 'snowflake', + redshift: 'postgresql', // Redshift uses PostgreSQL dialect + transactsql: 'transactsql', + flinksql: 'flinksql', + hive: 'hive', + + // Alternative names + postgres: 'postgresql', + mssql: 'transactsql', + sqlserver: 'transactsql', + athena: 'postgresql', // Athena uses Presto/PostgreSQL syntax + db2: 'db2', + noql: 'mysql', // Default fallback for NoQL +}; + +function getParserDialect(dataSourceSyntax?: string): string { + if (!dataSourceSyntax) { + return 'postgresql'; + } + + const dialect = DIALECT_MAPPING[dataSourceSyntax.toLowerCase()]; + if (!dialect) { + return 'postgresql'; + } + + return dialect; +} + +/** + * Extracts physical tables from SQL query, excluding CTEs + * Returns database.schema.table references with proper qualification + */ +export function extractPhysicalTables(sql: string, dataSourceSyntax?: string): ParsedTable[] { + const dialect = getParserDialect(dataSourceSyntax); + const parser = new Parser(); + + try { + // Parse SQL into AST with the appropriate dialect + const ast = parser.astify(sql, { database: dialect }); + + // Get all table references from parser with the appropriate dialect + const allTables = parser.tableList(sql, { database: dialect }); + + // Extract CTE names to exclude them + const cteNames = new Set(); + + // Handle single statement or array of statements + const statements = Array.isArray(ast) ? ast : [ast]; + + for (const statement of statements) { + // Type guard to check if statement has 'with' property + if ('with' in statement && statement.with && Array.isArray(statement.with)) { + for (const cte of statement.with) { + if (cte.name?.value) { + cteNames.add(cte.name.value.toLowerCase()); + } + } + } + } + + // Parse table references and filter out CTEs + const physicalTables: ParsedTable[] = []; + const processedTables = new Set(); + + for (const tableRef of allTables) { + const parsed = parseTableReference(tableRef); + + // Skip if it's a CTE + if (cteNames.has(parsed.table.toLowerCase())) { + continue; + } + + // Skip duplicates + const tableKey = `${parsed.database || ''}.${parsed.schema || ''}.${parsed.table}`; + if (processedTables.has(tableKey)) { + continue; + } + + processedTables.add(tableKey); + physicalTables.push(parsed); + } + + return physicalTables; + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error); + // Provide more specific guidance based on common parsing errors + if (errorMessage.includes('Expected')) { + throw new Error( + `SQL syntax error: ${errorMessage}. Please check your SQL syntax and ensure it's valid for the ${dialect} dialect.` + ); + } + if (errorMessage.includes('Unexpected token')) { + throw new Error( + `SQL parsing error: ${errorMessage}. This may be due to unsupported SQL features or incorrect syntax.` + ); + } + throw new Error( + `Failed to parse SQL query: ${errorMessage}. Please ensure your SQL is valid and uses standard ${dialect} syntax.` + ); + } +} + +/** + * Parses a table reference string into its components + * Handles formats like: + * - table + * - schema.table + * - database.schema.table + * - type::database::table (node-sql-parser format) + * - type::schema.table (node-sql-parser format) + */ +export function parseTableReference(tableRef: string): ParsedTable { + // Remove any quotes and trim + let cleanRef = tableRef.replace(/["'`\[\]]/g, '').trim(); + + // Handle node-sql-parser format: "type::database::table" or "type::table" + if (cleanRef.includes('::')) { + const parts = cleanRef.split('::'); + // Remove the type prefix (select, insert, update, etc.) + const firstPart = parts[0]; + if ( + parts.length >= 2 && + firstPart && + ['select', 'insert', 'update', 'delete', 'create', 'drop', 'alter'].includes(firstPart) + ) { + parts.shift(); // Remove type + } + cleanRef = parts.join('.'); + } + + // Split by . for schema/table + const parts = cleanRef.split('.').filter((p) => p && p !== 'null'); + + if (parts.length === 3) { + const [database, schema, table] = parts; + if (!database || !schema || !table) { + return { + table: cleanRef, + fullName: cleanRef, + }; + } + return { + database, + schema, + table, + fullName: `${database}.${schema}.${table}`, + }; + } + + if (parts.length === 2) { + const [schema, table] = parts; + if (!schema || !table) { + return { + table: cleanRef, + fullName: cleanRef, + }; + } + return { + schema, + table, + fullName: `${schema}.${table}`, + }; + } + + if (parts.length === 1) { + const [table] = parts; + if (!table) { + return { + table: cleanRef, + fullName: cleanRef, + }; + } + return { + table, + fullName: table, + }; + } + + return { + table: cleanRef, + fullName: cleanRef, + }; +} + +/** + * Normalizes a table identifier for comparison + * Converts to lowercase and handles different qualification levels + */ +export function normalizeTableIdentifier(identifier: ParsedTable): string { + const parts = []; + + if (identifier.database) { + parts.push(identifier.database.toLowerCase()); + } + if (identifier.schema) { + parts.push(identifier.schema.toLowerCase()); + } + parts.push(identifier.table.toLowerCase()); + + return parts.join('.'); +} + +/** + * Checks if two table identifiers match, considering different qualification levels + * For example, "schema.table" matches "database.schema.table" if schema and table match + */ +export function tablesMatch(queryTable: ParsedTable, permissionTable: ParsedTable): boolean { + // Exact table name must match + if (queryTable.table.toLowerCase() !== permissionTable.table.toLowerCase()) { + return false; + } + + // If permission specifies schema, query must match + if (permissionTable.schema && queryTable.schema) { + if (permissionTable.schema.toLowerCase() !== queryTable.schema.toLowerCase()) { + return false; + } + } + + // If permission specifies database, query must match + if (permissionTable.database && queryTable.database) { + if (permissionTable.database.toLowerCase() !== queryTable.database.toLowerCase()) { + return false; + } + } + + // If permission has schema but query doesn't, it's not a match + // (we require explicit schema matching for security) + if (permissionTable.schema && !queryTable.schema) { + return false; + } + + return true; +} + +/** + * Extracts table references from dataset YML content + * Handles multiple formats: + * 1. Flat format (top-level fields): + * name: table_name + * schema: schema_name + * database: database_name + * 2. Models array with separate fields: + * models: + * - name: table_name + * schema: schema_name + * database: database_name + */ +export function extractTablesFromYml(ymlContent: string): ParsedTable[] { + const tables: ParsedTable[] = []; + const processedTables = new Set(); + + try { + // Parse YML content + const parsed = yaml.parse(ymlContent); + + // Check for flat format (top-level name, schema, database) + if (parsed?.name && !parsed?.models && (parsed?.schema || parsed?.database)) { + const parsedTable: ParsedTable = { + table: parsed.name, + fullName: parsed.name, + }; + + // Add schema if present + if (parsed.schema) { + parsedTable.schema = parsed.schema; + parsedTable.fullName = `${parsed.schema}.${parsed.name}`; + } + + // Add database if present + if (parsed.database) { + parsedTable.database = parsed.database; + if (parsed.schema) { + parsedTable.fullName = `${parsed.database}.${parsed.schema}.${parsed.name}`; + } else { + parsedTable.fullName = `${parsed.database}.${parsed.name}`; + } + } + + const key = normalizeTableIdentifier(parsedTable); + if (!processedTables.has(key)) { + processedTables.add(key); + tables.push(parsedTable); + } + } + + // Look for models array + if (parsed?.models && Array.isArray(parsed.models)) { + for (const model of parsed.models) { + // Process models that have name and at least schema or database + if (model.name && (model.schema || model.database)) { + const parsedTable: ParsedTable = { + table: model.name, + fullName: model.name, + }; + + // Add schema if present + if (model.schema) { + parsedTable.schema = model.schema; + parsedTable.fullName = `${model.schema}.${model.name}`; + } + + // Add database if present + if (model.database) { + parsedTable.database = model.database; + if (model.schema) { + parsedTable.fullName = `${model.database}.${model.schema}.${model.name}`; + } else { + parsedTable.fullName = `${model.database}.${model.name}`; + } + } + + const key = normalizeTableIdentifier(parsedTable); + if (!processedTables.has(key)) { + processedTables.add(key); + tables.push(parsedTable); + } + } + } + } + } catch (error) { + // Log the error for debugging but don't throw - return empty array + // This is expected behavior when YML content is invalid or not a dataset + const errorMessage = error instanceof Error ? error.message : String(error); + console.warn(`Failed to parse YML content for table extraction: ${errorMessage}`); + } + + return tables; +} + +/** + * Extracts datasets with allowed columns from YML content + * Handles dataset format with dimensions and measures + */ +export function extractDatasetsFromYml(ymlContent: string): ParsedDataset[] { + const datasets: ParsedDataset[] = []; + const processedDatasets = new Set(); + + try { + // Parse YML content + const parsed = yaml.parse(ymlContent); + + // Check for dataset format with dimensions and measures + if (parsed?.name) { + const parsedDataset: ParsedDataset = { + table: parsed.name, + fullName: parsed.name, + allowedColumns: new Set(), + }; + + // Add schema if present + if (parsed.schema) { + parsedDataset.schema = parsed.schema; + parsedDataset.fullName = `${parsed.schema}.${parsed.name}`; + } + + // Add database if present + if (parsed.database) { + parsedDataset.database = parsed.database; + if (parsed.schema) { + parsedDataset.fullName = `${parsed.database}.${parsed.schema}.${parsed.name}`; + } else { + parsedDataset.fullName = `${parsed.database}.${parsed.name}`; + } + } + + // Extract columns from dimensions + if (parsed.dimensions && Array.isArray(parsed.dimensions)) { + for (const dimension of parsed.dimensions) { + if (dimension.name && typeof dimension.name === 'string') { + parsedDataset.allowedColumns.add(dimension.name.toLowerCase()); + } + } + } + + // Extract columns from measures + if (parsed.measures && Array.isArray(parsed.measures)) { + for (const measure of parsed.measures) { + if (measure.name && typeof measure.name === 'string') { + parsedDataset.allowedColumns.add(measure.name.toLowerCase()); + } + } + } + + const key = normalizeTableIdentifier(parsedDataset); + if (!processedDatasets.has(key)) { + processedDatasets.add(key); + datasets.push(parsedDataset); + } + } + + // Also check for models array format with dimensions/measures + if (parsed?.models && Array.isArray(parsed.models)) { + for (const model of parsed.models) { + if (model.name) { + const parsedDataset: ParsedDataset = { + table: model.name, + fullName: model.name, + allowedColumns: new Set(), + }; + + // Add schema if present + if (model.schema) { + parsedDataset.schema = model.schema; + parsedDataset.fullName = `${model.schema}.${model.name}`; + } + + // Add database if present + if (model.database) { + parsedDataset.database = model.database; + if (model.schema) { + parsedDataset.fullName = `${model.database}.${model.schema}.${model.name}`; + } else { + parsedDataset.fullName = `${model.database}.${model.name}`; + } + } + + // Extract columns from dimensions + if (model.dimensions && Array.isArray(model.dimensions)) { + for (const dimension of model.dimensions) { + if (dimension.name && typeof dimension.name === 'string') { + parsedDataset.allowedColumns.add(dimension.name.toLowerCase()); + } + } + } + + // Extract columns from measures + if (model.measures && Array.isArray(model.measures)) { + for (const measure of model.measures) { + if (measure.name && typeof measure.name === 'string') { + parsedDataset.allowedColumns.add(measure.name.toLowerCase()); + } + } + } + + const key = normalizeTableIdentifier(parsedDataset); + if (!processedDatasets.has(key)) { + processedDatasets.add(key); + datasets.push(parsedDataset); + } + } + } + } + } catch (error) { + // Log the error for debugging but don't throw - return empty array + // This is expected behavior when YML content is invalid or not a dataset + const errorMessage = error instanceof Error ? error.message : String(error); + console.warn(`Failed to parse YML content for dataset extraction: ${errorMessage}`); + } + + return datasets; +} + +/** + * Validates that wildcards (SELECT *) are not used on physical tables + * Allows wildcards on CTEs but blocks them on physical database tables + */ +export function validateWildcardUsage( + sql: string, + dataSourceSyntax?: string +): WildcardValidationResult { + const dialect = getParserDialect(dataSourceSyntax); + const parser = new Parser(); + + try { + // Parse SQL into AST with the appropriate dialect + const ast = parser.astify(sql, { database: dialect }); + + // Handle single statement or array of statements + const statements = Array.isArray(ast) ? ast : [ast]; + + // Extract CTE names to allow wildcards on them + const cteNames = new Set(); + for (const statement of statements) { + if ('with' in statement && statement.with && Array.isArray(statement.with)) { + for (const cte of statement.with) { + if (cte.name?.value) { + cteNames.add(cte.name.value.toLowerCase()); + } + } + } + } + + const tableList = parser.tableList(sql, { database: dialect }); + const tableAliasMap = new Map(); // alias -> table name + + if (Array.isArray(tableList)) { + for (const tableRef of tableList) { + if (typeof tableRef === 'string') { + // Simple table name + tableAliasMap.set(tableRef.toLowerCase(), tableRef); + } else if (tableRef && typeof tableRef === 'object') { + const tableRefObj = tableRef as Record; + const tableName = tableRefObj.table || tableRefObj.name; + const alias = tableRefObj.as || tableRefObj.alias; + if (tableName && typeof tableName === 'string') { + if (alias && typeof alias === 'string') { + tableAliasMap.set(alias.toLowerCase(), tableName); + } + tableAliasMap.set(tableName.toLowerCase(), tableName); + } + } + } + } + + // Check each statement for wildcard usage + const blockedTables: string[] = []; + + for (const statement of statements) { + if ('type' in statement && statement.type === 'select') { + const wildcardTables = findWildcardUsageOnPhysicalTables( + statement as unknown as Record, + cteNames + ); + blockedTables.push(...wildcardTables); + } + } + + if (blockedTables.length > 0) { + // Create a more helpful error message with specific tables and guidance + const tableList = + blockedTables.length > 1 + ? `tables: ${blockedTables.join(', ')}` + : `table: ${blockedTables[0]}`; + + return { + isValid: false, + error: `SELECT * is not allowed on physical ${tableList}. Please explicitly specify the column names you need instead of using wildcards. For example, use 'SELECT column1, column2 FROM table' instead of 'SELECT * FROM table'. This restriction helps ensure data security and prevents unintended data exposure.`, + blockedTables, + }; + } + + return { isValid: true }; + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error); + return { + isValid: false, + error: `Failed to validate wildcard usage in SQL query: ${errorMessage}. Please ensure your SQL syntax is correct and try specifying explicit column names instead of using SELECT *.`, + }; + } +} + +/** + * Recursively finds wildcard usage on physical tables in a SELECT statement + */ +function findWildcardUsageOnPhysicalTables( + selectStatement: Record, + cteNames: Set +): string[] { + const blockedTables: string[] = []; + + // Build alias mapping for this statement + const aliasToTableMap = new Map(); + if (selectStatement.from && Array.isArray(selectStatement.from)) { + for (const fromItem of selectStatement.from) { + const fromItemAny = fromItem as unknown as Record; + if (fromItemAny.table && fromItemAny.as) { + let tableName: string; + if (typeof fromItemAny.table === 'string') { + tableName = fromItemAny.table; + } else if (fromItemAny.table && typeof fromItemAny.table === 'object') { + const tableObj = fromItemAny.table as Record; + tableName = String( + tableObj.table || tableObj.name || tableObj.value || fromItemAny.table + ); + } else { + continue; + } + aliasToTableMap.set(String(fromItemAny.as).toLowerCase(), tableName.toLowerCase()); + } + + // Handle JOINs + if (fromItemAny.join && Array.isArray(fromItemAny.join)) { + for (const joinItem of fromItemAny.join) { + if (joinItem.table && joinItem.as) { + let tableName: string; + if (typeof joinItem.table === 'string') { + tableName = joinItem.table; + } else if (joinItem.table && typeof joinItem.table === 'object') { + const tableObj = joinItem.table as Record; + tableName = String( + tableObj.table || tableObj.name || tableObj.value || joinItem.table + ); + } else { + continue; + } + aliasToTableMap.set(String(joinItem.as).toLowerCase(), tableName.toLowerCase()); + } + } + } + } + } + + if (selectStatement.columns && Array.isArray(selectStatement.columns)) { + for (const column of selectStatement.columns) { + if (column.expr && column.expr.type === 'column_ref') { + // Check for unqualified wildcard (SELECT *) + if (column.expr.column === '*' && !column.expr.table) { + // Get all tables in FROM clause that are not CTEs + const physicalTables = getPhysicalTablesFromFrom( + selectStatement.from as unknown as Record[], + cteNames + ); + blockedTables.push(...physicalTables); + } + // Check for qualified wildcard (SELECT table.*) + else if (column.expr.column === '*' && column.expr.table) { + // Handle table reference - could be string or object + let tableName: string; + if (typeof column.expr.table === 'string') { + tableName = column.expr.table; + } else if (column.expr.table && typeof column.expr.table === 'object') { + // Handle object format - could have table property or be the table name itself + const tableRefObj = column.expr.table as Record; + tableName = String( + tableRefObj.table || tableRefObj.name || tableRefObj.value || column.expr.table + ); + } else { + continue; // Skip if we can't determine table name + } + + // Check if this is an alias that maps to a CTE + const actualTableName = aliasToTableMap.get(tableName.toLowerCase()); + const isAliasToCte = actualTableName && cteNames.has(actualTableName); + const isDirectCte = cteNames.has(tableName.toLowerCase()); + + if (!isAliasToCte && !isDirectCte) { + // Push the actual table name if it's an alias, otherwise push the table name itself + blockedTables.push(actualTableName || tableName); + } + } + } + } + } + + // Check CTEs for nested wildcard usage + if (selectStatement.with && Array.isArray(selectStatement.with)) { + for (const cte of selectStatement.with) { + const cteAny = cte as unknown as Record; + if (cteAny.stmt && typeof cteAny.stmt === 'object' && cteAny.stmt !== null) { + const stmt = cteAny.stmt as Record; + if (stmt.type === 'select') { + const subBlocked = findWildcardUsageOnPhysicalTables(stmt, cteNames); + blockedTables.push(...subBlocked); + } + } + } + } + + if (selectStatement.from && Array.isArray(selectStatement.from)) { + for (const fromItem of selectStatement.from) { + const fromItemAny = fromItem as unknown as Record; + if (fromItemAny.expr && typeof fromItemAny.expr === 'object' && fromItemAny.expr !== null) { + const expr = fromItemAny.expr as Record; + if (expr.type === 'select') { + const subBlocked = findWildcardUsageOnPhysicalTables(expr, cteNames); + blockedTables.push(...subBlocked); + } + } + } + } + + return blockedTables; +} + +/** + * Extracts physical table names from FROM clause, excluding CTEs + */ +function getPhysicalTablesFromFrom( + fromClause: Record[], + cteNames: Set +): string[] { + const tables: string[] = []; + + if (!fromClause || !Array.isArray(fromClause)) { + return tables; + } + + for (const fromItem of fromClause) { + // Extract table name from fromItem + if (fromItem.table) { + let tableName: string; + if (typeof fromItem.table === 'string') { + tableName = fromItem.table; + } else if (fromItem.table && typeof fromItem.table === 'object') { + const tableObj = fromItem.table as Record; + tableName = String(tableObj.table || tableObj.name || tableObj.value || fromItem.table); + } else { + continue; + } + + if (tableName && !cteNames.has(tableName.toLowerCase())) { + const aliasName = fromItem.as || tableName; + tables.push(String(aliasName)); + } + } + + // Handle JOINs + if (fromItem.join && Array.isArray(fromItem.join)) { + for (const joinItem of fromItem.join) { + if (joinItem.table) { + let tableName: string; + if (typeof joinItem.table === 'string') { + tableName = joinItem.table; + } else if (joinItem.table && typeof joinItem.table === 'object') { + const tableObj = joinItem.table as Record; + tableName = String(tableObj.table || tableObj.name || tableObj.value || joinItem.table); + } else { + continue; + } + + if (tableName && !cteNames.has(tableName.toLowerCase())) { + const aliasName = joinItem.as || tableName; + tables.push(String(aliasName)); + } + } + } + } + } + + return tables; +} + +/** + * Extracts column references from SQL query grouped by table + * Returns a map of table -> set of column names referenced + * Excludes CTE internal columns + */ +export function extractColumnReferences( + sql: string, + dataSourceSyntax?: string +): Map> { + const dialect = getParserDialect(dataSourceSyntax); + const parser = new Parser(); + const tableColumnMap = new Map>(); + + try { + // Parse SQL into AST + const ast = parser.astify(sql, { database: dialect }); + const statements = Array.isArray(ast) ? ast : [ast]; + + // Get CTEs to exclude from column validation + const cteNames = new Set(); + for (const statement of statements) { + if ('with' in statement && statement.with && Array.isArray(statement.with)) { + for (const cte of statement.with) { + if (cte.name?.value) { + cteNames.add(cte.name.value.toLowerCase()); + } + } + } + } + + // Process each statement + for (const statement of statements) { + // First process CTEs in the main statement + if ('with' in statement && statement.with && Array.isArray(statement.with)) { + for (const cte of statement.with) { + if (cte.stmt && typeof cte.stmt === 'object') { + const cteStmt = cte.stmt as Record; + // Handle CTEs with UNION/UNION ALL - they have an ast property + // Check for ast first since UNION CTEs don't have a direct type property + if (!cteStmt.type && cteStmt.ast && typeof cteStmt.ast === 'object') { + const ast = cteStmt.ast as Record; + if (ast.type === 'select') { + // Process the first SELECT + extractColumnsFromStatement(ast, tableColumnMap, cteNames); + + // Process UNION parts (_next chain) + let nextStmt = ast._next as Record | undefined; + while (nextStmt) { + if (nextStmt.type === 'select') { + extractColumnsFromStatement(nextStmt, tableColumnMap, cteNames); + } + nextStmt = nextStmt._next as Record | undefined; + } + } + } else if (cteStmt.type === 'select') { + // CTE with type 'select' - may have UNION via _next + extractColumnsFromStatement(cteStmt, tableColumnMap, cteNames); + + // Handle UNION parts (_next chain) for CTEs + const cteWithNext = cteStmt as StatementWithNext; + let nextStmt = cteWithNext._next; + while (nextStmt) { + if (nextStmt.type === 'select') { + extractColumnsFromStatement(nextStmt, tableColumnMap, cteNames); + } + nextStmt = nextStmt._next; + } + } + } + } + } + + if ('type' in statement && statement.type === 'select') { + // Process the main SELECT statement + extractColumnsFromStatement( + statement as unknown as Record, + tableColumnMap, + cteNames + ); + + // Handle UNION queries - they have a _next property for the next SELECT + const statementWithNext = statement as unknown as StatementWithNext; + let nextStatement = statementWithNext._next; + while (nextStatement) { + if (nextStatement.type === 'select') { + extractColumnsFromStatement(nextStatement, tableColumnMap, cteNames); + } + nextStatement = nextStatement._next; + } + } + } + + return tableColumnMap; + } catch (error) { + // Log the error for debugging but return empty map to allow validation to continue + const errorMessage = error instanceof Error ? error.message : String(error); + console.warn( + `Failed to extract column references from SQL: ${errorMessage}. Column-level permissions cannot be validated.` + ); + return new Map(); + } +} + +/** + * Helper function to extract columns from a SELECT statement + */ +function extractColumnsFromStatement( + statement: Record, + tableColumnMap: Map>, + cteNames: Set, + parentAliasMap?: Map +): void { + // Build table alias mapping and track all non-CTE tables + // Include parent aliases for subqueries to resolve outer references + const aliasToTableMap = new Map(parentAliasMap || []); + const physicalTables: string[] = []; + + // Process FROM clause + if (statement.from && Array.isArray(statement.from)) { + for (const fromItem of statement.from) { + processFromItem(fromItem, aliasToTableMap, cteNames, physicalTables); + } + } + + // Track column aliases defined in SELECT clause + // These should NOT be treated as physical columns when referenced in ORDER BY, GROUP BY, etc. + const columnAliases = new Set(); + + // Extract columns from SELECT clause + // Important: We only extract from the expression (column.expr), not from the alias (column.as) + // This ensures column aliases like "AS total_count" are not treated as physical columns + if (statement.columns && Array.isArray(statement.columns)) { + for (const column of statement.columns) { + // Track the alias if it exists + if (column.as) { + columnAliases.add(String(column.as).toLowerCase()); + } + + // Only process the expression part, which contains the actual column references + // The 'as' property contains the alias which should not be treated as a column + if (column.expr && typeof column.expr === 'object') { + extractColumnFromExpression( + column.expr, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables, + columnAliases + ); + } + } + } + + // Extract columns from WHERE clause + if (statement.where) { + extractColumnFromExpression( + statement.where, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables, + columnAliases + ); + } + + // Extract columns from GROUP BY clause + if (statement.groupby && Array.isArray(statement.groupby)) { + for (const groupItem of statement.groupby) { + // Skip if this is a reference to a column alias + if (groupItem.type === 'column_ref' && groupItem.column && !groupItem.table) { + const columnName = + typeof groupItem.column === 'string' ? groupItem.column : String(groupItem.column); + if (columnAliases.has(columnName.toLowerCase())) { + continue; // Skip column aliases + } + } + extractColumnFromExpression( + groupItem, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables, + columnAliases + ); + } + } + + // Extract columns from HAVING clause + if (statement.having) { + extractColumnFromExpression( + statement.having, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables, + columnAliases + ); + } + + // Extract columns from ORDER BY clause + if (statement.orderby && Array.isArray(statement.orderby)) { + for (const orderItem of statement.orderby) { + if (orderItem.expr) { + // Check if this is a reference to a column alias + if ( + orderItem.expr.type === 'column_ref' && + orderItem.expr.column && + !orderItem.expr.table + ) { + const columnName = + typeof orderItem.expr.column === 'string' + ? orderItem.expr.column + : String(orderItem.expr.column); + if (columnAliases.has(columnName.toLowerCase())) { + continue; // Skip column aliases + } + } + extractColumnFromExpression( + orderItem.expr, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables, + columnAliases + ); + } + } + } + + // Process nested CTEs + if (statement.with && Array.isArray(statement.with)) { + for (const cte of statement.with) { + if (cte.stmt && typeof cte.stmt === 'object') { + const cteStmt = cte.stmt as Record; + + // Handle CTEs with UNION/UNION ALL - they have an ast property + // Check for ast first since UNION CTEs don't have a direct type property + if (!cteStmt.type && cteStmt.ast && typeof cteStmt.ast === 'object') { + const ast = cteStmt.ast as Record; + if (ast.type === 'select') { + // Process the first SELECT + extractColumnsFromStatement(ast, tableColumnMap, cteNames, aliasToTableMap); + + // Process UNION parts (_next chain) + let nextStmt = ast._next as Record | undefined; + while (nextStmt) { + if (nextStmt.type === 'select') { + extractColumnsFromStatement(nextStmt, tableColumnMap, cteNames, aliasToTableMap); + } + nextStmt = nextStmt._next as Record | undefined; + } + } + } else if (cteStmt.type === 'select') { + // Regular CTE without UNION + extractColumnsFromStatement(cteStmt, tableColumnMap, cteNames, aliasToTableMap); + } + } + } + } + + // Process JOIN conditions + if (statement.from && Array.isArray(statement.from)) { + for (const fromItem of statement.from) { + // Process JOIN ON conditions (fromItem.join is the join type, fromItem.on is the condition) + if (fromItem.join && fromItem.on) { + extractColumnFromExpression( + fromItem.on, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables, + columnAliases + ); + } + + // Process subqueries in FROM clause + if (fromItem.expr && typeof fromItem.expr === 'object' && fromItem.expr.type === 'select') { + extractColumnsFromStatement( + fromItem.expr as Record, + tableColumnMap, + cteNames, + aliasToTableMap + ); + } + } + } +} + +/** + * Process FROM item to build alias mapping + */ +function processFromItem( + fromItem: unknown, + aliasToTableMap: Map, + cteNames: Set, + physicalTables?: string[] +): void { + const item = fromItem as Record; + + if (item.table) { + const tableName = extractTableName(item.table); + // Handle schema-qualified names (parser puts schema in 'db' field) + const fullTableName = item.db ? `${item.db}.${tableName}` : tableName; + const alias = item.as ? String(item.as) : tableName; + + if (!cteNames.has(tableName.toLowerCase())) { + aliasToTableMap.set(alias.toLowerCase(), fullTableName); + if (physicalTables) { + physicalTables.push(fullTableName); + } + } + } + + // Process JOINs + if (item.join && Array.isArray(item.join)) { + for (const joinItem of item.join) { + if (joinItem.table) { + const tableName = extractTableName(joinItem.table); + // Handle schema-qualified names (parser puts schema in 'db' field) + const fullTableName = joinItem.db ? `${joinItem.db}.${tableName}` : tableName; + const alias = joinItem.as ? String(joinItem.as) : tableName; + + if (!cteNames.has(tableName.toLowerCase())) { + aliasToTableMap.set(alias.toLowerCase(), fullTableName); + if (physicalTables) { + physicalTables.push(fullTableName); + } + } + + // Extract columns from JOIN conditions + if (joinItem.on) { + // JOIN conditions will be processed with the main extraction + } + } + } + } +} + +/** + * Extract table name from various formats + */ +function extractTableName(table: unknown): string { + if (typeof table === 'string') { + return table; + } + + if (table && typeof table === 'object') { + const tableObj = table as Record; + // Try different property names that might contain the table name + const tableName = tableObj.table || tableObj.name || tableObj.value; + return tableName ? String(tableName) : ''; + } + + return ''; +} + +/** + * Extract column references from an expression + */ +function extractColumnFromExpression( + expr: unknown, + aliasToTableMap: Map, + tableColumnMap: Map>, + cteNames: Set, + physicalTables?: string[], + columnAliases?: Set +): void { + if (!expr || typeof expr !== 'object') return; + + const expression = expr as Record; + + // Skip string literals and other non-column types + if ( + expression.type === 'single_quote_string' || + expression.type === 'double_quote_string' || + expression.type === 'number' || + expression.type === 'bool' || + expression.type === 'null' + ) { + return; + } + + // Handle column references + if (expression.type === 'column_ref') { + const columnName = expression.column; + const tableRef = expression.table; + + if (columnName && columnName !== '*') { + // Get the actual column name - handle both string and nested object formats + let actualColumn: string; + if (typeof columnName === 'string') { + actualColumn = columnName.toLowerCase(); + } else if (typeof columnName === 'object' && columnName !== null) { + // Handle nested object format from parser + const colObj = columnName as Record; + + // Check if it has the nested expr structure + if (colObj.expr && typeof colObj.expr === 'object') { + const exprObj = colObj.expr as Record; + // Skip string literals and non-column types + if ( + exprObj.type === 'single_quote_string' || + exprObj.type === 'double_quote_string' || + exprObj.type === 'number' || + exprObj.type === 'bool' || + exprObj.type === 'null' + ) { + return; + } + // Make sure this is actually a column reference + if (exprObj.type === 'default' || exprObj.type === 'column_ref') { + const colValue = exprObj.value || exprObj.name; + if (colValue !== undefined && colValue !== null) { + actualColumn = String(colValue).toLowerCase(); + } else { + // If we can't extract a value, skip this column + return; + } + } else { + // Unknown type, skip + return; + } + } else { + // Try direct properties as fallback + const colValue = colObj.value || colObj.name || colObj.column; + if (colValue !== undefined && colValue !== null) { + actualColumn = String(colValue).toLowerCase(); + } else { + // If we can't extract a value, skip this column + return; + } + } + } else { + // Unknown format, skip + return; + } + + // Skip if this column is actually a column alias (not a physical column) + // This handles cases where aliases are referenced in ORDER BY, GROUP BY, etc. + if (!tableRef && columnAliases && columnAliases.has(actualColumn)) { + return; // Skip column aliases + } + + if (tableRef) { + // Get table name from reference + const tableName = extractTableName(tableRef).toLowerCase(); + + // Check if it's an alias + const actualTable = aliasToTableMap.get(tableName) || tableName; + + // Only track if not a CTE + if (!cteNames.has(actualTable.toLowerCase())) { + if (!tableColumnMap.has(actualTable)) { + tableColumnMap.set(actualTable, new Set()); + } + const tableColumns = tableColumnMap.get(actualTable); + if (tableColumns) { + tableColumns.add(actualColumn); + } + } + } else if (physicalTables && physicalTables.length > 0) { + // If no table reference but we have physical tables, assign to first table + // This handles simple queries like SELECT id, name FROM users + const firstTable = physicalTables[0]; + if (firstTable && !cteNames.has(firstTable.toLowerCase())) { + if (!tableColumnMap.has(firstTable)) { + tableColumnMap.set(firstTable, new Set()); + } + const tableColumns = tableColumnMap.get(firstTable); + if (tableColumns) { + tableColumns.add(actualColumn); + } + } + } + } + } + + // Handle aggregate functions + if (expression.type === 'aggr_func' && expression.args) { + if (Array.isArray(expression.args)) { + for (const arg of expression.args) { + extractColumnFromExpression( + arg, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables, + columnAliases + ); + } + } else if (typeof expression.args === 'object') { + const argsObj = expression.args as Record; + if (argsObj.expr) { + // Handle nested expr structure in args + extractColumnFromExpression( + argsObj.expr, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables, + columnAliases + ); + } else { + extractColumnFromExpression( + expression.args, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables, + columnAliases + ); + } + } else { + extractColumnFromExpression( + expression.args, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables, + columnAliases + ); + } + } + + // Handle binary expressions (e.g., col1 = col2) + if (expression.type === 'binary_expr') { + if (expression.left) { + extractColumnFromExpression( + expression.left, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables, + columnAliases + ); + } + if (expression.right) { + extractColumnFromExpression( + expression.right, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables, + columnAliases + ); + } + } + + // Handle expression lists (e.g., IN clauses with subqueries) + if (expression.type === 'expr_list' && expression.value && Array.isArray(expression.value)) { + for (const item of expression.value) { + if (item.ast && item.ast.type === 'select') { + // This is a subquery + extractColumnsFromStatement( + item.ast as Record, + tableColumnMap, + cteNames, + aliasToTableMap + ); + } else { + extractColumnFromExpression( + item, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables, + columnAliases + ); + } + } + } + + // Handle direct subqueries + if (expression.type === 'select') { + extractColumnsFromStatement( + expression as Record, + tableColumnMap, + cteNames, + aliasToTableMap + ); + } + + // Handle subqueries with ast property (as in SELECT column subqueries) + if (expression.ast && typeof expression.ast === 'object') { + const astObj = expression.ast as Record; + if (astObj.type === 'select') { + extractColumnsFromStatement(astObj, tableColumnMap, cteNames, aliasToTableMap); + } + } + + // Handle window functions (e.g., LAG, LEAD, ROW_NUMBER with OVER clause) + if (expression.type === 'window_func') { + // Process the function arguments + if (expression.args) { + if (Array.isArray(expression.args)) { + for (const arg of expression.args) { + if (arg.value) { + extractColumnFromExpression( + arg.value, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables, + columnAliases + ); + } else { + extractColumnFromExpression( + arg, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables + ); + } + } + } else if (typeof expression.args === 'object') { + const argsObj = expression.args as Record; + // Handle expr_list for window functions + if (argsObj.type === 'expr_list' && argsObj.value && Array.isArray(argsObj.value)) { + for (const item of argsObj.value) { + extractColumnFromExpression( + item, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables, + columnAliases + ); + } + } else { + extractColumnFromExpression( + expression.args, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables + ); + } + } + } + + // Process the OVER clause + if (expression.over && typeof expression.over === 'object') { + const overObj = expression.over as Record; + + // Handle PARTITION BY + if (overObj.partitionby && Array.isArray(overObj.partitionby)) { + for (const partItem of overObj.partitionby) { + extractColumnFromExpression( + partItem, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables, + columnAliases + ); + } + } + + // Handle ORDER BY + if (overObj.orderby && Array.isArray(overObj.orderby)) { + for (const orderItem of overObj.orderby) { + if (orderItem && typeof orderItem === 'object') { + const orderObj = orderItem as Record; + if (orderObj.expr) { + extractColumnFromExpression( + orderObj.expr, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables, + columnAliases + ); + } else { + extractColumnFromExpression( + orderItem, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables, + columnAliases + ); + } + } + } + } + } + } + + // Handle function calls + if (expression.type === 'function' && expression.args) { + if (Array.isArray(expression.args)) { + for (const arg of expression.args) { + if (arg.value) { + extractColumnFromExpression( + arg.value, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables, + columnAliases + ); + } else { + extractColumnFromExpression( + arg, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables + ); + } + } + } else if (typeof expression.args === 'object') { + const argsObj = expression.args as Record; + // Handle expr_list for functions like EXISTS, etc. + if (argsObj.type === 'expr_list' && argsObj.value && Array.isArray(argsObj.value)) { + for (const item of argsObj.value) { + // Check for subquery with ast property (EXISTS subqueries have this structure) + const itemObj = item as Record; + if (itemObj?.ast && typeof itemObj.ast === 'object') { + const astObj = itemObj.ast as Record; + if (astObj.type === 'select') { + extractColumnsFromStatement(astObj, tableColumnMap, cteNames, aliasToTableMap); + } + } else { + // Process any other expression type (including aggr_func, column_ref, etc.) + extractColumnFromExpression( + item, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables, + columnAliases + ); + } + } + } else { + extractColumnFromExpression( + expression.args, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables + ); + } + } + + // Handle window functions (OVER clause) + if (expression.over && typeof expression.over === 'object') { + const overObj = expression.over as Record; + if (overObj.as_window_specification && typeof overObj.as_window_specification === 'object') { + const windowSpec = overObj.as_window_specification as Record; + if ( + windowSpec.window_specification && + typeof windowSpec.window_specification === 'object' + ) { + const spec = windowSpec.window_specification as Record; + + // Handle PARTITION BY + if (spec.partitionby && Array.isArray(spec.partitionby)) { + for (const partItem of spec.partitionby) { + extractColumnFromExpression( + partItem, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables + ); + } + } + + // Handle ORDER BY + if (spec.orderby && Array.isArray(spec.orderby)) { + for (const orderItem of spec.orderby) { + if (orderItem.expr) { + extractColumnFromExpression( + orderItem.expr, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables + ); + } else if (orderItem) { + // Sometimes the orderItem itself might be the expression + extractColumnFromExpression( + orderItem, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables + ); + } + } + } + } + } + } + } + + // Handle CASE expressions + if (expression.type === 'case') { + if (expression.expr) { + extractColumnFromExpression( + expression.expr, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables, + columnAliases + ); + } + if (expression.args && Array.isArray(expression.args)) { + for (const arg of expression.args) { + // Handle WHEN conditions + if (arg.cond) { + extractColumnFromExpression( + arg.cond, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables, + columnAliases + ); + } + // Also handle older format + if (arg.when) { + extractColumnFromExpression( + arg.when, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables, + columnAliases + ); + } + // Handle THEN results (may contain columns) + if (arg.result) { + extractColumnFromExpression( + arg.result, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables, + columnAliases + ); + } + if (arg.then) { + extractColumnFromExpression( + arg.then, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables, + columnAliases + ); + } + } + } + if (expression.else) { + extractColumnFromExpression( + expression.else, + aliasToTableMap, + tableColumnMap, + cteNames, + physicalTables, + columnAliases + ); + } + } +} diff --git a/packages/access-controls/src/sql-permissions/validator.ts b/packages/access-controls/src/sql-permissions/validator.ts new file mode 100644 index 000000000..349b63c08 --- /dev/null +++ b/packages/access-controls/src/sql-permissions/validator.ts @@ -0,0 +1,276 @@ +import { getPermissionedDatasets } from '../datasets/permissions'; +import { + type ParsedDataset, + type ParsedTable, + checkQueryIsReadOnly, + extractColumnReferences, + extractDatasetsFromYml, + extractPhysicalTables, + extractTablesFromYml, + tablesMatch, + validateWildcardUsage, +} from './parser-helpers'; + +export interface UnauthorizedColumn { + table: string; + column: string; +} + +export interface PermissionValidationResult { + isAuthorized: boolean; + unauthorizedTables: string[]; + unauthorizedColumns?: UnauthorizedColumn[]; + error?: string; +} + +/** + * Validates SQL query against user's permissioned datasets + * Checks that all tables and columns referenced in the query are accessible to the user + */ +export async function validateSqlPermissions( + sql: string, + userId: string, + dataSourceSyntax?: string +): Promise { + try { + // First check if query is read-only + const readOnlyCheck = checkQueryIsReadOnly(sql, dataSourceSyntax); + if (!readOnlyCheck.isReadOnly) { + return { + isAuthorized: false, + unauthorizedTables: [], + error: + readOnlyCheck.error || + 'Only SELECT statements are allowed for read-only access. Please modify your query to use SELECT instead of write operations like INSERT, UPDATE, DELETE, or DDL statements.', + }; + } + + const wildcardCheck = validateWildcardUsage(sql, dataSourceSyntax); + // Store the wildcard error but continue to validate columns to provide comprehensive feedback + let wildcardError: string | undefined; + if (!wildcardCheck.isValid) { + wildcardError = + wildcardCheck.error || + 'SELECT * is not allowed on physical tables. Please explicitly list the column names you need (e.g., SELECT id, name, email FROM users). This helps prevent unintended data exposure and improves query performance.'; + } + + // Extract physical tables from SQL + const tablesInQuery = extractPhysicalTables(sql, dataSourceSyntax); + + if (tablesInQuery.length === 0) { + // No tables referenced (might be a function call or constant select) + return { isAuthorized: true, unauthorizedTables: [] }; + } + + // Get user's permissioned datasets + const permissionedDatasets = await getPermissionedDatasets({ + userId, + page: 0, + pageSize: 1000, + }); + + // Extract all allowed tables and datasets from permissions + const allowedTables: ParsedTable[] = []; + const allowedDatasets: ParsedDataset[] = []; + + for (const dataset of permissionedDatasets.datasets) { + if (dataset.ymlContent) { + const tables = extractTablesFromYml(dataset.ymlContent); + allowedTables.push(...tables); + + const datasetsWithColumns = extractDatasetsFromYml(dataset.ymlContent); + allowedDatasets.push(...datasetsWithColumns); + } + } + + // Check each table in query against permissions + const unauthorizedTables: string[] = []; + + for (const queryTable of tablesInQuery) { + let isAuthorized = false; + + // Check if query table matches any allowed table + for (const allowedTable of allowedTables) { + const matches = tablesMatch(queryTable, allowedTable); + if (matches) { + isAuthorized = true; + break; + } + } + + if (!isAuthorized) { + unauthorizedTables.push(queryTable.fullName); + } + } + + // Continue to validate column-level permissions even if tables are unauthorized + // This allows us to report both unauthorized tables AND their columns + const columnReferences = extractColumnReferences(sql, dataSourceSyntax); + const unauthorizedColumns: UnauthorizedColumn[] = []; + + for (const [tableName, columns] of columnReferences) { + // Find the matching allowed dataset for this table + let matchingDataset: ParsedDataset | undefined; + + for (const dataset of allowedDatasets) { + // Check if table names match (case-insensitive) + const tableNameLower = tableName.toLowerCase(); + const datasetFullNameLower = dataset.fullName.toLowerCase(); + const datasetTableLower = dataset.table.toLowerCase(); + + // Handle different qualification levels - check both fullName and table + if ( + tableNameLower === datasetFullNameLower || + tableNameLower === datasetTableLower || + tableNameLower.endsWith(`.${datasetTableLower}`) || + datasetFullNameLower === tableNameLower + ) { + matchingDataset = dataset; + break; + } + } + + if (matchingDataset) { + // Found a matching dataset - validate columns if it has restrictions + if (matchingDataset.allowedColumns.size > 0) { + for (const column of columns) { + if (!matchingDataset.allowedColumns.has(column.toLowerCase())) { + unauthorizedColumns.push({ + table: tableName, + column: column, + }); + } + } + } + // If dataset has no column restrictions, it's backward compatibility mode (allow all columns) + } else { + // No matching dataset found - this table is completely unauthorized + // Check if this table was already marked as unauthorized + const isTableUnauthorized = unauthorizedTables.some((t) => { + const tLower = t.toLowerCase(); + const tableNameLower = tableName.toLowerCase(); + return ( + tLower === tableNameLower || + tLower.endsWith(`.${tableNameLower.split('.').pop()}`) || + tableNameLower.endsWith(`.${tLower.split('.').pop()}`) + ); + }); + + if (isTableUnauthorized) { + // Table is unauthorized, so all its columns are also unauthorized + for (const column of columns) { + unauthorizedColumns.push({ + table: tableName, + column: column, + }); + } + } + } + } + + const result: PermissionValidationResult = { + isAuthorized: + unauthorizedTables.length === 0 && unauthorizedColumns.length === 0 && !wildcardError, + unauthorizedTables, + }; + + if (unauthorizedColumns.length > 0) { + result.unauthorizedColumns = unauthorizedColumns; + } + + if (wildcardError) { + result.error = wildcardError; + } + + return result; + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error); + + // Provide more specific guidance based on the error + if (errorMessage.includes('parse')) { + return { + isAuthorized: false, + unauthorizedTables: [], + error: `Failed to validate SQL permissions due to parsing error: ${errorMessage}. Please check your SQL syntax and ensure it's valid.`, + }; + } + + if (errorMessage.includes('permission')) { + return { + isAuthorized: false, + unauthorizedTables: [], + error: `Permission check failed: ${errorMessage}. Please ensure you have the necessary access rights for the requested tables and columns.`, + }; + } + + return { + isAuthorized: false, + unauthorizedTables: [], + error: `Permission validation failed: ${errorMessage}. Please verify your SQL query syntax and ensure you have access to the requested resources.`, + }; + } +} + +/** + * Creates a detailed error message for unauthorized table or column access + */ +export function createPermissionErrorMessage( + unauthorizedTables: string[], + unauthorizedColumns?: UnauthorizedColumn[] +): string { + const messages: string[] = []; + + // Handle unauthorized tables with actionable guidance + if (unauthorizedTables.length > 0) { + const tableList = unauthorizedTables.join(', '); + if (unauthorizedTables.length === 1) { + messages.push( + `You do not have access to table: ${tableList}. Please request access to this table or use a different table that you have permissions for.` + ); + } else { + messages.push( + `You do not have access to the following tables: ${tableList}. Please request access to these tables or modify your query to use only authorized tables.` + ); + } + } + + // Handle unauthorized columns + if (unauthorizedColumns && unauthorizedColumns.length > 0) { + // Group columns by table for better error messages + const columnsByTable = new Map(); + + for (const { table, column } of unauthorizedColumns) { + if (!columnsByTable.has(table)) { + columnsByTable.set(table, []); + } + const tableColumns = columnsByTable.get(table); + if (tableColumns) { + tableColumns.push(column); + } + } + + const columnMessages: string[] = []; + for (const [table, columns] of columnsByTable) { + const columnList = columns.join(', '); + columnMessages.push( + `Table '${table}': columns [${columnList}] are not available in your permitted dataset` + ); + } + + if (columnMessages.length === 1) { + messages.push( + `Unauthorized column access - ${columnMessages[0]}. Please use only the columns that are available in your permitted datasets, or request access to additional columns.` + ); + } else { + messages.push( + `Unauthorized column access:\n${columnMessages.map((m) => ` - ${m}`).join('\n')}\n\nPlease modify your query to use only the columns available in your permitted datasets, or request access to the additional columns you need.` + ); + } + } + + if (messages.length === 0) { + return ''; + } + + return `Insufficient permissions: ${messages.join('. ')}`; +} diff --git a/packages/ai/src/tools/database-tools/execute-sql/execute-sql-execute.ts b/packages/ai/src/tools/database-tools/execute-sql/execute-sql-execute.ts index d1f86e2fa..f83f93438 100644 --- a/packages/ai/src/tools/database-tools/execute-sql/execute-sql-execute.ts +++ b/packages/ai/src/tools/database-tools/execute-sql/execute-sql-execute.ts @@ -1,11 +1,8 @@ +import { createPermissionErrorMessage, validateSqlPermissions } from '@buster/access-controls'; import { type DataSource, withRateLimit } from '@buster/data-source'; import { updateMessageEntries } from '@buster/database'; import { wrapTraced } from 'braintrust'; import { getDataSource } from '../../../utils/get-data-source'; -import { - createPermissionErrorMessage, - validateSqlPermissions, -} from '../../../utils/sql-permissions'; import { createRawToolResultEntry } from '../../shared/create-raw-llm-tool-result-entry'; import { EXECUTE_SQL_TOOL_NAME, diff --git a/packages/ai/src/tools/database-tools/super-execute-sql/super-execute-sql-execute.ts b/packages/ai/src/tools/database-tools/super-execute-sql/super-execute-sql-execute.ts index 301e62f38..f40dd5b7c 100644 --- a/packages/ai/src/tools/database-tools/super-execute-sql/super-execute-sql-execute.ts +++ b/packages/ai/src/tools/database-tools/super-execute-sql/super-execute-sql-execute.ts @@ -1,7 +1,7 @@ +import { checkQueryIsReadOnly } from '@buster/access-controls'; import { type DataSource, withRateLimit } from '@buster/data-source'; import { wrapTraced } from 'braintrust'; import { getDataSource } from '../../../utils/get-data-source'; -import { checkQueryIsReadOnly } from '../../../utils/sql-permissions/sql-parser-helpers'; import type { SuperExecuteSqlContext, SuperExecuteSqlInput, diff --git a/packages/ai/src/tools/database-tools/super-execute-sql/super-execute-sql.test.ts b/packages/ai/src/tools/database-tools/super-execute-sql/super-execute-sql.test.ts index c04dbdb00..48f402848 100644 --- a/packages/ai/src/tools/database-tools/super-execute-sql/super-execute-sql.test.ts +++ b/packages/ai/src/tools/database-tools/super-execute-sql/super-execute-sql.test.ts @@ -1,7 +1,7 @@ +import { checkQueryIsReadOnly } from '@buster/access-controls'; import type { DataSource } from '@buster/data-source'; import { beforeEach, describe, expect, it, vi } from 'vitest'; import { getDataSource } from '../../../utils/get-data-source'; -import { checkQueryIsReadOnly } from '../../../utils/sql-permissions/sql-parser-helpers'; import type { SuperExecuteSqlContext, SuperExecuteSqlState } from './super-execute-sql'; import { createSuperExecuteSqlExecute } from './super-execute-sql-execute'; @@ -10,7 +10,7 @@ vi.mock('../../../utils/get-data-source', () => ({ getDataSource: vi.fn(), })); -vi.mock('../../../utils/sql-permissions/sql-parser-helpers', () => ({ +vi.mock('@buster/access-controls', () => ({ checkQueryIsReadOnly: vi.fn(), })); diff --git a/packages/ai/src/tools/visualization-tools/metrics/create-metrics-tool/create-metrics-execute.ts b/packages/ai/src/tools/visualization-tools/metrics/create-metrics-tool/create-metrics-execute.ts index 8da199152..4c93e2a21 100644 --- a/packages/ai/src/tools/visualization-tools/metrics/create-metrics-tool/create-metrics-execute.ts +++ b/packages/ai/src/tools/visualization-tools/metrics/create-metrics-tool/create-metrics-execute.ts @@ -1,4 +1,5 @@ import { randomUUID } from 'node:crypto'; +import { createPermissionErrorMessage, validateSqlPermissions } from '@buster/access-controls'; import type { Credentials } from '@buster/data-source'; import { createMetadataFromResults, executeMetricQuery } from '@buster/data-source'; import { assetPermissions, db, metricFiles, updateMessageEntries } from '@buster/database'; @@ -12,10 +13,6 @@ import { wrapTraced } from 'braintrust'; import * as yaml from 'yaml'; import { z } from 'zod'; import { getDataSourceCredentials } from '../../../../utils/get-data-source'; -import { - createPermissionErrorMessage, - validateSqlPermissions, -} from '../../../../utils/sql-permissions'; import { createRawToolResultEntry } from '../../../shared/create-raw-llm-tool-result-entry'; import { trackFileAssociations } from '../../file-tracking-helper'; import { validateAndAdjustBarLineAxes } from '../helpers/bar-line-axis-validator'; diff --git a/packages/ai/src/tools/visualization-tools/metrics/modify-metrics-tool/modify-metrics-execute.ts b/packages/ai/src/tools/visualization-tools/metrics/modify-metrics-tool/modify-metrics-execute.ts index 0c287e2ce..cf777123a 100644 --- a/packages/ai/src/tools/visualization-tools/metrics/modify-metrics-tool/modify-metrics-execute.ts +++ b/packages/ai/src/tools/visualization-tools/metrics/modify-metrics-tool/modify-metrics-execute.ts @@ -1,3 +1,4 @@ +import { createPermissionErrorMessage, validateSqlPermissions } from '@buster/access-controls'; import type { Credentials } from '@buster/data-source'; import { createMetadataFromResults, executeMetricQuery } from '@buster/data-source'; import { db, metricFiles, updateMessageEntries } from '@buster/database'; @@ -12,10 +13,6 @@ import { eq, inArray } from 'drizzle-orm'; import * as yaml from 'yaml'; import { z } from 'zod'; import { getDataSourceCredentials } from '../../../../utils/get-data-source'; -import { - createPermissionErrorMessage, - validateSqlPermissions, -} from '../../../../utils/sql-permissions'; import { createRawToolResultEntry } from '../../../shared/create-raw-llm-tool-result-entry'; import { trackFileAssociations } from '../../file-tracking-helper'; import { validateAndAdjustBarLineAxes } from '../helpers/bar-line-axis-validator'; diff --git a/packages/database/src/queries/data-sources/get-data-source-by-id.ts b/packages/database/src/queries/data-sources/get-data-source-by-id.ts new file mode 100644 index 000000000..6e5995593 --- /dev/null +++ b/packages/database/src/queries/data-sources/get-data-source-by-id.ts @@ -0,0 +1,36 @@ +import { and, eq, isNull } from 'drizzle-orm'; +import { z } from 'zod'; +import { db } from '../../connection'; +import { dataSources } from '../../schema'; + +// Zod schema for the data source +export const DataSourceSchema = z.object({ + id: z.string(), + name: z.string(), + type: z.string(), + organizationId: z.string(), + secretId: z.string(), +}); + +export type DataSource = z.infer; + +/** + * Fetches a data source by its ID + * @param dataSourceId - The ID of the data source to fetch + * @returns The data source or null if not found + */ +export async function getDataSourceById(dataSourceId: string): Promise { + const [result] = await db + .select({ + id: dataSources.id, + name: dataSources.name, + type: dataSources.type, + organizationId: dataSources.organizationId, + secretId: dataSources.secretId, + }) + .from(dataSources) + .where(and(eq(dataSources.id, dataSourceId), isNull(dataSources.deletedAt))) + .limit(1); + + return result || null; +} diff --git a/packages/database/src/queries/data-sources/index.ts b/packages/database/src/queries/data-sources/index.ts new file mode 100644 index 000000000..697c40b12 --- /dev/null +++ b/packages/database/src/queries/data-sources/index.ts @@ -0,0 +1,2 @@ +export * from './get-data-source-by-id'; +export * from './organizationDataSource'; diff --git a/packages/database/src/queries/dataSources/organizationDataSource.ts b/packages/database/src/queries/data-sources/organizationDataSource.ts similarity index 100% rename from packages/database/src/queries/dataSources/organizationDataSource.ts rename to packages/database/src/queries/data-sources/organizationDataSource.ts diff --git a/packages/database/src/queries/dataSources/index.ts b/packages/database/src/queries/dataSources/index.ts index bb5299ade..bbc236ede 100644 --- a/packages/database/src/queries/dataSources/index.ts +++ b/packages/database/src/queries/dataSources/index.ts @@ -1 +1 @@ -export * from './organizationDataSource'; +export * from '../data-sources/organizationDataSource'; diff --git a/packages/database/src/queries/index.ts b/packages/database/src/queries/index.ts index b1997fca3..2db9ce6a1 100644 --- a/packages/database/src/queries/index.ts +++ b/packages/database/src/queries/index.ts @@ -1,7 +1,7 @@ export * from './api-keys'; export * from './messages'; export * from './users'; -export * from './dataSources'; +export * from './data-sources'; export * from './datasets'; export * from './assets'; export * from './asset-permissions'; diff --git a/packages/server-shared/src/index.ts b/packages/server-shared/src/index.ts index 2f77c2c5c..66e767041 100644 --- a/packages/server-shared/src/index.ts +++ b/packages/server-shared/src/index.ts @@ -30,3 +30,4 @@ export * from './type-utilities'; export * from './user'; export * from './shortcuts'; export * from './healthcheck'; +export * from './sql'; diff --git a/packages/server-shared/src/sql/index.ts b/packages/server-shared/src/sql/index.ts new file mode 100644 index 000000000..d6bdd23df --- /dev/null +++ b/packages/server-shared/src/sql/index.ts @@ -0,0 +1,17 @@ +import { z } from 'zod'; +import type { DataMetadata } from '../metrics'; + +// Request schema for running SQL queries +export const RunSqlRequestSchema = z.object({ + data_source_id: z.string().uuid('Data source ID must be a valid UUID'), + sql: z.string().min(1, 'SQL query cannot be empty'), +}); + +export type RunSqlRequest = z.infer; + +// Response type matching the structure from metric responses +export interface RunSqlResponse { + data: Record[]; + data_metadata: DataMetadata; + has_more_records: boolean; +} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index b73e07357..bb0a7d425 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -926,6 +926,9 @@ importers: packages/access-controls: dependencies: + '@buster/data-source': + specifier: workspace:* + version: link:../data-source '@buster/database': specifier: workspace:* version: link:../database @@ -944,9 +947,15 @@ importers: lru-cache: specifier: ^11.1.0 version: 11.1.0 + node-sql-parser: + specifier: ^5.3.12 + version: 5.3.12 uuid: specifier: 'catalog:' version: 11.1.0 + yaml: + specifier: ^2.8.1 + version: 2.8.1 zod: specifier: 'catalog:' version: 3.25.76