diff --git a/apps/server/src/api/v2/tools/metadata/GET.ts b/apps/server/src/api/v2/tools/metadata/GET.ts index 1ada2c882..69a7da0b2 100644 --- a/apps/server/src/api/v2/tools/metadata/GET.ts +++ b/apps/server/src/api/v2/tools/metadata/GET.ts @@ -1,17 +1,178 @@ -import { getDatasetMetadata } from '@buster/database/queries'; -import type { GetMetadataRequest, GetMetadataResponse } from '@buster/server-shared'; +import { type Credentials, DataSourceType, createAdapter } from '@buster/data-source'; +import { getDataSourceCredentials } from '@buster/database/queries'; +import type { + GetMetadataRequest, + GetMetadataResponse, + GetTableStatisticsOutput, +} from '@buster/server-shared'; import type { ApiKeyContext } from '@buster/server-shared'; +import { runs, tasks } from '@trigger.dev/sdk'; import { HTTPException } from 'hono/http-exception'; +/** + * Validates that identifier doesn't contain SQL injection attempts + * Allows alphanumeric, underscores, hyphens, and dots (for qualified names) + */ +function validateIdentifier(identifier: string, fieldName: string): void { + // Allow alphanumeric, underscores, hyphens, dots, and spaces (for some database names) + const validPattern = /^[a-zA-Z0-9_\-\.\s]+$/; + + if (!validPattern.test(identifier)) { + throw new HTTPException(400, { + message: `Invalid ${fieldName}: contains disallowed characters. Only alphanumeric, underscores, hyphens, dots, and spaces are allowed.`, + }); + } + + // Block common SQL injection keywords + const sqlKeywords = + /(\b(DROP|DELETE|INSERT|UPDATE|ALTER|CREATE|EXEC|EXECUTE|UNION|SELECT|WHERE|FROM|JOIN)\b)/i; + if (sqlKeywords.test(identifier)) { + throw new HTTPException(400, { + message: `Invalid ${fieldName}: contains disallowed SQL keywords.`, + }); + } +} + +function isCredentials(value: unknown): value is Credentials { + if (!value || typeof value !== 'object') return false; + const type = (value as { type?: unknown }).type; + if (typeof type !== 'string') return false; + return (Object.values(DataSourceType) as string[]).includes(type); +} + +/** + * Quick metadata lookup from warehouse information schema + * Returns table size, row count, and type without full sampling + */ +async function getQuickTableMetadata( + dataSourceId: string, + database: string, + schema: string, + tableName: string +): Promise<{ rowCount: number; sizeBytes?: number; type: string }> { + let adapter = null; + + try { + // Fetch credentials + const credentials = await getDataSourceCredentials({ dataSourceId }); + if (!isCredentials(credentials)) { + throw new Error('Invalid credentials returned from vault'); + } + + // Create adapter + adapter = await createAdapter(credentials); + + // Query information schema based on database type + // This is fast and doesn't require sampling + let query: string; + let params: unknown[] = []; + + switch (credentials.type) { + case DataSourceType.PostgreSQL: + case DataSourceType.Redshift: + query = ` + SELECT + COALESCE(s.n_live_tup, 0) as row_count, + pg_relation_size(quote_ident($1)||'.'||quote_ident($2)) as size_bytes, + t.table_type + FROM information_schema.tables t + LEFT JOIN pg_stat_user_tables s + ON s.schemaname = t.table_schema + AND s.relname = t.table_name + WHERE t.table_schema = $1 + AND t.table_name = $2 + LIMIT 1 + `; + params = [schema, tableName]; + break; + + case DataSourceType.Snowflake: + query = ` + SELECT + row_count, + bytes as size_bytes, + table_type + FROM ${database}.information_schema.tables + WHERE table_schema = ? + AND table_name = ? + LIMIT 1 + `; + params = [schema, tableName]; + break; + + case DataSourceType.BigQuery: + query = ` + SELECT + row_count, + size_bytes, + table_type + FROM \`${database}.${schema}.INFORMATION_SCHEMA.TABLES\` + WHERE table_name = ? + LIMIT 1 + `; + params = [tableName]; + break; + + case DataSourceType.MySQL: + query = ` + SELECT + table_rows as row_count, + data_length as size_bytes, + table_type + FROM information_schema.tables + WHERE table_schema = ? + AND table_name = ? + LIMIT 1 + `; + params = [schema, tableName]; + break; + + default: + // Fallback for unknown types - use reasonable defaults + return { rowCount: 100000, type: 'TABLE' }; + } + + const result = await adapter.query(query, params as (string | number)[]); + + if (!result.rows || result.rows.length === 0) { + throw new HTTPException(404, { + message: `Table not found: ${database}.${schema}.${tableName}`, + }); + } + + const row = result.rows[0] as Record; + + const rowCount = typeof row.row_count === 'number' ? row.row_count : 100000; + const sizeBytes = typeof row.size_bytes === 'number' ? row.size_bytes : undefined; + const tableType = typeof row.table_type === 'string' ? row.table_type : 'TABLE'; + + return { + rowCount, + ...(sizeBytes !== undefined && { sizeBytes }), + type: tableType, + }; + } finally { + if (adapter) { + await adapter.close().catch(() => { + // Ignore cleanup errors + }); + } + } +} + /** * Handler for retrieving dataset metadata via API key authentication * * This handler: * 1. Validates API key has access to the organization - * 2. Queries for dataset matching database, schema, name, and organization - * 3. Returns the dataset's metadata column + * 2. Validates identifiers for SQL injection protection + * 3. Triggers get-table-statistics task to compute fresh metadata + * 4. Waits for task completion (max 2 minutes) + * 5. Transforms and returns the metadata * - * @param request - The metadata request containing database, schema, and name + * Note: This works on ANY table in the data source, not just registered datasets + * + * @param request - The metadata request containing dataSourceId, database, schema, and name * @param apiKeyContext - The authenticated API key context * @returns The dataset metadata */ @@ -20,22 +181,127 @@ export async function getMetadataHandler( apiKeyContext: ApiKeyContext ): Promise { const { organizationId } = apiKeyContext; + const { dataSourceId, database, schema, name } = request; - // Get dataset metadata - const result = await getDatasetMetadata({ - database: request.database, - schema: request.schema, - name: request.name, - organizationId, - }); + // Validate identifiers for SQL injection protection + validateIdentifier(database, 'database'); + validateIdentifier(schema, 'schema'); + validateIdentifier(name, 'table name'); - if (!result || !result.metadata) { - throw new HTTPException(404, { - message: `Dataset not found: ${request.database}.${request.schema}.${request.name}`, + try { + // Quick lookup from warehouse information schema to get accurate row count and size + // This is fast (< 1 second) and doesn't require sampling + const quickMetadata = await getQuickTableMetadata(dataSourceId, database, schema, name); + + // Determine sample size based on actual row count + // Use up to 100k samples for large tables, scale down for smaller ones + const sampleSize = Math.min(100000, Math.max(10000, Math.floor(quickMetadata.rowCount * 0.1))); + + // Map table type to enum + const tableType: 'TABLE' | 'VIEW' | 'MATERIALIZED_VIEW' | 'EXTERNAL_TABLE' | 'TEMPORARY_TABLE' = + quickMetadata.type === 'VIEW' + ? 'VIEW' + : quickMetadata.type === 'MATERIALIZED VIEW' + ? 'MATERIALIZED_VIEW' + : quickMetadata.type === 'FOREIGN TABLE' + ? 'EXTERNAL_TABLE' + : quickMetadata.type === 'LOCAL TEMPORARY' + ? 'TEMPORARY_TABLE' + : 'TABLE'; + + // Trigger the get-table-statistics task with idempotency + // If the same table is requested within 5 minutes, return existing task + const handle = await tasks.trigger( + 'get-table-statistics', + { + dataSourceId, + table: { + name, + schema, + database, + rowCount: quickMetadata.rowCount, + sizeBytes: quickMetadata.sizeBytes, + type: tableType, + }, + sampleSize, + }, + { + idempotencyKey: `metadata-${organizationId}-${dataSourceId}-${database}-${schema}-${name}`, + idempotencyKeyTTL: '5m', // 5 minutes TTL + } + ); + + // Poll for task completion with timeout + const startTime = Date.now(); + const timeout = 120000; // 2 minutes + const pollInterval = 2000; // Poll every 2 seconds + + let run: Awaited>; + while (true) { + run = await runs.retrieve(handle.id); + + // Check if task completed, failed, or was canceled + if (run.status === 'COMPLETED' || run.status === 'FAILED' || run.status === 'CANCELED') { + break; + } + + // Check for timeout + if (Date.now() - startTime > timeout) { + throw new HTTPException(504, { + message: 'Metadata collection took too long to complete. Please try again.', + }); + } + + // Wait before next poll + await new Promise((resolve) => setTimeout(resolve, pollInterval)); + } + + // Check task status + if (run.status === 'FAILED' || run.status === 'CANCELED') { + throw new HTTPException(500, { + message: `Metadata collection task ${run.status.toLowerCase()}`, + }); + } + + // Check if task completed successfully + if (!run.output) { + throw new HTTPException(500, { + message: 'Metadata collection task did not return any output', + }); + } + + const output = run.output as GetTableStatisticsOutput; + + if (!output.success) { + throw new HTTPException(500, { + message: output.error || 'Metadata collection failed', + }); + } + + // Transform GetTableStatisticsOutput to DatasetMetadata format + const metadata = { + rowCount: output.totalRows, + sizeBytes: quickMetadata.sizeBytes, + sampleSize: output.actualSamples, + samplingMethod: output.samplingMethod, + columnProfiles: output.columnProfiles || [], + introspectedAt: new Date().toISOString(), + }; + + return { + metadata, + }; + } catch (error) { + // Re-throw HTTPException as-is + if (error instanceof HTTPException) { + throw error; + } + + // Log unexpected errors + console.error('Unexpected error during metadata retrieval:', error); + + throw new HTTPException(500, { + message: 'An unexpected error occurred during metadata retrieval', }); } - - return { - metadata: result.metadata, - }; } diff --git a/packages/ai/src/tools/database-tools/retrieve-metadata/retrieve-metadata-execute.test.ts b/packages/ai/src/tools/database-tools/retrieve-metadata/retrieve-metadata-execute.test.ts index c1fe42c12..7015b271c 100644 --- a/packages/ai/src/tools/database-tools/retrieve-metadata/retrieve-metadata-execute.test.ts +++ b/packages/ai/src/tools/database-tools/retrieve-metadata/retrieve-metadata-execute.test.ts @@ -35,6 +35,7 @@ describe('retrieve-metadata-execute error handling', () => { } as Response); const result = await executeHandler({ + dataSourceId: 'test-ds-id', database: 'test_db', schema: 'public', name: 'users', @@ -42,7 +43,7 @@ describe('retrieve-metadata-execute error handling', () => { expect(result).toEqual(mockResponse); expect(fetch).toHaveBeenCalledWith( - 'http://localhost:3000/api/v2/tools/metadata?database=test_db&schema=public&name=users', + 'http://localhost:3000/api/v2/tools/metadata?dataSourceId=test-ds-id&database=test_db&schema=public&name=users', expect.objectContaining({ method: 'GET', headers: { @@ -64,6 +65,7 @@ describe('retrieve-metadata-execute error handling', () => { await expect( executeHandler({ + dataSourceId: 'test-ds-id', database: 'test_db', schema: 'public', name: 'nonexistent_table', @@ -78,6 +80,7 @@ describe('retrieve-metadata-execute error handling', () => { await expect( executeHandler({ + dataSourceId: 'test-ds-id', database: 'test_db', schema: 'public', name: 'users', @@ -97,6 +100,7 @@ describe('retrieve-metadata-execute error handling', () => { await expect( executeHandler({ + dataSourceId: 'test-ds-id', database: 'test_db', schema: 'restricted', name: 'sensitive_table', @@ -116,6 +120,7 @@ describe('retrieve-metadata-execute error handling', () => { await expect( executeHandler({ + dataSourceId: 'test-ds-id', database: 'test_db', schema: 'public', name: 'users', @@ -137,6 +142,7 @@ describe('retrieve-metadata-execute error handling', () => { await expect( executeHandler({ + dataSourceId: 'test-ds-id', database: 'test_db', schema: 'public', name: 'users', @@ -151,6 +157,7 @@ describe('retrieve-metadata-execute error handling', () => { await expect( executeHandler({ + dataSourceId: 'test-ds-id', database: 'test_db', schema: 'public', name: 'users', diff --git a/packages/ai/src/tools/database-tools/retrieve-metadata/retrieve-metadata-execute.ts b/packages/ai/src/tools/database-tools/retrieve-metadata/retrieve-metadata-execute.ts index 277ea277e..03a659f94 100644 --- a/packages/ai/src/tools/database-tools/retrieve-metadata/retrieve-metadata-execute.ts +++ b/packages/ai/src/tools/database-tools/retrieve-metadata/retrieve-metadata-execute.ts @@ -10,6 +10,7 @@ import { * Retrieve dataset metadata via API endpoint */ async function executeApiRequest( + dataSourceId: string, database: string, schema: string, name: string, @@ -22,6 +23,7 @@ async function executeApiRequest( try { // Build query string const params = new URLSearchParams({ + dataSourceId, database, schema, name, @@ -66,10 +68,10 @@ async function executeApiRequest( export function createRetrieveMetadataExecute(context: RetrieveMetadataContext) { return wrapTraced( async (input: RetrieveMetadataInput): Promise => { - const { database, schema, name } = input; + const { dataSourceId, database, schema, name } = input; // Execute API request - const result = await executeApiRequest(database, schema, name, context); + const result = await executeApiRequest(dataSourceId, database, schema, name, context); if (result.success && result.data) { return result.data; diff --git a/packages/ai/src/tools/database-tools/retrieve-metadata/retrieve-metadata.ts b/packages/ai/src/tools/database-tools/retrieve-metadata/retrieve-metadata.ts index 954a26fbc..85cece254 100644 --- a/packages/ai/src/tools/database-tools/retrieve-metadata/retrieve-metadata.ts +++ b/packages/ai/src/tools/database-tools/retrieve-metadata/retrieve-metadata.ts @@ -5,6 +5,7 @@ import { createRetrieveMetadataExecute } from './retrieve-metadata-execute'; export const RETRIEVE_METADATA_TOOL_NAME = 'retrieveMetadata'; export const RetrieveMetadataInputSchema = z.object({ + dataSourceId: z.string().min(1).describe('Data source identifier'), database: z.string().min(1).describe('Database name where the dataset resides'), schema: z.string().min(1).describe('Schema name where the dataset resides'), name: z.string().min(1).describe('Dataset/table name'), diff --git a/packages/server-shared/src/metadata/index.ts b/packages/server-shared/src/metadata/index.ts index 4b1792ddb..458dbbec9 100644 --- a/packages/server-shared/src/metadata/index.ts +++ b/packages/server-shared/src/metadata/index.ts @@ -3,6 +3,7 @@ import { z } from 'zod'; // Request schema for getting dataset metadata export const GetMetadataRequestSchema = z.object({ + dataSourceId: z.string().min(1, 'Data source ID cannot be empty'), database: z.string().min(1, 'Database name cannot be empty'), schema: z.string().min(1, 'Schema name cannot be empty'), name: z.string().min(1, 'Dataset name cannot be empty'), @@ -14,3 +15,10 @@ export type GetMetadataRequest = z.infer; export interface GetMetadataResponse { metadata: DatasetMetadata; } + +// Re-export trigger task types for use in server +export type { + GetTableStatisticsInput, + GetTableStatisticsOutput, + ColumnProfile, +} from './trigger-task-types'; diff --git a/packages/server-shared/src/metadata/trigger-task-types.ts b/packages/server-shared/src/metadata/trigger-task-types.ts new file mode 100644 index 000000000..fcf5a256c --- /dev/null +++ b/packages/server-shared/src/metadata/trigger-task-types.ts @@ -0,0 +1,116 @@ +import { z } from 'zod'; + +/** + * Types for trigger.dev tasks + * These are duplicated from @buster-app/trigger to avoid circular dependencies + * IMPORTANT: Keep in sync with apps/trigger/src/tasks/introspect-data/types/index.ts + */ + +// Column profile schema +export const ColumnProfileSchema = z.object({ + columnName: z.string(), + dataType: z.string(), + + // Basic Statistics + nullRate: z.number().min(0).max(1), + distinctCount: z.number().int().nonnegative(), + uniquenessRatio: z.number().min(0).max(1), + emptyStringRate: z.number().min(0).max(1), + + // Distribution + topValues: z.array( + z.object({ + value: z.unknown(), + count: z.number(), + percentage: z.number(), + }) + ), + entropy: z.number(), + giniCoefficient: z.number().min(0).max(1), + + // Sample values + sampleValues: z.array(z.unknown()), + + // Numeric-specific + numericStats: z + .object({ + mean: z.number(), + median: z.number(), + stdDev: z.number(), + skewness: z.number(), + percentiles: z.object({ + p25: z.number(), + p50: z.number(), + p75: z.number(), + p95: z.number(), + p99: z.number(), + }), + outlierRate: z.number().min(0).max(1), + }) + .optional(), + + // Classification + classification: z.object({ + isLikelyEnum: z.boolean(), + isLikelyIdentifier: z.boolean(), + identifierType: z + .enum(['primary_key', 'foreign_key', 'natural_key', 'sequential', 'uuid_like']) + .optional(), + enumValues: z.array(z.string()).optional(), + }), + + // Dynamic metadata based on detected column semantics + dynamicMetadata: z + .union([ + z.object({ type: z.literal('datetime') }).passthrough(), + z.object({ type: z.literal('numeric') }).passthrough(), + z.object({ type: z.literal('identifier') }).passthrough(), + z.object({ type: z.literal('url') }).passthrough(), + z.object({ type: z.literal('email') }).passthrough(), + z.object({ type: z.literal('json') }).passthrough(), + ]) + .optional(), +}); + +export type ColumnProfile = z.infer; + +// Get table statistics task input +export const GetTableStatisticsInputSchema = z.object({ + dataSourceId: z.string().min(1, 'Data source ID is required'), + table: z.object({ + name: z.string(), + schema: z.string(), + database: z.string(), + rowCount: z.number(), + sizeBytes: z.number().optional(), + type: z.enum(['TABLE', 'VIEW', 'MATERIALIZED_VIEW', 'EXTERNAL_TABLE', 'TEMPORARY_TABLE']), + }), + sampleSize: z.number().int().positive(), +}); + +export type GetTableStatisticsInput = z.infer; + +// Get table statistics task output +export const GetTableStatisticsOutputSchema = z.object({ + success: z.boolean(), + tableId: z.string(), + totalRows: z.number(), + sampleSize: z.number(), + actualSamples: z.number(), + samplingMethod: z.string(), + + // Statistical analysis results + columnProfiles: z.array(ColumnProfileSchema).optional(), + tableMetadata: z + .object({ + sampleSize: z.number(), + totalRows: z.number(), + samplingRate: z.number(), + analysisTimeMs: z.number(), + }) + .optional(), + + error: z.string().optional(), +}); + +export type GetTableStatisticsOutput = z.infer;