suna/backend/utils/auth_utils.py

129 lines
4.3 KiB
Python

from fastapi import HTTPException, Request, Depends
from typing import Optional
import jwt
from jwt.exceptions import PyJWTError
# This function extracts the user ID from Supabase JWT
async def get_current_user_id(request: Request) -> str:
"""
Extract and verify the user ID from the JWT in the Authorization header.
This function is used as a dependency in FastAPI routes to ensure the user
is authenticated and to provide the user ID for authorization checks.
Args:
request: The FastAPI request object
Returns:
str: The user ID extracted from the JWT
Raises:
HTTPException: If no valid token is found or if the token is invalid
"""
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
raise HTTPException(
status_code=401,
detail="No valid authentication credentials found",
headers={"WWW-Authenticate": "Bearer"}
)
token = auth_header.split(' ')[1]
try:
# For Supabase JWT, we just need to decode and extract the user ID
# The actual validation is handled by Supabase's RLS
payload = jwt.decode(token, options={"verify_signature": False})
# Supabase stores the user ID in the 'sub' claim
user_id = payload.get('sub')
if not user_id:
raise HTTPException(
status_code=401,
detail="Invalid token payload",
headers={"WWW-Authenticate": "Bearer"}
)
return user_id
except PyJWTError:
raise HTTPException(
status_code=401,
detail="Invalid token",
headers={"WWW-Authenticate": "Bearer"}
)
async def get_user_id_from_stream_auth(
request: Request,
token: Optional[str] = None
) -> str:
"""
Extract and verify the user ID from either the Authorization header or query parameter token.
This function is specifically designed for streaming endpoints that need to support both
header-based and query parameter-based authentication (for EventSource compatibility).
Args:
request: The FastAPI request object
token: Optional token from query parameters
Returns:
str: The user ID extracted from the JWT
Raises:
HTTPException: If no valid token is found or if the token is invalid
"""
# Try to get user_id from token in query param (for EventSource which can't set headers)
if token:
try:
# For Supabase JWT, we just need to decode and extract the user ID
payload = jwt.decode(token, options={"verify_signature": False})
user_id = payload.get('sub')
if user_id:
return user_id
except Exception:
pass
# If no valid token in query param, try to get it from the Authorization header
auth_header = request.headers.get('Authorization')
if auth_header and auth_header.startswith('Bearer '):
try:
# Extract token from header
header_token = auth_header.split(' ')[1]
payload = jwt.decode(header_token, options={"verify_signature": False})
user_id = payload.get('sub')
if user_id:
return user_id
except Exception:
pass
# If we still don't have a user_id, return authentication error
raise HTTPException(
status_code=401,
detail="No valid authentication credentials found",
headers={"WWW-Authenticate": "Bearer"}
)
async def verify_thread_access(client, thread_id: str, user_id: str):
"""
Verify that a user has access to a specific thread.
Args:
client: The Supabase client
thread_id: The thread ID to check access for
user_id: The user ID to check permissions for
Returns:
bool: True if the user has access
Raises:
HTTPException: If the user doesn't have access to the thread
"""
thread = await client.table('threads').select('thread_id').eq('thread_id', thread_id).eq('user_id', user_id).execute()
if not thread.data or len(thread.data) == 0:
raise HTTPException(status_code=403, detail="Not authorized to access this thread")
return True