security wip

This commit is contained in:
marko-kraemer 2025-08-31 20:59:45 -07:00
parent 2fea768be9
commit e01aa2e332
24 changed files with 501 additions and 313 deletions

View File

@ -1,7 +1,7 @@
from typing import Optional
from fastapi import APIRouter, HTTPException, Depends, File, UploadFile, Query
from utils.auth_utils import get_current_user_id_from_jwt
from utils.auth_utils import verify_and_get_user_id_from_jwt
from utils.logger import logger
from utils.config import config, EnvMode
from utils.pagination import PaginationParams
@ -21,7 +21,7 @@ router = APIRouter()
async def update_agent(
agent_id: str,
agent_data: AgentUpdateRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
logger.debug(f"Updating agent {agent_id} for user: {user_id}")
@ -436,7 +436,7 @@ async def update_agent(
raise HTTPException(status_code=500, detail=f"Failed to update agent: {str(e)}")
@router.delete("/agents/{agent_id}")
async def delete_agent(agent_id: str, user_id: str = Depends(get_current_user_id_from_jwt)):
async def delete_agent(agent_id: str, user_id: str = Depends(verify_and_get_user_id_from_jwt)):
logger.debug(f"Deleting agent: {agent_id}")
client = await utils.db.client
@ -502,7 +502,7 @@ async def delete_agent(agent_id: str, user_id: str = Depends(get_current_user_id
@router.get("/agents", response_model=AgentsResponse)
async def get_agents(
user_id: str = Depends(get_current_user_id_from_jwt),
user_id: str = Depends(verify_and_get_user_id_from_jwt),
page: Optional[int] = Query(1, ge=1, description="Page number (1-based)"),
limit: Optional[int] = Query(20, ge=1, le=100, description="Number of items per page"),
search: Optional[str] = Query(None, description="Search in name and description"),
@ -570,7 +570,7 @@ async def get_agents(
raise HTTPException(status_code=500, detail=f"Failed to fetch agents: {str(e)}")
@router.get("/agents/{agent_id}", response_model=AgentResponse)
async def get_agent(agent_id: str, user_id: str = Depends(get_current_user_id_from_jwt)):
async def get_agent(agent_id: str, user_id: str = Depends(verify_and_get_user_id_from_jwt)):
logger.debug(f"Fetching agent {agent_id} for user: {user_id}")
@ -697,7 +697,7 @@ async def get_agent(agent_id: str, user_id: str = Depends(get_current_user_id_fr
@router.post("/agents", response_model=AgentResponse)
async def create_agent(
agent_data: AgentCreateRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
logger.debug(f"Creating new agent for user: {user_id}")
client = await utils.db.client

View File

@ -3,7 +3,7 @@ from typing import Optional, Dict, List, Any
from fastapi import APIRouter, HTTPException, Depends
from uuid import uuid4
from utils.auth_utils import get_current_user_id_from_jwt
from utils.auth_utils import verify_and_get_user_id_from_jwt
from utils.logger import logger
from templates.template_service import MCPRequirementValue, ConfigType, ProfileId, QualifiedName
@ -343,7 +343,7 @@ class JsonImportService:
raise # Re-raise the exception to ensure import fails if version creation fails
@router.get("/agents/{agent_id}/export")
async def export_agent(agent_id: str, user_id: str = Depends(get_current_user_id_from_jwt)):
async def export_agent(agent_id: str, user_id: str = Depends(verify_and_get_user_id_from_jwt)):
"""Export an agent configuration as JSON"""
logger.debug(f"Exporting agent {agent_id} for user: {user_id}")
@ -419,7 +419,7 @@ async def export_agent(agent_id: str, user_id: str = Depends(get_current_user_id
@router.post("/agents/json/analyze", response_model=JsonAnalysisResponse)
async def analyze_json_for_import(
request: JsonAnalysisRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Analyze imported JSON to determine required credentials and configurations"""
logger.debug(f"Analyzing JSON for import - user: {user_id}")
@ -439,7 +439,7 @@ async def analyze_json_for_import(
@router.post("/agents/json/import", response_model=JsonImportResponse)
async def import_agent_from_json(
request: JsonImportRequestModel,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
logger.debug(f"Importing agent from JSON - user: {user_id}")

View File

@ -8,7 +8,7 @@ from typing import Optional, List
from fastapi import APIRouter, HTTPException, Depends, Request, Body, File, UploadFile, Form
from fastapi.responses import StreamingResponse
from utils.auth_utils import get_current_user_id_from_jwt, get_user_id_from_stream_auth, verify_thread_access
from utils.auth_utils import verify_and_get_user_id_from_jwt, get_user_id_from_stream_auth, verify_and_authorize_thread_access
from utils.logger import logger, structlog
from services.billing import check_billing_status, can_use_model
from utils.config import config
@ -32,7 +32,7 @@ router = APIRouter()
async def start_agent(
thread_id: str,
body: AgentStartRequest = Body(...),
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Start an agent for a specific thread in the background"""
structlog.contextvars.bind_contextvars(
@ -67,7 +67,7 @@ async def start_agent(
thread_metadata = thread_data.get('metadata', {})
if account_id != user_id:
await verify_thread_access(client, thread_id, user_id)
await verify_and_authorize_thread_access(client, thread_id, user_id)
structlog.contextvars.bind_contextvars(
project_id=project_id,
@ -241,7 +241,7 @@ async def start_agent(
return {"agent_run_id": agent_run_id, "status": "running"}
@router.post("/agent-run/{agent_run_id}/stop")
async def stop_agent(agent_run_id: str, user_id: str = Depends(get_current_user_id_from_jwt)):
async def stop_agent(agent_run_id: str, user_id: str = Depends(verify_and_get_user_id_from_jwt)):
"""Stop a running agent."""
structlog.contextvars.bind_contextvars(
agent_run_id=agent_run_id,
@ -253,20 +253,20 @@ async def stop_agent(agent_run_id: str, user_id: str = Depends(get_current_user_
return {"status": "stopped"}
@router.get("/thread/{thread_id}/agent-runs")
async def get_agent_runs(thread_id: str, user_id: str = Depends(get_current_user_id_from_jwt)):
async def get_agent_runs(thread_id: str, user_id: str = Depends(verify_and_get_user_id_from_jwt)):
"""Get all agent runs for a thread."""
structlog.contextvars.bind_contextvars(
thread_id=thread_id,
)
logger.debug(f"Fetching agent runs for thread: {thread_id}")
client = await utils.db.client
await verify_thread_access(client, thread_id, user_id)
await verify_and_authorize_thread_access(client, thread_id, user_id)
agent_runs = await client.table('agent_runs').select('id, thread_id, status, started_at, completed_at, error, created_at, updated_at').eq("thread_id", thread_id).order('created_at', desc=True).execute()
logger.debug(f"Found {len(agent_runs.data)} agent runs for thread: {thread_id}")
return {"agent_runs": agent_runs.data}
@router.get("/agent-run/{agent_run_id}")
async def get_agent_run(agent_run_id: str, user_id: str = Depends(get_current_user_id_from_jwt)):
async def get_agent_run(agent_run_id: str, user_id: str = Depends(verify_and_get_user_id_from_jwt)):
"""Get agent run status and responses."""
structlog.contextvars.bind_contextvars(
agent_run_id=agent_run_id,
@ -285,7 +285,7 @@ async def get_agent_run(agent_run_id: str, user_id: str = Depends(get_current_us
}
@router.get("/thread/{thread_id}/agent", response_model=ThreadAgentResponse)
async def get_thread_agent(thread_id: str, user_id: str = Depends(get_current_user_id_from_jwt)):
async def get_thread_agent(thread_id: str, user_id: str = Depends(verify_and_get_user_id_from_jwt)):
"""Get the agent details for a specific thread. Since threads are fully agent-agnostic,
this returns the most recently used agent from agent_runs only."""
structlog.contextvars.bind_contextvars(
@ -296,7 +296,7 @@ async def get_thread_agent(thread_id: str, user_id: str = Depends(get_current_us
try:
# Verify thread access and get thread data
await verify_thread_access(client, thread_id, user_id)
await verify_and_authorize_thread_access(client, thread_id, user_id)
thread_result = await client.table('threads').select('account_id').eq('thread_id', thread_id).execute()
if not thread_result.data:
@ -621,7 +621,7 @@ async def initiate_agent_with_files(
enable_context_manager: Optional[bool] = Form(False),
agent_id: Optional[str] = Form(None), # Add agent_id parameter
files: List[UploadFile] = File(default=[]),
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""
Initiate a new agent session with optional file attachments.

View File

@ -4,7 +4,7 @@ from datetime import datetime
from typing import Optional, Dict, Any
from fastapi import APIRouter, HTTPException, Depends, Request, Body
from utils.auth_utils import get_current_user_id_from_jwt
from utils.auth_utils import verify_and_get_user_id_from_jwt
from utils.logger import logger
from sandbox.sandbox import get_or_start_sandbox
from services.supabase import DBConnection
@ -19,7 +19,7 @@ router = APIRouter()
async def get_custom_mcp_tools_for_agent(
agent_id: str,
request: Request,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
logger.debug(f"Getting custom MCP tools for agent {agent_id}, user {user_id}")
try:
@ -98,7 +98,7 @@ async def get_custom_mcp_tools_for_agent(
async def update_custom_mcp_tools_for_agent(
agent_id: str,
request: dict,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
logger.debug(f"Updating custom MCP tools for agent {agent_id}, user {user_id}")
@ -209,7 +209,7 @@ async def update_custom_mcp_tools_for_agent(
async def update_agent_custom_mcps(
agent_id: str,
request: dict,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
logger.debug(f"Updating agent {agent_id} custom MCPs for user {user_id}")
@ -316,7 +316,7 @@ async def update_agent_custom_mcps(
@router.get("/agents/{agent_id}/tools")
async def get_agent_tools(
agent_id: str,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
logger.debug(f"Fetching enabled tools for agent: {agent_id} by user: {user_id}")

View File

@ -5,7 +5,7 @@ from datetime import datetime, timezone
from typing import Optional
from fastapi import APIRouter, HTTPException, Depends, Form, Query
from utils.auth_utils import get_current_user_id_from_jwt, verify_thread_access
from utils.auth_utils import verify_and_get_user_id_from_jwt, verify_and_authorize_thread_access, require_thread_access, AuthorizedThreadAccess
from utils.logger import logger
from sandbox.sandbox import create_sandbox, delete_sandbox
@ -16,7 +16,7 @@ router = APIRouter()
@router.get("/threads")
async def get_user_threads(
user_id: str = Depends(get_current_user_id_from_jwt),
user_id: str = Depends(verify_and_get_user_id_from_jwt),
page: Optional[int] = Query(1, ge=1, description="Page number (1-based)"),
limit: Optional[int] = Query(1000, ge=1, le=1000, description="Number of items per page (max 1000)")
):
@ -121,14 +121,15 @@ async def get_user_threads(
@router.get("/threads/{thread_id}")
async def get_thread(
thread_id: str,
user_id: str = Depends(get_current_user_id_from_jwt)
auth: AuthorizedThreadAccess = Depends(require_thread_access)
):
"""Get a specific thread by ID with complete related data."""
logger.debug(f"Fetching thread: {thread_id}")
client = await utils.db.client
user_id = auth.user_id # Already authenticated and authorized!
try:
await verify_thread_access(client, thread_id, user_id)
# No need for manual authorization - it's already done in the dependency!
# Get the thread data
thread_result = await client.table('threads').select('*').eq('thread_id', thread_id).execute()
@ -200,7 +201,7 @@ async def get_thread(
@router.post("/threads", response_model=CreateThreadResponse)
async def create_thread(
name: Optional[str] = Form(None),
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""
Create a new thread without starting an agent run.
@ -303,13 +304,13 @@ async def create_thread(
@router.get("/threads/{thread_id}/messages")
async def get_thread_messages(
thread_id: str,
user_id: str = Depends(get_current_user_id_from_jwt),
user_id: str = Depends(verify_and_get_user_id_from_jwt),
order: str = Query("desc", description="Order by created_at: 'asc' or 'desc'")
):
"""Get all messages for a thread, fetching in batches of 1000 from the DB to avoid large queries."""
logger.debug(f"Fetching all messages for thread: {thread_id}, order={order}")
client = await utils.db.client
await verify_thread_access(client, thread_id, user_id)
await verify_and_authorize_thread_access(client, thread_id, user_id)
try:
batch_size = 1000
offset = 0
@ -333,7 +334,7 @@ async def get_thread_messages(
@router.get("/agent-runs/{agent_run_id}")
async def get_agent_run(
agent_run_id: str,
user_id: str = Depends(get_current_user_id_from_jwt),
user_id: str = Depends(verify_and_get_user_id_from_jwt),
):
"""
[DEPRECATED] Get an agent run by ID.
@ -355,12 +356,12 @@ async def get_agent_run(
async def add_message_to_thread(
thread_id: str,
message: str,
user_id: str = Depends(get_current_user_id_from_jwt),
user_id: str = Depends(verify_and_get_user_id_from_jwt),
):
"""Add a message to a thread"""
logger.debug(f"Adding message to thread: {thread_id}")
client = await utils.db.client
await verify_thread_access(client, thread_id, user_id)
await verify_and_authorize_thread_access(client, thread_id, user_id)
try:
message_result = await client.table('messages').insert({
'thread_id': thread_id,
@ -380,14 +381,14 @@ async def add_message_to_thread(
async def create_message(
thread_id: str,
message_data: MessageCreateRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Create a new message in a thread."""
logger.debug(f"Creating message in thread: {thread_id}")
client = await utils.db.client
try:
await verify_thread_access(client, thread_id, user_id)
await verify_and_authorize_thread_access(client, thread_id, user_id)
message_payload = {
"role": "user" if message_data.type == "user" else "assistant",
@ -421,12 +422,12 @@ async def create_message(
async def delete_message(
thread_id: str,
message_id: str,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Delete a message from a thread."""
logger.debug(f"Deleting message from thread: {thread_id}")
client = await utils.db.client
await verify_thread_access(client, thread_id, user_id)
await verify_and_authorize_thread_access(client, thread_id, user_id)
try:
# Don't allow users to delete the "status" messages
await client.table('messages').delete().eq('message_id', message_id).eq('is_llm_message', True).eq('thread_id', thread_id).execute()

View File

@ -3,7 +3,7 @@ from typing import List, Optional, Dict, Any
from pydantic import BaseModel
from utils.logger import logger
from utils.auth_utils import get_current_user_id_from_jwt
from utils.auth_utils import verify_and_get_user_id_from_jwt
from .version_service import (
get_version_service,
@ -61,7 +61,7 @@ class VersionComparisonResponse(BaseModel):
@router.get("/agents/{agent_id}/versions", response_model=List[VersionResponse])
async def get_versions(
agent_id: str,
user_id: str = Depends(get_current_user_id_from_jwt),
user_id: str = Depends(verify_and_get_user_id_from_jwt),
version_service: VersionService = Depends(get_version_service)
):
try:
@ -80,7 +80,7 @@ async def get_versions(
async def create_version(
agent_id: str,
request: CreateVersionRequest,
user_id: str = Depends(get_current_user_id_from_jwt),
user_id: str = Depends(verify_and_get_user_id_from_jwt),
version_service: VersionService = Depends(get_version_service)
):
try:
@ -112,7 +112,7 @@ async def create_version(
async def get_version(
agent_id: str,
version_id: str,
user_id: str = Depends(get_current_user_id_from_jwt),
user_id: str = Depends(verify_and_get_user_id_from_jwt),
version_service: VersionService = Depends(get_version_service)
):
try:
@ -131,7 +131,7 @@ async def get_version(
async def activate_version(
agent_id: str,
version_id: str,
user_id: str = Depends(get_current_user_id_from_jwt),
user_id: str = Depends(verify_and_get_user_id_from_jwt),
version_service: VersionService = Depends(get_version_service)
):
try:
@ -153,7 +153,7 @@ async def compare_versions(
agent_id: str,
version1_id: str,
version2_id: str,
user_id: str = Depends(get_current_user_id_from_jwt),
user_id: str = Depends(verify_and_get_user_id_from_jwt),
version_service: VersionService = Depends(get_version_service)
):
try:
@ -179,7 +179,7 @@ async def compare_versions(
async def rollback_to_version(
agent_id: str,
version_id: str,
user_id: str = Depends(get_current_user_id_from_jwt),
user_id: str = Depends(verify_and_get_user_id_from_jwt),
version_service: VersionService = Depends(get_version_service)
):
try:
@ -202,7 +202,7 @@ async def update_version_details(
agent_id: str,
version_id: str,
request: UpdateVersionDetailsRequest,
user_id: str = Depends(get_current_user_id_from_jwt),
user_id: str = Depends(verify_and_get_user_id_from_jwt),
version_service: VersionService = Depends(get_version_service)
):
try:

View File

@ -80,7 +80,7 @@ class VersionService:
async def _get_client(self):
return await self.db.client
async def _verify_agent_access(self, agent_id: str, user_id: str) -> tuple[bool, bool]:
async def _verify_and_authorize_agent_access(self, agent_id: str, user_id: str) -> tuple[bool, bool]:
if user_id == "system":
return True, True
@ -205,7 +205,7 @@ class VersionService:
logger.debug(f"Creating version for agent {agent_id}")
client = await self.db.client
is_owner, _ = await self._verify_agent_access(agent_id, user_id)
is_owner, _ = await self._verify_and_authorize_agent_access(agent_id, user_id)
if not is_owner:
raise UnauthorizedError("Unauthorized to create version for this agent")
@ -298,7 +298,7 @@ class VersionService:
return version
async def get_version(self, agent_id: str, version_id: str, user_id: str) -> AgentVersion:
is_owner, is_public = await self._verify_agent_access(agent_id, user_id)
is_owner, is_public = await self._verify_and_authorize_agent_access(agent_id, user_id)
if not is_owner and not is_public:
raise UnauthorizedError("You don't have permission to view this version")
@ -314,7 +314,7 @@ class VersionService:
return self._version_from_db_row(result.data[0])
async def get_active_version(self, agent_id: str, user_id: str = "system") -> Optional[AgentVersion]:
is_owner, is_public = await self._verify_agent_access(agent_id, user_id)
is_owner, is_public = await self._verify_and_authorize_agent_access(agent_id, user_id)
if not is_owner and not is_public:
raise UnauthorizedError("You don't have permission to view this agent")
@ -341,7 +341,7 @@ class VersionService:
return version
async def get_all_versions(self, agent_id: str, user_id: str) -> List[AgentVersion]:
is_owner, is_public = await self._verify_agent_access(agent_id, user_id)
is_owner, is_public = await self._verify_and_authorize_agent_access(agent_id, user_id)
if not is_owner and not is_public:
raise UnauthorizedError("You don't have permission to view versions")
@ -355,7 +355,7 @@ class VersionService:
return versions
async def activate_version(self, agent_id: str, version_id: str, user_id: str) -> None:
is_owner, _ = await self._verify_agent_access(agent_id, user_id)
is_owner, _ = await self._verify_and_authorize_agent_access(agent_id, user_id)
if not is_owner:
raise UnauthorizedError("You don't have permission to activate versions")
@ -458,7 +458,7 @@ class VersionService:
) -> AgentVersion:
version_to_restore = await self.get_version(agent_id, version_id, user_id)
is_owner, _ = await self._verify_agent_access(agent_id, user_id)
is_owner, _ = await self._verify_and_authorize_agent_access(agent_id, user_id)
if not is_owner:
raise UnauthorizedError("You don't have permission to rollback versions")
@ -483,7 +483,7 @@ class VersionService:
version_name: Optional[str] = None,
change_description: Optional[str] = None
) -> AgentVersion:
is_owner, _ = await self._verify_agent_access(agent_id, user_id)
is_owner, _ = await self._verify_and_authorize_agent_access(agent_id, user_id)
if not is_owner:
raise UnauthorizedError("You don't have permission to update this version")

View File

@ -21,7 +21,7 @@ from agent.tools.expand_msg_tool import ExpandMessageTool
from agent.prompts.prompt import get_system_prompt
from utils.logger import logger
from utils.auth_utils import get_account_id_from_thread
from services.billing import check_billing_status
from agent.tools.sb_vision_tool import SandboxVisionTool
from agent.tools.sb_image_edit_tool import SandboxImageEditTool
@ -448,9 +448,16 @@ class AgentRunner:
)
self.client = await self.thread_manager.db.client
self.account_id = await get_account_id_from_thread(self.client, self.config.thread_id)
response = await self.client.table('threads').select('account_id').eq('thread_id', self.config.thread_id).execute()
if not response.data or len(response.data) == 0:
raise ValueError(f"Thread {self.config.thread_id} not found")
self.account_id = response.data[0].get('account_id')
if not self.account_id:
raise ValueError("Could not determine account ID for thread")
raise ValueError(f"Thread {self.config.thread_id} has no associated account")
project = await self.client.table('projects').select('*').eq('project_id', self.config.project_id).execute()
if not project.data or len(project.data) == 0:

View File

@ -7,7 +7,7 @@ from fastapi import HTTPException
from utils.cache import Cache
from utils.logger import logger
from utils.config import config
from utils.auth_utils import verify_thread_access
from utils.auth_utils import verify_and_authorize_thread_access
from services import redis
from services.supabase import DBConnection
from services.llm import make_llm_api_call
@ -121,7 +121,7 @@ async def get_agent_run_with_access_check(client, agent_run_id: str, user_id: st
account_id = agent_run_data['threads']['account_id']
if account_id == user_id:
return agent_run_data
await verify_thread_access(client, thread_id, user_id)
await verify_and_authorize_thread_access(client, thread_id, user_id)
return agent_run_data
async def generate_and_update_project_name(project_id: str, prompt: str):

View File

@ -3,7 +3,7 @@ from fastapi.responses import JSONResponse
from typing import Dict, Any, Optional
from pydantic import BaseModel
from uuid import uuid4
from utils.auth_utils import get_current_user_id_from_jwt, get_optional_current_user_id_from_jwt
from utils.auth_utils import verify_and_get_user_id_from_jwt, get_optional_current_user_id_from_jwt
from utils.logger import logger
from services.supabase import DBConnection
from datetime import datetime
@ -215,7 +215,7 @@ class ProfileResponse(BaseModel):
@router.get("/categories")
async def list_categories(
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
) -> Dict[str, Any]:
try:
logger.debug("Fetching Composio categories")
@ -240,7 +240,7 @@ async def list_toolkits(
cursor: Optional[str] = Query(None),
search: Optional[str] = Query(None),
category: Optional[str] = Query(None),
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
) -> Dict[str, Any]:
try:
logger.debug(f"Fetching Composio toolkits with limit: {limit}, cursor: {cursor}, search: {search}, category: {category}")
@ -270,7 +270,7 @@ async def list_toolkits(
@router.get("/toolkits/{toolkit_slug}/details")
async def get_toolkit_details(
toolkit_slug: str,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
) -> Dict[str, Any]:
try:
logger.debug(f"Fetching detailed toolkit info for: {toolkit_slug}")
@ -296,7 +296,7 @@ async def get_toolkit_details(
@router.post("/integrate", response_model=IntegrationStatusResponse)
async def integrate_toolkit(
request: IntegrateToolkitRequest,
current_user_id: str = Depends(get_current_user_id_from_jwt)
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
) -> IntegrationStatusResponse:
try:
integration_user_id = str(uuid4())
@ -333,7 +333,7 @@ async def integrate_toolkit(
@router.post("/profiles", response_model=ProfileResponse)
async def create_profile(
request: CreateProfileRequest,
current_user_id: str = Depends(get_current_user_id_from_jwt)
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
) -> ProfileResponse:
try:
integration_user_id = str(uuid4())
@ -378,7 +378,7 @@ async def create_profile(
@router.get("/profiles")
async def get_profiles(
toolkit_slug: Optional[str] = Query(None),
current_user_id: str = Depends(get_current_user_id_from_jwt)
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
) -> Dict[str, Any]:
try:
profile_service = ComposioProfileService(db)
@ -403,7 +403,7 @@ async def get_profiles(
@router.get("/profiles/{profile_id}/mcp-config")
async def get_profile_mcp_config(
profile_id: str,
current_user_id: str = Depends(get_current_user_id_from_jwt)
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
) -> Dict[str, Any]:
try:
profile_service = ComposioProfileService(db)
@ -423,7 +423,7 @@ async def get_profile_mcp_config(
@router.get("/profiles/{profile_id}")
async def get_profile_info(
profile_id: str,
current_user_id: str = Depends(get_current_user_id_from_jwt)
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
) -> Dict[str, Any]:
try:
profile_service = ComposioProfileService(db)
@ -452,7 +452,7 @@ async def get_profile_info(
@router.get("/integration/{connected_account_id}/status")
async def get_integration_status(
connected_account_id: str,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
) -> Dict[str, Any]:
try:
service = get_integration_service()
@ -466,7 +466,7 @@ async def get_integration_status(
@router.post("/profiles/{profile_id}/discover-tools")
async def discover_composio_tools(
profile_id: str,
current_user_id: str = Depends(get_current_user_id_from_jwt)
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
) -> Dict[str, Any]:
try:
profile_service = ComposioProfileService(db)
@ -509,7 +509,7 @@ async def discover_composio_tools(
@router.post("/discover-tools/{profile_id}")
async def discover_tools_post(
profile_id: str,
current_user_id: str = Depends(get_current_user_id_from_jwt)
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
) -> Dict[str, Any]:
return await discover_composio_tools(profile_id, current_user_id)
@ -544,7 +544,7 @@ async def get_toolkit_icon(
@router.post("/tools/list")
async def list_toolkit_tools(
request: ToolsListRequest,
current_user_id: str = Depends(get_current_user_id_from_jwt)
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
try:
logger.debug(f"User {current_user_id} requesting tools for toolkit: {request.toolkit_slug}")
@ -572,7 +572,7 @@ async def list_toolkit_tools(
@router.get("/triggers/apps")
async def list_apps_with_triggers(
user_id: str = Depends(get_current_user_id_from_jwt),
user_id: str = Depends(verify_and_get_user_id_from_jwt),
) -> Dict[str, Any]:
try:
trigger_service = ComposioTriggerService()
@ -588,7 +588,7 @@ async def list_apps_with_triggers(
@router.get("/triggers/apps/{toolkit_slug}")
async def list_triggers_for_app(
toolkit_slug: str,
user_id: str = Depends(get_current_user_id_from_jwt),
user_id: str = Depends(verify_and_get_user_id_from_jwt),
) -> Dict[str, Any]:
try:
trigger_service = ComposioTriggerService()
@ -617,7 +617,7 @@ class CreateComposioTriggerRequest(BaseModel):
@router.post("/triggers/create")
async def create_composio_trigger(req: CreateComposioTriggerRequest, current_user_id: str = Depends(get_current_user_id_from_jwt)) -> Dict[str, Any]:
async def create_composio_trigger(req: CreateComposioTriggerRequest, current_user_id: str = Depends(verify_and_get_user_id_from_jwt)) -> Dict[str, Any]:
try:
client_db = await db.client
agent_check = await client_db.table('agents').select('agent_id').eq('agent_id', req.agent_id).eq('account_id', current_user_id).execute()

View File

@ -4,7 +4,7 @@ from pydantic import BaseModel, validator
import urllib.parse
from utils.logger import logger
from utils.auth_utils import get_current_user_id_from_jwt
from utils.auth_utils import verify_and_get_user_id_from_jwt
from services.supabase import DBConnection
from .credential_service import (
@ -116,7 +116,7 @@ def initialize(database: DBConnection):
@router.post("/credentials", response_model=CredentialResponse)
async def store_credential(
request: StoreCredentialRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
try:
credential_service = get_credential_service(db)
@ -151,7 +151,7 @@ async def store_credential(
@router.get("/credentials", response_model=List[CredentialResponse])
async def get_user_credentials(
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
try:
credential_service = get_credential_service(db)
@ -178,7 +178,7 @@ async def get_user_credentials(
@router.delete("/credentials/{mcp_qualified_name:path}")
async def delete_credential(
mcp_qualified_name: str,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
try:
decoded_name = decode_mcp_qualified_name(mcp_qualified_name)
@ -201,7 +201,7 @@ async def delete_credential(
@router.post("/credential-profiles", response_model=CredentialProfileResponse)
async def store_credential_profile(
request: StoreCredentialProfileRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
try:
profile_service = get_profile_service(db)
@ -240,7 +240,7 @@ async def store_credential_profile(
@router.get("/credential-profiles", response_model=List[CredentialProfileResponse])
async def get_user_credential_profiles(
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
try:
profile_service = get_profile_service(db)
@ -269,7 +269,7 @@ async def get_user_credential_profiles(
@router.get("/credential-profiles/{mcp_qualified_name:path}", response_model=List[CredentialProfileResponse])
async def get_credential_profiles_for_mcp(
mcp_qualified_name: str,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
try:
decoded_name = decode_mcp_qualified_name(mcp_qualified_name)
@ -300,7 +300,7 @@ async def get_credential_profiles_for_mcp(
@router.get("/credential-profiles/profile/{profile_id}", response_model=CredentialProfileResponse)
async def get_credential_profile(
profile_id: str,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
try:
profile_service = get_profile_service(db)
@ -331,7 +331,7 @@ async def get_credential_profile(
@router.put("/credential-profiles/{profile_id}/set-default")
async def set_default_credential_profile(
profile_id: str,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
try:
profile_service = get_profile_service(db)
@ -350,7 +350,7 @@ async def set_default_credential_profile(
@router.delete("/credential-profiles/{profile_id}")
async def delete_credential_profile(
profile_id: str,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
try:
profile_service = get_profile_service(db)
@ -369,7 +369,7 @@ async def delete_credential_profile(
@router.post("/credential-profiles/bulk-delete", response_model=BulkDeleteProfilesResponse)
async def bulk_delete_credential_profiles(
request: BulkDeleteProfilesRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
try:
profile_service = get_profile_service(db)
@ -399,7 +399,7 @@ async def bulk_delete_credential_profiles(
@router.get("/composio-profiles", response_model=ComposioCredentialsResponse)
async def get_composio_profiles(
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
try:
profile_service = get_profile_service(db)
@ -482,7 +482,7 @@ async def get_composio_profiles(
@router.get("/composio-profiles/{profile_id}/mcp-url", response_model=ComposioMcpUrlResponse)
async def get_composio_mcp_url(
profile_id: str,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
try:
from composio_integration.composio_profile_service import ComposioProfileService

View File

@ -19,7 +19,7 @@ from fastapi import APIRouter, HTTPException, Query, Depends
from fastapi.responses import RedirectResponse
from pydantic import BaseModel, Field
from utils.auth_utils import get_current_user_id_from_jwt
from utils.auth_utils import verify_and_get_user_id_from_jwt
from utils.logger import logger
from services.supabase import DBConnection
from .google_slides_service import GoogleSlidesService, OAuthTokenService
@ -92,7 +92,7 @@ presentation_router = APIRouter(prefix="/presentation-tools", tags=["presentatio
@oauth_router.get("/auth-url", response_model=AuthURLResponse)
async def get_google_auth_url(
user_id: str = Depends(get_current_user_id_from_jwt),
user_id: str = Depends(verify_and_get_user_id_from_jwt),
return_url: Optional[str] = Query(None, description="URL to redirect to after OAuth"),
google_service: GoogleSlidesService = Depends(get_google_slides_service)
):
@ -172,7 +172,7 @@ async def google_oauth_callback(
# UNUSED: Disconnect endpoint - frontend never calls this
# @oauth_router.post("/disconnect")
# async def disconnect_google(
# user_id: str = Depends(get_current_user_id_from_jwt),
# user_id: str = Depends(verify_and_get_user_id_from_jwt),
# google_service: GoogleSlidesService = Depends(get_google_slides_service)
# ):
# """
@ -196,7 +196,7 @@ async def google_oauth_callback(
@presentation_router.post("/convert-and-upload-to-slides", response_model=ConvertToSlidesResponse)
async def convert_and_upload_to_google_slides(
request: ConvertToSlidesRequest,
user_id: str = Depends(get_current_user_id_from_jwt),
user_id: str = Depends(verify_and_get_user_id_from_jwt),
google_service: GoogleSlidesService = Depends(get_google_slides_service)
):
"""

View File

@ -2,7 +2,7 @@ import json
from typing import List, Optional
from fastapi import APIRouter, HTTPException, Depends, UploadFile, File, Form, BackgroundTasks
from pydantic import BaseModel, Field, HttpUrl
from utils.auth_utils import get_current_user_id_from_jwt, verify_agent_access
from utils.auth_utils import verify_and_get_user_id_from_jwt, verify_and_get_agent_authorization, require_agent_access, AuthorizedAgentAccess
from services.supabase import DBConnection
from knowledge_base.file_processor import FileProcessor
from utils.logger import logger
@ -69,15 +69,16 @@ db = DBConnection()
async def get_agent_knowledge_base(
agent_id: str,
include_inactive: bool = False,
user_id: str = Depends(get_current_user_id_from_jwt)
auth: AuthorizedAgentAccess = Depends(require_agent_access)
):
"""Get all knowledge base entries for an agent"""
try:
client = await db.client
user_id = auth.user_id # Already authenticated and authorized!
agent_data = auth.agent_data # Agent data already fetched during authorization
# Verify agent access
await verify_agent_access(client, agent_id, user_id)
# No need for manual authorization - it's already done in the dependency!
result = await client.rpc('get_agent_knowledge_base', {
'p_agent_id': agent_id,
@ -122,7 +123,7 @@ async def get_agent_knowledge_base(
async def create_agent_knowledge_base_entry(
agent_id: str,
entry_data: CreateKnowledgeBaseEntryRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Create a new knowledge base entry for an agent"""
@ -130,7 +131,7 @@ async def create_agent_knowledge_base_entry(
client = await db.client
# Verify agent access and get agent data
agent_data = await verify_agent_access(client, agent_id, user_id)
agent_data = await verify_and_get_agent_authorization(client, agent_id, user_id)
account_id = agent_data['account_id']
insert_data = {
@ -172,7 +173,7 @@ async def upload_file_to_agent_kb(
agent_id: str,
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Upload and process a file for agent knowledge base"""
@ -180,7 +181,7 @@ async def upload_file_to_agent_kb(
client = await db.client
# Verify agent access and get agent data
agent_data = await verify_agent_access(client, agent_id, user_id)
agent_data = await verify_and_get_agent_authorization(client, agent_id, user_id)
account_id = agent_data['account_id']
file_content = await file.read()
@ -226,7 +227,7 @@ async def upload_file_to_agent_kb(
async def update_knowledge_base_entry(
entry_id: str,
entry_data: UpdateKnowledgeBaseEntryRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Update an agent knowledge base entry"""
@ -243,7 +244,7 @@ async def update_knowledge_base_entry(
agent_id = entry['agent_id']
# Verify agent access
await verify_agent_access(client, agent_id, user_id)
await verify_and_get_agent_authorization(client, agent_id, user_id)
update_data = {}
if entry_data.name is not None:
@ -294,7 +295,7 @@ async def update_knowledge_base_entry(
@router.delete("/{entry_id}")
async def delete_knowledge_base_entry(
entry_id: str,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Delete an agent knowledge base entry"""
@ -311,7 +312,7 @@ async def delete_knowledge_base_entry(
agent_id = entry['agent_id']
# Verify agent access
await verify_agent_access(client, agent_id, user_id)
await verify_and_get_agent_authorization(client, agent_id, user_id)
result = await client.table('agent_knowledge_base_entries').delete().eq('entry_id', entry_id).execute()
@ -329,7 +330,7 @@ async def delete_knowledge_base_entry(
@router.get("/{entry_id}", response_model=KnowledgeBaseEntryResponse)
async def get_knowledge_base_entry(
entry_id: str,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Get a specific agent knowledge base entry"""
try:
@ -345,7 +346,7 @@ async def get_knowledge_base_entry(
agent_id = entry['agent_id']
# Verify agent access
await verify_agent_access(client, agent_id, user_id)
await verify_and_get_agent_authorization(client, agent_id, user_id)
logger.debug(f"Retrieved agent knowledge base entry {entry_id} for agent {agent_id}")
@ -376,7 +377,7 @@ async def get_knowledge_base_entry(
async def get_agent_processing_jobs(
agent_id: str,
limit: int = 10,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Get processing jobs for an agent"""
@ -384,7 +385,7 @@ async def get_agent_processing_jobs(
client = await db.client
# Verify agent access
await verify_agent_access(client, agent_id, user_id)
await verify_and_get_agent_authorization(client, agent_id, user_id)
result = await client.rpc('get_agent_kb_processing_jobs', {
'p_agent_id': agent_id,
@ -468,7 +469,7 @@ async def process_file_background(
async def get_agent_knowledge_base_context(
agent_id: str,
max_tokens: int = 4000,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Get knowledge base context for agent prompts"""
@ -476,7 +477,7 @@ async def get_agent_knowledge_base_context(
client = await db.client
# Verify agent access
await verify_agent_access(client, agent_id, user_id)
await verify_and_get_agent_authorization(client, agent_id, user_id)
result = await client.rpc('get_agent_knowledge_base_context', {
'p_agent_id': agent_id,

View File

@ -2,7 +2,7 @@ from fastapi import APIRouter, HTTPException, Depends
from typing import List, Optional, Dict, Any
from pydantic import BaseModel
from utils.auth_utils import get_current_user_id_from_jwt
from utils.auth_utils import verify_and_get_user_id_from_jwt
from utils.logger import logger
from .mcp_service import mcp_service, MCPException

View File

@ -5,7 +5,7 @@ from uuid import UUID
from datetime import datetime
from utils.logger import logger
from utils.auth_utils import get_current_user_id_from_jwt
from utils.auth_utils import verify_and_get_user_id_from_jwt
from .profile_service import ProfileService, Profile, ProfileServiceError, ProfileNotFoundError, ProfileAlreadyExistsError, InvalidConfigError, EncryptionError
from .connection_service import ConnectionService
from .app_service import get_app_service
@ -162,7 +162,7 @@ def _handle_pipedream_exception(e: Exception) -> HTTPException:
@router.post("/connection-token", response_model=ConnectionTokenResponse)
async def create_connection_token(
request: CreateConnectionTokenRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
logger.debug(f"Creating Pipedream connection token for user: {user_id}, app: {request.app}")
@ -190,7 +190,7 @@ async def create_connection_token(
@router.get("/connections", response_model=ConnectionResponse)
async def get_user_connections(
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
logger.debug(f"Getting connections for user: {user_id}")
@ -228,7 +228,7 @@ async def get_user_connections(
@router.post("/mcp/discover", response_model=MCPDiscoveryResponse)
async def discover_mcp_servers(
request: MCPDiscoveryRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
logger.debug(f"Discovering MCP servers for user: {user_id}, app: {request.app_slug}")
@ -277,7 +277,7 @@ async def discover_mcp_servers(
@router.post("/mcp/discover-profile", response_model=MCPDiscoveryResponse)
async def discover_mcp_servers_for_profile(
request: MCPProfileDiscoveryRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
logger.debug(f"Discovering MCP servers for profile: {request.external_user_id}")
@ -326,7 +326,7 @@ async def discover_mcp_servers_for_profile(
@router.post("/mcp/connect", response_model=MCPConnectionResponse)
async def create_mcp_connection(
request: MCPConnectionRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
logger.debug(f"Creating MCP connection for user: {user_id}, app: {request.app_slug}")
@ -535,7 +535,7 @@ async def get_app_tools(app_slug: str):
@router.post("/profiles", response_model=ProfileResponse)
async def create_credential_profile(
request: ProfileRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
logger.debug(f"Creating credential profile for user: {user_id}, app: {request.app_slug}")
@ -563,7 +563,7 @@ async def create_credential_profile(
async def get_credential_profiles(
app_slug: Optional[str] = Query(None),
is_active: Optional[bool] = Query(None),
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
logger.debug(f"Getting credential profiles for user: {user_id}, app: {app_slug}")
@ -581,7 +581,7 @@ async def get_credential_profiles(
@router.get("/profiles/{profile_id}", response_model=ProfileResponse)
async def get_credential_profile(
profile_id: str,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
logger.debug(f"Getting credential profile: {profile_id} for user: {user_id}")
@ -603,7 +603,7 @@ async def get_credential_profile(
async def update_credential_profile(
profile_id: str,
request: UpdateProfileRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
logger.debug(f"Updating credential profile: {profile_id} for user: {user_id}")
@ -628,7 +628,7 @@ async def update_credential_profile(
@router.delete("/profiles/{profile_id}")
async def delete_credential_profile(
profile_id: str,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
logger.debug(f"Deleting credential profile: {profile_id} for user: {user_id}")
@ -649,7 +649,7 @@ async def delete_credential_profile(
async def connect_credential_profile(
profile_id: str,
app: Optional[str] = Query(None),
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
logger.debug(f"Connecting credential profile: {profile_id} for user: {user_id}")
@ -686,7 +686,7 @@ async def connect_credential_profile(
@router.get("/profiles/{profile_id}/connections")
async def get_profile_connections(
profile_id: str,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
logger.debug(f"Getting connections for profile: {profile_id}, user: {user_id}")

View File

@ -18,7 +18,7 @@ from services.api_keys import (
APIKeyCreateResponse,
)
from services.supabase import DBConnection
from utils.auth_utils import get_current_user_id_from_jwt
from utils.auth_utils import verify_and_get_user_id_from_jwt
from utils.logger import logger
router = APIRouter()
@ -61,7 +61,7 @@ async def get_account_id_from_user_id(user_id: str) -> UUID:
@router.post("/api-keys", response_model=APIKeyCreateResponse)
async def create_api_key(
request: APIKeyCreateRequest,
user_id: str = Depends(get_current_user_id_from_jwt),
user_id: str = Depends(verify_and_get_user_id_from_jwt),
api_key_service: APIKeyService = Depends(get_api_key_service),
):
"""
@ -105,7 +105,7 @@ async def create_api_key(
@router.get("/api-keys", response_model=List[APIKeyResponse])
async def list_api_keys(
user_id: str = Depends(get_current_user_id_from_jwt),
user_id: str = Depends(verify_and_get_user_id_from_jwt),
api_key_service: APIKeyService = Depends(get_api_key_service),
):
"""
@ -141,7 +141,7 @@ async def list_api_keys(
@router.patch("/api-keys/{key_id}/revoke")
async def revoke_api_key(
key_id: UUID,
user_id: str = Depends(get_current_user_id_from_jwt),
user_id: str = Depends(verify_and_get_user_id_from_jwt),
api_key_service: APIKeyService = Depends(get_api_key_service),
):
"""
@ -185,7 +185,7 @@ async def revoke_api_key(
@router.delete("/api-keys/{key_id}")
async def delete_api_key(
key_id: UUID,
user_id: str = Depends(get_current_user_id_from_jwt),
user_id: str = Depends(verify_and_get_user_id_from_jwt),
api_key_service: APIKeyService = Depends(get_api_key_service),
):
"""

View File

@ -15,7 +15,7 @@ from utils.cache import Cache
from utils.logger import logger
from utils.config import config, EnvMode
from services.supabase import DBConnection
from utils.auth_utils import get_current_user_id_from_jwt
from utils.auth_utils import verify_and_get_user_id_from_jwt
from pydantic import BaseModel
from models import model_manager
from litellm.cost_calculator import cost_per_token
@ -1234,7 +1234,7 @@ async def handle_usage_with_credits(
@router.post("/create-checkout-session")
async def create_checkout_session(
request: CreateCheckoutSessionRequest,
current_user_id: str = Depends(get_current_user_id_from_jwt)
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Create a Stripe Checkout session or modify an existing subscription."""
try:
@ -1555,7 +1555,7 @@ async def create_checkout_session(
@router.post("/create-portal-session")
async def create_portal_session(
request: CreatePortalSessionRequest,
current_user_id: str = Depends(get_current_user_id_from_jwt)
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Create a Stripe Customer Portal session for subscription management."""
try:
@ -1655,7 +1655,7 @@ async def create_portal_session(
@router.get("/subscription")
async def get_subscription(
current_user_id: str = Depends(get_current_user_id_from_jwt)
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Get the current subscription status for the current user, including scheduled changes and credit balance."""
try:
@ -1891,7 +1891,7 @@ async def get_subscription(
@router.get("/check-status")
async def check_status(
current_user_id: str = Depends(get_current_user_id_from_jwt)
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Check if the user can run agents based on their subscription, usage, and credit balance."""
try:
@ -2076,7 +2076,7 @@ async def stripe_webhook(request: Request):
@router.get("/available-models")
async def get_available_models(
current_user_id: str = Depends(get_current_user_id_from_jwt)
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Get the list of models available to the user based on their subscription tier."""
try:
@ -2200,7 +2200,7 @@ async def get_available_models(
async def get_usage_logs_endpoint(
page: int = 0,
items_per_page: int = 1000,
current_user_id: str = Depends(get_current_user_id_from_jwt)
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Get detailed usage logs for a user with pagination."""
logger.debug(f"[USAGE_LOGS_ENDPOINT] Starting get_usage_logs_endpoint for user_id={current_user_id}, page={page}, items_per_page={items_per_page}")
@ -2254,7 +2254,7 @@ async def get_usage_logs_endpoint(
@router.get("/subscription-commitment/{subscription_id}")
async def get_subscription_commitment(
subscription_id: str,
current_user_id: str = Depends(get_current_user_id_from_jwt)
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Get commitment status for a subscription."""
try:
@ -2278,7 +2278,7 @@ async def get_subscription_commitment(
@router.get("/subscription-details")
async def get_subscription_details(
current_user_id: str = Depends(get_current_user_id_from_jwt)
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Get detailed subscription information including commitment status."""
try:
@ -2314,7 +2314,7 @@ async def get_subscription_details(
@router.post("/cancel-subscription")
async def cancel_subscription(
current_user_id: str = Depends(get_current_user_id_from_jwt)
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Cancel subscription with yearly commitment handling."""
try:
@ -2412,7 +2412,7 @@ async def cancel_subscription(
@router.post("/reactivate-subscription")
async def reactivate_subscription(
current_user_id: str = Depends(get_current_user_id_from_jwt)
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Reactivate a subscription that was marked for cancellation."""
try:
@ -2492,7 +2492,7 @@ async def reactivate_subscription(
@router.post("/purchase-credits")
async def purchase_credits(
request: PurchaseCreditsRequest,
current_user_id: str = Depends(get_current_user_id_from_jwt)
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""
Create a Stripe checkout session for purchasing credits.
@ -2609,7 +2609,7 @@ async def purchase_credits(
@router.get("/credit-balance")
async def get_credit_balance_endpoint(
current_user_id: str = Depends(get_current_user_id_from_jwt)
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Get the current credit balance for the user."""
try:
@ -2628,7 +2628,7 @@ async def get_credit_balance_endpoint(
async def get_credit_history(
page: int = 0,
items_per_page: int = 50,
current_user_id: str = Depends(get_current_user_id_from_jwt)
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Get credit purchase and usage history for the user."""
try:
@ -2690,7 +2690,7 @@ async def get_credit_history(
@router.get("/can-purchase-credits")
async def can_purchase_credits(
current_user_id: str = Depends(get_current_user_id_from_jwt)
current_user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Check if the current user can purchase credits (must be on highest tier)."""
try:

View File

@ -5,7 +5,7 @@ from fastapi import APIRouter, UploadFile, File, HTTPException, Depends
from pydantic import BaseModel
from typing import Optional
from utils.logger import logger
from utils.auth_utils import get_current_user_id_from_jwt
from utils.auth_utils import verify_and_get_user_id_from_jwt
router = APIRouter(tags=["transcription"])
@ -15,7 +15,7 @@ class TranscriptionResponse(BaseModel):
@router.post("/transcription", response_model=TranscriptionResponse)
async def transcribe_audio(
audio_file: UploadFile = File(...),
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Transcribe audio file to text using OpenAI Whisper."""
try:

View File

@ -0,0 +1,4 @@
-- Remove the problematic GRANT that bypasses RLS
REVOKE SELECT ON TABLE projects FROM anon;
REVOKE SELECT ON TABLE threads FROM anon;
REVOKE SELECT ON TABLE messages FROM anon;

View File

@ -3,7 +3,7 @@ from typing import List, Optional, Dict, Any
from pydantic import BaseModel
from utils.logger import logger
from utils.auth_utils import get_current_user_id_from_jwt
from utils.auth_utils import verify_and_get_user_id_from_jwt
from services.supabase import DBConnection
from utils.pagination import PaginationParams
@ -133,7 +133,7 @@ async def validate_agent_ownership(agent_id: str, user_id: str) -> Dict[str, Any
@router.post("", response_model=Dict[str, str])
async def create_template_from_agent(
request: CreateTemplateRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
try:
await validate_agent_ownership(request.agent_id, user_id)
@ -172,7 +172,7 @@ async def create_template_from_agent(
async def publish_template(
template_id: str,
request: PublishTemplateRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
try:
template = await validate_template_ownership_and_get(template_id, user_id)
@ -200,7 +200,7 @@ async def publish_template(
@router.post("/{template_id}/unpublish")
async def unpublish_template(
template_id: str,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
try:
template = await validate_template_ownership_and_get(template_id, user_id)
@ -228,7 +228,7 @@ async def unpublish_template(
@router.delete("/{template_id}")
async def delete_template(
template_id: str,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
try:
template = await validate_template_ownership_and_get(template_id, user_id)
@ -256,7 +256,7 @@ async def delete_template(
@router.post("/install", response_model=InstallationResponse)
async def install_template(
request: InstallTemplateRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
try:
await validate_template_access_and_get(request.template_id, user_id)
@ -345,8 +345,8 @@ async def get_marketplace_templates(
creator_id_filter = None
if mine:
try:
from utils.auth_utils import get_current_user_id_from_jwt
user_id = await get_current_user_id_from_jwt(request)
from utils.auth_utils import verify_and_get_user_id_from_jwt
user_id = await verify_and_get_user_id_from_jwt(request)
creator_id_filter = user_id
except Exception as e:
raise HTTPException(status_code=401, detail="Authentication required for 'mine' filter")
@ -406,7 +406,7 @@ async def get_my_templates(
search: Optional[str] = Query(None, description="Search term for name and description"),
sort_by: Optional[str] = Query("created_at", description="Sort field: created_at, name, download_count"),
sort_order: Optional[str] = Query("desc", description="Sort order: asc, desc"),
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
try:
from templates.services.template_service import TemplateService, MarketplaceFilters
@ -497,7 +497,7 @@ async def get_public_template(template_id: str):
@router.get("/{template_id}", response_model=TemplateResponse)
async def get_template(
template_id: str,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
try:
template = await validate_template_access_and_get(template_id, user_id)

View File

@ -9,7 +9,7 @@ import json
import hmac
from services.supabase import DBConnection
from utils.auth_utils import get_current_user_id_from_jwt
from utils.auth_utils import verify_and_get_user_id_from_jwt
from utils.logger import logger
from utils.config import config
from services.billing import check_billing_status, can_use_model
@ -131,7 +131,7 @@ def initialize(database: DBConnection):
db = database
async def verify_agent_access(agent_id: str, user_id: str):
async def verify_and_authorize_trigger_agent_access(agent_id: str, user_id: str):
client = await db.client
result = await client.table('agents').select('agent_id').eq('agent_id', agent_id).eq('account_id', user_id).execute()
@ -228,10 +228,10 @@ async def get_providers():
@router.get("/agents/{agent_id}/triggers", response_model=List[TriggerResponse])
async def get_agent_triggers(
agent_id: str,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
await verify_agent_access(agent_id, user_id)
await verify_and_authorize_trigger_agent_access(agent_id, user_id)
try:
trigger_service = get_trigger_service(db)
@ -266,7 +266,7 @@ async def get_agent_triggers(
@router.get("/all", response_model=List[Dict[str, Any]])
async def get_all_user_triggers(
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
try:
client = await db.client
@ -347,11 +347,11 @@ async def get_all_user_triggers(
async def get_agent_upcoming_runs(
agent_id: str,
limit: int = Query(10, ge=1, le=50),
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Get upcoming scheduled runs for agent triggers"""
await verify_agent_access(agent_id, user_id)
await verify_and_authorize_trigger_agent_access(agent_id, user_id)
try:
trigger_service = get_trigger_service(db)
@ -419,11 +419,11 @@ async def get_agent_upcoming_runs(
async def create_agent_trigger(
agent_id: str,
request: TriggerCreateRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Create a new trigger for an agent"""
await verify_agent_access(agent_id, user_id)
await verify_and_authorize_trigger_agent_access(agent_id, user_id)
try:
trigger_service = get_trigger_service(db)
@ -466,7 +466,7 @@ async def create_agent_trigger(
@router.get("/{trigger_id}", response_model=TriggerResponse)
async def get_trigger(
trigger_id: str,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Get a trigger by ID"""
@ -477,7 +477,7 @@ async def get_trigger(
if not trigger:
raise HTTPException(status_code=404, detail="Trigger not found")
await verify_agent_access(trigger.agent_id, user_id)
await verify_and_authorize_trigger_agent_access(trigger.agent_id, user_id)
base_url = os.getenv("WEBHOOK_BASE_URL", "http://localhost:8000")
webhook_url = f"{base_url}/api/triggers/{trigger_id}/webhook"
@ -505,7 +505,7 @@ async def get_trigger(
async def update_trigger(
trigger_id: str,
request: TriggerUpdateRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Update a trigger"""
@ -516,7 +516,7 @@ async def update_trigger(
if not trigger:
raise HTTPException(status_code=404, detail="Trigger not found")
await verify_agent_access(trigger.agent_id, user_id)
await verify_and_authorize_trigger_agent_access(trigger.agent_id, user_id)
updated_trigger = await trigger_service.update_trigger(
trigger_id=trigger_id,
@ -556,7 +556,7 @@ async def update_trigger(
@router.delete("/{trigger_id}")
async def delete_trigger(
trigger_id: str,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Delete a trigger"""
@ -566,7 +566,7 @@ async def delete_trigger(
if not trigger:
raise HTTPException(status_code=404, detail="Trigger not found")
await verify_agent_access(trigger.agent_id, user_id)
await verify_and_authorize_trigger_agent_access(trigger.agent_id, user_id)
# Store agent_id before deletion
agent_id = trigger.agent_id
@ -709,10 +709,10 @@ def convert_steps_to_json(steps: List[WorkflowStepRequest]) -> List[Dict[str, An
@workflows_router.get("/agents/{agent_id}/workflows")
async def get_agent_workflows(
agent_id: str,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Get workflows for an agent"""
await verify_agent_access(agent_id, user_id)
await verify_and_authorize_trigger_agent_access(agent_id, user_id)
client = await db.client
result = await client.table('agent_workflows').select('*').eq('agent_id', agent_id).order('created_at', desc=True).execute()
@ -724,10 +724,10 @@ async def get_agent_workflows(
async def create_agent_workflow(
agent_id: str,
workflow_data: WorkflowCreateRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Create a new workflow for an agent"""
await verify_agent_access(agent_id, user_id)
await verify_and_authorize_trigger_agent_access(agent_id, user_id)
try:
client = await db.client
@ -758,10 +758,10 @@ async def update_agent_workflow(
agent_id: str,
workflow_id: str,
workflow_data: WorkflowUpdateRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
"""Update a workflow"""
await verify_agent_access(agent_id, user_id)
await verify_and_authorize_trigger_agent_access(agent_id, user_id)
client = await db.client
@ -800,9 +800,9 @@ async def update_agent_workflow(
async def delete_agent_workflow(
agent_id: str,
workflow_id: str,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
await verify_agent_access(agent_id, user_id)
await verify_and_authorize_trigger_agent_access(agent_id, user_id)
client = await db.client
@ -822,10 +822,10 @@ async def execute_agent_workflow(
agent_id: str,
workflow_id: str,
execution_data: WorkflowExecuteRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
user_id: str = Depends(verify_and_get_user_id_from_jwt)
):
print("DEBUG: Executing workflow", workflow_id, "for agent", agent_id)
await verify_agent_access(agent_id, user_id)
await verify_and_authorize_trigger_agent_access(agent_id, user_id)
client = await db.client

View File

@ -6,9 +6,93 @@ from jwt.exceptions import PyJWTError
from utils.logger import structlog
from utils.config import config
import os
import base64
import hashlib
import hmac
from services.supabase import DBConnection
from services import redis
async def verify_admin_api_key(x_admin_api_key: Optional[str] = Header(None)):
if not config.KORTIX_ADMIN_API_KEY:
raise HTTPException(
status_code=500,
detail="Admin API key not configured on server"
)
if not x_admin_api_key:
raise HTTPException(
status_code=401,
detail="Admin API key required. Include X-Admin-Api-Key header."
)
if x_admin_api_key != config.KORTIX_ADMIN_API_KEY:
raise HTTPException(
status_code=403,
detail="Invalid admin API key"
)
return True
def _verify_jwt_signature(token: str, secret: str) -> dict:
"""
Securely verify JWT signature using the Supabase JWT secret.
Args:
token: The JWT token to verify
secret: The Supabase JWT secret for signature verification
Returns:
dict: The decoded JWT payload if signature is valid
Raises:
PyJWTError: If signature verification fails or token is invalid
"""
try:
# Decode and verify the JWT signature using HS256 algorithm (Supabase default)
payload = jwt.decode(
token,
secret,
algorithms=["HS256"],
options={
"verify_signature": True,
"verify_exp": True,
"verify_aud": False, # Supabase JWTs may not always have audience
"verify_iss": False # Supabase JWTs may not always have issuer
}
)
return payload
except PyJWTError as e:
structlog.get_logger().warning(f"JWT signature verification failed: {e}")
raise
def _decode_jwt_safely(token: str) -> dict:
"""
Safely decode and verify a JWT token with proper signature verification.
Falls back to no verification only if JWT secret is not configured.
Args:
token: The JWT token to decode
Returns:
dict: The decoded JWT payload
Raises:
PyJWTError: If token is invalid or signature verification fails
"""
jwt_secret = config.SUPABASE_JWT_SECRET
if jwt_secret:
# Production mode: Verify signature
structlog.get_logger().debug("Verifying JWT signature")
return _verify_jwt_signature(token, jwt_secret)
else:
# Development mode: Log warning and decode without verification
structlog.get_logger().warning(
"JWT_SECRET not configured - using insecure JWT decoding. "
"This should only be used in development!"
)
return jwt.decode(token, options={"verify_signature": False})
async def _get_user_id_from_account_cached(account_id: str) -> Optional[str]:
"""
Get user_id from account_id with Redis caching for performance
@ -58,7 +142,7 @@ async def _get_user_id_from_account_cached(account_id: str) -> Optional[str]:
return None
# This function extracts the user ID from Supabase JWT
async def get_current_user_id_from_jwt(request: Request) -> str:
async def verify_and_get_user_id_from_jwt(request: Request) -> str:
"""
Extract and verify the user ID from the JWT in the Authorization header or API key.
@ -149,7 +233,7 @@ async def get_current_user_id_from_jwt(request: Request) -> str:
token = auth_header.split(' ')[1]
try:
payload = jwt.decode(token, options={"verify_signature": False})
payload = _decode_jwt_safely(token)
user_id = payload.get('sub')
if not user_id:
@ -173,51 +257,6 @@ async def get_current_user_id_from_jwt(request: Request) -> str:
headers={"WWW-Authenticate": "Bearer"}
)
async def get_account_id_from_thread(client, thread_id: str) -> str:
"""
Extract and verify the account ID from the thread.
Args:
client: The Supabase client
thread_id: The ID of the thread
Returns:
str: The account ID associated with the thread
Raises:
HTTPException: If the thread is not found or if there's an error
"""
try:
response = await client.table('threads').select('account_id').eq('thread_id', thread_id).execute()
if not response.data or len(response.data) == 0:
raise HTTPException(
status_code=404,
detail="Thread not found"
)
account_id = response.data[0].get('account_id')
if not account_id:
raise HTTPException(
status_code=500,
detail="Thread has no associated account"
)
return account_id
except Exception as e:
error_msg = str(e)
if "cannot schedule new futures after shutdown" in error_msg or "connection is closed" in error_msg:
raise HTTPException(
status_code=503,
detail="Server is shutting down"
)
else:
raise HTTPException(
status_code=500,
detail=f"Error retrieving thread information: {str(e)}"
)
async def get_user_id_from_stream_auth(
request: Request,
@ -246,7 +285,7 @@ async def get_user_id_from_stream_auth(
try:
# First, try the standard authentication (handles both API keys and Authorization header)
try:
return await get_current_user_id_from_jwt(request)
return await verify_and_get_user_id_from_jwt(request)
except HTTPException:
# If standard auth fails, try query parameter JWT for EventSource compatibility
pass
@ -254,8 +293,8 @@ async def get_user_id_from_stream_auth(
# Try to get user_id from token in query param (for EventSource which can't set headers)
if token:
try:
# For Supabase JWT, we just need to decode and extract the user ID
payload = jwt.decode(token, options={"verify_signature": False})
# For Supabase JWT, verify signature and extract the user ID
payload = _decode_jwt_safely(token)
user_id = payload.get('sub')
if user_id:
sentry.sentry.set_user({ "id": user_id })
@ -289,7 +328,76 @@ async def get_user_id_from_stream_auth(
detail=f"Error during authentication: {str(e)}"
)
async def verify_thread_access(client, thread_id: str, user_id: str):
async def get_optional_user_id(request: Request) -> Optional[str]:
"""
Extract the user ID from the JWT in the Authorization header if present,
but don't require authentication. Returns None if no valid token is found.
This function is used for endpoints that support both authenticated and
unauthenticated access (like public projects).
Args:
request: The FastAPI request object
Returns:
Optional[str]: The user ID extracted from the JWT, or None if no valid token
"""
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return None
token = auth_header.split(' ')[1]
try:
# For Supabase JWT, verify signature and extract the user ID
payload = _decode_jwt_safely(token)
# Supabase stores the user ID in the 'sub' claim
user_id = payload.get('sub')
if user_id:
sentry.sentry.set_user({ "id": user_id })
structlog.contextvars.bind_contextvars(
user_id=user_id
)
return user_id
except PyJWTError:
return None
# Alias for consistency with other auth functions
get_optional_current_user_id_from_jwt = get_optional_user_id
async def verify_and_get_agent_authorization(client, agent_id: str, user_id: str) -> dict:
"""
Verify that a user has access to a specific agent based on ownership.
Args:
client: The Supabase client
agent_id: The agent ID to check access for
user_id: The user ID to check permissions for
Returns:
dict: Agent data if access is granted
Raises:
HTTPException: If the user doesn't have access to the agent or agent doesn't exist
"""
try:
agent_result = await client.table('agents').select('*').eq('agent_id', agent_id).eq('account_id', user_id).execute()
if not agent_result.data:
raise HTTPException(status_code=404, detail="Agent not found or access denied")
return agent_result.data[0]
except HTTPException:
raise
except Exception as e:
structlog.error(f"Error verifying agent access for agent {agent_id}, user {user_id}: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to verify agent access")
async def verify_and_authorize_thread_access(client, thread_id: str, user_id: str):
"""
Verify that a user has access to a specific thread based on account membership.
@ -347,92 +455,137 @@ async def verify_thread_access(client, thread_id: str, user_id: str):
detail=f"Error verifying thread access: {str(e)}"
)
async def get_optional_user_id(request: Request) -> Optional[str]:
# ============================================================================
# FastAPI Dependency Functions for Combined Authentication + Authorization
# ============================================================================
async def get_authorized_user_for_thread(
thread_id: str,
request: Request
) -> str:
"""
Extract the user ID from the JWT in the Authorization header if present,
but don't require authentication. Returns None if no valid token is found.
This function is used for endpoints that support both authenticated and
unauthenticated access (like public projects).
FastAPI dependency that verifies JWT and authorizes thread access.
Args:
thread_id: The thread ID to authorize access for
request: The FastAPI request object
Returns:
Optional[str]: The user ID extracted from the JWT, or None if no valid token
"""
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return None
token = auth_header.split(' ')[1]
try:
# For Supabase JWT, we just need to decode and extract the user ID
payload = jwt.decode(token, options={"verify_signature": False})
# Supabase stores the user ID in the 'sub' claim
user_id = payload.get('sub')
if user_id:
sentry.sentry.set_user({ "id": user_id })
structlog.contextvars.bind_contextvars(
user_id=user_id
)
return user_id
except PyJWTError:
return None
# Alias for consistency with other auth functions
get_optional_current_user_id_from_jwt = get_optional_user_id
async def verify_admin_api_key(x_admin_api_key: Optional[str] = Header(None)):
if not config.KORTIX_ADMIN_API_KEY:
raise HTTPException(
status_code=500,
detail="Admin API key not configured on server"
)
if not x_admin_api_key:
raise HTTPException(
status_code=401,
detail="Admin API key required. Include X-Admin-Api-Key header."
)
if x_admin_api_key != config.KORTIX_ADMIN_API_KEY:
raise HTTPException(
status_code=403,
detail="Invalid admin API key"
)
return True
async def verify_agent_access(client, agent_id: str, user_id: str) -> dict:
"""
Verify that a user has access to a specific agent based on ownership.
Args:
client: The Supabase client
agent_id: The agent ID to check access for
user_id: The user ID to check permissions for
Returns:
dict: Agent data if access is granted
str: The authenticated and authorized user ID
Raises:
HTTPException: If the user doesn't have access to the agent or agent doesn't exist
HTTPException: If authentication fails or user lacks thread access
"""
try:
agent_result = await client.table('agents').select('*').eq('agent_id', agent_id).eq('account_id', user_id).execute()
from services.supabase import DBConnection
# First, authenticate the user
user_id = await verify_and_get_user_id_from_jwt(request)
# Then, authorize thread access
db = DBConnection()
client = await db.client
await verify_and_authorize_thread_access(client, thread_id, user_id)
return user_id
async def get_authorized_user_for_agent(
agent_id: str,
request: Request
) -> tuple[str, dict]:
"""
FastAPI dependency that verifies JWT and authorizes agent access.
Args:
agent_id: The agent ID to authorize access for
request: The FastAPI request object
if not agent_result.data:
raise HTTPException(status_code=404, detail="Agent not found or access denied")
Returns:
tuple[str, dict]: The authenticated user ID and agent data
return agent_result.data[0]
Raises:
HTTPException: If authentication fails or user lacks agent access
"""
from services.supabase import DBConnection
# First, authenticate the user
user_id = await verify_and_get_user_id_from_jwt(request)
# Then, authorize agent access and get agent data
db = DBConnection()
client = await db.client
agent_data = await verify_and_get_agent_authorization(client, agent_id, user_id)
return user_id, agent_data
class AuthorizedThreadAccess:
"""
FastAPI dependency that combines authentication and thread authorization.
Usage:
@router.get("/threads/{thread_id}/messages")
async def get_messages(
thread_id: str,
auth: AuthorizedThreadAccess = Depends()
):
user_id = auth.user_id # Authenticated and authorized user
"""
def __init__(self, user_id: str):
self.user_id = user_id
class AuthorizedAgentAccess:
"""
FastAPI dependency that combines authentication and agent authorization.
Usage:
@router.get("/agents/{agent_id}/config")
async def get_agent_config(
agent_id: str,
auth: AuthorizedAgentAccess = Depends()
):
user_id = auth.user_id # Authenticated and authorized user
agent_data = auth.agent_data # Agent data from authorization check
"""
def __init__(self, user_id: str, agent_data: dict):
self.user_id = user_id
self.agent_data = agent_data
async def require_thread_access(
thread_id: str,
request: Request
) -> AuthorizedThreadAccess:
"""
FastAPI dependency that verifies JWT and authorizes thread access.
Args:
thread_id: The thread ID from the path parameter
request: The FastAPI request object
except HTTPException:
raise
except Exception as e:
structlog.error(f"Error verifying agent access for agent {agent_id}, user {user_id}: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to verify agent access")
Returns:
AuthorizedThreadAccess: Object containing authenticated user_id
Raises:
HTTPException: If authentication fails or user lacks thread access
"""
user_id = await get_authorized_user_for_thread(thread_id, request)
return AuthorizedThreadAccess(user_id)
async def require_agent_access(
agent_id: str,
request: Request
) -> AuthorizedAgentAccess:
"""
FastAPI dependency that verifies JWT and authorizes agent access.
Args:
agent_id: The agent ID from the path parameter
request: The FastAPI request object
Returns:
AuthorizedAgentAccess: Object containing user_id and agent_data
Raises:
HTTPException: If authentication fails or user lacks agent access
"""
user_id, agent_data = await get_authorized_user_for_agent(agent_id, request)
return AuthorizedAgentAccess(user_id, agent_data)

View File

@ -16,6 +16,7 @@ Usage:
import os
from enum import Enum
from re import S
from typing import Dict, Any, Optional, get_type_hints, Union
from dotenv import load_dotenv
import logging
@ -277,6 +278,7 @@ class Configuration:
SUPABASE_URL: str
SUPABASE_ANON_KEY: str
SUPABASE_SERVICE_ROLE_KEY: str
SUPABASE_JWT_SECRET: str
# Redis configuration
REDIS_HOST: str

View File

@ -119,6 +119,7 @@ def load_existing_env_vars():
"SUPABASE_SERVICE_ROLE_KEY": backend_env.get(
"SUPABASE_SERVICE_ROLE_KEY", ""
),
"SUPABASE_JWT_SECRET": backend_env.get("SUPABASE_JWT_SECRET", ""),
},
"daytona": {
"DAYTONA_API_KEY": backend_env.get("DAYTONA_API_KEY", ""),
@ -285,8 +286,17 @@ class SetupWizard:
config_items = []
# Check Supabase
if self.env_vars["supabase"]["SUPABASE_URL"]:
config_items.append(f"{Colors.GREEN}{Colors.ENDC} Supabase")
supabase_complete = (
self.env_vars["supabase"]["SUPABASE_URL"] and
self.env_vars["supabase"]["SUPABASE_ANON_KEY"] and
self.env_vars["supabase"]["SUPABASE_SERVICE_ROLE_KEY"]
)
supabase_secure = self.env_vars["supabase"]["SUPABASE_JWT_SECRET"]
if supabase_complete and supabase_secure:
config_items.append(f"{Colors.GREEN}{Colors.ENDC} Supabase (secure)")
elif supabase_complete:
config_items.append(f"{Colors.YELLOW}{Colors.ENDC} Supabase (missing JWT secret)")
else:
config_items.append(f"{Colors.YELLOW}{Colors.ENDC} Supabase")
@ -600,8 +610,12 @@ class SetupWizard:
"You'll need a Supabase project. Visit https://supabase.com/dashboard/projects to create one."
)
print_info(
"In your project settings, go to 'API' to find the required information."
"In your project settings, go to 'API' to find the required information:"
)
print_info(" - Project URL (at the top)")
print_info(" - anon public key (under 'Project API keys')")
print_info(" - service_role secret key (under 'Project API keys')")
print_info(" - JWT Secret (under 'JWT Settings' - critical for security!)")
input("Press Enter to continue once you have your project details...")
self.env_vars["supabase"]["SUPABASE_URL"] = self._get_input(
@ -622,6 +636,12 @@ class SetupWizard:
"This does not look like a valid key. It should be at least 10 characters.",
default_value=self.env_vars["supabase"]["SUPABASE_SERVICE_ROLE_KEY"],
)
self.env_vars["supabase"]["SUPABASE_JWT_SECRET"] = self._get_input(
"Enter your Supabase JWT secret (for signature verification): ",
validate_api_key,
"This does not look like a valid JWT secret. It should be at least 10 characters.",
default_value=self.env_vars["supabase"]["SUPABASE_JWT_SECRET"],
)
print_success("Supabase information saved.")
def collect_daytona_info(self):