diff --git a/backend/agent/handlers/agent_crud.py b/backend/agent/handlers/agent_crud.py index e2bfbfcd..0c43d6ab 100644 --- a/backend/agent/handlers/agent_crud.py +++ b/backend/agent/handlers/agent_crud.py @@ -1,7 +1,7 @@ from typing import Optional from fastapi import APIRouter, HTTPException, Depends, File, UploadFile, Query -from utils.auth_utils import get_current_user_id_from_jwt +from utils.auth_utils import verify_and_get_user_id_from_jwt from utils.logger import logger from utils.config import config, EnvMode from utils.pagination import PaginationParams @@ -21,7 +21,7 @@ router = APIRouter() async def update_agent( agent_id: str, agent_data: AgentUpdateRequest, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): logger.debug(f"Updating agent {agent_id} for user: {user_id}") @@ -436,7 +436,7 @@ async def update_agent( raise HTTPException(status_code=500, detail=f"Failed to update agent: {str(e)}") @router.delete("/agents/{agent_id}") -async def delete_agent(agent_id: str, user_id: str = Depends(get_current_user_id_from_jwt)): +async def delete_agent(agent_id: str, user_id: str = Depends(verify_and_get_user_id_from_jwt)): logger.debug(f"Deleting agent: {agent_id}") client = await utils.db.client @@ -502,7 +502,7 @@ async def delete_agent(agent_id: str, user_id: str = Depends(get_current_user_id @router.get("/agents", response_model=AgentsResponse) async def get_agents( - user_id: str = Depends(get_current_user_id_from_jwt), + user_id: str = Depends(verify_and_get_user_id_from_jwt), page: Optional[int] = Query(1, ge=1, description="Page number (1-based)"), limit: Optional[int] = Query(20, ge=1, le=100, description="Number of items per page"), search: Optional[str] = Query(None, description="Search in name and description"), @@ -570,7 +570,7 @@ async def get_agents( raise HTTPException(status_code=500, detail=f"Failed to fetch agents: {str(e)}") @router.get("/agents/{agent_id}", response_model=AgentResponse) -async def get_agent(agent_id: str, user_id: str = Depends(get_current_user_id_from_jwt)): +async def get_agent(agent_id: str, user_id: str = Depends(verify_and_get_user_id_from_jwt)): logger.debug(f"Fetching agent {agent_id} for user: {user_id}") @@ -697,7 +697,7 @@ async def get_agent(agent_id: str, user_id: str = Depends(get_current_user_id_fr @router.post("/agents", response_model=AgentResponse) async def create_agent( agent_data: AgentCreateRequest, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): logger.debug(f"Creating new agent for user: {user_id}") client = await utils.db.client diff --git a/backend/agent/handlers/agent_json.py b/backend/agent/handlers/agent_json.py index bb5d0476..61596dbe 100644 --- a/backend/agent/handlers/agent_json.py +++ b/backend/agent/handlers/agent_json.py @@ -3,7 +3,7 @@ from typing import Optional, Dict, List, Any from fastapi import APIRouter, HTTPException, Depends from uuid import uuid4 -from utils.auth_utils import get_current_user_id_from_jwt +from utils.auth_utils import verify_and_get_user_id_from_jwt from utils.logger import logger from templates.template_service import MCPRequirementValue, ConfigType, ProfileId, QualifiedName @@ -343,7 +343,7 @@ class JsonImportService: raise # Re-raise the exception to ensure import fails if version creation fails @router.get("/agents/{agent_id}/export") -async def export_agent(agent_id: str, user_id: str = Depends(get_current_user_id_from_jwt)): +async def export_agent(agent_id: str, user_id: str = Depends(verify_and_get_user_id_from_jwt)): """Export an agent configuration as JSON""" logger.debug(f"Exporting agent {agent_id} for user: {user_id}") @@ -419,7 +419,7 @@ async def export_agent(agent_id: str, user_id: str = Depends(get_current_user_id @router.post("/agents/json/analyze", response_model=JsonAnalysisResponse) async def analyze_json_for_import( request: JsonAnalysisRequest, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Analyze imported JSON to determine required credentials and configurations""" logger.debug(f"Analyzing JSON for import - user: {user_id}") @@ -439,7 +439,7 @@ async def analyze_json_for_import( @router.post("/agents/json/import", response_model=JsonImportResponse) async def import_agent_from_json( request: JsonImportRequestModel, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): logger.debug(f"Importing agent from JSON - user: {user_id}") diff --git a/backend/agent/handlers/agent_runs.py b/backend/agent/handlers/agent_runs.py index 299131e9..115e0719 100644 --- a/backend/agent/handlers/agent_runs.py +++ b/backend/agent/handlers/agent_runs.py @@ -8,7 +8,7 @@ from typing import Optional, List from fastapi import APIRouter, HTTPException, Depends, Request, Body, File, UploadFile, Form from fastapi.responses import StreamingResponse -from utils.auth_utils import get_current_user_id_from_jwt, get_user_id_from_stream_auth, verify_thread_access +from utils.auth_utils import verify_and_get_user_id_from_jwt, get_user_id_from_stream_auth, verify_and_authorize_thread_access from utils.logger import logger, structlog from services.billing import check_billing_status, can_use_model from utils.config import config @@ -32,7 +32,7 @@ router = APIRouter() async def start_agent( thread_id: str, body: AgentStartRequest = Body(...), - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Start an agent for a specific thread in the background""" structlog.contextvars.bind_contextvars( @@ -67,7 +67,7 @@ async def start_agent( thread_metadata = thread_data.get('metadata', {}) if account_id != user_id: - await verify_thread_access(client, thread_id, user_id) + await verify_and_authorize_thread_access(client, thread_id, user_id) structlog.contextvars.bind_contextvars( project_id=project_id, @@ -241,7 +241,7 @@ async def start_agent( return {"agent_run_id": agent_run_id, "status": "running"} @router.post("/agent-run/{agent_run_id}/stop") -async def stop_agent(agent_run_id: str, user_id: str = Depends(get_current_user_id_from_jwt)): +async def stop_agent(agent_run_id: str, user_id: str = Depends(verify_and_get_user_id_from_jwt)): """Stop a running agent.""" structlog.contextvars.bind_contextvars( agent_run_id=agent_run_id, @@ -253,20 +253,20 @@ async def stop_agent(agent_run_id: str, user_id: str = Depends(get_current_user_ return {"status": "stopped"} @router.get("/thread/{thread_id}/agent-runs") -async def get_agent_runs(thread_id: str, user_id: str = Depends(get_current_user_id_from_jwt)): +async def get_agent_runs(thread_id: str, user_id: str = Depends(verify_and_get_user_id_from_jwt)): """Get all agent runs for a thread.""" structlog.contextvars.bind_contextvars( thread_id=thread_id, ) logger.debug(f"Fetching agent runs for thread: {thread_id}") client = await utils.db.client - await verify_thread_access(client, thread_id, user_id) + await verify_and_authorize_thread_access(client, thread_id, user_id) agent_runs = await client.table('agent_runs').select('id, thread_id, status, started_at, completed_at, error, created_at, updated_at').eq("thread_id", thread_id).order('created_at', desc=True).execute() logger.debug(f"Found {len(agent_runs.data)} agent runs for thread: {thread_id}") return {"agent_runs": agent_runs.data} @router.get("/agent-run/{agent_run_id}") -async def get_agent_run(agent_run_id: str, user_id: str = Depends(get_current_user_id_from_jwt)): +async def get_agent_run(agent_run_id: str, user_id: str = Depends(verify_and_get_user_id_from_jwt)): """Get agent run status and responses.""" structlog.contextvars.bind_contextvars( agent_run_id=agent_run_id, @@ -285,7 +285,7 @@ async def get_agent_run(agent_run_id: str, user_id: str = Depends(get_current_us } @router.get("/thread/{thread_id}/agent", response_model=ThreadAgentResponse) -async def get_thread_agent(thread_id: str, user_id: str = Depends(get_current_user_id_from_jwt)): +async def get_thread_agent(thread_id: str, user_id: str = Depends(verify_and_get_user_id_from_jwt)): """Get the agent details for a specific thread. Since threads are fully agent-agnostic, this returns the most recently used agent from agent_runs only.""" structlog.contextvars.bind_contextvars( @@ -296,7 +296,7 @@ async def get_thread_agent(thread_id: str, user_id: str = Depends(get_current_us try: # Verify thread access and get thread data - await verify_thread_access(client, thread_id, user_id) + await verify_and_authorize_thread_access(client, thread_id, user_id) thread_result = await client.table('threads').select('account_id').eq('thread_id', thread_id).execute() if not thread_result.data: @@ -621,7 +621,7 @@ async def initiate_agent_with_files( enable_context_manager: Optional[bool] = Form(False), agent_id: Optional[str] = Form(None), # Add agent_id parameter files: List[UploadFile] = File(default=[]), - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """ Initiate a new agent session with optional file attachments. diff --git a/backend/agent/handlers/agent_tools.py b/backend/agent/handlers/agent_tools.py index bcd753a4..0f2ff183 100644 --- a/backend/agent/handlers/agent_tools.py +++ b/backend/agent/handlers/agent_tools.py @@ -4,7 +4,7 @@ from datetime import datetime from typing import Optional, Dict, Any from fastapi import APIRouter, HTTPException, Depends, Request, Body -from utils.auth_utils import get_current_user_id_from_jwt +from utils.auth_utils import verify_and_get_user_id_from_jwt from utils.logger import logger from sandbox.sandbox import get_or_start_sandbox from services.supabase import DBConnection @@ -19,7 +19,7 @@ router = APIRouter() async def get_custom_mcp_tools_for_agent( agent_id: str, request: Request, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): logger.debug(f"Getting custom MCP tools for agent {agent_id}, user {user_id}") try: @@ -98,7 +98,7 @@ async def get_custom_mcp_tools_for_agent( async def update_custom_mcp_tools_for_agent( agent_id: str, request: dict, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): logger.debug(f"Updating custom MCP tools for agent {agent_id}, user {user_id}") @@ -209,7 +209,7 @@ async def update_custom_mcp_tools_for_agent( async def update_agent_custom_mcps( agent_id: str, request: dict, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): logger.debug(f"Updating agent {agent_id} custom MCPs for user {user_id}") @@ -316,7 +316,7 @@ async def update_agent_custom_mcps( @router.get("/agents/{agent_id}/tools") async def get_agent_tools( agent_id: str, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): logger.debug(f"Fetching enabled tools for agent: {agent_id} by user: {user_id}") diff --git a/backend/agent/handlers/threads.py b/backend/agent/handlers/threads.py index 6d96eaa2..9059f3ab 100644 --- a/backend/agent/handlers/threads.py +++ b/backend/agent/handlers/threads.py @@ -5,7 +5,7 @@ from datetime import datetime, timezone from typing import Optional from fastapi import APIRouter, HTTPException, Depends, Form, Query -from utils.auth_utils import get_current_user_id_from_jwt, verify_thread_access +from utils.auth_utils import verify_and_get_user_id_from_jwt, verify_and_authorize_thread_access, require_thread_access, AuthorizedThreadAccess from utils.logger import logger from sandbox.sandbox import create_sandbox, delete_sandbox @@ -16,7 +16,7 @@ router = APIRouter() @router.get("/threads") async def get_user_threads( - user_id: str = Depends(get_current_user_id_from_jwt), + user_id: str = Depends(verify_and_get_user_id_from_jwt), page: Optional[int] = Query(1, ge=1, description="Page number (1-based)"), limit: Optional[int] = Query(1000, ge=1, le=1000, description="Number of items per page (max 1000)") ): @@ -121,14 +121,15 @@ async def get_user_threads( @router.get("/threads/{thread_id}") async def get_thread( thread_id: str, - user_id: str = Depends(get_current_user_id_from_jwt) + auth: AuthorizedThreadAccess = Depends(require_thread_access) ): """Get a specific thread by ID with complete related data.""" logger.debug(f"Fetching thread: {thread_id}") client = await utils.db.client + user_id = auth.user_id # Already authenticated and authorized! try: - await verify_thread_access(client, thread_id, user_id) + # No need for manual authorization - it's already done in the dependency! # Get the thread data thread_result = await client.table('threads').select('*').eq('thread_id', thread_id).execute() @@ -200,7 +201,7 @@ async def get_thread( @router.post("/threads", response_model=CreateThreadResponse) async def create_thread( name: Optional[str] = Form(None), - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """ Create a new thread without starting an agent run. @@ -303,13 +304,13 @@ async def create_thread( @router.get("/threads/{thread_id}/messages") async def get_thread_messages( thread_id: str, - user_id: str = Depends(get_current_user_id_from_jwt), + user_id: str = Depends(verify_and_get_user_id_from_jwt), order: str = Query("desc", description="Order by created_at: 'asc' or 'desc'") ): """Get all messages for a thread, fetching in batches of 1000 from the DB to avoid large queries.""" logger.debug(f"Fetching all messages for thread: {thread_id}, order={order}") client = await utils.db.client - await verify_thread_access(client, thread_id, user_id) + await verify_and_authorize_thread_access(client, thread_id, user_id) try: batch_size = 1000 offset = 0 @@ -333,7 +334,7 @@ async def get_thread_messages( @router.get("/agent-runs/{agent_run_id}") async def get_agent_run( agent_run_id: str, - user_id: str = Depends(get_current_user_id_from_jwt), + user_id: str = Depends(verify_and_get_user_id_from_jwt), ): """ [DEPRECATED] Get an agent run by ID. @@ -355,12 +356,12 @@ async def get_agent_run( async def add_message_to_thread( thread_id: str, message: str, - user_id: str = Depends(get_current_user_id_from_jwt), + user_id: str = Depends(verify_and_get_user_id_from_jwt), ): """Add a message to a thread""" logger.debug(f"Adding message to thread: {thread_id}") client = await utils.db.client - await verify_thread_access(client, thread_id, user_id) + await verify_and_authorize_thread_access(client, thread_id, user_id) try: message_result = await client.table('messages').insert({ 'thread_id': thread_id, @@ -380,14 +381,14 @@ async def add_message_to_thread( async def create_message( thread_id: str, message_data: MessageCreateRequest, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Create a new message in a thread.""" logger.debug(f"Creating message in thread: {thread_id}") client = await utils.db.client try: - await verify_thread_access(client, thread_id, user_id) + await verify_and_authorize_thread_access(client, thread_id, user_id) message_payload = { "role": "user" if message_data.type == "user" else "assistant", @@ -421,12 +422,12 @@ async def create_message( async def delete_message( thread_id: str, message_id: str, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Delete a message from a thread.""" logger.debug(f"Deleting message from thread: {thread_id}") client = await utils.db.client - await verify_thread_access(client, thread_id, user_id) + await verify_and_authorize_thread_access(client, thread_id, user_id) try: # Don't allow users to delete the "status" messages await client.table('messages').delete().eq('message_id', message_id).eq('is_llm_message', True).eq('thread_id', thread_id).execute() diff --git a/backend/agent/handlers/versioning/api.py b/backend/agent/handlers/versioning/api.py index dd604f1b..a1052f75 100644 --- a/backend/agent/handlers/versioning/api.py +++ b/backend/agent/handlers/versioning/api.py @@ -3,7 +3,7 @@ from typing import List, Optional, Dict, Any from pydantic import BaseModel from utils.logger import logger -from utils.auth_utils import get_current_user_id_from_jwt +from utils.auth_utils import verify_and_get_user_id_from_jwt from .version_service import ( get_version_service, @@ -61,7 +61,7 @@ class VersionComparisonResponse(BaseModel): @router.get("/agents/{agent_id}/versions", response_model=List[VersionResponse]) async def get_versions( agent_id: str, - user_id: str = Depends(get_current_user_id_from_jwt), + user_id: str = Depends(verify_and_get_user_id_from_jwt), version_service: VersionService = Depends(get_version_service) ): try: @@ -80,7 +80,7 @@ async def get_versions( async def create_version( agent_id: str, request: CreateVersionRequest, - user_id: str = Depends(get_current_user_id_from_jwt), + user_id: str = Depends(verify_and_get_user_id_from_jwt), version_service: VersionService = Depends(get_version_service) ): try: @@ -112,7 +112,7 @@ async def create_version( async def get_version( agent_id: str, version_id: str, - user_id: str = Depends(get_current_user_id_from_jwt), + user_id: str = Depends(verify_and_get_user_id_from_jwt), version_service: VersionService = Depends(get_version_service) ): try: @@ -131,7 +131,7 @@ async def get_version( async def activate_version( agent_id: str, version_id: str, - user_id: str = Depends(get_current_user_id_from_jwt), + user_id: str = Depends(verify_and_get_user_id_from_jwt), version_service: VersionService = Depends(get_version_service) ): try: @@ -153,7 +153,7 @@ async def compare_versions( agent_id: str, version1_id: str, version2_id: str, - user_id: str = Depends(get_current_user_id_from_jwt), + user_id: str = Depends(verify_and_get_user_id_from_jwt), version_service: VersionService = Depends(get_version_service) ): try: @@ -179,7 +179,7 @@ async def compare_versions( async def rollback_to_version( agent_id: str, version_id: str, - user_id: str = Depends(get_current_user_id_from_jwt), + user_id: str = Depends(verify_and_get_user_id_from_jwt), version_service: VersionService = Depends(get_version_service) ): try: @@ -202,7 +202,7 @@ async def update_version_details( agent_id: str, version_id: str, request: UpdateVersionDetailsRequest, - user_id: str = Depends(get_current_user_id_from_jwt), + user_id: str = Depends(verify_and_get_user_id_from_jwt), version_service: VersionService = Depends(get_version_service) ): try: diff --git a/backend/agent/handlers/versioning/version_service.py b/backend/agent/handlers/versioning/version_service.py index 2797ce4c..b09e0775 100644 --- a/backend/agent/handlers/versioning/version_service.py +++ b/backend/agent/handlers/versioning/version_service.py @@ -80,7 +80,7 @@ class VersionService: async def _get_client(self): return await self.db.client - async def _verify_agent_access(self, agent_id: str, user_id: str) -> tuple[bool, bool]: + async def _verify_and_authorize_agent_access(self, agent_id: str, user_id: str) -> tuple[bool, bool]: if user_id == "system": return True, True @@ -205,7 +205,7 @@ class VersionService: logger.debug(f"Creating version for agent {agent_id}") client = await self.db.client - is_owner, _ = await self._verify_agent_access(agent_id, user_id) + is_owner, _ = await self._verify_and_authorize_agent_access(agent_id, user_id) if not is_owner: raise UnauthorizedError("Unauthorized to create version for this agent") @@ -298,7 +298,7 @@ class VersionService: return version async def get_version(self, agent_id: str, version_id: str, user_id: str) -> AgentVersion: - is_owner, is_public = await self._verify_agent_access(agent_id, user_id) + is_owner, is_public = await self._verify_and_authorize_agent_access(agent_id, user_id) if not is_owner and not is_public: raise UnauthorizedError("You don't have permission to view this version") @@ -314,7 +314,7 @@ class VersionService: return self._version_from_db_row(result.data[0]) async def get_active_version(self, agent_id: str, user_id: str = "system") -> Optional[AgentVersion]: - is_owner, is_public = await self._verify_agent_access(agent_id, user_id) + is_owner, is_public = await self._verify_and_authorize_agent_access(agent_id, user_id) if not is_owner and not is_public: raise UnauthorizedError("You don't have permission to view this agent") @@ -341,7 +341,7 @@ class VersionService: return version async def get_all_versions(self, agent_id: str, user_id: str) -> List[AgentVersion]: - is_owner, is_public = await self._verify_agent_access(agent_id, user_id) + is_owner, is_public = await self._verify_and_authorize_agent_access(agent_id, user_id) if not is_owner and not is_public: raise UnauthorizedError("You don't have permission to view versions") @@ -355,7 +355,7 @@ class VersionService: return versions async def activate_version(self, agent_id: str, version_id: str, user_id: str) -> None: - is_owner, _ = await self._verify_agent_access(agent_id, user_id) + is_owner, _ = await self._verify_and_authorize_agent_access(agent_id, user_id) if not is_owner: raise UnauthorizedError("You don't have permission to activate versions") @@ -458,7 +458,7 @@ class VersionService: ) -> AgentVersion: version_to_restore = await self.get_version(agent_id, version_id, user_id) - is_owner, _ = await self._verify_agent_access(agent_id, user_id) + is_owner, _ = await self._verify_and_authorize_agent_access(agent_id, user_id) if not is_owner: raise UnauthorizedError("You don't have permission to rollback versions") @@ -483,7 +483,7 @@ class VersionService: version_name: Optional[str] = None, change_description: Optional[str] = None ) -> AgentVersion: - is_owner, _ = await self._verify_agent_access(agent_id, user_id) + is_owner, _ = await self._verify_and_authorize_agent_access(agent_id, user_id) if not is_owner: raise UnauthorizedError("You don't have permission to update this version") diff --git a/backend/agent/run.py b/backend/agent/run.py index c4560a45..2146a669 100644 --- a/backend/agent/run.py +++ b/backend/agent/run.py @@ -21,7 +21,7 @@ from agent.tools.expand_msg_tool import ExpandMessageTool from agent.prompts.prompt import get_system_prompt from utils.logger import logger -from utils.auth_utils import get_account_id_from_thread + from services.billing import check_billing_status from agent.tools.sb_vision_tool import SandboxVisionTool from agent.tools.sb_image_edit_tool import SandboxImageEditTool @@ -448,9 +448,16 @@ class AgentRunner: ) self.client = await self.thread_manager.db.client - self.account_id = await get_account_id_from_thread(self.client, self.config.thread_id) + + response = await self.client.table('threads').select('account_id').eq('thread_id', self.config.thread_id).execute() + + if not response.data or len(response.data) == 0: + raise ValueError(f"Thread {self.config.thread_id} not found") + + self.account_id = response.data[0].get('account_id') + if not self.account_id: - raise ValueError("Could not determine account ID for thread") + raise ValueError(f"Thread {self.config.thread_id} has no associated account") project = await self.client.table('projects').select('*').eq('project_id', self.config.project_id).execute() if not project.data or len(project.data) == 0: diff --git a/backend/agent/utils.py b/backend/agent/utils.py index 4e6bb92f..d5368c3f 100644 --- a/backend/agent/utils.py +++ b/backend/agent/utils.py @@ -7,7 +7,7 @@ from fastapi import HTTPException from utils.cache import Cache from utils.logger import logger from utils.config import config -from utils.auth_utils import verify_thread_access +from utils.auth_utils import verify_and_authorize_thread_access from services import redis from services.supabase import DBConnection from services.llm import make_llm_api_call @@ -121,7 +121,7 @@ async def get_agent_run_with_access_check(client, agent_run_id: str, user_id: st account_id = agent_run_data['threads']['account_id'] if account_id == user_id: return agent_run_data - await verify_thread_access(client, thread_id, user_id) + await verify_and_authorize_thread_access(client, thread_id, user_id) return agent_run_data async def generate_and_update_project_name(project_id: str, prompt: str): diff --git a/backend/composio_integration/api.py b/backend/composio_integration/api.py index f931b1c5..6d80253f 100644 --- a/backend/composio_integration/api.py +++ b/backend/composio_integration/api.py @@ -3,7 +3,7 @@ from fastapi.responses import JSONResponse from typing import Dict, Any, Optional from pydantic import BaseModel from uuid import uuid4 -from utils.auth_utils import get_current_user_id_from_jwt, get_optional_current_user_id_from_jwt +from utils.auth_utils import verify_and_get_user_id_from_jwt, get_optional_current_user_id_from_jwt from utils.logger import logger from services.supabase import DBConnection from datetime import datetime @@ -215,7 +215,7 @@ class ProfileResponse(BaseModel): @router.get("/categories") async def list_categories( - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ) -> Dict[str, Any]: try: logger.debug("Fetching Composio categories") @@ -240,7 +240,7 @@ async def list_toolkits( cursor: Optional[str] = Query(None), search: Optional[str] = Query(None), category: Optional[str] = Query(None), - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ) -> Dict[str, Any]: try: logger.debug(f"Fetching Composio toolkits with limit: {limit}, cursor: {cursor}, search: {search}, category: {category}") @@ -270,7 +270,7 @@ async def list_toolkits( @router.get("/toolkits/{toolkit_slug}/details") async def get_toolkit_details( toolkit_slug: str, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ) -> Dict[str, Any]: try: logger.debug(f"Fetching detailed toolkit info for: {toolkit_slug}") @@ -296,7 +296,7 @@ async def get_toolkit_details( @router.post("/integrate", response_model=IntegrationStatusResponse) async def integrate_toolkit( request: IntegrateToolkitRequest, - current_user_id: str = Depends(get_current_user_id_from_jwt) + current_user_id: str = Depends(verify_and_get_user_id_from_jwt) ) -> IntegrationStatusResponse: try: integration_user_id = str(uuid4()) @@ -333,7 +333,7 @@ async def integrate_toolkit( @router.post("/profiles", response_model=ProfileResponse) async def create_profile( request: CreateProfileRequest, - current_user_id: str = Depends(get_current_user_id_from_jwt) + current_user_id: str = Depends(verify_and_get_user_id_from_jwt) ) -> ProfileResponse: try: integration_user_id = str(uuid4()) @@ -378,7 +378,7 @@ async def create_profile( @router.get("/profiles") async def get_profiles( toolkit_slug: Optional[str] = Query(None), - current_user_id: str = Depends(get_current_user_id_from_jwt) + current_user_id: str = Depends(verify_and_get_user_id_from_jwt) ) -> Dict[str, Any]: try: profile_service = ComposioProfileService(db) @@ -403,7 +403,7 @@ async def get_profiles( @router.get("/profiles/{profile_id}/mcp-config") async def get_profile_mcp_config( profile_id: str, - current_user_id: str = Depends(get_current_user_id_from_jwt) + current_user_id: str = Depends(verify_and_get_user_id_from_jwt) ) -> Dict[str, Any]: try: profile_service = ComposioProfileService(db) @@ -423,7 +423,7 @@ async def get_profile_mcp_config( @router.get("/profiles/{profile_id}") async def get_profile_info( profile_id: str, - current_user_id: str = Depends(get_current_user_id_from_jwt) + current_user_id: str = Depends(verify_and_get_user_id_from_jwt) ) -> Dict[str, Any]: try: profile_service = ComposioProfileService(db) @@ -452,7 +452,7 @@ async def get_profile_info( @router.get("/integration/{connected_account_id}/status") async def get_integration_status( connected_account_id: str, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ) -> Dict[str, Any]: try: service = get_integration_service() @@ -466,7 +466,7 @@ async def get_integration_status( @router.post("/profiles/{profile_id}/discover-tools") async def discover_composio_tools( profile_id: str, - current_user_id: str = Depends(get_current_user_id_from_jwt) + current_user_id: str = Depends(verify_and_get_user_id_from_jwt) ) -> Dict[str, Any]: try: profile_service = ComposioProfileService(db) @@ -509,7 +509,7 @@ async def discover_composio_tools( @router.post("/discover-tools/{profile_id}") async def discover_tools_post( profile_id: str, - current_user_id: str = Depends(get_current_user_id_from_jwt) + current_user_id: str = Depends(verify_and_get_user_id_from_jwt) ) -> Dict[str, Any]: return await discover_composio_tools(profile_id, current_user_id) @@ -544,7 +544,7 @@ async def get_toolkit_icon( @router.post("/tools/list") async def list_toolkit_tools( request: ToolsListRequest, - current_user_id: str = Depends(get_current_user_id_from_jwt) + current_user_id: str = Depends(verify_and_get_user_id_from_jwt) ): try: logger.debug(f"User {current_user_id} requesting tools for toolkit: {request.toolkit_slug}") @@ -572,7 +572,7 @@ async def list_toolkit_tools( @router.get("/triggers/apps") async def list_apps_with_triggers( - user_id: str = Depends(get_current_user_id_from_jwt), + user_id: str = Depends(verify_and_get_user_id_from_jwt), ) -> Dict[str, Any]: try: trigger_service = ComposioTriggerService() @@ -588,7 +588,7 @@ async def list_apps_with_triggers( @router.get("/triggers/apps/{toolkit_slug}") async def list_triggers_for_app( toolkit_slug: str, - user_id: str = Depends(get_current_user_id_from_jwt), + user_id: str = Depends(verify_and_get_user_id_from_jwt), ) -> Dict[str, Any]: try: trigger_service = ComposioTriggerService() @@ -617,7 +617,7 @@ class CreateComposioTriggerRequest(BaseModel): @router.post("/triggers/create") -async def create_composio_trigger(req: CreateComposioTriggerRequest, current_user_id: str = Depends(get_current_user_id_from_jwt)) -> Dict[str, Any]: +async def create_composio_trigger(req: CreateComposioTriggerRequest, current_user_id: str = Depends(verify_and_get_user_id_from_jwt)) -> Dict[str, Any]: try: client_db = await db.client agent_check = await client_db.table('agents').select('agent_id').eq('agent_id', req.agent_id).eq('account_id', current_user_id).execute() diff --git a/backend/credentials/api.py b/backend/credentials/api.py index 97363c81..f03cb49b 100644 --- a/backend/credentials/api.py +++ b/backend/credentials/api.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, validator import urllib.parse from utils.logger import logger -from utils.auth_utils import get_current_user_id_from_jwt +from utils.auth_utils import verify_and_get_user_id_from_jwt from services.supabase import DBConnection from .credential_service import ( @@ -116,7 +116,7 @@ def initialize(database: DBConnection): @router.post("/credentials", response_model=CredentialResponse) async def store_credential( request: StoreCredentialRequest, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): try: credential_service = get_credential_service(db) @@ -151,7 +151,7 @@ async def store_credential( @router.get("/credentials", response_model=List[CredentialResponse]) async def get_user_credentials( - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): try: credential_service = get_credential_service(db) @@ -178,7 +178,7 @@ async def get_user_credentials( @router.delete("/credentials/{mcp_qualified_name:path}") async def delete_credential( mcp_qualified_name: str, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): try: decoded_name = decode_mcp_qualified_name(mcp_qualified_name) @@ -201,7 +201,7 @@ async def delete_credential( @router.post("/credential-profiles", response_model=CredentialProfileResponse) async def store_credential_profile( request: StoreCredentialProfileRequest, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): try: profile_service = get_profile_service(db) @@ -240,7 +240,7 @@ async def store_credential_profile( @router.get("/credential-profiles", response_model=List[CredentialProfileResponse]) async def get_user_credential_profiles( - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): try: profile_service = get_profile_service(db) @@ -269,7 +269,7 @@ async def get_user_credential_profiles( @router.get("/credential-profiles/{mcp_qualified_name:path}", response_model=List[CredentialProfileResponse]) async def get_credential_profiles_for_mcp( mcp_qualified_name: str, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): try: decoded_name = decode_mcp_qualified_name(mcp_qualified_name) @@ -300,7 +300,7 @@ async def get_credential_profiles_for_mcp( @router.get("/credential-profiles/profile/{profile_id}", response_model=CredentialProfileResponse) async def get_credential_profile( profile_id: str, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): try: profile_service = get_profile_service(db) @@ -331,7 +331,7 @@ async def get_credential_profile( @router.put("/credential-profiles/{profile_id}/set-default") async def set_default_credential_profile( profile_id: str, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): try: profile_service = get_profile_service(db) @@ -350,7 +350,7 @@ async def set_default_credential_profile( @router.delete("/credential-profiles/{profile_id}") async def delete_credential_profile( profile_id: str, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): try: profile_service = get_profile_service(db) @@ -369,7 +369,7 @@ async def delete_credential_profile( @router.post("/credential-profiles/bulk-delete", response_model=BulkDeleteProfilesResponse) async def bulk_delete_credential_profiles( request: BulkDeleteProfilesRequest, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): try: profile_service = get_profile_service(db) @@ -399,7 +399,7 @@ async def bulk_delete_credential_profiles( @router.get("/composio-profiles", response_model=ComposioCredentialsResponse) async def get_composio_profiles( - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): try: profile_service = get_profile_service(db) @@ -482,7 +482,7 @@ async def get_composio_profiles( @router.get("/composio-profiles/{profile_id}/mcp-url", response_model=ComposioMcpUrlResponse) async def get_composio_mcp_url( profile_id: str, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): try: from composio_integration.composio_profile_service import ComposioProfileService diff --git a/backend/google/google_slides_api.py b/backend/google/google_slides_api.py index 4b875cab..c156624f 100644 --- a/backend/google/google_slides_api.py +++ b/backend/google/google_slides_api.py @@ -19,7 +19,7 @@ from fastapi import APIRouter, HTTPException, Query, Depends from fastapi.responses import RedirectResponse from pydantic import BaseModel, Field -from utils.auth_utils import get_current_user_id_from_jwt +from utils.auth_utils import verify_and_get_user_id_from_jwt from utils.logger import logger from services.supabase import DBConnection from .google_slides_service import GoogleSlidesService, OAuthTokenService @@ -92,7 +92,7 @@ presentation_router = APIRouter(prefix="/presentation-tools", tags=["presentatio @oauth_router.get("/auth-url", response_model=AuthURLResponse) async def get_google_auth_url( - user_id: str = Depends(get_current_user_id_from_jwt), + user_id: str = Depends(verify_and_get_user_id_from_jwt), return_url: Optional[str] = Query(None, description="URL to redirect to after OAuth"), google_service: GoogleSlidesService = Depends(get_google_slides_service) ): @@ -172,7 +172,7 @@ async def google_oauth_callback( # UNUSED: Disconnect endpoint - frontend never calls this # @oauth_router.post("/disconnect") # async def disconnect_google( -# user_id: str = Depends(get_current_user_id_from_jwt), +# user_id: str = Depends(verify_and_get_user_id_from_jwt), # google_service: GoogleSlidesService = Depends(get_google_slides_service) # ): # """ @@ -196,7 +196,7 @@ async def google_oauth_callback( @presentation_router.post("/convert-and-upload-to-slides", response_model=ConvertToSlidesResponse) async def convert_and_upload_to_google_slides( request: ConvertToSlidesRequest, - user_id: str = Depends(get_current_user_id_from_jwt), + user_id: str = Depends(verify_and_get_user_id_from_jwt), google_service: GoogleSlidesService = Depends(get_google_slides_service) ): """ diff --git a/backend/knowledge_base/api.py b/backend/knowledge_base/api.py index 46d67d2f..659d8542 100644 --- a/backend/knowledge_base/api.py +++ b/backend/knowledge_base/api.py @@ -2,7 +2,7 @@ import json from typing import List, Optional from fastapi import APIRouter, HTTPException, Depends, UploadFile, File, Form, BackgroundTasks from pydantic import BaseModel, Field, HttpUrl -from utils.auth_utils import get_current_user_id_from_jwt, verify_agent_access +from utils.auth_utils import verify_and_get_user_id_from_jwt, verify_and_get_agent_authorization, require_agent_access, AuthorizedAgentAccess from services.supabase import DBConnection from knowledge_base.file_processor import FileProcessor from utils.logger import logger @@ -69,15 +69,16 @@ db = DBConnection() async def get_agent_knowledge_base( agent_id: str, include_inactive: bool = False, - user_id: str = Depends(get_current_user_id_from_jwt) + auth: AuthorizedAgentAccess = Depends(require_agent_access) ): """Get all knowledge base entries for an agent""" try: client = await db.client + user_id = auth.user_id # Already authenticated and authorized! + agent_data = auth.agent_data # Agent data already fetched during authorization - # Verify agent access - await verify_agent_access(client, agent_id, user_id) + # No need for manual authorization - it's already done in the dependency! result = await client.rpc('get_agent_knowledge_base', { 'p_agent_id': agent_id, @@ -122,7 +123,7 @@ async def get_agent_knowledge_base( async def create_agent_knowledge_base_entry( agent_id: str, entry_data: CreateKnowledgeBaseEntryRequest, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Create a new knowledge base entry for an agent""" @@ -130,7 +131,7 @@ async def create_agent_knowledge_base_entry( client = await db.client # Verify agent access and get agent data - agent_data = await verify_agent_access(client, agent_id, user_id) + agent_data = await verify_and_get_agent_authorization(client, agent_id, user_id) account_id = agent_data['account_id'] insert_data = { @@ -172,7 +173,7 @@ async def upload_file_to_agent_kb( agent_id: str, background_tasks: BackgroundTasks, file: UploadFile = File(...), - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Upload and process a file for agent knowledge base""" @@ -180,7 +181,7 @@ async def upload_file_to_agent_kb( client = await db.client # Verify agent access and get agent data - agent_data = await verify_agent_access(client, agent_id, user_id) + agent_data = await verify_and_get_agent_authorization(client, agent_id, user_id) account_id = agent_data['account_id'] file_content = await file.read() @@ -226,7 +227,7 @@ async def upload_file_to_agent_kb( async def update_knowledge_base_entry( entry_id: str, entry_data: UpdateKnowledgeBaseEntryRequest, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Update an agent knowledge base entry""" @@ -243,7 +244,7 @@ async def update_knowledge_base_entry( agent_id = entry['agent_id'] # Verify agent access - await verify_agent_access(client, agent_id, user_id) + await verify_and_get_agent_authorization(client, agent_id, user_id) update_data = {} if entry_data.name is not None: @@ -294,7 +295,7 @@ async def update_knowledge_base_entry( @router.delete("/{entry_id}") async def delete_knowledge_base_entry( entry_id: str, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Delete an agent knowledge base entry""" @@ -311,7 +312,7 @@ async def delete_knowledge_base_entry( agent_id = entry['agent_id'] # Verify agent access - await verify_agent_access(client, agent_id, user_id) + await verify_and_get_agent_authorization(client, agent_id, user_id) result = await client.table('agent_knowledge_base_entries').delete().eq('entry_id', entry_id).execute() @@ -329,7 +330,7 @@ async def delete_knowledge_base_entry( @router.get("/{entry_id}", response_model=KnowledgeBaseEntryResponse) async def get_knowledge_base_entry( entry_id: str, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Get a specific agent knowledge base entry""" try: @@ -345,7 +346,7 @@ async def get_knowledge_base_entry( agent_id = entry['agent_id'] # Verify agent access - await verify_agent_access(client, agent_id, user_id) + await verify_and_get_agent_authorization(client, agent_id, user_id) logger.debug(f"Retrieved agent knowledge base entry {entry_id} for agent {agent_id}") @@ -376,7 +377,7 @@ async def get_knowledge_base_entry( async def get_agent_processing_jobs( agent_id: str, limit: int = 10, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Get processing jobs for an agent""" @@ -384,7 +385,7 @@ async def get_agent_processing_jobs( client = await db.client # Verify agent access - await verify_agent_access(client, agent_id, user_id) + await verify_and_get_agent_authorization(client, agent_id, user_id) result = await client.rpc('get_agent_kb_processing_jobs', { 'p_agent_id': agent_id, @@ -468,7 +469,7 @@ async def process_file_background( async def get_agent_knowledge_base_context( agent_id: str, max_tokens: int = 4000, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Get knowledge base context for agent prompts""" @@ -476,7 +477,7 @@ async def get_agent_knowledge_base_context( client = await db.client # Verify agent access - await verify_agent_access(client, agent_id, user_id) + await verify_and_get_agent_authorization(client, agent_id, user_id) result = await client.rpc('get_agent_knowledge_base_context', { 'p_agent_id': agent_id, diff --git a/backend/mcp_module/api.py b/backend/mcp_module/api.py index 15bf2b51..7d0ce007 100644 --- a/backend/mcp_module/api.py +++ b/backend/mcp_module/api.py @@ -2,7 +2,7 @@ from fastapi import APIRouter, HTTPException, Depends from typing import List, Optional, Dict, Any from pydantic import BaseModel -from utils.auth_utils import get_current_user_id_from_jwt +from utils.auth_utils import verify_and_get_user_id_from_jwt from utils.logger import logger from .mcp_service import mcp_service, MCPException diff --git a/backend/pipedream/api.py b/backend/pipedream/api.py index 40aa5e7d..75cb33ad 100644 --- a/backend/pipedream/api.py +++ b/backend/pipedream/api.py @@ -5,7 +5,7 @@ from uuid import UUID from datetime import datetime from utils.logger import logger -from utils.auth_utils import get_current_user_id_from_jwt +from utils.auth_utils import verify_and_get_user_id_from_jwt from .profile_service import ProfileService, Profile, ProfileServiceError, ProfileNotFoundError, ProfileAlreadyExistsError, InvalidConfigError, EncryptionError from .connection_service import ConnectionService from .app_service import get_app_service @@ -162,7 +162,7 @@ def _handle_pipedream_exception(e: Exception) -> HTTPException: @router.post("/connection-token", response_model=ConnectionTokenResponse) async def create_connection_token( request: CreateConnectionTokenRequest, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): logger.debug(f"Creating Pipedream connection token for user: {user_id}, app: {request.app}") @@ -190,7 +190,7 @@ async def create_connection_token( @router.get("/connections", response_model=ConnectionResponse) async def get_user_connections( - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): logger.debug(f"Getting connections for user: {user_id}") @@ -228,7 +228,7 @@ async def get_user_connections( @router.post("/mcp/discover", response_model=MCPDiscoveryResponse) async def discover_mcp_servers( request: MCPDiscoveryRequest, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): logger.debug(f"Discovering MCP servers for user: {user_id}, app: {request.app_slug}") @@ -277,7 +277,7 @@ async def discover_mcp_servers( @router.post("/mcp/discover-profile", response_model=MCPDiscoveryResponse) async def discover_mcp_servers_for_profile( request: MCPProfileDiscoveryRequest, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): logger.debug(f"Discovering MCP servers for profile: {request.external_user_id}") @@ -326,7 +326,7 @@ async def discover_mcp_servers_for_profile( @router.post("/mcp/connect", response_model=MCPConnectionResponse) async def create_mcp_connection( request: MCPConnectionRequest, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): logger.debug(f"Creating MCP connection for user: {user_id}, app: {request.app_slug}") @@ -535,7 +535,7 @@ async def get_app_tools(app_slug: str): @router.post("/profiles", response_model=ProfileResponse) async def create_credential_profile( request: ProfileRequest, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): logger.debug(f"Creating credential profile for user: {user_id}, app: {request.app_slug}") @@ -563,7 +563,7 @@ async def create_credential_profile( async def get_credential_profiles( app_slug: Optional[str] = Query(None), is_active: Optional[bool] = Query(None), - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): logger.debug(f"Getting credential profiles for user: {user_id}, app: {app_slug}") @@ -581,7 +581,7 @@ async def get_credential_profiles( @router.get("/profiles/{profile_id}", response_model=ProfileResponse) async def get_credential_profile( profile_id: str, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): logger.debug(f"Getting credential profile: {profile_id} for user: {user_id}") @@ -603,7 +603,7 @@ async def get_credential_profile( async def update_credential_profile( profile_id: str, request: UpdateProfileRequest, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): logger.debug(f"Updating credential profile: {profile_id} for user: {user_id}") @@ -628,7 +628,7 @@ async def update_credential_profile( @router.delete("/profiles/{profile_id}") async def delete_credential_profile( profile_id: str, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): logger.debug(f"Deleting credential profile: {profile_id} for user: {user_id}") @@ -649,7 +649,7 @@ async def delete_credential_profile( async def connect_credential_profile( profile_id: str, app: Optional[str] = Query(None), - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): logger.debug(f"Connecting credential profile: {profile_id} for user: {user_id}") @@ -686,7 +686,7 @@ async def connect_credential_profile( @router.get("/profiles/{profile_id}/connections") async def get_profile_connections( profile_id: str, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): logger.debug(f"Getting connections for profile: {profile_id}, user: {user_id}") diff --git a/backend/services/api_keys_api.py b/backend/services/api_keys_api.py index 71b6aa29..b20d1114 100644 --- a/backend/services/api_keys_api.py +++ b/backend/services/api_keys_api.py @@ -18,7 +18,7 @@ from services.api_keys import ( APIKeyCreateResponse, ) from services.supabase import DBConnection -from utils.auth_utils import get_current_user_id_from_jwt +from utils.auth_utils import verify_and_get_user_id_from_jwt from utils.logger import logger router = APIRouter() @@ -61,7 +61,7 @@ async def get_account_id_from_user_id(user_id: str) -> UUID: @router.post("/api-keys", response_model=APIKeyCreateResponse) async def create_api_key( request: APIKeyCreateRequest, - user_id: str = Depends(get_current_user_id_from_jwt), + user_id: str = Depends(verify_and_get_user_id_from_jwt), api_key_service: APIKeyService = Depends(get_api_key_service), ): """ @@ -105,7 +105,7 @@ async def create_api_key( @router.get("/api-keys", response_model=List[APIKeyResponse]) async def list_api_keys( - user_id: str = Depends(get_current_user_id_from_jwt), + user_id: str = Depends(verify_and_get_user_id_from_jwt), api_key_service: APIKeyService = Depends(get_api_key_service), ): """ @@ -141,7 +141,7 @@ async def list_api_keys( @router.patch("/api-keys/{key_id}/revoke") async def revoke_api_key( key_id: UUID, - user_id: str = Depends(get_current_user_id_from_jwt), + user_id: str = Depends(verify_and_get_user_id_from_jwt), api_key_service: APIKeyService = Depends(get_api_key_service), ): """ @@ -185,7 +185,7 @@ async def revoke_api_key( @router.delete("/api-keys/{key_id}") async def delete_api_key( key_id: UUID, - user_id: str = Depends(get_current_user_id_from_jwt), + user_id: str = Depends(verify_and_get_user_id_from_jwt), api_key_service: APIKeyService = Depends(get_api_key_service), ): """ diff --git a/backend/services/billing.py b/backend/services/billing.py index dc53da1d..1f45d427 100644 --- a/backend/services/billing.py +++ b/backend/services/billing.py @@ -15,7 +15,7 @@ from utils.cache import Cache from utils.logger import logger from utils.config import config, EnvMode from services.supabase import DBConnection -from utils.auth_utils import get_current_user_id_from_jwt +from utils.auth_utils import verify_and_get_user_id_from_jwt from pydantic import BaseModel from models import model_manager from litellm.cost_calculator import cost_per_token @@ -1234,7 +1234,7 @@ async def handle_usage_with_credits( @router.post("/create-checkout-session") async def create_checkout_session( request: CreateCheckoutSessionRequest, - current_user_id: str = Depends(get_current_user_id_from_jwt) + current_user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Create a Stripe Checkout session or modify an existing subscription.""" try: @@ -1555,7 +1555,7 @@ async def create_checkout_session( @router.post("/create-portal-session") async def create_portal_session( request: CreatePortalSessionRequest, - current_user_id: str = Depends(get_current_user_id_from_jwt) + current_user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Create a Stripe Customer Portal session for subscription management.""" try: @@ -1655,7 +1655,7 @@ async def create_portal_session( @router.get("/subscription") async def get_subscription( - current_user_id: str = Depends(get_current_user_id_from_jwt) + current_user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Get the current subscription status for the current user, including scheduled changes and credit balance.""" try: @@ -1891,7 +1891,7 @@ async def get_subscription( @router.get("/check-status") async def check_status( - current_user_id: str = Depends(get_current_user_id_from_jwt) + current_user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Check if the user can run agents based on their subscription, usage, and credit balance.""" try: @@ -2076,7 +2076,7 @@ async def stripe_webhook(request: Request): @router.get("/available-models") async def get_available_models( - current_user_id: str = Depends(get_current_user_id_from_jwt) + current_user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Get the list of models available to the user based on their subscription tier.""" try: @@ -2200,7 +2200,7 @@ async def get_available_models( async def get_usage_logs_endpoint( page: int = 0, items_per_page: int = 1000, - current_user_id: str = Depends(get_current_user_id_from_jwt) + current_user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Get detailed usage logs for a user with pagination.""" logger.debug(f"[USAGE_LOGS_ENDPOINT] Starting get_usage_logs_endpoint for user_id={current_user_id}, page={page}, items_per_page={items_per_page}") @@ -2254,7 +2254,7 @@ async def get_usage_logs_endpoint( @router.get("/subscription-commitment/{subscription_id}") async def get_subscription_commitment( subscription_id: str, - current_user_id: str = Depends(get_current_user_id_from_jwt) + current_user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Get commitment status for a subscription.""" try: @@ -2278,7 +2278,7 @@ async def get_subscription_commitment( @router.get("/subscription-details") async def get_subscription_details( - current_user_id: str = Depends(get_current_user_id_from_jwt) + current_user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Get detailed subscription information including commitment status.""" try: @@ -2314,7 +2314,7 @@ async def get_subscription_details( @router.post("/cancel-subscription") async def cancel_subscription( - current_user_id: str = Depends(get_current_user_id_from_jwt) + current_user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Cancel subscription with yearly commitment handling.""" try: @@ -2412,7 +2412,7 @@ async def cancel_subscription( @router.post("/reactivate-subscription") async def reactivate_subscription( - current_user_id: str = Depends(get_current_user_id_from_jwt) + current_user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Reactivate a subscription that was marked for cancellation.""" try: @@ -2492,7 +2492,7 @@ async def reactivate_subscription( @router.post("/purchase-credits") async def purchase_credits( request: PurchaseCreditsRequest, - current_user_id: str = Depends(get_current_user_id_from_jwt) + current_user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """ Create a Stripe checkout session for purchasing credits. @@ -2609,7 +2609,7 @@ async def purchase_credits( @router.get("/credit-balance") async def get_credit_balance_endpoint( - current_user_id: str = Depends(get_current_user_id_from_jwt) + current_user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Get the current credit balance for the user.""" try: @@ -2628,7 +2628,7 @@ async def get_credit_balance_endpoint( async def get_credit_history( page: int = 0, items_per_page: int = 50, - current_user_id: str = Depends(get_current_user_id_from_jwt) + current_user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Get credit purchase and usage history for the user.""" try: @@ -2690,7 +2690,7 @@ async def get_credit_history( @router.get("/can-purchase-credits") async def can_purchase_credits( - current_user_id: str = Depends(get_current_user_id_from_jwt) + current_user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Check if the current user can purchase credits (must be on highest tier).""" try: diff --git a/backend/services/transcription.py b/backend/services/transcription.py index 1408f88e..75f222a0 100644 --- a/backend/services/transcription.py +++ b/backend/services/transcription.py @@ -5,7 +5,7 @@ from fastapi import APIRouter, UploadFile, File, HTTPException, Depends from pydantic import BaseModel from typing import Optional from utils.logger import logger -from utils.auth_utils import get_current_user_id_from_jwt +from utils.auth_utils import verify_and_get_user_id_from_jwt router = APIRouter(tags=["transcription"]) @@ -15,7 +15,7 @@ class TranscriptionResponse(BaseModel): @router.post("/transcription", response_model=TranscriptionResponse) async def transcribe_audio( audio_file: UploadFile = File(...), - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Transcribe audio file to text using OpenAI Whisper.""" try: diff --git a/backend/supabase/migrations/20250901030240_security.sql b/backend/supabase/migrations/20250901030240_security.sql new file mode 100644 index 00000000..6efaa318 --- /dev/null +++ b/backend/supabase/migrations/20250901030240_security.sql @@ -0,0 +1,4 @@ +-- Remove the problematic GRANT that bypasses RLS +REVOKE SELECT ON TABLE projects FROM anon; +REVOKE SELECT ON TABLE threads FROM anon; +REVOKE SELECT ON TABLE messages FROM anon; \ No newline at end of file diff --git a/backend/templates/api.py b/backend/templates/api.py index ca329dc8..794e81f9 100644 --- a/backend/templates/api.py +++ b/backend/templates/api.py @@ -3,7 +3,7 @@ from typing import List, Optional, Dict, Any from pydantic import BaseModel from utils.logger import logger -from utils.auth_utils import get_current_user_id_from_jwt +from utils.auth_utils import verify_and_get_user_id_from_jwt from services.supabase import DBConnection from utils.pagination import PaginationParams @@ -133,7 +133,7 @@ async def validate_agent_ownership(agent_id: str, user_id: str) -> Dict[str, Any @router.post("", response_model=Dict[str, str]) async def create_template_from_agent( request: CreateTemplateRequest, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): try: await validate_agent_ownership(request.agent_id, user_id) @@ -172,7 +172,7 @@ async def create_template_from_agent( async def publish_template( template_id: str, request: PublishTemplateRequest, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): try: template = await validate_template_ownership_and_get(template_id, user_id) @@ -200,7 +200,7 @@ async def publish_template( @router.post("/{template_id}/unpublish") async def unpublish_template( template_id: str, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): try: template = await validate_template_ownership_and_get(template_id, user_id) @@ -228,7 +228,7 @@ async def unpublish_template( @router.delete("/{template_id}") async def delete_template( template_id: str, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): try: template = await validate_template_ownership_and_get(template_id, user_id) @@ -256,7 +256,7 @@ async def delete_template( @router.post("/install", response_model=InstallationResponse) async def install_template( request: InstallTemplateRequest, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): try: await validate_template_access_and_get(request.template_id, user_id) @@ -345,8 +345,8 @@ async def get_marketplace_templates( creator_id_filter = None if mine: try: - from utils.auth_utils import get_current_user_id_from_jwt - user_id = await get_current_user_id_from_jwt(request) + from utils.auth_utils import verify_and_get_user_id_from_jwt + user_id = await verify_and_get_user_id_from_jwt(request) creator_id_filter = user_id except Exception as e: raise HTTPException(status_code=401, detail="Authentication required for 'mine' filter") @@ -406,7 +406,7 @@ async def get_my_templates( search: Optional[str] = Query(None, description="Search term for name and description"), sort_by: Optional[str] = Query("created_at", description="Sort field: created_at, name, download_count"), sort_order: Optional[str] = Query("desc", description="Sort order: asc, desc"), - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): try: from templates.services.template_service import TemplateService, MarketplaceFilters @@ -497,7 +497,7 @@ async def get_public_template(template_id: str): @router.get("/{template_id}", response_model=TemplateResponse) async def get_template( template_id: str, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): try: template = await validate_template_access_and_get(template_id, user_id) diff --git a/backend/triggers/api.py b/backend/triggers/api.py index 309db5c9..f85db063 100644 --- a/backend/triggers/api.py +++ b/backend/triggers/api.py @@ -9,7 +9,7 @@ import json import hmac from services.supabase import DBConnection -from utils.auth_utils import get_current_user_id_from_jwt +from utils.auth_utils import verify_and_get_user_id_from_jwt from utils.logger import logger from utils.config import config from services.billing import check_billing_status, can_use_model @@ -131,7 +131,7 @@ def initialize(database: DBConnection): db = database -async def verify_agent_access(agent_id: str, user_id: str): +async def verify_and_authorize_trigger_agent_access(agent_id: str, user_id: str): client = await db.client result = await client.table('agents').select('agent_id').eq('agent_id', agent_id).eq('account_id', user_id).execute() @@ -228,10 +228,10 @@ async def get_providers(): @router.get("/agents/{agent_id}/triggers", response_model=List[TriggerResponse]) async def get_agent_triggers( agent_id: str, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): - await verify_agent_access(agent_id, user_id) + await verify_and_authorize_trigger_agent_access(agent_id, user_id) try: trigger_service = get_trigger_service(db) @@ -266,7 +266,7 @@ async def get_agent_triggers( @router.get("/all", response_model=List[Dict[str, Any]]) async def get_all_user_triggers( - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): try: client = await db.client @@ -347,11 +347,11 @@ async def get_all_user_triggers( async def get_agent_upcoming_runs( agent_id: str, limit: int = Query(10, ge=1, le=50), - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Get upcoming scheduled runs for agent triggers""" - await verify_agent_access(agent_id, user_id) + await verify_and_authorize_trigger_agent_access(agent_id, user_id) try: trigger_service = get_trigger_service(db) @@ -419,11 +419,11 @@ async def get_agent_upcoming_runs( async def create_agent_trigger( agent_id: str, request: TriggerCreateRequest, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Create a new trigger for an agent""" - await verify_agent_access(agent_id, user_id) + await verify_and_authorize_trigger_agent_access(agent_id, user_id) try: trigger_service = get_trigger_service(db) @@ -466,7 +466,7 @@ async def create_agent_trigger( @router.get("/{trigger_id}", response_model=TriggerResponse) async def get_trigger( trigger_id: str, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Get a trigger by ID""" @@ -477,7 +477,7 @@ async def get_trigger( if not trigger: raise HTTPException(status_code=404, detail="Trigger not found") - await verify_agent_access(trigger.agent_id, user_id) + await verify_and_authorize_trigger_agent_access(trigger.agent_id, user_id) base_url = os.getenv("WEBHOOK_BASE_URL", "http://localhost:8000") webhook_url = f"{base_url}/api/triggers/{trigger_id}/webhook" @@ -505,7 +505,7 @@ async def get_trigger( async def update_trigger( trigger_id: str, request: TriggerUpdateRequest, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Update a trigger""" @@ -516,7 +516,7 @@ async def update_trigger( if not trigger: raise HTTPException(status_code=404, detail="Trigger not found") - await verify_agent_access(trigger.agent_id, user_id) + await verify_and_authorize_trigger_agent_access(trigger.agent_id, user_id) updated_trigger = await trigger_service.update_trigger( trigger_id=trigger_id, @@ -556,7 +556,7 @@ async def update_trigger( @router.delete("/{trigger_id}") async def delete_trigger( trigger_id: str, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Delete a trigger""" @@ -566,7 +566,7 @@ async def delete_trigger( if not trigger: raise HTTPException(status_code=404, detail="Trigger not found") - await verify_agent_access(trigger.agent_id, user_id) + await verify_and_authorize_trigger_agent_access(trigger.agent_id, user_id) # Store agent_id before deletion agent_id = trigger.agent_id @@ -709,10 +709,10 @@ def convert_steps_to_json(steps: List[WorkflowStepRequest]) -> List[Dict[str, An @workflows_router.get("/agents/{agent_id}/workflows") async def get_agent_workflows( agent_id: str, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Get workflows for an agent""" - await verify_agent_access(agent_id, user_id) + await verify_and_authorize_trigger_agent_access(agent_id, user_id) client = await db.client result = await client.table('agent_workflows').select('*').eq('agent_id', agent_id).order('created_at', desc=True).execute() @@ -724,10 +724,10 @@ async def get_agent_workflows( async def create_agent_workflow( agent_id: str, workflow_data: WorkflowCreateRequest, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Create a new workflow for an agent""" - await verify_agent_access(agent_id, user_id) + await verify_and_authorize_trigger_agent_access(agent_id, user_id) try: client = await db.client @@ -758,10 +758,10 @@ async def update_agent_workflow( agent_id: str, workflow_id: str, workflow_data: WorkflowUpdateRequest, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): """Update a workflow""" - await verify_agent_access(agent_id, user_id) + await verify_and_authorize_trigger_agent_access(agent_id, user_id) client = await db.client @@ -800,9 +800,9 @@ async def update_agent_workflow( async def delete_agent_workflow( agent_id: str, workflow_id: str, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): - await verify_agent_access(agent_id, user_id) + await verify_and_authorize_trigger_agent_access(agent_id, user_id) client = await db.client @@ -822,10 +822,10 @@ async def execute_agent_workflow( agent_id: str, workflow_id: str, execution_data: WorkflowExecuteRequest, - user_id: str = Depends(get_current_user_id_from_jwt) + user_id: str = Depends(verify_and_get_user_id_from_jwt) ): print("DEBUG: Executing workflow", workflow_id, "for agent", agent_id) - await verify_agent_access(agent_id, user_id) + await verify_and_authorize_trigger_agent_access(agent_id, user_id) client = await db.client diff --git a/backend/utils/auth_utils.py b/backend/utils/auth_utils.py index 80ddcf8c..9e47f3e6 100644 --- a/backend/utils/auth_utils.py +++ b/backend/utils/auth_utils.py @@ -6,9 +6,93 @@ from jwt.exceptions import PyJWTError from utils.logger import structlog from utils.config import config import os +import base64 +import hashlib +import hmac from services.supabase import DBConnection from services import redis +async def verify_admin_api_key(x_admin_api_key: Optional[str] = Header(None)): + if not config.KORTIX_ADMIN_API_KEY: + raise HTTPException( + status_code=500, + detail="Admin API key not configured on server" + ) + + if not x_admin_api_key: + raise HTTPException( + status_code=401, + detail="Admin API key required. Include X-Admin-Api-Key header." + ) + + if x_admin_api_key != config.KORTIX_ADMIN_API_KEY: + raise HTTPException( + status_code=403, + detail="Invalid admin API key" + ) + + return True + +def _verify_jwt_signature(token: str, secret: str) -> dict: + """ + Securely verify JWT signature using the Supabase JWT secret. + + Args: + token: The JWT token to verify + secret: The Supabase JWT secret for signature verification + + Returns: + dict: The decoded JWT payload if signature is valid + + Raises: + PyJWTError: If signature verification fails or token is invalid + """ + try: + # Decode and verify the JWT signature using HS256 algorithm (Supabase default) + payload = jwt.decode( + token, + secret, + algorithms=["HS256"], + options={ + "verify_signature": True, + "verify_exp": True, + "verify_aud": False, # Supabase JWTs may not always have audience + "verify_iss": False # Supabase JWTs may not always have issuer + } + ) + return payload + except PyJWTError as e: + structlog.get_logger().warning(f"JWT signature verification failed: {e}") + raise + +def _decode_jwt_safely(token: str) -> dict: + """ + Safely decode and verify a JWT token with proper signature verification. + Falls back to no verification only if JWT secret is not configured. + + Args: + token: The JWT token to decode + + Returns: + dict: The decoded JWT payload + + Raises: + PyJWTError: If token is invalid or signature verification fails + """ + jwt_secret = config.SUPABASE_JWT_SECRET + + if jwt_secret: + # Production mode: Verify signature + structlog.get_logger().debug("Verifying JWT signature") + return _verify_jwt_signature(token, jwt_secret) + else: + # Development mode: Log warning and decode without verification + structlog.get_logger().warning( + "JWT_SECRET not configured - using insecure JWT decoding. " + "This should only be used in development!" + ) + return jwt.decode(token, options={"verify_signature": False}) + async def _get_user_id_from_account_cached(account_id: str) -> Optional[str]: """ Get user_id from account_id with Redis caching for performance @@ -58,7 +142,7 @@ async def _get_user_id_from_account_cached(account_id: str) -> Optional[str]: return None # This function extracts the user ID from Supabase JWT -async def get_current_user_id_from_jwt(request: Request) -> str: +async def verify_and_get_user_id_from_jwt(request: Request) -> str: """ Extract and verify the user ID from the JWT in the Authorization header or API key. @@ -149,7 +233,7 @@ async def get_current_user_id_from_jwt(request: Request) -> str: token = auth_header.split(' ')[1] try: - payload = jwt.decode(token, options={"verify_signature": False}) + payload = _decode_jwt_safely(token) user_id = payload.get('sub') if not user_id: @@ -173,51 +257,6 @@ async def get_current_user_id_from_jwt(request: Request) -> str: headers={"WWW-Authenticate": "Bearer"} ) -async def get_account_id_from_thread(client, thread_id: str) -> str: - """ - Extract and verify the account ID from the thread. - - Args: - client: The Supabase client - thread_id: The ID of the thread - - Returns: - str: The account ID associated with the thread - - Raises: - HTTPException: If the thread is not found or if there's an error - """ - try: - response = await client.table('threads').select('account_id').eq('thread_id', thread_id).execute() - - if not response.data or len(response.data) == 0: - raise HTTPException( - status_code=404, - detail="Thread not found" - ) - - account_id = response.data[0].get('account_id') - - if not account_id: - raise HTTPException( - status_code=500, - detail="Thread has no associated account" - ) - - return account_id - - except Exception as e: - error_msg = str(e) - if "cannot schedule new futures after shutdown" in error_msg or "connection is closed" in error_msg: - raise HTTPException( - status_code=503, - detail="Server is shutting down" - ) - else: - raise HTTPException( - status_code=500, - detail=f"Error retrieving thread information: {str(e)}" - ) async def get_user_id_from_stream_auth( request: Request, @@ -246,7 +285,7 @@ async def get_user_id_from_stream_auth( try: # First, try the standard authentication (handles both API keys and Authorization header) try: - return await get_current_user_id_from_jwt(request) + return await verify_and_get_user_id_from_jwt(request) except HTTPException: # If standard auth fails, try query parameter JWT for EventSource compatibility pass @@ -254,8 +293,8 @@ async def get_user_id_from_stream_auth( # Try to get user_id from token in query param (for EventSource which can't set headers) if token: try: - # For Supabase JWT, we just need to decode and extract the user ID - payload = jwt.decode(token, options={"verify_signature": False}) + # For Supabase JWT, verify signature and extract the user ID + payload = _decode_jwt_safely(token) user_id = payload.get('sub') if user_id: sentry.sentry.set_user({ "id": user_id }) @@ -289,7 +328,76 @@ async def get_user_id_from_stream_auth( detail=f"Error during authentication: {str(e)}" ) -async def verify_thread_access(client, thread_id: str, user_id: str): +async def get_optional_user_id(request: Request) -> Optional[str]: + """ + Extract the user ID from the JWT in the Authorization header if present, + but don't require authentication. Returns None if no valid token is found. + + This function is used for endpoints that support both authenticated and + unauthenticated access (like public projects). + + Args: + request: The FastAPI request object + + Returns: + Optional[str]: The user ID extracted from the JWT, or None if no valid token + """ + auth_header = request.headers.get('Authorization') + + if not auth_header or not auth_header.startswith('Bearer '): + return None + + token = auth_header.split(' ')[1] + + try: + # For Supabase JWT, verify signature and extract the user ID + payload = _decode_jwt_safely(token) + + # Supabase stores the user ID in the 'sub' claim + user_id = payload.get('sub') + if user_id: + sentry.sentry.set_user({ "id": user_id }) + structlog.contextvars.bind_contextvars( + user_id=user_id + ) + + return user_id + except PyJWTError: + return None + +# Alias for consistency with other auth functions +get_optional_current_user_id_from_jwt = get_optional_user_id + +async def verify_and_get_agent_authorization(client, agent_id: str, user_id: str) -> dict: + """ + Verify that a user has access to a specific agent based on ownership. + + Args: + client: The Supabase client + agent_id: The agent ID to check access for + user_id: The user ID to check permissions for + + Returns: + dict: Agent data if access is granted + + Raises: + HTTPException: If the user doesn't have access to the agent or agent doesn't exist + """ + try: + agent_result = await client.table('agents').select('*').eq('agent_id', agent_id).eq('account_id', user_id).execute() + + if not agent_result.data: + raise HTTPException(status_code=404, detail="Agent not found or access denied") + + return agent_result.data[0] + + except HTTPException: + raise + except Exception as e: + structlog.error(f"Error verifying agent access for agent {agent_id}, user {user_id}: {str(e)}") + raise HTTPException(status_code=500, detail="Failed to verify agent access") + +async def verify_and_authorize_thread_access(client, thread_id: str, user_id: str): """ Verify that a user has access to a specific thread based on account membership. @@ -347,92 +455,137 @@ async def verify_thread_access(client, thread_id: str, user_id: str): detail=f"Error verifying thread access: {str(e)}" ) -async def get_optional_user_id(request: Request) -> Optional[str]: +# ============================================================================ +# FastAPI Dependency Functions for Combined Authentication + Authorization +# ============================================================================ + +async def get_authorized_user_for_thread( + thread_id: str, + request: Request +) -> str: """ - Extract the user ID from the JWT in the Authorization header if present, - but don't require authentication. Returns None if no valid token is found. - - This function is used for endpoints that support both authenticated and - unauthenticated access (like public projects). + FastAPI dependency that verifies JWT and authorizes thread access. Args: + thread_id: The thread ID to authorize access for request: The FastAPI request object Returns: - Optional[str]: The user ID extracted from the JWT, or None if no valid token - """ - auth_header = request.headers.get('Authorization') - - if not auth_header or not auth_header.startswith('Bearer '): - return None - - token = auth_header.split(' ')[1] - - try: - # For Supabase JWT, we just need to decode and extract the user ID - payload = jwt.decode(token, options={"verify_signature": False}) - - # Supabase stores the user ID in the 'sub' claim - user_id = payload.get('sub') - if user_id: - sentry.sentry.set_user({ "id": user_id }) - structlog.contextvars.bind_contextvars( - user_id=user_id - ) - - return user_id - except PyJWTError: - return None - -# Alias for consistency with other auth functions -get_optional_current_user_id_from_jwt = get_optional_user_id - -async def verify_admin_api_key(x_admin_api_key: Optional[str] = Header(None)): - if not config.KORTIX_ADMIN_API_KEY: - raise HTTPException( - status_code=500, - detail="Admin API key not configured on server" - ) - - if not x_admin_api_key: - raise HTTPException( - status_code=401, - detail="Admin API key required. Include X-Admin-Api-Key header." - ) - - if x_admin_api_key != config.KORTIX_ADMIN_API_KEY: - raise HTTPException( - status_code=403, - detail="Invalid admin API key" - ) - - return True - -async def verify_agent_access(client, agent_id: str, user_id: str) -> dict: - """ - Verify that a user has access to a specific agent based on ownership. - - Args: - client: The Supabase client - agent_id: The agent ID to check access for - user_id: The user ID to check permissions for - - Returns: - dict: Agent data if access is granted + str: The authenticated and authorized user ID Raises: - HTTPException: If the user doesn't have access to the agent or agent doesn't exist + HTTPException: If authentication fails or user lacks thread access """ - try: - agent_result = await client.table('agents').select('*').eq('agent_id', agent_id).eq('account_id', user_id).execute() + from services.supabase import DBConnection + + # First, authenticate the user + user_id = await verify_and_get_user_id_from_jwt(request) + + # Then, authorize thread access + db = DBConnection() + client = await db.client + await verify_and_authorize_thread_access(client, thread_id, user_id) + + return user_id + +async def get_authorized_user_for_agent( + agent_id: str, + request: Request +) -> tuple[str, dict]: + """ + FastAPI dependency that verifies JWT and authorizes agent access. + + Args: + agent_id: The agent ID to authorize access for + request: The FastAPI request object - if not agent_result.data: - raise HTTPException(status_code=404, detail="Agent not found or access denied") + Returns: + tuple[str, dict]: The authenticated user ID and agent data - return agent_result.data[0] + Raises: + HTTPException: If authentication fails or user lacks agent access + """ + from services.supabase import DBConnection + + # First, authenticate the user + user_id = await verify_and_get_user_id_from_jwt(request) + + # Then, authorize agent access and get agent data + db = DBConnection() + client = await db.client + agent_data = await verify_and_get_agent_authorization(client, agent_id, user_id) + + return user_id, agent_data + +class AuthorizedThreadAccess: + """ + FastAPI dependency that combines authentication and thread authorization. + + Usage: + @router.get("/threads/{thread_id}/messages") + async def get_messages( + thread_id: str, + auth: AuthorizedThreadAccess = Depends() + ): + user_id = auth.user_id # Authenticated and authorized user + """ + def __init__(self, user_id: str): + self.user_id = user_id + +class AuthorizedAgentAccess: + """ + FastAPI dependency that combines authentication and agent authorization. + + Usage: + @router.get("/agents/{agent_id}/config") + async def get_agent_config( + agent_id: str, + auth: AuthorizedAgentAccess = Depends() + ): + user_id = auth.user_id # Authenticated and authorized user + agent_data = auth.agent_data # Agent data from authorization check + """ + def __init__(self, user_id: str, agent_data: dict): + self.user_id = user_id + self.agent_data = agent_data + +async def require_thread_access( + thread_id: str, + request: Request +) -> AuthorizedThreadAccess: + """ + FastAPI dependency that verifies JWT and authorizes thread access. + + Args: + thread_id: The thread ID from the path parameter + request: The FastAPI request object - except HTTPException: - raise - except Exception as e: - structlog.error(f"Error verifying agent access for agent {agent_id}, user {user_id}: {str(e)}") - raise HTTPException(status_code=500, detail="Failed to verify agent access") + Returns: + AuthorizedThreadAccess: Object containing authenticated user_id + + Raises: + HTTPException: If authentication fails or user lacks thread access + """ + user_id = await get_authorized_user_for_thread(thread_id, request) + return AuthorizedThreadAccess(user_id) + +async def require_agent_access( + agent_id: str, + request: Request +) -> AuthorizedAgentAccess: + """ + FastAPI dependency that verifies JWT and authorizes agent access. + + Args: + agent_id: The agent ID from the path parameter + request: The FastAPI request object + + Returns: + AuthorizedAgentAccess: Object containing user_id and agent_data + + Raises: + HTTPException: If authentication fails or user lacks agent access + """ + user_id, agent_data = await get_authorized_user_for_agent(agent_id, request) + return AuthorizedAgentAccess(user_id, agent_data) + diff --git a/backend/utils/config.py b/backend/utils/config.py index d782e2a9..e36f031b 100644 --- a/backend/utils/config.py +++ b/backend/utils/config.py @@ -16,6 +16,7 @@ Usage: import os from enum import Enum +from re import S from typing import Dict, Any, Optional, get_type_hints, Union from dotenv import load_dotenv import logging @@ -277,6 +278,7 @@ class Configuration: SUPABASE_URL: str SUPABASE_ANON_KEY: str SUPABASE_SERVICE_ROLE_KEY: str + SUPABASE_JWT_SECRET: str # Redis configuration REDIS_HOST: str diff --git a/setup.py b/setup.py index 6209fff6..6e6225b5 100644 --- a/setup.py +++ b/setup.py @@ -119,6 +119,7 @@ def load_existing_env_vars(): "SUPABASE_SERVICE_ROLE_KEY": backend_env.get( "SUPABASE_SERVICE_ROLE_KEY", "" ), + "SUPABASE_JWT_SECRET": backend_env.get("SUPABASE_JWT_SECRET", ""), }, "daytona": { "DAYTONA_API_KEY": backend_env.get("DAYTONA_API_KEY", ""), @@ -285,8 +286,17 @@ class SetupWizard: config_items = [] # Check Supabase - if self.env_vars["supabase"]["SUPABASE_URL"]: - config_items.append(f"{Colors.GREEN}✓{Colors.ENDC} Supabase") + supabase_complete = ( + self.env_vars["supabase"]["SUPABASE_URL"] and + self.env_vars["supabase"]["SUPABASE_ANON_KEY"] and + self.env_vars["supabase"]["SUPABASE_SERVICE_ROLE_KEY"] + ) + supabase_secure = self.env_vars["supabase"]["SUPABASE_JWT_SECRET"] + + if supabase_complete and supabase_secure: + config_items.append(f"{Colors.GREEN}✓{Colors.ENDC} Supabase (secure)") + elif supabase_complete: + config_items.append(f"{Colors.YELLOW}⚠{Colors.ENDC} Supabase (missing JWT secret)") else: config_items.append(f"{Colors.YELLOW}○{Colors.ENDC} Supabase") @@ -600,8 +610,12 @@ class SetupWizard: "You'll need a Supabase project. Visit https://supabase.com/dashboard/projects to create one." ) print_info( - "In your project settings, go to 'API' to find the required information." + "In your project settings, go to 'API' to find the required information:" ) + print_info(" - Project URL (at the top)") + print_info(" - anon public key (under 'Project API keys')") + print_info(" - service_role secret key (under 'Project API keys')") + print_info(" - JWT Secret (under 'JWT Settings' - critical for security!)") input("Press Enter to continue once you have your project details...") self.env_vars["supabase"]["SUPABASE_URL"] = self._get_input( @@ -622,6 +636,12 @@ class SetupWizard: "This does not look like a valid key. It should be at least 10 characters.", default_value=self.env_vars["supabase"]["SUPABASE_SERVICE_ROLE_KEY"], ) + self.env_vars["supabase"]["SUPABASE_JWT_SECRET"] = self._get_input( + "Enter your Supabase JWT secret (for signature verification): ", + validate_api_key, + "This does not look like a valid JWT secret. It should be at least 10 characters.", + default_value=self.env_vars["supabase"]["SUPABASE_JWT_SECRET"], + ) print_success("Supabase information saved.") def collect_daytona_info(self):