From fd7b0d148449365d8ef59f89273cca14312d51ac Mon Sep 17 00:00:00 2001 From: Adam Cohen Hillel Date: Mon, 21 Apr 2025 15:08:28 +0100 Subject: [PATCH] skip user id if project is public --- backend/sandbox/api.py | 44 +++++++++++++++++++++++++------------ backend/utils/auth_utils.py | 32 +++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 14 deletions(-) diff --git a/backend/sandbox/api.py b/backend/sandbox/api.py index 390d44c9..c8f5479b 100644 --- a/backend/sandbox/api.py +++ b/backend/sandbox/api.py @@ -6,7 +6,7 @@ from fastapi.responses import Response, JSONResponse from pydantic import BaseModel from utils.logger import logger -from utils.auth_utils import get_current_user_id, get_user_id_from_stream_auth +from utils.auth_utils import get_current_user_id, get_user_id_from_stream_auth, get_optional_user_id from sandbox.sandbox import get_or_start_sandbox from services.supabase import DBConnection @@ -31,14 +31,14 @@ class FileInfo(BaseModel): mod_time: str permissions: Optional[str] = None -async def verify_sandbox_access(client, sandbox_id: str, user_id: str): +async def verify_sandbox_access(client, sandbox_id: str, user_id: Optional[str] = None): """ Verify that a user has access to a specific sandbox based on account membership. Args: client: The Supabase client sandbox_id: The sandbox ID to check access for - user_id: The user ID to check permissions for + user_id: The user ID to check permissions for. Can be None for public resource access. Returns: dict: Project data containing sandbox information @@ -57,6 +57,10 @@ async def verify_sandbox_access(client, sandbox_id: str, user_id: str): if project_data.get('is_public'): return project_data + # For private projects, we must have a user_id + if not user_id: + raise HTTPException(status_code=401, detail="Authentication required for this resource") + account_id = project_data.get('account_id') # Verify account membership @@ -72,7 +76,8 @@ async def create_file( sandbox_id: str, path: str = Form(...), file: UploadFile = File(...), - user_id: str = Depends(get_current_user_id) + request: Request = None, + user_id: Optional[str] = Depends(get_optional_user_id) ): """Create a file in the sandbox using direct file upload""" client = await db.client @@ -101,7 +106,8 @@ async def create_file( async def create_file_json( sandbox_id: str, file_request: dict, - user_id: str = Depends(get_current_user_id) + request: Request = None, + user_id: Optional[str] = Depends(get_optional_user_id) ): """Create a file in the sandbox using JSON (legacy support)""" client = await db.client @@ -137,7 +143,8 @@ async def create_file_json( async def list_files( sandbox_id: str, path: str, - user_id: str = Depends(get_current_user_id) + request: Request = None, + user_id: Optional[str] = Depends(get_optional_user_id) ): """List files and directories at the specified path""" client = await db.client @@ -176,7 +183,8 @@ async def list_files( async def read_file( sandbox_id: str, path: str, - user_id: str = Depends(get_current_user_id) + request: Request = None, + user_id: Optional[str] = Depends(get_optional_user_id) ): """Read a file from the sandbox""" client = await db.client @@ -205,7 +213,8 @@ async def read_file( @router.post("/project/{project_id}/sandbox/ensure-active") async def ensure_project_sandbox_active( project_id: str, - user_id: str = Depends(get_current_user_id) + request: Request = None, + user_id: Optional[str] = Depends(get_optional_user_id) ): """ Ensure that a project's sandbox is active and running. @@ -220,13 +229,20 @@ async def ensure_project_sandbox_active( raise HTTPException(status_code=404, detail="Project not found") project_data = project_result.data[0] - account_id = project_data.get('account_id') - # Verify account membership - if account_id: - account_user_result = await client.schema('basejump').from_('account_user').select('account_role').eq('user_id', user_id).eq('account_id', account_id).execute() - if not (account_user_result.data and len(account_user_result.data) > 0): - raise HTTPException(status_code=403, detail="Not authorized to access this project") + # For public projects, no authentication is needed + if not project_data.get('is_public'): + # For private projects, we must have a user_id + if not user_id: + raise HTTPException(status_code=401, detail="Authentication required for this resource") + + account_id = project_data.get('account_id') + + # Verify account membership + if account_id: + account_user_result = await client.schema('basejump').from_('account_user').select('account_role').eq('user_id', user_id).eq('account_id', account_id).execute() + if not (account_user_result.data and len(account_user_result.data) > 0): + raise HTTPException(status_code=403, detail="Not authorized to access this project") # Check if project has a sandbox sandbox_id = project_data.get('sandbox', {}).get('id') diff --git a/backend/utils/auth_utils.py b/backend/utils/auth_utils.py index efd35c32..44162cc6 100644 --- a/backend/utils/auth_utils.py +++ b/backend/utils/auth_utils.py @@ -140,3 +140,35 @@ async def verify_thread_access(client, thread_id: str, user_id: str): if account_user_result.data and len(account_user_result.data) > 0: return True raise HTTPException(status_code=403, detail="Not authorized to access this thread") + +async def get_optional_user_id(request: Request) -> Optional[str]: + """ + Extract the user ID from the JWT in the Authorization header if present, + but don't require authentication. Returns None if no valid token is found. + + This function is used for endpoints that support both authenticated and + unauthenticated access (like public projects). + + Args: + request: The FastAPI request object + + Returns: + Optional[str]: The user ID extracted from the JWT, or None if no valid token + """ + auth_header = request.headers.get('Authorization') + + if not auth_header or not auth_header.startswith('Bearer '): + return None + + token = auth_header.split(' ')[1] + + try: + # For Supabase JWT, we just need to decode and extract the user ID + payload = jwt.decode(token, options={"verify_signature": False}) + + # Supabase stores the user ID in the 'sub' claim + user_id = payload.get('sub') + + return user_id + except PyJWTError: + return None