suna/backend/core/utils/auth_utils.py

613 lines
22 KiB
Python

import sentry
from fastapi import HTTPException, Request, Header
from typing import Optional
import jwt
from jwt.exceptions import PyJWTError
from core.utils.logger import structlog
from core.utils.config import config
import os
import base64
import hashlib
import hmac
from core.services.supabase import DBConnection
from core.services import redis
async def verify_admin_api_key(x_admin_api_key: Optional[str] = Header(None)):
if not config.KORTIX_ADMIN_API_KEY:
raise HTTPException(
status_code=500,
detail="Admin API key not configured on server"
)
if not x_admin_api_key:
raise HTTPException(
status_code=401,
detail="Admin API key required. Include X-Admin-Api-Key header."
)
if x_admin_api_key != config.KORTIX_ADMIN_API_KEY:
raise HTTPException(
status_code=403,
detail="Invalid admin API key"
)
return True
def _decode_jwt_safely(token: str) -> dict:
return jwt.decode(
token,
options={
"verify_signature": False,
"verify_exp": True,
"verify_aud": False,
"verify_iss": False
}
)
async def get_account_id_from_thread(thread_id: str, db: "DBConnection") -> str:
"""
Get account_id from thread_id.
Raises:
ValueError: If thread not found or has no account_id
"""
try:
client = await db.client
thread_result = await client.table('threads').select('account_id').eq('thread_id', thread_id).limit(1).execute()
if not thread_result.data:
raise ValueError(f"Could not find thread with ID: {thread_id}")
account_id = thread_result.data[0]['account_id']
if not account_id:
raise ValueError("Thread has no associated account_id")
return account_id
except Exception as e:
structlog.get_logger().error(f"Error getting account_id from thread: {e}")
raise
async def _get_user_id_from_account_cached(account_id: str) -> Optional[str]:
cache_key = f"account_user:{account_id}"
try:
redis_client = await redis.get_client()
cached_user_id = await redis_client.get(cache_key)
if cached_user_id:
return cached_user_id.decode('utf-8') if isinstance(cached_user_id, bytes) else cached_user_id
except Exception as e:
structlog.get_logger().warning(f"Redis cache lookup failed for account {account_id}: {e}")
try:
db = DBConnection()
await db.initialize()
client = await db.client
user_result = await client.schema('basejump').table('accounts').select(
'primary_owner_user_id'
).eq('id', account_id).limit(1).execute()
if user_result.data:
user_id = user_result.data[0]['primary_owner_user_id']
try:
await redis_client.setex(cache_key, 300, user_id)
except Exception as e:
structlog.get_logger().warning(f"Failed to cache user lookup: {e}")
return user_id
return None
except Exception as e:
structlog.get_logger().error(f"Database lookup failed for account {account_id}: {e}")
return None
async def verify_and_get_user_id_from_jwt(request: Request) -> str:
x_api_key = request.headers.get('x-api-key')
if x_api_key:
try:
if ':' not in x_api_key:
raise HTTPException(
status_code=401,
detail="Invalid API key format. Expected format: pk_xxx:sk_xxx",
headers={"WWW-Authenticate": "Bearer"}
)
public_key, secret_key = x_api_key.split(':', 1)
from core.services.api_keys import APIKeyService
db = DBConnection()
await db.initialize()
api_key_service = APIKeyService(db)
validation_result = await api_key_service.validate_api_key(public_key, secret_key)
if validation_result.is_valid:
user_id = await _get_user_id_from_account_cached(str(validation_result.account_id))
if user_id:
sentry.sentry.set_user({ "id": user_id })
structlog.contextvars.bind_contextvars(
user_id=user_id,
auth_method="api_key",
api_key_id=str(validation_result.key_id),
public_key=public_key
)
return user_id
else:
raise HTTPException(
status_code=401,
detail="API key account not found",
headers={"WWW-Authenticate": "Bearer"}
)
else:
raise HTTPException(
status_code=401,
detail=f"Invalid API key: {validation_result.error_message}",
headers={"WWW-Authenticate": "Bearer"}
)
except HTTPException:
raise
except Exception as e:
structlog.get_logger().error(f"Error validating API key: {e}")
raise HTTPException(
status_code=401,
detail="API key validation failed",
headers={"WWW-Authenticate": "Bearer"}
)
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
raise HTTPException(
status_code=401,
detail="No valid authentication credentials found",
headers={"WWW-Authenticate": "Bearer"}
)
token = auth_header.split(' ')[1]
try:
payload = _decode_jwt_safely(token)
user_id = payload.get('sub')
if not user_id:
raise HTTPException(
status_code=401,
detail="Invalid token payload",
headers={"WWW-Authenticate": "Bearer"}
)
sentry.sentry.set_user({ "id": user_id })
structlog.contextvars.bind_contextvars(
user_id=user_id,
auth_method="jwt"
)
return user_id
except PyJWTError:
raise HTTPException(
status_code=401,
detail="Invalid token",
headers={"WWW-Authenticate": "Bearer"}
)
async def get_user_id_from_stream_auth(
request: Request,
token: Optional[str] = None
) -> str:
try:
try:
return await verify_and_get_user_id_from_jwt(request)
except HTTPException:
pass
if token:
try:
payload = _decode_jwt_safely(token)
user_id = payload.get('sub')
if user_id:
sentry.sentry.set_user({ "id": user_id })
structlog.contextvars.bind_contextvars(
user_id=user_id,
auth_method="jwt_query"
)
return user_id
except Exception:
pass
raise HTTPException(
status_code=401,
detail="No valid authentication credentials found",
headers={"WWW-Authenticate": "Bearer"}
)
except HTTPException:
raise
except Exception as e:
error_msg = str(e)
if "cannot schedule new futures after shutdown" in error_msg or "connection is closed" in error_msg:
raise HTTPException(
status_code=503,
detail="Server is shutting down"
)
else:
raise HTTPException(
status_code=500,
detail=f"Error during authentication: {str(e)}"
)
async def get_optional_user_id(request: Request) -> Optional[str]:
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return None
token = auth_header.split(' ')[1]
try:
payload = _decode_jwt_safely(token)
user_id = payload.get('sub')
if user_id:
sentry.sentry.set_user({ "id": user_id })
structlog.contextvars.bind_contextvars(
user_id=user_id
)
return user_id
except PyJWTError:
return None
get_optional_current_user_id_from_jwt = get_optional_user_id
async def verify_and_get_agent_authorization(client, agent_id: str, user_id: str) -> dict:
try:
agent_result = await client.table('agents').select('*').eq('agent_id', agent_id).eq('account_id', user_id).execute()
if not agent_result.data:
raise HTTPException(status_code=404, detail="Agent not found or access denied")
return agent_result.data[0]
except HTTPException:
raise
except Exception as e:
structlog.error(f"Error verifying agent access for agent {agent_id}, user {user_id}: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to verify agent access")
async def verify_and_authorize_thread_access(client, thread_id: str, user_id: str):
try:
# Check if user is an admin first (admins have access to all threads)
admin_result = await client.table('user_roles').select('role').eq('user_id', user_id).execute()
if admin_result.data and len(admin_result.data) > 0:
role = admin_result.data[0].get('role')
if role in ('admin', 'super_admin'):
structlog.get_logger().debug(f"Admin access granted for thread {thread_id}", user_role=role)
# Just verify thread exists
thread_check = await client.table('threads').select('thread_id').eq('thread_id', thread_id).execute()
if not thread_check.data:
raise HTTPException(status_code=404, detail="Thread not found")
return True
thread_result = await client.table('threads').select('*').eq('thread_id', thread_id).execute()
if not thread_result.data or len(thread_result.data) == 0:
raise HTTPException(status_code=404, detail="Thread not found")
thread_data = thread_result.data[0]
if thread_data['account_id'] == user_id:
return True
project_id = thread_data.get('project_id')
if project_id:
project_result = await client.table('projects').select('is_public').eq('project_id', project_id).execute()
if project_result.data and len(project_result.data) > 0:
if project_result.data[0].get('is_public'):
return True
account_id = thread_data.get('account_id')
if account_id:
account_user_result = await client.schema('basejump').from_('account_user').select('account_role').eq('user_id', user_id).eq('account_id', account_id).execute()
if account_user_result.data and len(account_user_result.data) > 0:
return True
raise HTTPException(status_code=403, detail="Not authorized to access this thread")
except HTTPException:
raise
except Exception as e:
error_msg = str(e)
if "cannot schedule new futures after shutdown" in error_msg or "connection is closed" in error_msg:
raise HTTPException(
status_code=503,
detail="Server is shutting down"
)
else:
raise HTTPException(
status_code=500,
detail=f"Error verifying thread access: {str(e)}"
)
async def get_authorized_user_for_thread(
thread_id: str,
request: Request
) -> str:
"""
FastAPI dependency that verifies JWT and authorizes thread access.
Args:
thread_id: The thread ID to authorize access for
request: The FastAPI request object
Returns:
str: The authenticated and authorized user ID
Raises:
HTTPException: If authentication fails or user lacks thread access
"""
from core.services.supabase import DBConnection
# First, authenticate the user
user_id = await verify_and_get_user_id_from_jwt(request)
# Then, authorize thread access
db = DBConnection()
client = await db.client
await verify_and_authorize_thread_access(client, thread_id, user_id)
return user_id
async def get_authorized_user_for_agent(
agent_id: str,
request: Request
) -> tuple[str, dict]:
"""
FastAPI dependency that verifies JWT and authorizes agent access.
Args:
agent_id: The agent ID to authorize access for
request: The FastAPI request object
Returns:
tuple[str, dict]: The authenticated user ID and agent data
Raises:
HTTPException: If authentication fails or user lacks agent access
"""
from core.services.supabase import DBConnection
# First, authenticate the user
user_id = await verify_and_get_user_id_from_jwt(request)
# Then, authorize agent access and get agent data
db = DBConnection()
client = await db.client
agent_data = await verify_and_get_agent_authorization(client, agent_id, user_id)
return user_id, agent_data
class AuthorizedThreadAccess:
"""
FastAPI dependency that combines authentication and thread authorization.
Usage:
@router.get("/threads/{thread_id}/messages")
async def get_messages(
thread_id: str,
auth: AuthorizedThreadAccess = Depends()
):
user_id = auth.user_id # Authenticated and authorized user
"""
def __init__(self, user_id: str):
self.user_id = user_id
class AuthorizedAgentAccess:
"""
FastAPI dependency that combines authentication and agent authorization.
Usage:
@router.get("/agents/{agent_id}/config")
async def get_agent_config(
agent_id: str,
auth: AuthorizedAgentAccess = Depends()
):
user_id = auth.user_id # Authenticated and authorized user
agent_data = auth.agent_data # Agent data from authorization check
"""
def __init__(self, user_id: str, agent_data: dict):
self.user_id = user_id
self.agent_data = agent_data
async def require_thread_access(
thread_id: str,
request: Request
) -> AuthorizedThreadAccess:
"""
FastAPI dependency that verifies JWT and authorizes thread access.
Args:
thread_id: The thread ID from the path parameter
request: The FastAPI request object
Returns:
AuthorizedThreadAccess: Object containing authenticated user_id
Raises:
HTTPException: If authentication fails or user lacks thread access
"""
user_id = await get_authorized_user_for_thread(thread_id, request)
return AuthorizedThreadAccess(user_id)
async def require_agent_access(
agent_id: str,
request: Request
) -> AuthorizedAgentAccess:
"""
FastAPI dependency that verifies JWT and authorizes agent access.
Args:
agent_id: The agent ID from the path parameter
request: The FastAPI request object
Returns:
AuthorizedAgentAccess: Object containing user_id and agent_data
Raises:
HTTPException: If authentication fails or user lacks agent access
"""
user_id, agent_data = await get_authorized_user_for_agent(agent_id, request)
return AuthorizedAgentAccess(user_id, agent_data)
# ============================================================================
# Sandbox Authorization Functions
# ============================================================================
async def verify_sandbox_access(client, sandbox_id: str, user_id: str):
"""
Verify that a user has access to a specific sandbox by checking project ownership and permissions.
This function implements project-based access control:
- Public projects: Allow access to anyone
- Private projects: Only allow access to account members
Args:
client: The Supabase client
sandbox_id: The sandbox ID to check access for
user_id: The user ID to check permissions for (required for all operations)
Returns:
dict: Project data containing sandbox information
Raises:
HTTPException: If the user doesn't have access to the project/sandbox or sandbox doesn't exist
"""
# Find the project that owns this sandbox
project_result = await client.table('projects').select('*').filter('sandbox->>id', 'eq', sandbox_id).execute()
if not project_result.data or len(project_result.data) == 0:
raise HTTPException(status_code=404, detail="Sandbox not found - no project owns this sandbox")
project_data = project_result.data[0]
project_id = project_data.get('project_id')
is_public = project_data.get('is_public', False)
structlog.get_logger().debug(
"Checking sandbox access via project ownership",
sandbox_id=sandbox_id,
project_id=project_id,
is_public=is_public,
user_id=user_id
)
# Public projects: Allow access regardless of authentication
if is_public:
structlog.get_logger().debug("Allowing access to public project sandbox", project_id=project_id)
return project_data
# Private projects: Verify the user is a member of the project's account
account_id = project_data.get('account_id')
if not account_id:
raise HTTPException(status_code=500, detail="Project has no associated account")
# Check if user is a member of the project's account
account_user_result = await client.schema('basejump').from_('account_user').select('account_role').eq('user_id', user_id).eq('account_id', account_id).execute()
if account_user_result.data and len(account_user_result.data) > 0:
user_role = account_user_result.data[0].get('account_role')
structlog.get_logger().debug(
"User has access to private project sandbox",
project_id=project_id,
user_role=user_role
)
return project_data
structlog.get_logger().warning(
"User denied access to private project sandbox",
sandbox_id=sandbox_id,
project_id=project_id,
user_id=user_id,
account_id=account_id
)
raise HTTPException(status_code=403, detail="Not authorized to access this project's sandbox")
async def verify_sandbox_access_optional(client, sandbox_id: str, user_id: Optional[str] = None):
"""
Verify that a user has access to a specific sandbox by checking project ownership and permissions.
This function supports optional authentication for read-only operations.
This function implements project-based access control:
- Public projects: Allow access to anyone (no authentication required)
- Private projects: Require authentication and account membership
Args:
client: The Supabase client
sandbox_id: The sandbox ID to check access for
user_id: The user ID to check permissions for. Can be None for public project access.
Returns:
dict: Project data containing sandbox information
Raises:
HTTPException: If the user doesn't have access to the project/sandbox or sandbox doesn't exist
"""
# Find the project that owns this sandbox
project_result = await client.table('projects').select('*').filter('sandbox->>id', 'eq', sandbox_id).execute()
if not project_result.data or len(project_result.data) == 0:
raise HTTPException(status_code=404, detail="Sandbox not found - no project owns this sandbox")
project_data = project_result.data[0]
project_id = project_data.get('project_id')
is_public = project_data.get('is_public', False)
structlog.get_logger().debug(
"Checking optional sandbox access via project ownership",
sandbox_id=sandbox_id,
project_id=project_id,
is_public=is_public,
user_id=user_id
)
# Public projects: Allow access regardless of authentication
if is_public:
structlog.get_logger().debug("Allowing access to public project sandbox", project_id=project_id)
return project_data
# Private projects: Require authentication
if not user_id:
structlog.get_logger().warning(
"Authentication required for private project sandbox access",
project_id=project_id,
sandbox_id=sandbox_id
)
raise HTTPException(status_code=401, detail="Authentication required for this private project")
# Verify the user is a member of the project's account
account_id = project_data.get('account_id')
if not account_id:
raise HTTPException(status_code=500, detail="Project has no associated account")
# Check if user is a member of the project's account
account_user_result = await client.schema('basejump').from_('account_user').select('account_role').eq('user_id', user_id).eq('account_id', account_id).execute()
if account_user_result.data and len(account_user_result.data) > 0:
user_role = account_user_result.data[0].get('account_role')
structlog.get_logger().debug(
"User has access to private project sandbox",
project_id=project_id,
user_role=user_role
)
return project_data
structlog.get_logger().warning(
"User denied access to private project sandbox",
sandbox_id=sandbox_id,
project_id=project_id,
user_id=user_id,
account_id=account_id
)
raise HTTPException(status_code=403, detail="Not authorized to access this project's sandbox")