mirror of https://github.com/kortix-ai/suna.git
skip user id if project is public
This commit is contained in:
parent
7ce3d6f9cb
commit
fd7b0d1484
|
@ -6,7 +6,7 @@ from fastapi.responses import Response, JSONResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from utils.logger import logger
|
from utils.logger import logger
|
||||||
from utils.auth_utils import get_current_user_id, get_user_id_from_stream_auth
|
from utils.auth_utils import get_current_user_id, get_user_id_from_stream_auth, get_optional_user_id
|
||||||
from sandbox.sandbox import get_or_start_sandbox
|
from sandbox.sandbox import get_or_start_sandbox
|
||||||
from services.supabase import DBConnection
|
from services.supabase import DBConnection
|
||||||
|
|
||||||
|
@ -31,14 +31,14 @@ class FileInfo(BaseModel):
|
||||||
mod_time: str
|
mod_time: str
|
||||||
permissions: Optional[str] = None
|
permissions: Optional[str] = None
|
||||||
|
|
||||||
async def verify_sandbox_access(client, sandbox_id: str, user_id: str):
|
async def verify_sandbox_access(client, sandbox_id: str, user_id: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
Verify that a user has access to a specific sandbox based on account membership.
|
Verify that a user has access to a specific sandbox based on account membership.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
client: The Supabase client
|
client: The Supabase client
|
||||||
sandbox_id: The sandbox ID to check access for
|
sandbox_id: The sandbox ID to check access for
|
||||||
user_id: The user ID to check permissions for
|
user_id: The user ID to check permissions for. Can be None for public resource access.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Project data containing sandbox information
|
dict: Project data containing sandbox information
|
||||||
|
@ -57,6 +57,10 @@ async def verify_sandbox_access(client, sandbox_id: str, user_id: str):
|
||||||
if project_data.get('is_public'):
|
if project_data.get('is_public'):
|
||||||
return project_data
|
return project_data
|
||||||
|
|
||||||
|
# For private projects, we must have a user_id
|
||||||
|
if not user_id:
|
||||||
|
raise HTTPException(status_code=401, detail="Authentication required for this resource")
|
||||||
|
|
||||||
account_id = project_data.get('account_id')
|
account_id = project_data.get('account_id')
|
||||||
|
|
||||||
# Verify account membership
|
# Verify account membership
|
||||||
|
@ -72,7 +76,8 @@ async def create_file(
|
||||||
sandbox_id: str,
|
sandbox_id: str,
|
||||||
path: str = Form(...),
|
path: str = Form(...),
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
user_id: str = Depends(get_current_user_id)
|
request: Request = None,
|
||||||
|
user_id: Optional[str] = Depends(get_optional_user_id)
|
||||||
):
|
):
|
||||||
"""Create a file in the sandbox using direct file upload"""
|
"""Create a file in the sandbox using direct file upload"""
|
||||||
client = await db.client
|
client = await db.client
|
||||||
|
@ -101,7 +106,8 @@ async def create_file(
|
||||||
async def create_file_json(
|
async def create_file_json(
|
||||||
sandbox_id: str,
|
sandbox_id: str,
|
||||||
file_request: dict,
|
file_request: dict,
|
||||||
user_id: str = Depends(get_current_user_id)
|
request: Request = None,
|
||||||
|
user_id: Optional[str] = Depends(get_optional_user_id)
|
||||||
):
|
):
|
||||||
"""Create a file in the sandbox using JSON (legacy support)"""
|
"""Create a file in the sandbox using JSON (legacy support)"""
|
||||||
client = await db.client
|
client = await db.client
|
||||||
|
@ -137,7 +143,8 @@ async def create_file_json(
|
||||||
async def list_files(
|
async def list_files(
|
||||||
sandbox_id: str,
|
sandbox_id: str,
|
||||||
path: str,
|
path: str,
|
||||||
user_id: str = Depends(get_current_user_id)
|
request: Request = None,
|
||||||
|
user_id: Optional[str] = Depends(get_optional_user_id)
|
||||||
):
|
):
|
||||||
"""List files and directories at the specified path"""
|
"""List files and directories at the specified path"""
|
||||||
client = await db.client
|
client = await db.client
|
||||||
|
@ -176,7 +183,8 @@ async def list_files(
|
||||||
async def read_file(
|
async def read_file(
|
||||||
sandbox_id: str,
|
sandbox_id: str,
|
||||||
path: str,
|
path: str,
|
||||||
user_id: str = Depends(get_current_user_id)
|
request: Request = None,
|
||||||
|
user_id: Optional[str] = Depends(get_optional_user_id)
|
||||||
):
|
):
|
||||||
"""Read a file from the sandbox"""
|
"""Read a file from the sandbox"""
|
||||||
client = await db.client
|
client = await db.client
|
||||||
|
@ -205,7 +213,8 @@ async def read_file(
|
||||||
@router.post("/project/{project_id}/sandbox/ensure-active")
|
@router.post("/project/{project_id}/sandbox/ensure-active")
|
||||||
async def ensure_project_sandbox_active(
|
async def ensure_project_sandbox_active(
|
||||||
project_id: str,
|
project_id: str,
|
||||||
user_id: str = Depends(get_current_user_id)
|
request: Request = None,
|
||||||
|
user_id: Optional[str] = Depends(get_optional_user_id)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Ensure that a project's sandbox is active and running.
|
Ensure that a project's sandbox is active and running.
|
||||||
|
@ -220,13 +229,20 @@ async def ensure_project_sandbox_active(
|
||||||
raise HTTPException(status_code=404, detail="Project not found")
|
raise HTTPException(status_code=404, detail="Project not found")
|
||||||
|
|
||||||
project_data = project_result.data[0]
|
project_data = project_result.data[0]
|
||||||
account_id = project_data.get('account_id')
|
|
||||||
|
|
||||||
# Verify account membership
|
# For public projects, no authentication is needed
|
||||||
if account_id:
|
if not project_data.get('is_public'):
|
||||||
account_user_result = await client.schema('basejump').from_('account_user').select('account_role').eq('user_id', user_id).eq('account_id', account_id).execute()
|
# For private projects, we must have a user_id
|
||||||
if not (account_user_result.data and len(account_user_result.data) > 0):
|
if not user_id:
|
||||||
raise HTTPException(status_code=403, detail="Not authorized to access this project")
|
raise HTTPException(status_code=401, detail="Authentication required for this resource")
|
||||||
|
|
||||||
|
account_id = project_data.get('account_id')
|
||||||
|
|
||||||
|
# Verify account membership
|
||||||
|
if account_id:
|
||||||
|
account_user_result = await client.schema('basejump').from_('account_user').select('account_role').eq('user_id', user_id).eq('account_id', account_id).execute()
|
||||||
|
if not (account_user_result.data and len(account_user_result.data) > 0):
|
||||||
|
raise HTTPException(status_code=403, detail="Not authorized to access this project")
|
||||||
|
|
||||||
# Check if project has a sandbox
|
# Check if project has a sandbox
|
||||||
sandbox_id = project_data.get('sandbox', {}).get('id')
|
sandbox_id = project_data.get('sandbox', {}).get('id')
|
||||||
|
|
|
@ -140,3 +140,35 @@ async def verify_thread_access(client, thread_id: str, user_id: str):
|
||||||
if account_user_result.data and len(account_user_result.data) > 0:
|
if account_user_result.data and len(account_user_result.data) > 0:
|
||||||
return True
|
return True
|
||||||
raise HTTPException(status_code=403, detail="Not authorized to access this thread")
|
raise HTTPException(status_code=403, detail="Not authorized to access this thread")
|
||||||
|
|
||||||
|
async def get_optional_user_id(request: Request) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Extract the user ID from the JWT in the Authorization header if present,
|
||||||
|
but don't require authentication. Returns None if no valid token is found.
|
||||||
|
|
||||||
|
This function is used for endpoints that support both authenticated and
|
||||||
|
unauthenticated access (like public projects).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The FastAPI request object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: The user ID extracted from the JWT, or None if no valid token
|
||||||
|
"""
|
||||||
|
auth_header = request.headers.get('Authorization')
|
||||||
|
|
||||||
|
if not auth_header or not auth_header.startswith('Bearer '):
|
||||||
|
return None
|
||||||
|
|
||||||
|
token = auth_header.split(' ')[1]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# For Supabase JWT, we just need to decode and extract the user ID
|
||||||
|
payload = jwt.decode(token, options={"verify_signature": False})
|
||||||
|
|
||||||
|
# Supabase stores the user ID in the 'sub' claim
|
||||||
|
user_id = payload.get('sub')
|
||||||
|
|
||||||
|
return user_id
|
||||||
|
except PyJWTError:
|
||||||
|
return None
|
||||||
|
|
Loading…
Reference in New Issue