import sentry from fastapi import HTTPException, Request, Header from typing import Optional import jwt from jwt.exceptions import PyJWTError from core.utils.logger import structlog from core.utils.config import config import os import base64 import hashlib import hmac from core.services.supabase import DBConnection from core.services import redis async def verify_admin_api_key(x_admin_api_key: Optional[str] = Header(None)): if not config.KORTIX_ADMIN_API_KEY: raise HTTPException( status_code=500, detail="Admin API key not configured on server" ) if not x_admin_api_key: raise HTTPException( status_code=401, detail="Admin API key required. Include X-Admin-Api-Key header." ) if x_admin_api_key != config.KORTIX_ADMIN_API_KEY: raise HTTPException( status_code=403, detail="Invalid admin API key" ) return True def _decode_jwt_safely(token: str) -> dict: return jwt.decode( token, options={ "verify_signature": False, "verify_exp": True, "verify_aud": False, "verify_iss": False } ) async def get_account_id_from_thread(thread_id: str, db: "DBConnection") -> str: """ Get account_id from thread_id. Raises: ValueError: If thread not found or has no account_id """ try: client = await db.client thread_result = await client.table('threads').select('account_id').eq('thread_id', thread_id).limit(1).execute() if not thread_result.data: raise ValueError(f"Could not find thread with ID: {thread_id}") account_id = thread_result.data[0]['account_id'] if not account_id: raise ValueError("Thread has no associated account_id") return account_id except Exception as e: structlog.get_logger().error(f"Error getting account_id from thread: {e}") raise async def _get_user_id_from_account_cached(account_id: str) -> Optional[str]: cache_key = f"account_user:{account_id}" try: redis_client = await redis.get_client() cached_user_id = await redis_client.get(cache_key) if cached_user_id: return cached_user_id.decode('utf-8') if isinstance(cached_user_id, bytes) else cached_user_id except Exception as e: structlog.get_logger().warning(f"Redis cache lookup failed for account {account_id}: {e}") try: db = DBConnection() await db.initialize() client = await db.client user_result = await client.schema('basejump').table('accounts').select( 'primary_owner_user_id' ).eq('id', account_id).limit(1).execute() if user_result.data: user_id = user_result.data[0]['primary_owner_user_id'] try: await redis_client.setex(cache_key, 300, user_id) except Exception as e: structlog.get_logger().warning(f"Failed to cache user lookup: {e}") return user_id return None except Exception as e: structlog.get_logger().error(f"Database lookup failed for account {account_id}: {e}") return None async def verify_and_get_user_id_from_jwt(request: Request) -> str: x_api_key = request.headers.get('x-api-key') if x_api_key: try: if ':' not in x_api_key: raise HTTPException( status_code=401, detail="Invalid API key format. Expected format: pk_xxx:sk_xxx", headers={"WWW-Authenticate": "Bearer"} ) public_key, secret_key = x_api_key.split(':', 1) from core.services.api_keys import APIKeyService db = DBConnection() await db.initialize() api_key_service = APIKeyService(db) validation_result = await api_key_service.validate_api_key(public_key, secret_key) if validation_result.is_valid: user_id = await _get_user_id_from_account_cached(str(validation_result.account_id)) if user_id: sentry.sentry.set_user({ "id": user_id }) structlog.contextvars.bind_contextvars( user_id=user_id, auth_method="api_key", api_key_id=str(validation_result.key_id), public_key=public_key ) return user_id else: raise HTTPException( status_code=401, detail="API key account not found", headers={"WWW-Authenticate": "Bearer"} ) else: raise HTTPException( status_code=401, detail=f"Invalid API key: {validation_result.error_message}", headers={"WWW-Authenticate": "Bearer"} ) except HTTPException: raise except Exception as e: structlog.get_logger().error(f"Error validating API key: {e}") raise HTTPException( status_code=401, detail="API key validation failed", headers={"WWW-Authenticate": "Bearer"} ) auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith('Bearer '): raise HTTPException( status_code=401, detail="No valid authentication credentials found", headers={"WWW-Authenticate": "Bearer"} ) token = auth_header.split(' ')[1] try: payload = _decode_jwt_safely(token) user_id = payload.get('sub') if not user_id: raise HTTPException( status_code=401, detail="Invalid token payload", headers={"WWW-Authenticate": "Bearer"} ) sentry.sentry.set_user({ "id": user_id }) structlog.contextvars.bind_contextvars( user_id=user_id, auth_method="jwt" ) return user_id except PyJWTError: raise HTTPException( status_code=401, detail="Invalid token", headers={"WWW-Authenticate": "Bearer"} ) async def get_user_id_from_stream_auth( request: Request, token: Optional[str] = None ) -> str: try: try: return await verify_and_get_user_id_from_jwt(request) except HTTPException: pass if token: try: payload = _decode_jwt_safely(token) user_id = payload.get('sub') if user_id: sentry.sentry.set_user({ "id": user_id }) structlog.contextvars.bind_contextvars( user_id=user_id, auth_method="jwt_query" ) return user_id except Exception: pass raise HTTPException( status_code=401, detail="No valid authentication credentials found", headers={"WWW-Authenticate": "Bearer"} ) except HTTPException: raise except Exception as e: error_msg = str(e) if "cannot schedule new futures after shutdown" in error_msg or "connection is closed" in error_msg: raise HTTPException( status_code=503, detail="Server is shutting down" ) else: raise HTTPException( status_code=500, detail=f"Error during authentication: {str(e)}" ) async def get_optional_user_id(request: Request) -> Optional[str]: auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith('Bearer '): return None token = auth_header.split(' ')[1] try: payload = _decode_jwt_safely(token) user_id = payload.get('sub') if user_id: sentry.sentry.set_user({ "id": user_id }) structlog.contextvars.bind_contextvars( user_id=user_id ) return user_id except PyJWTError: return None get_optional_current_user_id_from_jwt = get_optional_user_id async def verify_and_get_agent_authorization(client, agent_id: str, user_id: str) -> dict: try: agent_result = await client.table('agents').select('*').eq('agent_id', agent_id).eq('account_id', user_id).execute() if not agent_result.data: raise HTTPException(status_code=404, detail="Agent not found or access denied") return agent_result.data[0] except HTTPException: raise except Exception as e: structlog.error(f"Error verifying agent access for agent {agent_id}, user {user_id}: {str(e)}") raise HTTPException(status_code=500, detail="Failed to verify agent access") async def verify_and_authorize_thread_access(client, thread_id: str, user_id: str): try: # Check if user is an admin first (admins have access to all threads) admin_result = await client.table('user_roles').select('role').eq('user_id', user_id).execute() if admin_result.data and len(admin_result.data) > 0: role = admin_result.data[0].get('role') if role in ('admin', 'super_admin'): structlog.get_logger().debug(f"Admin access granted for thread {thread_id}", user_role=role) # Just verify thread exists thread_check = await client.table('threads').select('thread_id').eq('thread_id', thread_id).execute() if not thread_check.data: raise HTTPException(status_code=404, detail="Thread not found") return True thread_result = await client.table('threads').select('*').eq('thread_id', thread_id).execute() if not thread_result.data or len(thread_result.data) == 0: raise HTTPException(status_code=404, detail="Thread not found") thread_data = thread_result.data[0] if thread_data['account_id'] == user_id: return True project_id = thread_data.get('project_id') if project_id: project_result = await client.table('projects').select('is_public').eq('project_id', project_id).execute() if project_result.data and len(project_result.data) > 0: if project_result.data[0].get('is_public'): return True account_id = thread_data.get('account_id') if account_id: account_user_result = await client.schema('basejump').from_('account_user').select('account_role').eq('user_id', user_id).eq('account_id', account_id).execute() if account_user_result.data and len(account_user_result.data) > 0: return True raise HTTPException(status_code=403, detail="Not authorized to access this thread") except HTTPException: raise except Exception as e: error_msg = str(e) if "cannot schedule new futures after shutdown" in error_msg or "connection is closed" in error_msg: raise HTTPException( status_code=503, detail="Server is shutting down" ) else: raise HTTPException( status_code=500, detail=f"Error verifying thread access: {str(e)}" ) async def get_authorized_user_for_thread( thread_id: str, request: Request ) -> str: """ FastAPI dependency that verifies JWT and authorizes thread access. Args: thread_id: The thread ID to authorize access for request: The FastAPI request object Returns: str: The authenticated and authorized user ID Raises: HTTPException: If authentication fails or user lacks thread access """ from core.services.supabase import DBConnection # First, authenticate the user user_id = await verify_and_get_user_id_from_jwt(request) # Then, authorize thread access db = DBConnection() client = await db.client await verify_and_authorize_thread_access(client, thread_id, user_id) return user_id async def get_authorized_user_for_agent( agent_id: str, request: Request ) -> tuple[str, dict]: """ FastAPI dependency that verifies JWT and authorizes agent access. Args: agent_id: The agent ID to authorize access for request: The FastAPI request object Returns: tuple[str, dict]: The authenticated user ID and agent data Raises: HTTPException: If authentication fails or user lacks agent access """ from core.services.supabase import DBConnection # First, authenticate the user user_id = await verify_and_get_user_id_from_jwt(request) # Then, authorize agent access and get agent data db = DBConnection() client = await db.client agent_data = await verify_and_get_agent_authorization(client, agent_id, user_id) return user_id, agent_data class AuthorizedThreadAccess: """ FastAPI dependency that combines authentication and thread authorization. Usage: @router.get("/threads/{thread_id}/messages") async def get_messages( thread_id: str, auth: AuthorizedThreadAccess = Depends() ): user_id = auth.user_id # Authenticated and authorized user """ def __init__(self, user_id: str): self.user_id = user_id class AuthorizedAgentAccess: """ FastAPI dependency that combines authentication and agent authorization. Usage: @router.get("/agents/{agent_id}/config") async def get_agent_config( agent_id: str, auth: AuthorizedAgentAccess = Depends() ): user_id = auth.user_id # Authenticated and authorized user agent_data = auth.agent_data # Agent data from authorization check """ def __init__(self, user_id: str, agent_data: dict): self.user_id = user_id self.agent_data = agent_data async def require_thread_access( thread_id: str, request: Request ) -> AuthorizedThreadAccess: """ FastAPI dependency that verifies JWT and authorizes thread access. Args: thread_id: The thread ID from the path parameter request: The FastAPI request object Returns: AuthorizedThreadAccess: Object containing authenticated user_id Raises: HTTPException: If authentication fails or user lacks thread access """ user_id = await get_authorized_user_for_thread(thread_id, request) return AuthorizedThreadAccess(user_id) async def require_agent_access( agent_id: str, request: Request ) -> AuthorizedAgentAccess: """ FastAPI dependency that verifies JWT and authorizes agent access. Args: agent_id: The agent ID from the path parameter request: The FastAPI request object Returns: AuthorizedAgentAccess: Object containing user_id and agent_data Raises: HTTPException: If authentication fails or user lacks agent access """ user_id, agent_data = await get_authorized_user_for_agent(agent_id, request) return AuthorizedAgentAccess(user_id, agent_data) # ============================================================================ # Sandbox Authorization Functions # ============================================================================ async def verify_sandbox_access(client, sandbox_id: str, user_id: str): """ Verify that a user has access to a specific sandbox by checking project ownership and permissions. This function implements project-based access control: - Public projects: Allow access to anyone - Private projects: Only allow access to account members Args: client: The Supabase client sandbox_id: The sandbox ID to check access for user_id: The user ID to check permissions for (required for all operations) Returns: dict: Project data containing sandbox information Raises: HTTPException: If the user doesn't have access to the project/sandbox or sandbox doesn't exist """ # Find the project that owns this sandbox project_result = await client.table('projects').select('*').filter('sandbox->>id', 'eq', sandbox_id).execute() if not project_result.data or len(project_result.data) == 0: raise HTTPException(status_code=404, detail="Sandbox not found - no project owns this sandbox") project_data = project_result.data[0] project_id = project_data.get('project_id') is_public = project_data.get('is_public', False) structlog.get_logger().debug( "Checking sandbox access via project ownership", sandbox_id=sandbox_id, project_id=project_id, is_public=is_public, user_id=user_id ) # Public projects: Allow access regardless of authentication if is_public: structlog.get_logger().debug("Allowing access to public project sandbox", project_id=project_id) return project_data # Private projects: Verify the user is a member of the project's account account_id = project_data.get('account_id') if not account_id: raise HTTPException(status_code=500, detail="Project has no associated account") # Check if user is a member of the project's account account_user_result = await client.schema('basejump').from_('account_user').select('account_role').eq('user_id', user_id).eq('account_id', account_id).execute() if account_user_result.data and len(account_user_result.data) > 0: user_role = account_user_result.data[0].get('account_role') structlog.get_logger().debug( "User has access to private project sandbox", project_id=project_id, user_role=user_role ) return project_data structlog.get_logger().warning( "User denied access to private project sandbox", sandbox_id=sandbox_id, project_id=project_id, user_id=user_id, account_id=account_id ) raise HTTPException(status_code=403, detail="Not authorized to access this project's sandbox") async def verify_sandbox_access_optional(client, sandbox_id: str, user_id: Optional[str] = None): """ Verify that a user has access to a specific sandbox by checking project ownership and permissions. This function supports optional authentication for read-only operations. This function implements project-based access control: - Public projects: Allow access to anyone (no authentication required) - Private projects: Require authentication and account membership Args: client: The Supabase client sandbox_id: The sandbox ID to check access for user_id: The user ID to check permissions for. Can be None for public project access. Returns: dict: Project data containing sandbox information Raises: HTTPException: If the user doesn't have access to the project/sandbox or sandbox doesn't exist """ # Find the project that owns this sandbox project_result = await client.table('projects').select('*').filter('sandbox->>id', 'eq', sandbox_id).execute() if not project_result.data or len(project_result.data) == 0: raise HTTPException(status_code=404, detail="Sandbox not found - no project owns this sandbox") project_data = project_result.data[0] project_id = project_data.get('project_id') is_public = project_data.get('is_public', False) structlog.get_logger().debug( "Checking optional sandbox access via project ownership", sandbox_id=sandbox_id, project_id=project_id, is_public=is_public, user_id=user_id ) # Public projects: Allow access regardless of authentication if is_public: structlog.get_logger().debug("Allowing access to public project sandbox", project_id=project_id) return project_data # Private projects: Require authentication if not user_id: structlog.get_logger().warning( "Authentication required for private project sandbox access", project_id=project_id, sandbox_id=sandbox_id ) raise HTTPException(status_code=401, detail="Authentication required for this private project") # Verify the user is a member of the project's account account_id = project_data.get('account_id') if not account_id: raise HTTPException(status_code=500, detail="Project has no associated account") # Check if user is a member of the project's account account_user_result = await client.schema('basejump').from_('account_user').select('account_role').eq('user_id', user_id).eq('account_id', account_id).execute() if account_user_result.data and len(account_user_result.data) > 0: user_role = account_user_result.data[0].get('account_role') structlog.get_logger().debug( "User has access to private project sandbox", project_id=project_id, user_role=user_role ) return project_data structlog.get_logger().warning( "User denied access to private project sandbox", sandbox_id=sandbox_id, project_id=project_id, user_id=user_id, account_id=account_id ) raise HTTPException(status_code=403, detail="Not authorized to access this project's sandbox")