mirror of https://github.com/kortix-ai/suna.git
skip user id if project is public
This commit is contained in:
parent
7ce3d6f9cb
commit
fd7b0d1484
|
@ -6,7 +6,7 @@ from fastapi.responses import Response, JSONResponse
|
|||
from pydantic import BaseModel
|
||||
|
||||
from utils.logger import logger
|
||||
from utils.auth_utils import get_current_user_id, get_user_id_from_stream_auth
|
||||
from utils.auth_utils import get_current_user_id, get_user_id_from_stream_auth, get_optional_user_id
|
||||
from sandbox.sandbox import get_or_start_sandbox
|
||||
from services.supabase import DBConnection
|
||||
|
||||
|
@ -31,14 +31,14 @@ class FileInfo(BaseModel):
|
|||
mod_time: str
|
||||
permissions: Optional[str] = None
|
||||
|
||||
async def verify_sandbox_access(client, sandbox_id: str, user_id: str):
|
||||
async def verify_sandbox_access(client, sandbox_id: str, user_id: Optional[str] = None):
|
||||
"""
|
||||
Verify that a user has access to a specific sandbox based on account membership.
|
||||
|
||||
Args:
|
||||
client: The Supabase client
|
||||
sandbox_id: The sandbox ID to check access for
|
||||
user_id: The user ID to check permissions for
|
||||
user_id: The user ID to check permissions for. Can be None for public resource access.
|
||||
|
||||
Returns:
|
||||
dict: Project data containing sandbox information
|
||||
|
@ -57,6 +57,10 @@ async def verify_sandbox_access(client, sandbox_id: str, user_id: str):
|
|||
if project_data.get('is_public'):
|
||||
return project_data
|
||||
|
||||
# For private projects, we must have a user_id
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Authentication required for this resource")
|
||||
|
||||
account_id = project_data.get('account_id')
|
||||
|
||||
# Verify account membership
|
||||
|
@ -72,7 +76,8 @@ async def create_file(
|
|||
sandbox_id: str,
|
||||
path: str = Form(...),
|
||||
file: UploadFile = File(...),
|
||||
user_id: str = Depends(get_current_user_id)
|
||||
request: Request = None,
|
||||
user_id: Optional[str] = Depends(get_optional_user_id)
|
||||
):
|
||||
"""Create a file in the sandbox using direct file upload"""
|
||||
client = await db.client
|
||||
|
@ -101,7 +106,8 @@ async def create_file(
|
|||
async def create_file_json(
|
||||
sandbox_id: str,
|
||||
file_request: dict,
|
||||
user_id: str = Depends(get_current_user_id)
|
||||
request: Request = None,
|
||||
user_id: Optional[str] = Depends(get_optional_user_id)
|
||||
):
|
||||
"""Create a file in the sandbox using JSON (legacy support)"""
|
||||
client = await db.client
|
||||
|
@ -137,7 +143,8 @@ async def create_file_json(
|
|||
async def list_files(
|
||||
sandbox_id: str,
|
||||
path: str,
|
||||
user_id: str = Depends(get_current_user_id)
|
||||
request: Request = None,
|
||||
user_id: Optional[str] = Depends(get_optional_user_id)
|
||||
):
|
||||
"""List files and directories at the specified path"""
|
||||
client = await db.client
|
||||
|
@ -176,7 +183,8 @@ async def list_files(
|
|||
async def read_file(
|
||||
sandbox_id: str,
|
||||
path: str,
|
||||
user_id: str = Depends(get_current_user_id)
|
||||
request: Request = None,
|
||||
user_id: Optional[str] = Depends(get_optional_user_id)
|
||||
):
|
||||
"""Read a file from the sandbox"""
|
||||
client = await db.client
|
||||
|
@ -205,7 +213,8 @@ async def read_file(
|
|||
@router.post("/project/{project_id}/sandbox/ensure-active")
|
||||
async def ensure_project_sandbox_active(
|
||||
project_id: str,
|
||||
user_id: str = Depends(get_current_user_id)
|
||||
request: Request = None,
|
||||
user_id: Optional[str] = Depends(get_optional_user_id)
|
||||
):
|
||||
"""
|
||||
Ensure that a project's sandbox is active and running.
|
||||
|
@ -220,13 +229,20 @@ async def ensure_project_sandbox_active(
|
|||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
project_data = project_result.data[0]
|
||||
account_id = project_data.get('account_id')
|
||||
|
||||
# Verify account membership
|
||||
if account_id:
|
||||
account_user_result = await client.schema('basejump').from_('account_user').select('account_role').eq('user_id', user_id).eq('account_id', account_id).execute()
|
||||
if not (account_user_result.data and len(account_user_result.data) > 0):
|
||||
raise HTTPException(status_code=403, detail="Not authorized to access this project")
|
||||
# For public projects, no authentication is needed
|
||||
if not project_data.get('is_public'):
|
||||
# For private projects, we must have a user_id
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Authentication required for this resource")
|
||||
|
||||
account_id = project_data.get('account_id')
|
||||
|
||||
# Verify account membership
|
||||
if account_id:
|
||||
account_user_result = await client.schema('basejump').from_('account_user').select('account_role').eq('user_id', user_id).eq('account_id', account_id).execute()
|
||||
if not (account_user_result.data and len(account_user_result.data) > 0):
|
||||
raise HTTPException(status_code=403, detail="Not authorized to access this project")
|
||||
|
||||
# Check if project has a sandbox
|
||||
sandbox_id = project_data.get('sandbox', {}).get('id')
|
||||
|
|
|
@ -140,3 +140,35 @@ async def verify_thread_access(client, thread_id: str, user_id: str):
|
|||
if account_user_result.data and len(account_user_result.data) > 0:
|
||||
return True
|
||||
raise HTTPException(status_code=403, detail="Not authorized to access this thread")
|
||||
|
||||
async def get_optional_user_id(request: Request) -> Optional[str]:
|
||||
"""
|
||||
Extract the user ID from the JWT in the Authorization header if present,
|
||||
but don't require authentication. Returns None if no valid token is found.
|
||||
|
||||
This function is used for endpoints that support both authenticated and
|
||||
unauthenticated access (like public projects).
|
||||
|
||||
Args:
|
||||
request: The FastAPI request object
|
||||
|
||||
Returns:
|
||||
Optional[str]: The user ID extracted from the JWT, or None if no valid token
|
||||
"""
|
||||
auth_header = request.headers.get('Authorization')
|
||||
|
||||
if not auth_header or not auth_header.startswith('Bearer '):
|
||||
return None
|
||||
|
||||
token = auth_header.split(' ')[1]
|
||||
|
||||
try:
|
||||
# For Supabase JWT, we just need to decode and extract the user ID
|
||||
payload = jwt.decode(token, options={"verify_signature": False})
|
||||
|
||||
# Supabase stores the user ID in the 'sub' claim
|
||||
user_id = payload.get('sub')
|
||||
|
||||
return user_id
|
||||
except PyJWTError:
|
||||
return None
|
||||
|
|
Loading…
Reference in New Issue